diff --git a/docs/examples.announce_addrs.rst b/docs/examples.announce_addrs.rst index cd9ae7a12..f262c88ae 100644 --- a/docs/examples.announce_addrs.rst +++ b/docs/examples.announce_addrs.rst @@ -49,6 +49,27 @@ This pattern is useful when: By announcing the correct external addresses, peers will successfully dial your node regardless of their network position. +Automatic discovery vs. explicit announce addresses +--------------------------------------------------- + +py-libp2p also ships with an :class:`~libp2p.host.observed_addr_manager.ObservedAddrManager` +that automatically discovers the host's externally observed addresses through +the Identify protocol. Once enough distinct peer groups confirm the same +external address, it is appended to the output of +:meth:`~libp2p.host.basic_host.BasicHost.get_addrs` -- no manual configuration +is required for the common NAT / EC2 case (see issue #1250). + +``announce_addrs`` takes priority over observed addresses: when it is set it +acts as a static ``AddrsFactory`` (matching go-libp2p's +``applyAddrsFactory`` behaviour), so only the explicitly announced list is +advertised. Observations are still recorded internally -- for example to feed +:meth:`~libp2p.host.basic_host.BasicHost.get_nat_type` -- but they are not +emitted by ``get_addrs`` when a static list has been provided. + +Use ``announce_addrs`` when you already know the exact public address(es) you +want peers to dial (e.g. a reverse proxy hostname such as ngrok). Rely on +automatic observed-address discovery otherwise. + The full source code for this example is below: .. literalinclude:: ../examples/announce_addrs/announce_addrs.py diff --git a/docs/libp2p.host.rst b/docs/libp2p.host.rst index 7529d5b29..89caa8819 100644 --- a/docs/libp2p.host.rst +++ b/docs/libp2p.host.rst @@ -36,6 +36,28 @@ libp2p.host.exceptions module :undoc-members: :show-inheritance: +libp2p.host.observed\_addr\_manager module +------------------------------------------- + +Automatic NAT address discovery. Remote peers report the address they see +us on via the Identify protocol; once enough *distinct observer groups* +(``ACTIVATION_THRESHOLD``, currently ``4``) report the same external +address, it is treated as confirmed and appended by +:meth:`libp2p.host.basic_host.BasicHost.get_addrs` so peers learn the +host's real public address (fixes issue #1250 for NAT/EC2 deployments). + +Interaction with ``announce_addrs``: when ``announce_addrs`` is passed to +:class:`~libp2p.host.basic_host.BasicHost` it is treated as an explicit +``AddrsFactory`` (mirroring go-libp2p's ``applyAddrsFactory``) and wins +over observed addresses: observations are still **recorded** (for +:meth:`~libp2p.host.basic_host.BasicHost.get_nat_type` and future +AutoNAT consumers) but are **not** advertised via ``get_addrs``. + +.. automodule:: libp2p.host.observed_addr_manager + :members: + :undoc-members: + :show-inheritance: + libp2p.host.ping module ----------------------- diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 2b5adca27..62bfab21a 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -18,7 +18,7 @@ from cryptography import x509 from cryptography.x509.oid import ExtensionOID import multiaddr -from multiaddr.exceptions import ProtocolLookupError +from multiaddr.exceptions import MultiaddrError, ProtocolLookupError import trio import libp2p @@ -50,6 +50,10 @@ from libp2p.host.exceptions import ( StreamFailure, ) +from libp2p.host.observed_addr_manager import ( + NATDeviceType, + ObservedAddrManager, +) from libp2p.host.ping import ( ID as PING_PROTOCOL_ID, ) @@ -208,8 +212,16 @@ def __init__( :param bootstrap_dns_timeout: DNS resolution timeout in seconds per attempt. :param bootstrap_dns_max_retries: Max DNS resolution retries (with backoff). :param announce_addrs: Optional addresses to advertise instead of - listen addresses. ``None`` (default) uses listen addresses; - an empty list advertises no addresses. + listen addresses. ``None`` (default) uses listen addresses + augmented with confirmed observed addresses from + :class:`~libp2p.host.observed_addr_manager.ObservedAddrManager`. + An empty list advertises no addresses. When set, this list acts + as a static ``AddrsFactory`` (mirroring go-libp2p's + ``applyAddrsFactory``) and wins over observed addresses: + observations are still **recorded** by the manager (for + :meth:`get_nat_type` and future AutoNAT consumers) but are + **not** emitted by :meth:`get_addrs`. See also + :meth:`get_addrs` for the exact composition rules. """ self._network = network self._network.set_stream_handler(self._swarm_stream_handler) @@ -255,10 +267,12 @@ def __init__( ) self.psk = psk - # Address announcement configuration + # Address announcement configuration (from #1268) self._announce_addrs = ( list(announce_addrs) if announce_addrs is not None else None ) + # Observed-address tracking (from #1284, issue #1250) + self._observed_addr_manager = ObservedAddrManager() # Cache a signed-record if the local-node in the PeerStore envelope = create_signed_peer_record( @@ -358,18 +372,35 @@ def get_addrs(self) -> list[multiaddr.Multiaddr]: """ Return the multiaddr addresses this host advertises to peers. - If ``announce_addrs`` was provided, those replace listen addresses - entirely. Otherwise listen addresses are used. - - Note: This method appends the /p2p/{peer_id} suffix to the addresses. - Use get_transport_addrs() for raw transport addresses. + Behavior (mirrors go-libp2p's ``AddrsFactory`` pipeline): + + * If ``announce_addrs`` was provided at construction time, that list + replaces everything — it is treated as a static ``AddrsFactory`` in + go-libp2p terms. Observed (NAT) addresses are **still recorded** + by :class:`~libp2p.host.observed_addr_manager.ObservedAddrManager` + (for ``get_nat_type`` and future AutoNAT consumers) but are not + emitted here, since the caller has explicitly chosen which + addresses to advertise. + * Otherwise the set of raw transport addresses is augmented with + externally observed addresses that have been confirmed by enough + distinct peer groups (see :data:`ACTIVATION_THRESHOLD`), then the + ``/p2p/{peer_id}`` suffix is appended to each. + + Use :meth:`get_transport_addrs` for the raw transport addresses + without any observed-address augmentation or ``/p2p`` suffix. """ p2p_part = multiaddr.Multiaddr(f"/p2p/{self.get_id()!s}") if self._announce_addrs is not None: addrs = list(self._announce_addrs) else: - addrs = self.get_transport_addrs() + addrs = list(self.get_transport_addrs()) + seen = {str(a) for a in addrs} + for obs_addr in self._observed_addr_manager.addrs(): + key = str(obs_addr) + if key not in seen: + seen.add(key) + addrs.append(obs_addr) result = [] for addr in addrs: @@ -391,6 +422,26 @@ def get_connected_peers(self) -> list[ID]: """ return list(self._network.connections.keys()) + def get_nat_type(self) -> tuple[NATDeviceType, NATDeviceType]: + """ + Return the classified NAT device type for TCP and UDP transports. + + Thin pass-through to + :meth:`libp2p.host.observed_addr_manager.ObservedAddrManager.get_nat_type`, + which infers NAT behaviour from the distribution of externally + observed addresses reported through Identify. Matches go-libp2p's + ``host.getNATType()`` algorithm. + + .. note:: + Experimental API. Intended primarily for AutoNAT / hole-punch + consumers; the return values, thresholds, and method name may + evolve as those subsystems land in py-libp2p. + + :return: ``(tcp_nat_type, udp_nat_type)``, each one of + :class:`~libp2p.host.observed_addr_manager.NATDeviceType`. + """ + return self._observed_addr_manager.get_nat_type() + def run( self, listen_addrs: Sequence[multiaddr.Multiaddr], @@ -1047,6 +1098,40 @@ async def _identify_peer(self, peer_id: ID, *, reason: str) -> None: identify_msg.ParseFromString(data) await _update_peerstore_from_identify(self.peerstore, peer_id, identify_msg) self._identified_peers.add(peer_id) + + if identify_msg.HasField("observed_addr") and identify_msg.observed_addr: + try: + our_observed = multiaddr.Multiaddr(identify_msg.observed_addr) + self._observed_addr_manager.record_observation( + swarm_conn, our_observed, self.get_transport_addrs() + ) + except MultiaddrError as exc: + # Malformed observed_addr bytes or unknown protocols from a + # misbehaving peer. Expected at low rates; log quietly. + logger.debug( + "ObservedAddrManager: ignoring malformed observed_addr " + "from peer %s: %s", + peer_id, + exc, + ) + except ValueError as exc: + logger.debug( + "ObservedAddrManager: ignoring invalid observed_addr " + "value from peer %s: %s", + peer_id, + exc, + ) + except Exception as exc: + # Unexpected failure: surface at warning with traceback so + # regressions don't disappear into debug logs. + logger.warning( + "ObservedAddrManager: unexpected failure recording " + "observation from peer %s: %s", + peer_id, + exc, + exc_info=True, + ) + logger.debug( "Identify[%s]: cached %s protocols for peer %s", reason, @@ -1094,6 +1179,7 @@ def _on_notifee_disconnected(self, conn: INetConn) -> None: if peer_id is None: return self._identified_peers.discard(peer_id) + self._observed_addr_manager.remove_conn(conn) def _get_first_connection(self, peer_id: ID) -> INetConn | None: connections = self._network.get_connections(peer_id) diff --git a/libp2p/host/observed_addr_manager.py b/libp2p/host/observed_addr_manager.py new file mode 100644 index 000000000..052b65cf1 --- /dev/null +++ b/libp2p/host/observed_addr_manager.py @@ -0,0 +1,500 @@ +""" +Observed Address Manager for py-libp2p. + +Tracks external addresses reported by remote peers via the Identify protocol. +When enough distinct peers confirm the same external address, it is added to +the host's advertised set — matching go-libp2p's ``p2p/host/observedaddrs``. +""" + +from __future__ import annotations + +from enum import Enum +import ipaddress +import logging +from typing import TYPE_CHECKING, cast + +from multiaddr import Multiaddr +from multiaddr.protocols import ( + P_IP4, + P_IP6, + P_P2P_CIRCUIT, + P_TCP, + P_UDP, + Protocol, +) + +if TYPE_CHECKING: + from libp2p.abc import INetConn + +logger = logging.getLogger(__name__) + +ACTIVATION_THRESHOLD = 4 +MAX_EXTERNAL_ADDRS_PER_LOCAL = 3 +_ADDR_CACHE_SIZE = 10 + +_THIN_WAIST_TRANSPORT_CODES = frozenset({P_TCP, P_UDP}) +_THIN_WAIST_IP_CODES = frozenset({P_IP4, P_IP6}) +_NAT64_PREFIX = ipaddress.IPv6Network("64:ff9b::/96") + + +class NATDeviceType(Enum): + """NAT device type classification.""" + + UNKNOWN = "unknown" + ENDPOINT_INDEPENDENT = "endpoint_independent" + ENDPOINT_DEPENDENT = "endpoint_dependent" + + +def extract_thin_waist(maddr: Multiaddr) -> tuple[Multiaddr, str] | None: + """ + Split *maddr* into a thin-waist prefix and the remaining suffix. + + The thin waist is the IP + transport portion, e.g. + ``/ip4/1.2.3.4/tcp/4001``. Everything after (``/ws``, ``/p2p/Qm…``, …) + is returned as the *rest* string. + + Returns ``None`` when the address does not contain a recognisable + thin-waist prefix. + """ + protos = cast(list[Protocol], maddr.protocols()) + + # We need at least an IP and a transport protocol. + if len(protos) < 2: + return None + if protos[0].code not in _THIN_WAIST_IP_CODES: + return None + if protos[1].code not in _THIN_WAIST_TRANSPORT_CODES: + return None + + ip_code = protos[0].code + transport_code = protos[1].code + + ip_val = maddr.value_for_protocol(ip_code) + transport_val = maddr.value_for_protocol(transport_code) + + thin_waist = Multiaddr( + f"/{protos[0].name}/{ip_val}/{protos[1].name}/{transport_val}" + ) + + # Build the rest string from protocols after the first two. + rest_parts: list[str] = [] + for p in protos[2:]: + val = maddr.value_for_protocol(p.code) + if val: + rest_parts.append(f"/{p.name}/{val}") + else: + rest_parts.append(f"/{p.name}") + rest = "".join(rest_parts) + + return thin_waist, rest + + +def observer_group(remote_addr: tuple[str, int]) -> str: + """ + Compute a grouping key from the observer's remote address. + + IPv4: full IP string. IPv6: /56 prefix. Multiple connections from the + same group count as a single observer. + """ + ip_str = remote_addr[0] + try: + addr = ipaddress.ip_address(ip_str) + except ValueError: + return ip_str + + if isinstance(addr, ipaddress.IPv6Address): + network = ipaddress.IPv6Network(f"{ip_str}/56", strict=False) + return str(network.network_address) + return ip_str + + +def is_valid_observation(observed: Multiaddr) -> bool: + """Reject loopback, relay, NAT64, and non-thin-waist observations.""" + protos = observed.protocols() + proto_codes = {p.code for p in protos} + + # Reject relay addresses. + if P_P2P_CIRCUIT in proto_codes: + return False + + # Reject loopback. + if P_IP4 in proto_codes: + ip_val = observed.value_for_protocol(P_IP4) + if ip_val and ip_val.startswith("127."): + return False + if P_IP6 in proto_codes: + ip_val = observed.value_for_protocol(P_IP6) + if ip_val == "::1": + return False + # Reject NAT64 well-known prefix (64:ff9b::/96). + if ip_val: + try: + addr = ipaddress.ip_address(ip_val) + if isinstance(addr, ipaddress.IPv6Address) and addr in _NAT64_PREFIX: + return False + except ValueError: + pass + + # Must have a recognisable thin waist. + result = extract_thin_waist(observed) + if result is None: + return False + + return True + + +class ObservedAddrManager: + """Tracks externally observed addresses reported by remote peers.""" + + def __init__(self) -> None: + # local_tw_str -> external_tw_str -> observer group key -> count + self._external_addrs: dict[str, dict[str, dict[str, int]]] = {} + # id(conn) -> (local_tw_str, external_tw_str, observer_group_key) + self._conn_observations: dict[int, tuple[str, str, str]] = {} + # local_tw_str -> set of "rest" suffixes seen on listen addresses + self._local_addr_rests: dict[str, set[str]] = {} + # Cache: full_addr_str → Multiaddr object (avoids repeated construction) + self._addr_cache: dict[str, Multiaddr] = {} + + def _cached_multiaddr(self, addr_str: str) -> Multiaddr: + """Return a cached Multiaddr for the given string, constructing if needed.""" + ma = self._addr_cache.get(addr_str) + if ma is not None: + return ma + ma = Multiaddr(addr_str) + if len(self._addr_cache) >= _ADDR_CACHE_SIZE: + # Evict the oldest entry (first inserted key in insertion-order dict). + oldest = next(iter(self._addr_cache)) + del self._addr_cache[oldest] + self._addr_cache[addr_str] = ma + return ma + + def record_observation( + self, + conn: INetConn, + observed_addr: Multiaddr, + local_addrs: list[Multiaddr], + ) -> None: + """ + Record an observed address from a remote peer. + + Parameters + ---------- + conn: + The network connection the observation came from. + observed_addr: + The address the remote peer sees us as. + local_addrs: + Our current listen/transport addresses (without ``/p2p/…``). + + """ + if getattr(conn, "is_closed", False): + return + + if not is_valid_observation(observed_addr): + return + + obs_result = extract_thin_waist(observed_addr) + if obs_result is None: + return + external_tw, _ = obs_result + + # Get remote address from connection for observer grouping. + remote = _get_remote_addr(conn) + if remote is None: + return + obs_group = observer_group(remote) + + # Match observed address to a local thin waist. + local_tw_str = _match_local_thin_waist(external_tw, local_addrs) + if local_tw_str is None: + return + + external_tw_str = str(external_tw) + conn_id = id(conn) + + # If this connection already had an observation, check if it changed. + if conn_id in self._conn_observations: + old_local, old_ext, old_obs = self._conn_observations[conn_id] + if old_local == local_tw_str and old_ext == external_tw_str: + return # Same observation, nothing to do + self._remove_observation(old_local, old_ext, old_obs) + + # Store the observation. + if local_tw_str not in self._external_addrs: + self._external_addrs[local_tw_str] = {} + if external_tw_str not in self._external_addrs[local_tw_str]: + self._external_addrs[local_tw_str][external_tw_str] = {} + observers = self._external_addrs[local_tw_str][external_tw_str] + observers[obs_group] = observers.get(obs_group, 0) + 1 + self._conn_observations[conn_id] = ( + local_tw_str, + external_tw_str, + obs_group, + ) + + # Update rest suffixes from local addresses for address inference. + self._update_local_addr_rests(local_addrs) + + def remove_conn(self, conn: INetConn) -> None: + """Clean up observations when a connection is closed.""" + conn_id = id(conn) + if conn_id not in self._conn_observations: + return + local_tw, ext_tw, obs = self._conn_observations.pop(conn_id) + self._remove_observation(local_tw, ext_tw, obs) + + def addrs(self, min_observers: int = ACTIVATION_THRESHOLD) -> list[Multiaddr]: + """ + Return confirmed external addresses. + + An address is confirmed when at least *min_observers* distinct + observer groups have reported it. Returns up to + ``MAX_EXTERNAL_ADDRS_PER_LOCAL`` addresses per local thin waist, + sorted by observer count descending with lexicographic tiebreak. + + For each confirmed external thin waist, full addresses are generated by + combining with rest suffixes seen on local listen addresses (address + inference). + """ + result: list[Multiaddr] = [] + for local_tw_str, ext_map in self._external_addrs.items(): + result.extend( + self._get_top_external_addrs(local_tw_str, ext_map, min_observers) + ) + return result + + def addrs_for( + self, listen_addr: Multiaddr, min_observers: int = ACTIVATION_THRESHOLD + ) -> list[Multiaddr]: + """ + Return confirmed external addresses for a specific listen address. + + Equivalent to Go's ``AddrsFor``. Returns one address per confirmed + external thin waist, combining it with the rest suffix of the queried + *listen_addr* (not all known rest suffixes). + """ + tw_result = extract_thin_waist(listen_addr) + if tw_result is None: + return [] + local_tw, rest = tw_result + local_tw_str = str(local_tw) + ext_map = self._external_addrs.get(local_tw_str) + if ext_map is None: + return [] + + candidates = self._select_candidates(ext_map, min_observers) + result: list[Multiaddr] = [] + for ext_tw_str, _ in candidates: + if rest: + result.append(self._cached_multiaddr(ext_tw_str + rest)) + else: + result.append(self._cached_multiaddr(ext_tw_str)) + return result + + @staticmethod + def _select_candidates( + ext_map: dict[str, dict[str, int]], min_observers: int + ) -> list[tuple[str, int]]: + """Filter, sort, and cap external address candidates.""" + candidates = [ + (ext_tw_str, len(observers)) + for ext_tw_str, observers in ext_map.items() + if len(observers) >= min_observers + ] + candidates.sort(key=lambda x: (-x[1], x[0])) + return candidates[:MAX_EXTERNAL_ADDRS_PER_LOCAL] + + def _get_top_external_addrs( + self, + local_tw_str: str, + ext_map: dict[str, dict[str, int]], + min_observers: int, + ) -> list[Multiaddr]: + """Return top external addresses for a single local thin waist.""" + candidates = self._select_candidates(ext_map, min_observers) + + result: list[Multiaddr] = [] + sorted_rests = sorted(self._local_addr_rests.get(local_tw_str, ())) + seen: set[str] = set() + for ext_tw_str, _ in candidates: + if ext_tw_str not in seen: + seen.add(ext_tw_str) + result.append(self._cached_multiaddr(ext_tw_str)) + for rest in sorted_rests: + if rest: + full = ext_tw_str + rest + if full not in seen: + seen.add(full) + result.append(self._cached_multiaddr(full)) + return result + + def get_nat_type(self) -> tuple[NATDeviceType, NATDeviceType]: + """ + Return (tcp_nat_type, udp_nat_type) based on observation distribution. + + Matches Go's ``getNATType()`` algorithm. + """ + # Gather per-protocol observation counts. + # proto_key -> list of observer counts per external address + tcp_counts: list[int] = [] + udp_counts: list[int] = [] + + for _, ext_map in self._external_addrs.items(): + # Invariant: every ``ext_tw_str`` key in a single ``ext_map`` shares + # the same transport (all ``/tcp/`` or all ``/udp/``), because the + # outer key (``local_tw_str``) is derived from a single local + # thin-waist in ``record_observation`` via ``_match_local_thin_waist`` + # and only matches external thin waists with the same IP family + + # transport (see ``has_consistent_transport``). It is therefore + # enough to classify the bucket once from the first key — this + # matches go-libp2p's ``getNATType`` behaviour in + # ``p2p/host/basic/basic_host.go``. + first_key = next(iter(ext_map), "") + is_tcp = "/tcp/" in first_key + for ext_tw_str, observers in ext_map.items(): + # Defensive skip: if a future refactor ever allows a mixed + # bucket, misclassifying as "other transport" is silently + # wrong — drop the stray entry instead. + if is_tcp and "/tcp/" not in ext_tw_str: + continue + if not is_tcp and "/udp/" not in ext_tw_str: + continue + count = len(observers) # unique observer groups, not total connections + if is_tcp: + tcp_counts.append(count) + else: + udp_counts.append(count) + + return ( + self._classify_nat(tcp_counts), + self._classify_nat(udp_counts), + ) + + @staticmethod + def _classify_nat(counts: list[int]) -> NATDeviceType: + """Classify NAT type from per-address observation counts.""" + if not counts: + return NATDeviceType.UNKNOWN + + all_total = sum(counts) + top = sorted(counts, reverse=True)[:MAX_EXTERNAL_ADDRS_PER_LOCAL] + top_total = sum(top) + + # Need enough observations to make a determination. + if all_total < 3 * MAX_EXTERNAL_ADDRS_PER_LOCAL: + return NATDeviceType.UNKNOWN + + # If top addresses cover >= 50% of all observations → cone NAT. + if top_total * 2 >= all_total: + return NATDeviceType.ENDPOINT_INDEPENDENT + return NATDeviceType.ENDPOINT_DEPENDENT + + def _remove_observation(self, local_tw: str, ext_tw: str, obs_group: str) -> None: + ext_map = self._external_addrs.get(local_tw) + if ext_map is None: + return + observers = ext_map.get(ext_tw) + if observers is None: + return + if obs_group in observers: + observers[obs_group] -= 1 + if observers[obs_group] <= 0: + del observers[obs_group] + if not observers: + del ext_map[ext_tw] + self._addr_cache.clear() # cached addrs may reference removed ext_tw + if not ext_map: + del self._external_addrs[local_tw] + + def _update_local_addr_rests(self, local_addrs: list[Multiaddr]) -> None: + """Extract rest suffixes from local addresses for later inference.""" + for addr in local_addrs: + result = extract_thin_waist(addr) + if result is None: + continue + tw, rest = result + tw_str = str(tw) + + # Also store for wildcard-matched versions. + # We match by port+transport, so store under the canonical + # local_tw that would be used by _match_local_thin_waist. + if tw_str not in self._local_addr_rests: + self._local_addr_rests[tw_str] = set() + if rest: + self._local_addr_rests[tw_str].add(rest) + + +def has_consistent_transport(tw_a: Multiaddr, tw_b: Multiaddr) -> bool: + """Check both thin waists use the same protocol codes (IP + transport).""" + protos_a = cast(list[Protocol], tw_a.protocols()) + protos_b = cast(list[Protocol], tw_b.protocols()) + if len(protos_a) < 2 or len(protos_b) < 2: + return False + return protos_a[0].code == protos_b[0].code and protos_a[1].code == protos_b[1].code + + +def _get_remote_addr(conn: INetConn) -> tuple[str, int] | None: + """Extract the remote (IP, port) from a connection.""" + muxed = getattr(conn, "muxed_conn", None) + if muxed is None: + return None + get_remote = getattr(muxed, "get_remote_address", None) + if get_remote is None: + return None + result = get_remote() + if result is not None and isinstance(result, tuple) and len(result) == 2: + return result + return None + + +def _match_local_thin_waist( + external_tw: Multiaddr, local_addrs: list[Multiaddr] +) -> str | None: + """ + Find a local thin waist that matches the external observation. + + Handles wildcard IPs (``0.0.0.0`` / ``::``) by matching on port and + transport protocol only. + """ + ext_protos = cast(list[Protocol], external_tw.protocols()) + if len(ext_protos) < 2: + return None + ext_ip_code = ext_protos[0].code + ext_transport_code = ext_protos[1].code + ext_port = external_tw.value_for_protocol(ext_transport_code) + + for addr in local_addrs: + result = extract_thin_waist(addr) + if result is None: + continue + local_tw, _ = result + local_protos = cast(list[Protocol], local_tw.protocols()) + if len(local_protos) < 2: + continue + local_ip_code = local_protos[0].code + local_transport_code = local_protos[1].code + + # Transport must match. + if local_transport_code != ext_transport_code: + continue + + # Port must match. + local_port = local_tw.value_for_protocol(local_transport_code) + if local_port != ext_port: + continue + + # IP family must be compatible. + local_ip = local_tw.value_for_protocol(local_ip_code) + is_wildcard = local_ip in ("0.0.0.0", "::") + + if is_wildcard: + # Wildcard matches any IP of the same family. + if ext_ip_code == local_ip_code: + return str(local_tw) + else: + # Exact thin waist match (same IP family, same port). + if local_ip_code == ext_ip_code: + return str(local_tw) + + return None diff --git a/newsfragments/1250.bugfix.rst b/newsfragments/1250.bugfix.rst new file mode 100644 index 000000000..d4bd9648f --- /dev/null +++ b/newsfragments/1250.bugfix.rst @@ -0,0 +1 @@ +Fixed ``BasicHost.get_addrs()`` not announcing externally observed NAT addresses, which previously caused peers behind NAT (e.g. AWS/EC2) to advertise only private/loopback addresses. Added ``ObservedAddrManager`` that collects observed addresses reported by peers via the Identify protocol and advertises them once enough distinct observer groups confirm the same address. diff --git a/tests/core/host/test_basic_host.py b/tests/core/host/test_basic_host.py index 4b8a30be2..37ecd952a 100644 --- a/tests/core/host/test_basic_host.py +++ b/tests/core/host/test_basic_host.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +import logging from unittest.mock import ( AsyncMock, MagicMock, @@ -6,6 +7,7 @@ import pytest from multiaddr import Multiaddr +from multiaddr.exceptions import MultiaddrError from libp2p import ( new_swarm, @@ -23,6 +25,10 @@ from libp2p.host.exceptions import ( StreamFailure, ) +from libp2p.identity.identify.pb.identify_pb2 import ( + Identify as IdentifyMsg, +) +from libp2p.peer.id import ID def test_default_protocols(): @@ -208,6 +214,75 @@ def test_announce_addrs_with_correct_peer_id(): assert str(addrs[0]) == f"/ip4/1.2.3.4/tcp/4001/p2p/{peer_id_str}" +def test_get_addrs_appends_observed_when_no_announce(): + """Observed addresses are appended to get_addrs when announce_addrs is None.""" + host = _make_host_with_listener(announce_addrs=None) + observed = Multiaddr("/ip4/5.6.7.8/tcp/4001") + fake_manager = MagicMock() + fake_manager.addrs.return_value = [observed] + host._observed_addr_manager = fake_manager + + addrs = host.get_addrs() + peer_id_str = str(host.get_id()) + + addr_strs = [str(a) for a in addrs] + assert f"/ip4/127.0.0.1/tcp/8000/p2p/{peer_id_str}" in addr_strs + assert f"/ip4/5.6.7.8/tcp/4001/p2p/{peer_id_str}" in addr_strs + + +def test_get_addrs_skips_observed_when_announce_set(): + """ + announce_addrs acts as a static AddrsFactory (like go-libp2p): observed + addresses are still recorded but not advertised via get_addrs. + """ + announce = [Multiaddr("/ip4/1.2.3.4/tcp/4001")] + host = _make_host_with_listener(announce_addrs=announce) + observed = Multiaddr("/ip4/5.6.7.8/tcp/4001") + fake_manager = MagicMock() + fake_manager.addrs.return_value = [observed] + host._observed_addr_manager = fake_manager + + addrs = host.get_addrs() + addr_strs = [str(a) for a in addrs] + assert len(addrs) == 1 + assert "/ip4/1.2.3.4/tcp/4001" in addr_strs[0] + assert not any("5.6.7.8" in s for s in addr_strs) + # Observed manager's addrs() must not be consulted at all in this branch. + fake_manager.addrs.assert_not_called() + + +def test_get_addrs_deduplicates_observed_matching_transport(): + """If the observed address equals a listen addr it must not be duplicated.""" + host = _make_host_with_listener(announce_addrs=None) + fake_manager = MagicMock() + fake_manager.addrs.return_value = [ + Multiaddr("/ip4/127.0.0.1/tcp/8000"), + ] + host._observed_addr_manager = fake_manager + + addrs = host.get_addrs() + assert len(addrs) == 1 + + +def test_get_nat_type_delegates_to_observed_addr_manager(): + """BasicHost.get_nat_type() is a thin pass-through to ObservedAddrManager.""" + from libp2p.host.observed_addr_manager import NATDeviceType + + host = _make_host_with_listener() + fake_manager = MagicMock() + fake_manager.get_nat_type.return_value = ( + NATDeviceType.ENDPOINT_INDEPENDENT, + NATDeviceType.UNKNOWN, + ) + host._observed_addr_manager = fake_manager + + tcp_nat, udp_nat = host.get_nat_type() + + fake_manager.get_nat_type.assert_called_once_with() + assert tcp_nat == NATDeviceType.ENDPOINT_INDEPENDENT + assert udp_nat == NATDeviceType.UNKNOWN + + @pytest.mark.trio async def test_initiate_autotls_procedure_builds_transport_aware_broker_multiaddr( monkeypatch, tmp_path @@ -378,3 +453,259 @@ async def wait_for_dns(self): assert captured_addrs broker_addr = captured_addrs[-1] assert "/ip4/11.22.33.44/udp/4001/quic-v1" in str(broker_addr) + + +# --------------------------------------------------------------------------- +# Integration tests for ObservedAddrManager wiring into BasicHost: +# * _identify_peer records observation on the manager (Gap 1) +# * _identify_peer's narrow exception handling (Gap 2) +# * _on_notifee_disconnected cleans up the manager's per-conn state (Gap 3) +# --------------------------------------------------------------------------- + + +class _BasicHostLogCollector(logging.Handler): + """Handler that simply collects records into a list.""" + + def __init__(self) -> None: + super().__init__(level=logging.DEBUG) + self.records: list[logging.LogRecord] = [] + + def emit(self, record: logging.LogRecord) -> None: + self.records.append(record) + + +@pytest.fixture +def basic_host_log_records(): + """ + Capture log records from the ``libp2p.host.basic_host`` logger directly. + + We can't rely on pytest's ``caplog`` (attached at the root logger) because + ``libp2p/utils/logging.py`` reconfigures the ``libp2p`` hierarchy at + import time based on the ``LIBP2P_DEBUG`` env var: depending on its value + it may set ``propagate=False`` on the ``libp2p`` logger (or on specific + submodule loggers like ``libp2p.host.basic_host``), which breaks + propagation to root. xdist workers in CI occasionally hit a config where + propagation is broken by the time the test runs, even if it works under a + plain ``pytest`` invocation. + + Attaching our own handler directly to the target logger sidesteps every + one of those failure modes: we don't care about propagation or parent + levels, only whether the logger itself is enabled for DEBUG — which we + force here. + """ + target = logging.getLogger("libp2p.host.basic_host") + handler = _BasicHostLogCollector() + prev_level = target.level + prev_disabled = target.disabled + target.setLevel(logging.DEBUG) + target.disabled = False + target.addHandler(handler) + try: + yield handler.records + finally: + target.removeHandler(handler) + target.setLevel(prev_level) + target.disabled = prev_disabled + + +def _prepare_identify_host( + monkeypatch: pytest.MonkeyPatch, observed_addr_bytes: bytes +) -> tuple[BasicHost, ID, MagicMock]: + """ + Build a BasicHost wired up just enough to run ``_identify_peer`` against a + canned ``IdentifyMsg`` payload for a single fake connection. + + Returns ``(host, peer_id, swarm_conn)``. ``host._observed_addr_manager`` is + left as the real instance — individual tests should replace it with a mock + if they want to assert on it. + """ + host = _make_host_with_listener(announce_addrs=None) + peer_id = host.get_id() + + swarm_conn = MagicMock() + swarm_conn.muxed_conn = MagicMock() + swarm_conn.muxed_conn.peer_id = peer_id + swarm_conn.muxed_conn.get_remote_address = MagicMock( + return_value=("10.0.0.1", 4001) + ) + swarm_conn.is_closed = False + swarm_conn.event_started = None # no gating on connect event in tests + + monkeypatch.setattr(host._network, "get_connections", lambda pid: [swarm_conn]) + + fake_stream = MagicMock() + fake_stream.reset = AsyncMock() + fake_stream.close = AsyncMock() + host.new_stream = AsyncMock(return_value=fake_stream) + + msg = IdentifyMsg(observed_addr=observed_addr_bytes) + msg_bytes = msg.SerializeToString() + + async def fake_read(stream, use_varint_format=True): + return msg_bytes + + async def fake_update(peerstore, peer_id, identify_msg): + return None + + monkeypatch.setattr( + "libp2p.host.basic_host.read_length_prefixed_protobuf", fake_read + ) + monkeypatch.setattr( + "libp2p.host.basic_host._update_peerstore_from_identify", fake_update + ) + + return host, peer_id, swarm_conn + + +@pytest.mark.trio +async def test_identify_peer_records_observation(monkeypatch): + """Gap 1: _identify_peer forwards the peer's observed_addr to the manager.""" + observed = Multiaddr("/ip4/5.6.7.8/tcp/4001") + host, peer_id, swarm_conn = _prepare_identify_host(monkeypatch, observed.to_bytes()) + fake_manager = MagicMock() + host._observed_addr_manager = fake_manager + + await host._identify_peer(peer_id, reason="test") + + fake_manager.record_observation.assert_called_once() + args, _kwargs = fake_manager.record_observation.call_args + passed_conn, passed_observed, passed_locals = args + assert passed_conn is swarm_conn + assert str(passed_observed) == "/ip4/5.6.7.8/tcp/4001" + # Third arg must be the current list of transport addrs. + assert list(passed_locals) == host.get_transport_addrs() + + +@pytest.mark.trio +async def test_identify_peer_swallows_multiaddr_error( + monkeypatch, basic_host_log_records +): + """ + Gap 2a: a ``MultiaddrError`` raised while recording the observation must + be caught and logged at DEBUG; nothing propagates out of _identify_peer. + We use a well-formed multiaddr and make ``record_observation`` raise — + real-world malformed byte payloads surface the same exception class, but + the ``multiaddr`` library constructs lazily and only raises in ``str()``. + """ + observed = Multiaddr("/ip4/5.6.7.8/tcp/4001") + host, peer_id, _ = _prepare_identify_host(monkeypatch, observed.to_bytes()) + fake_manager = MagicMock() + fake_manager.record_observation.side_effect = MultiaddrError( + "malformed observed_addr" + ) + host._observed_addr_manager = fake_manager + + # Should not raise. + await host._identify_peer(peer_id, reason="test") + + fake_manager.record_observation.assert_called_once() + matching = [ + r + for r in basic_host_log_records + if "ignoring malformed observed_addr" in r.getMessage() + ] + assert matching, ( + f"expected a DEBUG log for MultiaddrError path; got " + f"{[(r.levelname, r.getMessage()) for r in basic_host_log_records]}" + ) + assert matching[0].levelno == logging.DEBUG + + +@pytest.mark.trio +async def test_identify_peer_swallows_value_error(monkeypatch, basic_host_log_records): + """ + Gap 2b: record_observation raising ValueError must be caught and logged at + DEBUG. Nothing propagates out of _identify_peer. + """ + observed = Multiaddr("/ip4/5.6.7.8/tcp/4001") + host, peer_id, _ = _prepare_identify_host(monkeypatch, observed.to_bytes()) + fake_manager = MagicMock() + fake_manager.record_observation.side_effect = ValueError("bogus bytes") + host._observed_addr_manager = fake_manager + + await host._identify_peer(peer_id, reason="test") + + fake_manager.record_observation.assert_called_once() + matching = [ + r + for r in basic_host_log_records + if "ignoring invalid observed_addr" in r.getMessage() + ] + assert matching, ( + f"expected a DEBUG log for ValueError path; got " + f"{[(r.levelname, r.getMessage()) for r in basic_host_log_records]}" + ) + assert matching[0].levelno == logging.DEBUG + + +@pytest.mark.trio +async def test_identify_peer_warns_on_unexpected_error( + monkeypatch, basic_host_log_records +): + """ + Gap 2c: any other exception from record_observation must be caught and + surfaced at WARNING level with a traceback, not swallowed silently. + """ + observed = Multiaddr("/ip4/5.6.7.8/tcp/4001") + host, peer_id, _ = _prepare_identify_host(monkeypatch, observed.to_bytes()) + fake_manager = MagicMock() + fake_manager.record_observation.side_effect = RuntimeError("boom") + host._observed_addr_manager = fake_manager + + await host._identify_peer(peer_id, reason="test") + + fake_manager.record_observation.assert_called_once() + matching = [ + r + for r in basic_host_log_records + if "unexpected failure recording observation" in r.getMessage() + ] + assert matching, ( + f"expected a WARNING log for the generic exception path; got " + f"{[(r.levelname, r.getMessage()) for r in basic_host_log_records]}" + ) + assert matching[0].levelno == logging.WARNING + # exc_info=True must have attached the RuntimeError traceback. + assert matching[0].exc_info is not None + assert matching[0].exc_info[0] is RuntimeError + + +def test_on_notifee_disconnected_calls_remove_conn(): + """ + Gap 3: when a peer disconnects, ObservedAddrManager.remove_conn must be + invoked so per-conn observations are released. + """ + host = _make_host_with_listener(announce_addrs=None) + fake_manager = MagicMock() + host._observed_addr_manager = fake_manager + + peer_id = host.get_id() + conn = MagicMock() + conn.muxed_conn = MagicMock() + conn.muxed_conn.peer_id = peer_id + + # Preload identified_peers so we can also verify the cleanup happens. + host._identified_peers.add(peer_id) + + host._on_notifee_disconnected(conn) + + fake_manager.remove_conn.assert_called_once_with(conn) + assert peer_id not in host._identified_peers + + +def test_on_notifee_disconnected_without_peer_id_is_noop(): + """ + Defensive: if the muxed_conn has no peer_id attribute, the handler must + short-circuit without touching the manager. + """ + host = _make_host_with_listener(announce_addrs=None) + fake_manager = MagicMock() + host._observed_addr_manager = fake_manager + + conn = MagicMock() + conn.muxed_conn = MagicMock() + conn.muxed_conn.peer_id = None + + host._on_notifee_disconnected(conn) + + fake_manager.remove_conn.assert_not_called() diff --git a/tests/core/host/test_observed_addr_manager.py b/tests/core/host/test_observed_addr_manager.py new file mode 100644 index 000000000..ba63fb6f8 --- /dev/null +++ b/tests/core/host/test_observed_addr_manager.py @@ -0,0 +1,618 @@ +"""Tests for the ObservedAddrManager.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from multiaddr import Multiaddr + +from libp2p.host.observed_addr_manager import ( + _ADDR_CACHE_SIZE, + ACTIVATION_THRESHOLD, + MAX_EXTERNAL_ADDRS_PER_LOCAL, + NATDeviceType, + ObservedAddrManager, + extract_thin_waist, + has_consistent_transport, + is_valid_observation, + observer_group, +) + +# --------------------------------------------------------------------------- +# Helper: create a mock INetConn with a configurable remote address. +# --------------------------------------------------------------------------- + +# Keep all mock connections alive so that id() values are never reused +# within the test session. ObservedAddrManager keys on id(conn), which +# is only unique while the object exists. +_LIVE_CONNS: list[MagicMock] = [] + + +def _make_conn( + remote_ip: str = "10.0.0.1", remote_port: int = 12345, is_closed: bool = False +) -> MagicMock: + conn = MagicMock(spec=[]) + conn.muxed_conn = MagicMock() + conn.muxed_conn.get_remote_address = MagicMock( + return_value=(remote_ip, remote_port) + ) + conn.is_closed = is_closed + _LIVE_CONNS.append(conn) + return conn + + +# --------------------------------------------------------------------------- +# extract_thin_waist +# --------------------------------------------------------------------------- + + +def test_extract_thin_waist_ipv4_tcp(): + ma = Multiaddr( + "/ip4/1.2.3.4/tcp/4001/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC" + ) + result = extract_thin_waist(ma) + assert result is not None + tw, rest = result + assert str(tw) == "/ip4/1.2.3.4/tcp/4001" + assert "/p2p/" in rest + + +def test_extract_thin_waist_ipv6_udp(): + ma = Multiaddr("/ip6/::1/udp/4001/quic-v1") + result = extract_thin_waist(ma) + assert result is not None + tw, rest = result + assert str(tw) == "/ip6/::1/udp/4001" + assert "quic-v1" in rest + + +# --------------------------------------------------------------------------- +# observer_group +# --------------------------------------------------------------------------- + + +def test_observer_group_ipv4(): + assert observer_group(("192.168.1.1", 5000)) == "192.168.1.1" + + +def test_observer_group_ipv6(): + group = observer_group(("2001:db8:abcd:0012::1", 5000)) + # /56 prefix — the host part is zeroed out. + assert group == "2001:db8:abcd::" + + +# --------------------------------------------------------------------------- +# is_valid_observation +# --------------------------------------------------------------------------- + + +def test_loopback_rejected(): + obs = Multiaddr("/ip4/127.0.0.1/tcp/4001") + assert is_valid_observation(obs) is False + + +def test_relay_rejected(): + obs = Multiaddr("/ip4/1.2.3.4/tcp/4001/p2p-circuit") + assert is_valid_observation(obs) is False + + +def test_valid_observation(): + obs = Multiaddr("/ip4/1.2.3.4/tcp/4001") + assert is_valid_observation(obs) is True + + +# --------------------------------------------------------------------------- +# ObservedAddrManager – record / confirm / remove +# --------------------------------------------------------------------------- + + +def test_record_and_confirm(): + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + conns = [] + for i in range(ACTIVATION_THRESHOLD): + c = _make_conn(remote_ip=f"10.0.0.{i + 1}") + conns.append(c) + mgr.record_observation(c, observed, local) + + addrs = mgr.addrs() + assert any(str(a) == "/ip4/1.2.3.4/tcp/4001" for a in addrs) + + +def test_below_threshold(): + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + for i in range(ACTIVATION_THRESHOLD - 1): + c = _make_conn(remote_ip=f"10.0.0.{i + 1}") + mgr.record_observation(c, observed, local) + + assert mgr.addrs() == [] + + +def test_duplicate_observer(): + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + # Same IP, different connections — should count as 1 observer. + for _ in range(ACTIVATION_THRESHOLD): + c = _make_conn(remote_ip="10.0.0.1") + mgr.record_observation(c, observed, local) + + assert mgr.addrs() == [] + + +def test_remove_conn_decrements(): + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + conns = [] + for i in range(ACTIVATION_THRESHOLD): + c = _make_conn(remote_ip=f"10.0.0.{i + 1}") + conns.append(c) + mgr.record_observation(c, observed, local) + + # Confirmed. + assert len(mgr.addrs()) > 0 + + # Remove one — drops below threshold. + mgr.remove_conn(conns[0]) + assert mgr.addrs() == [] + + +def test_max_external_addrs(): + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + + # Register MAX + 1 different external addresses, each with enough observers. + for ext_idx in range(MAX_EXTERNAL_ADDRS_PER_LOCAL + 1): + observed = Multiaddr(f"/ip4/1.2.3.{ext_idx}/tcp/4001") + for obs_idx in range(ACTIVATION_THRESHOLD): + c = _make_conn(remote_ip=f"10.{ext_idx}.0.{obs_idx + 1}") + mgr.record_observation(c, observed, local) + + addrs = mgr.addrs() + # Each external thin waist produces exactly 1 addr (no rest suffixes), + # so we expect at most MAX_EXTERNAL_ADDRS_PER_LOCAL. + assert len(addrs) <= MAX_EXTERNAL_ADDRS_PER_LOCAL + + +def test_wildcard_matching(): + """Local 0.0.0.0 matches any IPv4 observation.""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/99.88.77.66/tcp/4001") + + for i in range(ACTIVATION_THRESHOLD): + c = _make_conn(remote_ip=f"10.0.0.{i + 1}") + mgr.record_observation(c, observed, local) + + addrs = mgr.addrs() + assert any(str(a) == "/ip4/99.88.77.66/tcp/4001" for a in addrs) + + +def test_address_inference(): + """Listen /ip4/0.0.0.0/tcp/4001/ws → infer /ip4/1.2.3.4/tcp/4001/ws.""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001/ws")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + for i in range(ACTIVATION_THRESHOLD): + c = _make_conn(remote_ip=f"10.0.0.{i + 1}") + mgr.record_observation(c, observed, local) + + addrs = mgr.addrs() + addr_strs = [str(a) for a in addrs] + assert "/ip4/1.2.3.4/tcp/4001/ws" in addr_strs + + +# --------------------------------------------------------------------------- +# Step 1: Observer count tracks connections (counter, not set) +# --------------------------------------------------------------------------- + + +def test_observer_count_tracks_connections(): + """Two conns from same subnet: disconnect 1, observer should remain.""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + # We need ACTIVATION_THRESHOLD distinct observer groups to confirm. + # Use (ACTIVATION_THRESHOLD - 1) unique IPs, plus 2 conns from same IP. + conns = [] + for i in range(ACTIVATION_THRESHOLD - 1): + c = _make_conn(remote_ip=f"10.0.0.{i + 1}") + conns.append(c) + mgr.record_observation(c, observed, local) + + # Two more connections from the same IP (same observer group). + c_dup1 = _make_conn(remote_ip=f"10.0.0.{ACTIVATION_THRESHOLD}") + c_dup2 = _make_conn(remote_ip=f"10.0.0.{ACTIVATION_THRESHOLD}") + mgr.record_observation(c_dup1, observed, local) + mgr.record_observation(c_dup2, observed, local) + + # Should be confirmed (ACTIVATION_THRESHOLD unique groups). + assert len(mgr.addrs()) > 0 + + # Disconnect one of the duplicate-group connections. + mgr.remove_conn(c_dup1) + + # Observer group should still be counted — address remains confirmed. + assert len(mgr.addrs()) > 0 + + +# --------------------------------------------------------------------------- +# Step 2: Same observation short-circuit +# --------------------------------------------------------------------------- + + +def test_same_observation_shortcircuit(): + """Same observation twice from same conn should not double-count.""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + c = _make_conn(remote_ip="10.0.0.1") + mgr.record_observation(c, observed, local) + mgr.record_observation(c, observed, local) # same observation again + + # Should only have 1 observer group with count 1. + ext_map = mgr._external_addrs.get("/ip4/0.0.0.0/tcp/4001", {}) + observers = ext_map.get("/ip4/1.2.3.4/tcp/4001", {}) + total = sum(observers.values()) + assert total == 1 + + +# --------------------------------------------------------------------------- +# Step 3: Closed connection is skipped +# --------------------------------------------------------------------------- + + +def test_is_closed_skipped(): + """Closed connection should not record observation.""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + c = _make_conn(remote_ip="10.0.0.1", is_closed=True) + mgr.record_observation(c, observed, local) + + assert mgr._external_addrs == {} + + +# --------------------------------------------------------------------------- +# Step 4: NAT64 address rejected +# --------------------------------------------------------------------------- + + +def test_nat64_rejected(): + """Addresses in 64:ff9b::/96 should be rejected.""" + # 1.2.3.4 mapped to NAT64 prefix: 64:ff9b::102:304 + obs = Multiaddr("/ip6/64:ff9b::102:304/tcp/4001") + assert is_valid_observation(obs) is False + + +# --------------------------------------------------------------------------- +# Step 5a: Configurable threshold +# --------------------------------------------------------------------------- + + +def test_addrs_with_min_observers(): + """addrs(min_observers=1) returns with fewer observers.""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + c = _make_conn(remote_ip="10.0.0.1") + mgr.record_observation(c, observed, local) + + # Default threshold should not return it. + assert mgr.addrs() == [] + # But min_observers=1 should. + assert len(mgr.addrs(min_observers=1)) > 0 + assert any(str(a) == "/ip4/1.2.3.4/tcp/4001" for a in mgr.addrs(min_observers=1)) + + +# --------------------------------------------------------------------------- +# Step 5c: addrs_for +# --------------------------------------------------------------------------- + + +def test_addrs_for(): + """Per-listen-address querying.""" + mgr = ObservedAddrManager() + local_tcp = Multiaddr("/ip4/0.0.0.0/tcp/4001") + local_udp = Multiaddr("/ip4/0.0.0.0/udp/5001") + local = [local_tcp, local_udp] + + observed_tcp = Multiaddr("/ip4/1.2.3.4/tcp/4001") + observed_udp = Multiaddr("/ip4/5.6.7.8/udp/5001") + + for i in range(ACTIVATION_THRESHOLD): + c1 = _make_conn(remote_ip=f"10.0.0.{i + 1}") + mgr.record_observation(c1, observed_tcp, local) + c2 = _make_conn(remote_ip=f"10.1.0.{i + 1}") + mgr.record_observation(c2, observed_udp, local) + + tcp_addrs = mgr.addrs_for(local_tcp) + udp_addrs = mgr.addrs_for(local_udp) + + tcp_strs = [str(a) for a in tcp_addrs] + udp_strs = [str(a) for a in udp_addrs] + + assert "/ip4/1.2.3.4/tcp/4001" in tcp_strs + assert "/ip4/5.6.7.8/udp/5001" in udp_strs + # TCP query should not return UDP addresses. + assert "/ip4/5.6.7.8/udp/5001" not in tcp_strs + + # Query with rest suffix: only that suffix should appear (Go parity). + local_ws = Multiaddr("/ip4/0.0.0.0/tcp/4001/ws") + ws_addrs = mgr.addrs_for(local_ws) + ws_strs = [str(a) for a in ws_addrs] + assert "/ip4/1.2.3.4/tcp/4001/ws" in ws_strs + # Bare TW should NOT appear when queried with a rest suffix. + assert "/ip4/1.2.3.4/tcp/4001" not in ws_strs + + +# --------------------------------------------------------------------------- +# Step 6: Sorting tiebreaker (lexicographic) +# --------------------------------------------------------------------------- + + +def test_sorting_tiebreaker(): + """Equal-count addrs should be in lexicographic order.""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + + # Register 3 external addrs with same observer count. + ext_ips = ["1.2.3.4", "1.2.3.2", "1.2.3.3"] + for ext_ip in ext_ips: + observed = Multiaddr(f"/ip4/{ext_ip}/tcp/4001") + for i in range(ACTIVATION_THRESHOLD): + c = _make_conn(remote_ip=f"10.{ext_ips.index(ext_ip)}.0.{i + 1}") + mgr.record_observation(c, observed, local) + + addrs = mgr.addrs() + addr_strs = [str(a) for a in addrs] + # All should be present and in lexicographic order of the TW string. + assert addr_strs == sorted(addr_strs) + + +# --------------------------------------------------------------------------- +# Step 7: NAT type detection +# --------------------------------------------------------------------------- + + +def test_get_nat_type_unknown(): + """Too few observations → UNKNOWN.""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + # Only a couple of observations — not enough for classification. + for i in range(2): + c = _make_conn(remote_ip=f"10.0.0.{i + 1}") + mgr.record_observation(c, observed, local) + + tcp_nat, udp_nat = mgr.get_nat_type() + assert tcp_nat == NATDeviceType.UNKNOWN + assert udp_nat == NATDeviceType.UNKNOWN + + +def test_get_nat_type_independent(): + """Concentrated observations → ENDPOINT_INDEPENDENT (cone NAT).""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + # Many observations to a single external address. + # Need >= 3 * MAX_EXTERNAL_ADDRS_PER_LOCAL = 9 total observations. + for i in range(12): + c = _make_conn(remote_ip=f"10.0.{i // 256}.{i % 256 + 1}") + mgr.record_observation(c, observed, local) + + tcp_nat, udp_nat = mgr.get_nat_type() + assert tcp_nat == NATDeviceType.ENDPOINT_INDEPENDENT + assert udp_nat == NATDeviceType.UNKNOWN # no UDP observations + + +def test_get_nat_type_dependent(): + """Spread observations → ENDPOINT_DEPENDENT (symmetric NAT).""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + + # Spread observations across many different external addresses. + # Each external addr gets 1 observer → no concentration. + for i in range(12): + observed = Multiaddr(f"/ip4/1.2.3.{i}/tcp/4001") + c = _make_conn(remote_ip=f"10.0.{i // 256}.{i % 256 + 1}") + mgr.record_observation(c, observed, local) + + tcp_nat, udp_nat = mgr.get_nat_type() + assert tcp_nat == NATDeviceType.ENDPOINT_DEPENDENT + assert udp_nat == NATDeviceType.UNKNOWN + + +def test_get_nat_type_skips_mixed_transport_bucket(): + """ + White-box test for the defensive TCP/UDP skip guard in + ``ObservedAddrManager.get_nat_type``. + + Normal flow cannot produce a mixed-transport bucket because + ``record_observation`` routes through ``_match_local_thin_waist`` + + ``has_consistent_transport``. We bypass that invariant by poking + ``_external_addrs`` directly to simulate a future refactor leaking a + stray entry of the opposite transport into a bucket. + + This test is designed so the result *differs* between guarded and + unguarded implementations: + + * Bucket is tagged TCP (first inserted key is ``/tcp/``). + * Real TCP observations are spread across 12 distinct external addresses + with 1 observer each → alone classifies as ``ENDPOINT_DEPENDENT``. + * A stray ``/udp/`` entry with 12 concentrated observers is inserted. + * Without the guard, the UDP count [12] would be absorbed into + ``tcp_counts`` and the 50% concentration rule would flip the result + to ``ENDPOINT_INDEPENDENT`` — a silent misclassification. + * With the guard, the stray UDP entry is skipped and TCP stays + ``ENDPOINT_DEPENDENT``. + """ + mgr = ObservedAddrManager() + + local_tw_str = "/ip4/0.0.0.0/tcp/4001" + bucket: dict[str, dict[str, int]] = {} + # 12 distinct TCP external addrs, 1 observer each → DEPENDENT on its own. + for i in range(12): + bucket[f"/ip4/1.2.3.{i}/tcp/4001"] = {f"10.0.0.{i + 1}": 1} + # Stray concentrated UDP entry — would flip classification without the + # guard. + bucket["/ip4/1.2.3.200/udp/4001"] = {f"10.99.0.{i + 1}": 1 for i in range(12)} + mgr._external_addrs[local_tw_str] = bucket + + tcp_nat, udp_nat = mgr.get_nat_type() + + # TCP: stays DEPENDENT because the guard filtered out the stray UDP entry. + assert tcp_nat == NATDeviceType.ENDPOINT_DEPENDENT + # UDP: no legitimate UDP bucket exists, so UDP stays UNKNOWN. + assert udp_nat == NATDeviceType.UNKNOWN + + +# --------------------------------------------------------------------------- +# Step 8: has_consistent_transport +# --------------------------------------------------------------------------- + + +def test_has_consistent_transport(): + """TCP/TCP passes, TCP/UDP fails.""" + tcp_a = Multiaddr("/ip4/1.2.3.4/tcp/4001") + tcp_b = Multiaddr("/ip4/5.6.7.8/tcp/5001") + udp_a = Multiaddr("/ip4/1.2.3.4/udp/4001") + + assert has_consistent_transport(tcp_a, tcp_b) is True + assert has_consistent_transport(tcp_a, udp_a) is False + + # Too few protocols. + short = Multiaddr("/ip4/1.2.3.4") + assert has_consistent_transport(tcp_a, short) is False + + +# --------------------------------------------------------------------------- +# Integration: BasicHost.get_addrs() includes observed addrs +# --------------------------------------------------------------------------- + + +def test_get_addrs_includes_observed(): + from libp2p import new_swarm + from libp2p.crypto.rsa import create_new_key_pair + from libp2p.host.basic_host import BasicHost + + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + # Set up a mock listener so get_transport_addrs returns something. + mock_transport = MagicMock() + mock_transport.get_addrs.return_value = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + swarm.listeners = {"tcp": mock_transport} + + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + for i in range(ACTIVATION_THRESHOLD): + c = _make_conn(remote_ip=f"10.0.0.{i + 1}") + host._observed_addr_manager.record_observation( + c, observed, host.get_transport_addrs() + ) + + addrs = host.get_addrs() + addr_strs = [str(a) for a in addrs] + assert any("/ip4/1.2.3.4/tcp/4001" in s for s in addr_strs) + + +# --------------------------------------------------------------------------- +# Multiaddr cache behavior +# --------------------------------------------------------------------------- + + +def test_addr_cache_invalidation(): + """Cache is populated on addrs() and cleared when an observation is removed.""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + observed = Multiaddr("/ip4/1.2.3.4/tcp/4001") + + conns = [] + for i in range(ACTIVATION_THRESHOLD): + c = _make_conn(remote_ip=f"10.0.0.{i + 1}") + conns.append(c) + mgr.record_observation(c, observed, local) + + # First call populates the cache. + addrs1 = mgr.addrs() + assert len(addrs1) > 0 + assert len(mgr._addr_cache) > 0 + + # Second call should return identical results (cache hit path). + addrs2 = mgr.addrs() + assert [str(a) for a in addrs1] == [str(a) for a in addrs2] + + # Remove enough connections to drop the external address entirely. + for c in conns: + mgr.remove_conn(c) + + # Cache should have been cleared. + assert len(mgr._addr_cache) == 0 + + # addrs() still works correctly after cache clear (cache miss → rebuild). + assert mgr.addrs() == [] + + +def test_old_observation_replaced(): + """When a connection reports a different observed addr, the old one is replaced.""" + mgr = ObservedAddrManager() + local = [Multiaddr("/ip4/0.0.0.0/tcp/4001")] + local_tw_str = "/ip4/0.0.0.0/tcp/4001" + observed_a = Multiaddr("/ip4/1.2.3.4/tcp/4001") + observed_b = Multiaddr("/ip4/5.6.7.8/tcp/4001") + + # Record observed_a from ACTIVATION_THRESHOLD distinct connections. + conns = [] + for i in range(ACTIVATION_THRESHOLD): + c = _make_conn(remote_ip=f"10.0.0.{i + 1}") + conns.append(c) + mgr.record_observation(c, observed_a, local) + + # Address A should be confirmed. + assert any(str(a) == "/ip4/1.2.3.4/tcp/4001" for a in mgr.addrs()) + + # Re-record on the SAME connections with a different observed address B. + for c in conns: + mgr.record_observation(c, observed_b, local) + + addrs = mgr.addrs() + addr_strs = [str(a) for a in addrs] + + # Address A should be gone; address B should be active. + assert "/ip4/1.2.3.4/tcp/4001" not in addr_strs + assert "/ip4/5.6.7.8/tcp/4001" in addr_strs + + # Internal state: _conn_observations should all point to B. + for c in conns: + _, ext, _ = mgr._conn_observations[id(c)] + assert ext == "/ip4/5.6.7.8/tcp/4001" + + # Internal state: old external entry should be cleaned up. + ext_map = mgr._external_addrs.get(local_tw_str, {}) + assert "/ip4/1.2.3.4/tcp/4001" not in ext_map + + +def test_addr_cache_size_cap(): + """Cache never exceeds _ADDR_CACHE_SIZE entries.""" + mgr = ObservedAddrManager() + # Insert more entries than the cap via the internal helper. + for i in range(_ADDR_CACHE_SIZE + 5): + mgr._cached_multiaddr(f"/ip4/1.2.3.{i}/tcp/4001") + + assert len(mgr._addr_cache) == _ADDR_CACHE_SIZE