diff --git a/python-client/pypegasus/base/ttypes.py b/python-client/pypegasus/base/ttypes.py index 6ed2e88dee..05ef2b46cf 100644 --- a/python-client/pypegasus/base/ttypes.py +++ b/python-client/pypegasus/base/ttypes.py @@ -311,7 +311,7 @@ class host_port_types(Enum): kHostTypeInvalid = 0 kHostTypeIpv4 = 1 kHostTypeGroup = 2 - + kHostTypeIpv6 = 3 class host_port: diff --git a/python-client/pypegasus/pgclient.py b/python-client/pypegasus/pgclient.py index 9960f44c2b..5ee1725486 100644 --- a/python-client/pypegasus/pgclient.py +++ b/python-client/pypegasus/pgclient.py @@ -21,13 +21,16 @@ import os import logging.config import six +import ipaddress +import socket import yaml from thrift.Thrift import TMessageType, TApplicationException from twisted.internet import defer from twisted.internet import reactor -from twisted.internet.defer import inlineCallbacks, succeed, fail +from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.protocol import ClientCreator +from twisted.names import client, dns from pypegasus import rrdb, replication from pypegasus.base.ttypes import * @@ -58,7 +61,7 @@ DEFAULT_TIMEOUT = 2000 # ms META_CHECK_INTERVAL = 2 # s MAX_TIMEOUT_THRESHOLD = 5 # times - +MAX_META_QUERY_THRESHOLD = 1 # times class BaseSession(object): @@ -253,20 +256,160 @@ class MetaSessionManager(SessionManager): def __init__(self, table_name, timeout): SessionManager.__init__(self, table_name, timeout) self.addr_list = [] + self.host_ports = [] + self.query_times = 0 + + # validate if the given string is a valid IP address + def is_valid_ip(self, address): + try: + ipaddress.ip_address(address) + return True + except ValueError: + return False + + def get_host_type(self, ip_list): + """ + Determines the host type based on the provided list of IP addresses. + - kHostTypeIpv4: Single IPv4 address + - kHostTypeGroup: Multiple IPv4 addresses (a group/list) + - kHostTypeIpv6: IPv6 address (not currently supported in server) + - kHostTypeInvalid: Invalid or unsupported configuration + """ + ipv4_count = 0 + has_ipv6 = False + + for ip_str in ip_list: + try: + ip = ipaddress.ip_address(ip_str) + if ip.version == 4: + ipv4_count += 1 + elif ip.version == 6: + has_ipv6 = True + except ValueError: + continue + if has_ipv6: + return host_port_types.kHostTypeInvalid + if ipv4_count > 1: + return host_port_types.kHostTypeGroup + elif ipv4_count == 1: + return host_port_types.kHostTypeIpv4 + else: + return host_port_types.kHostTypeInvalid + + def resolve_all_ips(self, hostname): + """ + Resolves a hostname to all associated IP addresses (both A and AAAA records). + + This function performs DNS lookups for IPv4 and IPv6 records and combines the results. + """ + def extract_ips(result, record_type): + answers, _, _ = result + ips = [] + for answer in answers: + if record_type == dns.A and isinstance(answer.payload, dns.Record_A): + ips.append(answer.payload.dottedQuad()) + elif record_type == dns.AAAA and isinstance(answer.payload, dns.Record_AAAA): + if isinstance(answer.payload.address, bytes): + ip6_bytes = answer.payload.address + if len(ip6_bytes) == 16: + ip_str = socket.inet_ntop(socket.AF_INET6, ip6_bytes) + ips.append(ip_str) + else: + logger.error('Invalid IPv6 bytes length: %s', len(ip6_bytes)) + else: + ips.append(str(answer.payload.address)) + return ips + + resolver = client.getResolver() + + # query A record (ipv4) + d_a = resolver.query(dns.Query(hostname, dns.A, dns.IN)) + d_a_ips = d_a.addCallback(lambda res: extract_ips(res, dns.A)) + + # query AAAA record (ipv6) + d_aaaa = resolver.query(dns.Query(hostname, dns.AAAA, dns.IN)) + d_aaaa_ips = d_aaaa.addCallback(lambda res: extract_ips(res, dns.AAAA)) + + # gather A record and AAAA record + return defer.gatherResults([d_a_ips, d_aaaa_ips], consumeErrors=True) \ + .addCallback(lambda results: results[0] + results[1]) + + @inlineCallbacks def add_meta_server(self, meta_addr): - rpc_addr = rpc_address() - if rpc_addr.from_string(meta_addr): - ip_port = meta_addr.split(':') - if not len(ip_port) == 2: - return False + if not isinstance(meta_addr, str): + logger.error("meta_addr must be a string: %s", meta_addr) + returnValue(False) + + try: + host, port_str = meta_addr.split(':', 1) + except ValueError: + logger.error("invalid address format (expected host:port): %s", meta_addr) + returnValue(False) + + ips = [] + if self.is_valid_ip(host): + ips = [host] + else: + try: + ips = yield self.resolve_all_ips(host) + hp = host_port() + if hp.from_string(meta_addr): + hp.type = self.get_host_type(ips) + self.host_ports.append(hp) + logger.info("resolved hostname %s to IP type:%s, addr:%s", hp.host, hp.type, ips) + except Exception as e: + logger.error("failed to resolve hostname %s: %s", host, e) + returnValue(False) + + for ip in ips: + self.addr_list.append((ip, int(port_str))) + + returnValue(True) + + @inlineCallbacks + def resolve_host(self, err): + """ + Resolves the hostname of meta server and updates the address list. + + call this function means that all the meta server is not reachable, so we need to + resolve the hostname of meta server and update the address list. then, client will + retry to query new meta server. + """ + if not self.host_ports: + returnValue(None) + + logger.error("fail to query meta server, try to resolve hostname, err: %s", err) + new_addr_list = [] + for host_port in self.host_ports: + host, port = host_port.to_host_port() + try: + ips = yield self.resolve_all_ips(host) + host_port.type = self.get_host_type(ips) + for ip in ips: + new_addr_list.append((ip, port)) + logger.info("resolved hostname %s to IP type:%s, addr:%s", host_port.host, host_port.type, ips) + except Exception as e: + logger.error("failed to resolve hostname %s: %s", host, e) + continue + + if not new_addr_list or set(new_addr_list) == set(self.addr_list): + returnValue(None) + + stale_sessions = [] + for rpc_addr in list(self.session_dict): + ip, port = rpc_addr.to_ip_port() + if (ip, port) not in new_addr_list: + stale_sessions.append(self.session_dict.pop(rpc_addr)) + logger.info("removed stale server: %s:%s", ip, port) - ip, port = ip_port[0], int(ip_port[1]) - self.addr_list.append((ip, port)) + self.addr_list = new_addr_list - return True - else: - return False + for session in stale_sessions: + if session is not None: + reactor.callFromThread(session.close) + + returnValue(None) @inlineCallbacks def query_one(self, session): @@ -284,8 +427,15 @@ def got_results(self, res): self.name, result) return result - logger.error('query partition info err. table: %s err: %s', - self.name, res) + # all meta server queries failed, maybe need to re-resolve hostname + # to get new IP addresses + self.query_times += 1 + logger.error('query partition info err. table: %s, query_times: %d, err: %s', + self.name, self.query_times, res) + + if self.query_times >= MAX_META_QUERY_THRESHOLD: + self.query_times = 0 + raise RuntimeError("all meta server queries failed, try to re-resolve hostname") def query(self): ds = [] @@ -309,6 +459,7 @@ def query(self): dlist = defer.DeferredList(ds, consumeErrors=True) dlist.addCallback(self.got_results) + dlist.addErrback(self.resolve_host) return dlist @@ -331,7 +482,7 @@ def update_cfg(self, resp): if resp.__class__.__name__ != "query_cfg_response": logger.error('table: %s, query_cfg_response is error', self.name) - return None + return defer.succeed(None) self.query_cfg_response = resp self.app_id = self.query_cfg_response.app_id @@ -649,23 +800,34 @@ def __init__(self, meta_addrs=None, table_name='', self.name = table_name self.table = Table(table_name, self, timeout) self.meta_session_manager = MetaSessionManager(table_name, timeout) - if isinstance(meta_addrs, list): - for meta_addr in meta_addrs: - self.meta_session_manager.add_meta_server(meta_addr) + self.initial_meta = False + self.meta_addrs = meta_addrs if isinstance(meta_addrs, list) else [] PegasusHash.populate_table() self.timeout_times = 0 self.update_partition = False self.timer = reactor.callLater(META_CHECK_INTERVAL, self.check_state) + @inlineCallbacks def init(self): """ Initialize the client before you can use it. :return: (DeferredList) True when initialized succeed, others when failed. """ + if not self.initial_meta: + deferreds = [] + self.initial_meta = True + for host_port in self.meta_addrs: + d = self.meta_session_manager.add_meta_server(host_port) + deferreds.append(d) + results = yield defer.DeferredList(deferreds, consumeErrors=True) + if not any(success for _, success in results): + raise Exception("all meta servers failed to initialize") + dlist = self.meta_session_manager.query() dlist.addCallback(self.table.update_cfg) - return dlist + results = yield dlist + defer.returnValue(results) def close(self): """