diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 90e57d6f2..855697c87 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -5,13 +5,23 @@ import logging from pathlib import Path import ssl -from libp2p.transport.quic.utils import is_quic_multiaddr from typing import Any from cryptography.hazmat.primitives.asymmetric import ed25519 from cryptography.hazmat.primitives import serialization -from libp2p.transport.quic.transport import QUICTransport -from libp2p.transport.quic.config import QUICTransportConfig +try: + from libp2p.transport.quic.utils import is_quic_multiaddr + from libp2p.transport.quic.transport import QUICTransport + from libp2p.transport.quic.config import QUICTransportConfig + + _HAS_QUIC = True +except ImportError: + _HAS_QUIC = False + QUICTransport = None + QUICTransportConfig = None + + def is_quic_multiaddr(maddr: Any) -> bool: + return False from collections.abc import ( Mapping, Sequence, @@ -86,24 +96,58 @@ PLAINTEXT_PROTOCOL_ID, InsecureTransport, ) -from libp2p.security.noise.transport import ( - PROTOCOL_ID as NOISE_PROTOCOL_ID, - Transport as NoiseTransport, -) -from libp2p.security.tls.transport import ( - PROTOCOL_ID as TLS_PROTOCOL_ID, - TLSTransport -) -import libp2p.security.secio.transport as secio -from libp2p.stream_muxer.mplex.mplex import ( - MPLEX_PROTOCOL_ID, - Mplex, -) -from libp2p.stream_muxer.yamux.yamux import ( - PROTOCOL_ID as YAMUX_PROTOCOL_ID, - Yamux, -) +try: + from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, + ) + _HAS_NOISE = True +except ImportError: + _HAS_NOISE = False + NOISE_PROTOCOL_ID = None + NoiseTransport = None + +try: + from libp2p.security.tls.transport import ( + PROTOCOL_ID as TLS_PROTOCOL_ID, + TLSTransport, + ) + _HAS_TLS = True +except ImportError: + _HAS_TLS = False + TLS_PROTOCOL_ID = None + TLSTransport = None + +try: + import libp2p.security.secio.transport as secio + _HAS_SECIO = True +except ImportError: + _HAS_SECIO = False + secio = None + +try: + from libp2p.stream_muxer.mplex.mplex import ( + MPLEX_PROTOCOL_ID, + Mplex, + ) + _HAS_MPLEX = True +except ImportError: + _HAS_MPLEX = False + MPLEX_PROTOCOL_ID = None + Mplex = None + +try: + from libp2p.stream_muxer.yamux.yamux import ( + PROTOCOL_ID as YAMUX_PROTOCOL_ID, + Yamux, + ) + _HAS_YAMUX = True +except ImportError: + _HAS_YAMUX = False + YAMUX_PROTOCOL_ID = None + Yamux = None + from libp2p.transport.tcp.tcp import ( TCP, ) @@ -113,6 +157,45 @@ from libp2p.transport.transport_registry import ( create_transport_for_multiaddr, get_supported_transport_protocols, + transport_needs_muxer, + transport_needs_security, +) +from libp2p.capabilities import ( + ConnectionCapabilities, + NeedsSetup, + TransportCapabilities, +) +from libp2p.requirements import ( + ConnectionRequirementError, + after_connection, + check_connection_requirements, + get_after_connections, + get_required_connections, + requires_connection, +) +from libp2p.providers import ( + MuxerProvider, + ProvidesConnection, + ProvidesTransport, + ProviderRegistry, + SecurityProvider, + TransportProvider, +) +from libp2p.network.resolver import ( + AllPathsFailedError, + ConnectionResolver, + NoTransportError, + ResolutionError, + ResolvedStack, +) +from libp2p.entrypoints import ( + EP_GROUP_MUXERS, + EP_GROUP_SECURITY, + EP_GROUP_TRANSPORTS, + discover_and_register, + discover_muxers, + discover_security, + discover_transports, ) import libp2p.utils from libp2p.utils.logging import ( @@ -229,24 +312,32 @@ def create_yamux_muxer_option() -> TMuxerOptions: """ Returns muxer options with Yamux as the primary choice. + Only includes muxers whose extras are installed. + :return: Muxer options with Yamux first """ - return { - TProtocol(YAMUX_PROTOCOL_ID): Yamux, # Primary choice - TProtocol(MPLEX_PROTOCOL_ID): Mplex, # Fallback for compatibility - } + opts: dict[TProtocol, Any] = {} + if _HAS_YAMUX and Yamux is not None and YAMUX_PROTOCOL_ID is not None: + opts[TProtocol(YAMUX_PROTOCOL_ID)] = Yamux + if _HAS_MPLEX and Mplex is not None and MPLEX_PROTOCOL_ID is not None: + opts[TProtocol(MPLEX_PROTOCOL_ID)] = Mplex + return opts def create_mplex_muxer_option() -> TMuxerOptions: """ Returns muxer options with Mplex as the primary choice. + Only includes muxers whose extras are installed. + :return: Muxer options with Mplex first """ - return { - TProtocol(MPLEX_PROTOCOL_ID): Mplex, # Primary choice - TProtocol(YAMUX_PROTOCOL_ID): Yamux, # Fallback - } + opts: dict[TProtocol, Any] = {} + if _HAS_MPLEX and Mplex is not None and MPLEX_PROTOCOL_ID is not None: + opts[TProtocol(MPLEX_PROTOCOL_ID)] = Mplex + if _HAS_YAMUX and Yamux is not None and YAMUX_PROTOCOL_ID is not None: + opts[TProtocol(YAMUX_PROTOCOL_ID)] = Yamux + return opts def generate_new_rsa_identity() -> KeyPair: @@ -289,7 +380,7 @@ def new_swarm( enable_quic: bool = False, enable_autotls: bool = False, retry_config: RetryConfig | None = None, - connection_config: ConnectionConfig | QUICTransportConfig | None = None, + connection_config: ConnectionConfig | Any | None = None, tls_client_config: ssl.SSLContext | None = None, tls_server_config: ssl.SSLContext | None = None, resource_manager: ResourceManager | None = None, @@ -328,16 +419,22 @@ def new_swarm( id_opt = generate_peer_id_from(key_pair) - transport: TCP | QUICTransport | ITransport - quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None + transport: TCP | ITransport + quic_transport_opt = ( + connection_config + if _HAS_QUIC and QUICTransportConfig is not None + and isinstance(connection_config, QUICTransportConfig) + else None + ) if listen_addrs is None: if enable_quic: - transport = QUICTransport( - key_pair.private_key, - config=quic_transport_opt, - enable_autotls=enable_autotls, - ) + if not _HAS_QUIC or QUICTransport is None: + raise ImportError( + "QUIC transport is not available. " + "Install the 'quic' extra: pip install libp2p[quic]" + ) + transport = QUICTransport(key_pair.private_key, config=quic_transport_opt, enable_autotls=enable_autotls) else: transport = TCP() else: @@ -370,13 +467,13 @@ def new_swarm( logger.debug(f"new_swarm: Created transport: {type(transport)}") # If enable_quic is True but we didn't get a QUIC transport, force QUIC - if enable_quic and not isinstance(transport, QUICTransport): - logger.debug(f"new_swarm: Forcing QUIC transport (enable_quic=True but got {type(transport)})") - transport = QUICTransport( - key_pair.private_key, - config=quic_transport_opt, - enable_autotls=enable_autotls, - ) + if enable_quic and _HAS_QUIC and QUICTransport is not None: + if not isinstance(transport, QUICTransport): + logger.debug( + f"new_swarm: Forcing QUIC transport (enable_quic=True " + f"but got {type(transport)})" + ) + transport = QUICTransport(key_pair.private_key, config=quic_transport_opt, enable_autotls=enable_autotls) logger.debug(f"new_swarm: Final transport type: {type(transport)}") @@ -387,19 +484,25 @@ def new_swarm( # NOTE: Using Noise as primary for now because Python's ssl module has limitations # with mutual TLS authentication. See TLS_ANALYSIS.md for details. # TLS is still offered as a fallback option. - secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] = sec_opt or { - # TLS_PROTOCOL_ID: TLSTransport(key_pair), - NOISE_PROTOCOL_ID: NoiseTransport( - key_pair, noise_privkey=noise_key_pair.private_key - ), - TLS_PROTOCOL_ID: TLSTransport ( - key_pair, enable_autotls = enable_autotls - ), - TProtocol(secio.ID): secio.Transport(key_pair), - TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport( + # Only include transports whose optional extras are installed. + if sec_opt is not None: + secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] = sec_opt + else: + _sec_map: dict[TProtocol, ISecureTransport] = {} + if _HAS_NOISE and NoiseTransport is not None: + _sec_map[NOISE_PROTOCOL_ID] = NoiseTransport( + key_pair, noise_privkey=noise_key_pair.private_key + ) + if _HAS_TLS and TLSTransport is not None: + _sec_map[TLS_PROTOCOL_ID] = TLSTransport( + key_pair, enable_autotls=enable_autotls + ) + if _HAS_SECIO and secio is not None: + _sec_map[TProtocol(secio.ID)] = secio.Transport(key_pair) + _sec_map[TProtocol(PLAINTEXT_PROTOCOL_ID)] = InsecureTransport( key_pair, peerstore=peerstore_opt - ), - } + ) + secure_transports_by_protocol = _sec_map # Use given muxer preference if provided, otherwise use global default if muxer_preference is not None: @@ -471,7 +574,7 @@ def new_host( bootstrap: list[str] | None = None, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, enable_quic: bool = False, - quic_transport_opt: QUICTransportConfig | None = None, + quic_transport_opt: Any | None = None, tls_client_config: ssl.SSLContext | None = None, tls_server_config: ssl.SSLContext | None = None, resource_manager: ResourceManager | None = None, @@ -526,7 +629,7 @@ def new_host( # Determine the connection config to use # QUIC transport config takes precedence if QUIC is enabled - effective_config: ConnectionConfig | QUICTransportConfig | None + effective_config: ConnectionConfig | Any | None if enable_quic and quic_transport_opt is not None: effective_config = quic_transport_opt else: diff --git a/libp2p/bitswap/cid.py b/libp2p/bitswap/cid.py index 9f21d90de..fa9f223d4 100644 --- a/libp2p/bitswap/cid.py +++ b/libp2p/bitswap/cid.py @@ -24,8 +24,19 @@ import hashlib from typing import TypeAlias -from cid import CIDv0, CIDv1, V0Builder, V1Builder, from_string, make_cid -from cid.prefix import Prefix +try: + from cid import CIDv0, CIDv1, V0Builder, V1Builder, from_string, make_cid + from cid.prefix import Prefix + _HAS_CID_BUILDERS = True +except Exception: + # Older/newer py-cid variations may not expose builder/prefix helpers. + from cid import CIDv0, CIDv1, from_string, make_cid + Prefix = None # type: ignore + V0Builder = None # type: ignore + V1Builder = None # type: ignore + _HAS_CID_BUILDERS = False + +import multihash as _multihash from multicodec import Code, is_codec from multicodec.code_table import DAG_PB, RAW, SHA2_256 @@ -61,7 +72,12 @@ def _normalise_codec(codec: Code | str | int) -> Code: def compute_cid_v0_obj(data: bytes) -> CIDv0: """Compute a CIDv0 object for data.""" - return V0Builder().sum(data) + if _HAS_CID_BUILDERS and V0Builder is not None: + return V0Builder().sum(data) + + # Fallback: compute sha2-256 multihash and construct CIDv0 + mh = _multihash.digest(data, "sha2-256") + return CIDv0(mh.encode()) def compute_cid_v0(data: bytes) -> bytes: @@ -84,7 +100,13 @@ def compute_cid_v0(data: bytes) -> bytes: def compute_cid_v1_obj(data: bytes, codec: Code | str | int = CODEC_RAW) -> CIDv1: """Compute a CIDv1 object for data and codec.""" code_obj = _normalise_codec(codec) - return V1Builder(codec=str(code_obj), mh_type=str(HASH_SHA256)).sum(data) + if _HAS_CID_BUILDERS and V1Builder is not None: + return V1Builder(codec=str(code_obj), mh_type=str(HASH_SHA256)).sum(data) + + # Fallback: compute multihash and construct CIDv1 using codec name + codec_name = getattr(code_obj, "name", str(code_obj)) + mh = _multihash.digest(data, "sha2-256") + return CIDv1(codec_name, mh.encode()) def compute_cid_v1(data: bytes, codec: Code | str | int = CODEC_RAW) -> bytes: @@ -132,7 +154,26 @@ def get_cid_prefix(cid: CIDInput) -> bytes: if cid_obj.version != CID_V1: return b"" - return cid_obj.prefix().to_bytes() + # Prefer high-level Prefix helper when available + if Prefix is not None and hasattr(cid_obj, "prefix"): + try: + return cid_obj.prefix().to_bytes() + except Exception: + pass + + # Fallback: reconstruct prefix by removing the digest bytes from the + # raw CID buffer using py-multihash to determine digest length. + try: + mh_bytes = getattr(cid_obj, "multihash", None) + if mh_bytes is None: + # Some CID implementations expose different attributes + # Fall back to parsing the tail of the buffer. + return b"" + mh = _multihash.decode(mh_bytes) + digest_len = mh.length + return cid_obj.buffer[:-digest_len] + except Exception: + return b"" def reconstruct_cid_from_prefix_and_data(prefix: bytes, data: bytes) -> bytes: @@ -153,12 +194,16 @@ def reconstruct_cid_from_prefix_and_data(prefix: bytes, data: bytes) -> bytes: # No prefix means CIDv0 return compute_cid_v0(data) - try: - return Prefix.from_bytes(prefix).sum(data).buffer - except ValueError: - # Preserve previous permissive behavior for malformed prefixes. - digest = hashlib.sha256(data).digest() - return prefix + digest + if Prefix is not None: + try: + return Prefix.from_bytes(prefix).sum(data).buffer + except ValueError: + pass + + # Fallback: when Prefix helper isn't available, try to conservatively + # append the raw digest to the prefix (preserves prior permissive behavior). + digest = hashlib.sha256(data).digest() + return prefix + digest def verify_cid(cid: CIDInput, data: bytes) -> bool: @@ -187,7 +232,11 @@ def verify_cid(cid: CIDInput, data: bytes) -> bool: return False try: - recomputed = cid_obj.prefix().sum(data).buffer + # Prefer using compute helpers which work across py-cid variants. + if cid_obj.version == CID_V1: + recomputed = compute_cid_v1(data, codec=getattr(cid_obj, "codec", None)) + else: + recomputed = compute_cid_v0(data) except ValueError: logger.debug(" Failed to recompute CID from parsed prefix") return False diff --git a/libp2p/capabilities.py b/libp2p/capabilities.py new file mode 100644 index 000000000..df222aab2 --- /dev/null +++ b/libp2p/capabilities.py @@ -0,0 +1,89 @@ +""" +Transport and connection capability declarations. + +Instead of the core checking ``isinstance(transport, QUICTransport)`` to decide +whether to skip security or muxing upgrades, transports and connections *declare* +what they already provide via simple boolean properties. + +Any transport or connection that satisfies the structural protocol (duck typing) +is automatically recognised — no base-class inheritance required. + +Example — a hypothetical WebRTC transport that bundles its own DTLS + SCTP:: + + class WebRTCTransport(ITransport): + @property + def provides_security(self) -> bool: + return True + + @property + def provides_muxing(self) -> bool: + return True + +See Also +-------- +:pep:`544` — Structural subtyping (static duck typing) via ``typing.Protocol``. +""" + +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class TransportCapabilities(Protocol): + """Structural protocol implemented by transports that bundle their own + security and/or multiplexing layers. + + A transport that does **not** implement these properties is assumed to + provide neither (``getattr(t, 'provides_security', False)`` → ``False``). + """ + + @property + def provides_security(self) -> bool: + """Return ``True`` if this transport includes built-in encryption / + authentication (e.g. QUIC's integrated TLS 1.3).""" + ... + + @property + def provides_muxing(self) -> bool: + """Return ``True`` if this transport includes built-in stream + multiplexing (e.g. QUIC's native streams).""" + ... + + +@runtime_checkable +class ConnectionCapabilities(Protocol): + """Structural protocol implemented by connections that are already + secured and/or multiplexed at creation time. + + A connection that does **not** implement these properties is assumed to + be neither secure nor muxed. + """ + + @property + def is_secure(self) -> bool: + """Return ``True`` if the connection is already encrypted and the + remote peer's identity has been verified.""" + ... + + @property + def is_muxed(self) -> bool: + """Return ``True`` if the connection already supports opening + multiple independent streams.""" + ... + + +@runtime_checkable +class NeedsSetup(Protocol): + """Structural protocol for transports that require lifecycle hooks + from the swarm (e.g. a background nursery or a back-reference to + the swarm itself). + + Transports that do **not** need these hooks simply omit the methods. + """ + + def set_background_nursery(self, nursery: object) -> None: + """Receive the long-lived nursery managed by the swarm.""" + ... + + def set_swarm(self, swarm: object) -> None: + """Receive a back-reference to the owning swarm.""" + ... diff --git a/libp2p/custom_types.py b/libp2p/custom_types.py index 9bfb89ea9..492e0b6f5 100644 --- a/libp2p/custom_types.py +++ b/libp2p/custom_types.py @@ -5,17 +5,13 @@ ) from typing import TYPE_CHECKING, NewType, Union, cast -from libp2p.transport.quic.stream import QUICStream - if TYPE_CHECKING: from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport - from libp2p.transport.quic.connection import QUICConnection else: IMuxedConn = cast(type, object) INetStream = cast(type, object) ISecureTransport = cast(type, object) IMuxedStream = cast(type, object) - QUICConnection = cast(type, object) from libp2p.io.abc import ( ReadWriteCloser, @@ -38,6 +34,11 @@ AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] UnsubscribeFn = Callable[[], Awaitable[None]] -TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]] -TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]] MessageID = NewType("MessageID", str) + +# Re-export QUIC-specific types for backward compatibility. +# New code should import directly from libp2p.transport.quic.types. +from libp2p.transport.quic.types import ( # noqa: E402, F401 + TQUICConnHandlerFn as TQUICConnHandlerFn, + TQUICStreamHandlerFn as TQUICStreamHandlerFn, +) diff --git a/libp2p/entrypoints.py b/libp2p/entrypoints.py new file mode 100644 index 000000000..e65d3ec46 --- /dev/null +++ b/libp2p/entrypoints.py @@ -0,0 +1,209 @@ +""" +Entry-point-based plugin discovery. + +Allows transports, security protocols, and stream muxers to be +registered automatically via Python packaging entry points. This lets +third-party packages contribute to the libp2p stack simply by declaring +an entry-point in their ``pyproject.toml``:: + + [project.entry-points."libp2p.transports"] + quic = "libp2p.transport.quic:create_provider" + + [project.entry-points."libp2p.security"] + noise = "libp2p.security.noise:create_provider" + + [project.entry-points."libp2p.muxers"] + yamux = "libp2p.stream_muxer.yamux:create_provider" + +Each entry point must resolve to a **callable** that takes no arguments +and returns a :class:`~libp2p.providers.TransportProvider`, +:class:`~libp2p.providers.SecurityProvider`, or +:class:`~libp2p.providers.MuxerProvider`, respectively. + +The main function :func:`discover_and_register` scans all three groups +and populates a :class:`~libp2p.providers.ProviderRegistry`. + +See Also +-------- +:mod:`libp2p.providers` — provider abstractions and registry. +:mod:`libp2p.network.resolver` — resolver that consumes the registry. + +""" + +from __future__ import annotations + +import logging +import sys +from typing import Any + +from libp2p.providers import ( + MuxerProvider, + ProviderRegistry, + SecurityProvider, + TransportProvider, +) + +logger = logging.getLogger(__name__) + +EP_GROUP_TRANSPORTS = "libp2p.transports" +EP_GROUP_SECURITY = "libp2p.security" +EP_GROUP_MUXERS = "libp2p.muxers" + + +def _load_entry_points(group: str) -> list[tuple[str, Any]]: + """ + Load all entry points for *group*. + + Returns a list of ``(name, loaded_object)`` tuples. Loading errors + are logged and skipped so one broken plugin cannot take down the + whole application. + """ + if sys.version_info >= (3, 12): + from importlib.metadata import entry_points + + eps = entry_points(group=group) + else: + from importlib.metadata import entry_points + + eps = entry_points(group=group) + + results: list[tuple[str, Any]] = [] + for ep in eps: + try: + obj = ep.load() + results.append((ep.name, obj)) + except Exception: + logger.warning( + "Failed to load entry point %r from group %r", + ep.name, + group, + exc_info=True, + ) + return results + + +def discover_transports() -> list[TransportProvider]: + """ + Discover transport providers registered via entry points. + + Each entry point must be a callable that returns a + :class:`~libp2p.providers.TransportProvider`. + """ + providers: list[TransportProvider] = [] + for name, factory in _load_entry_points(EP_GROUP_TRANSPORTS): + try: + provider = factory() if callable(factory) else factory + if isinstance(provider, TransportProvider): + providers.append(provider) + logger.debug("Discovered transport provider: %s", name) + else: + logger.warning( + "Entry point %r (group %s) did not produce a " + "TransportProvider; got %s", + name, + EP_GROUP_TRANSPORTS, + type(provider).__name__, + ) + except Exception: + logger.warning( + "Error creating transport provider from entry point %r", + name, + exc_info=True, + ) + return providers + + +def discover_security() -> list[SecurityProvider]: + """ + Discover security providers registered via entry points. + + Each entry point must be a callable that returns a + :class:`~libp2p.providers.SecurityProvider`. + """ + providers: list[SecurityProvider] = [] + for name, factory in _load_entry_points(EP_GROUP_SECURITY): + try: + provider = factory() if callable(factory) else factory + if isinstance(provider, SecurityProvider): + providers.append(provider) + logger.debug("Discovered security provider: %s", name) + else: + logger.warning( + "Entry point %r (group %s) did not produce a " + "SecurityProvider; got %s", + name, + EP_GROUP_SECURITY, + type(provider).__name__, + ) + except Exception: + logger.warning( + "Error creating security provider from entry point %r", + name, + exc_info=True, + ) + return providers + + +def discover_muxers() -> list[MuxerProvider]: + """ + Discover muxer providers registered via entry points. + + Each entry point must be a callable that returns a + :class:`~libp2p.providers.MuxerProvider`. + """ + providers: list[MuxerProvider] = [] + for name, factory in _load_entry_points(EP_GROUP_MUXERS): + try: + provider = factory() if callable(factory) else factory + if isinstance(provider, MuxerProvider): + providers.append(provider) + logger.debug("Discovered muxer provider: %s", name) + else: + logger.warning( + "Entry point %r (group %s) did not produce a MuxerProvider; got %s", + name, + EP_GROUP_MUXERS, + type(provider).__name__, + ) + except Exception: + logger.warning( + "Error creating muxer provider from entry point %r", + name, + exc_info=True, + ) + return providers + + +def discover_and_register( + registry: ProviderRegistry | None = None, +) -> ProviderRegistry: + """ + Scan all entry-point groups and register discovered providers. + + Parameters + ---------- + registry: + An existing registry to populate. If ``None``, a new + :class:`~libp2p.providers.ProviderRegistry` is created. + + Returns + ------- + ProviderRegistry + The (possibly new) registry, with all discovered providers + registered. + + """ + if registry is None: + registry = ProviderRegistry() + + for tp in discover_transports(): + registry.register_transport(tp) + + for sp in discover_security(): + registry.register_security(sp) + + for mp in discover_muxers(): + registry.register_muxer(mp) + + logger.info("Entry-point discovery complete: %s", registry) + return registry diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 2b5adca27..3cc8c6201 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -63,6 +63,7 @@ ID_PUSH as IdentifyPushID, _update_peerstore_from_identify, ) +from libp2p.network.resolver import ConnectionResolver from libp2p.peer.id import ( ID, ) @@ -93,7 +94,6 @@ from libp2p.tools.anyio_service import ( background_trio_service, ) -from libp2p.transport.quic.connection import QUICConnection import libp2p.utils.paths from libp2p.utils.varint import ( read_length_prefixed_protobuf, @@ -175,6 +175,7 @@ class BasicHost(IHost): mDNS: MDNSDiscovery | None upnp: UpnpManager | None bootstrap: BootstrapDiscovery | None + _resolver: ConnectionResolver | None def __init__( self, @@ -187,11 +188,12 @@ def __init__( negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, resource_manager: ResourceManager | None = None, psk: str | None = None, - *, + resolver: ConnectionResolver | None = None, bootstrap_allow_ipv6: bool = False, bootstrap_dns_timeout: float = 10.0, bootstrap_dns_max_retries: int = 3, announce_addrs: Sequence[multiaddr.Multiaddr] | None = None, + *, ) -> None: """ Initialize a BasicHost instance. @@ -204,16 +206,23 @@ def __init__( :param negotiate_timeout: Protocol negotiation timeout :param resource_manager: Optional resource manager instance :type resource_manager: :class:`libp2p.rcmgr.ResourceManager` or None + :param resolver: Optional pull-based connection resolver. + When set, :meth:`connect` can use the resolver to build the + connection stack dynamically instead of the fixed + ``TransportUpgrader`` pipeline. :param bootstrap_allow_ipv6: If True, bootstrap uses IPv6+TCP when available. :param bootstrap_dns_timeout: DNS resolution timeout in seconds per attempt. :param bootstrap_dns_max_retries: Max DNS resolution retries (with backoff). :param announce_addrs: Optional addresses to advertise instead of listen addresses. ``None`` (default) uses listen addresses; an empty list advertises no addresses. + :type resolver: :class:`libp2p.network.resolver.ConnectionResolver` + or None """ self._network = network self._network.set_stream_handler(self._swarm_stream_handler) self.peerstore = self._network.peerstore + self._resolver = resolver # Coordinate negotiate_timeout with transport config if available # For QUIC transports, use the config value to ensure consistency @@ -281,6 +290,11 @@ def __init__( self._identified_peers: set[ID] = set() self._network.register_notifee(_IdentifyNotifee(self)) + @property + def resolver(self) -> ConnectionResolver | None: + """Return the pull-based connection resolver, if configured.""" + return self._resolver + def get_id(self) -> ID: """ :return: peer_id of host @@ -1079,7 +1093,7 @@ async def _on_notifee_connected(self, conn: INetConn) -> None: if not is_initiator: # Only the dialer (initiator) needs to actively run identify. return - if not self._is_quic_muxer(muxed_conn): + if not self._is_native_muxer(muxed_conn): return event_started = getattr(conn, "event_started", None) if event_started is not None and not event_started.is_set(): @@ -1101,15 +1115,16 @@ def _get_first_connection(self, peer_id: ID) -> INetConn | None: return connections[0] return None - def _is_quic_muxer(self, muxed_conn: IMuxedConn | None) -> bool: - return isinstance(muxed_conn, QUICConnection) + def _is_native_muxer(self, muxed_conn: IMuxedConn | None) -> bool: + """Return True if the muxed connection is natively muxed (e.g. QUIC).""" + return getattr(muxed_conn, "is_muxed", False) def _should_identify_peer(self, peer_id: ID) -> bool: connection = self._get_first_connection(peer_id) if connection is None: return False muxed_conn = getattr(connection, "muxed_conn", None) - return self._is_quic_muxer(muxed_conn) + return self._is_native_muxer(muxed_conn) # Reference: `BasicHost.newStreamHandler` in Go. async def _swarm_stream_handler(self, net_stream: INetStream) -> None: @@ -1186,6 +1201,18 @@ async def _swarm_stream_handler(self, net_stream: INetStream) -> None: await net_stream.reset() return + # Check handler connection requirements (if declared) + from libp2p.requirements import check_connection_requirements + + underlying_conn = getattr(net_stream, "muxed_conn", None) + if not check_connection_requirements(handler, underlying_conn): + logger.warning( + "Handler for protocol %s has unmet connection requirements " + "on stream from peer %s — proceeding anyway", + protocol, + net_stream.muxed_conn.peer_id, + ) + await handler(net_stream) def get_live_peers(self) -> list[ID]: diff --git a/libp2p/network/resolver.py b/libp2p/network/resolver.py new file mode 100644 index 000000000..cdc757252 --- /dev/null +++ b/libp2p/network/resolver.py @@ -0,0 +1,362 @@ +""" +Pull-based connection resolver. + +The resolver implements a *pull model*: an application (or the host) +requests a connection with certain capabilities (e.g. ``IMuxedConn``), +and the resolver builds the connection stack automatically from +registered providers. + +How it works +------------ +1. **Desired capability** — the caller asks for a connection satisfying + a target interface (e.g. ``IMuxedConn``). +2. **Transport selection** — the resolver picks a transport provider + that can dial the target multiaddr. +3. **Layer resolution** — using ordering metadata + (``@after_connection``) and the provider registry, the resolver + determines which upgrade layers are needed (security, muxer) and in + what order. +4. **Stack execution** — the resolver dials the transport, then applies + each upgrade layer in sequence, short-circuiting if the transport + already provides security / muxing (capability flags). +5. **Fallback** — if a path fails (handshake error, unsupported + protocol), the resolver tries the next registered transport. + +This module is opt-in: existing code using the fixed +``TransportUpgrader`` pipeline continues to work. The host can switch +to the resolver by calling :meth:`ConnectionResolver.resolve` instead +of the manual dial→upgrade→upgrade sequence. + +See Also +-------- +:mod:`libp2p.providers` — provider abstractions and registry. +:mod:`libp2p.capabilities` — capability flags. +:mod:`libp2p.requirements` — ordering metadata. + +""" + +from __future__ import annotations + +from dataclasses import dataclass +import logging +from typing import Any + +from multiaddr import Multiaddr + +from libp2p.abc import ( + IMuxedConn, + IRawConnection, + ISecureConn, +) +from libp2p.peer.id import ID +from libp2p.providers import ( + MuxerProvider, + ProviderRegistry, + SecurityProvider, + TransportProvider, +) +from libp2p.requirements import get_after_connections +from libp2p.transport.exceptions import ( + MuxerUpgradeFailure, + SecurityUpgradeFailure, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class ResolvedStack: + """ + The result of a successful resolution. + + Contains the final connection and metadata about how the stack was + built, useful for diagnostics and logging. + """ + + raw_conn: IRawConnection | None = None + secure_conn: ISecureConn | None = None + muxed_conn: IMuxedConn | None = None + + transport_provider: TransportProvider | None = None + security_provider: SecurityProvider | None = None + muxer_provider: MuxerProvider | None = None + + skipped_security: bool = False + skipped_muxer: bool = False + + @property + def top_connection(self) -> Any: + """Return the highest-level connection in the stack.""" + if self.muxed_conn is not None: + return self.muxed_conn + if self.secure_conn is not None: + return self.secure_conn + return self.raw_conn + + def describes(self) -> str: + """Human-readable description of the resolved stack.""" + parts: list[str] = [] + if self.transport_provider: + parts.append(f"transport={self.transport_provider.protocol_name}") + if self.skipped_security: + parts.append("security=builtin") + elif self.security_provider: + parts.append(f"security={self.security_provider.protocol_id}") + if self.skipped_muxer: + parts.append("muxer=builtin") + elif self.muxer_provider: + parts.append(f"muxer={self.muxer_provider.protocol_id}") + return " → ".join(parts) or "(empty)" + + +class ResolutionError(Exception): + """No viable connection stack could be built.""" + + +class NoTransportError(ResolutionError): + """No registered transport can dial the target multiaddr.""" + + +class AllPathsFailedError(ResolutionError): + """Every transport / upgrade path failed.""" + + def __init__(self, failures: list[tuple[str, Exception]]) -> None: + self.failures = failures + paths = "; ".join(f"{name}: {err}" for name, err in failures) + super().__init__(f"All resolution paths failed: {paths}") + + +class ConnectionResolver: + """ + Pull-based connection stack builder. + + Given a target multiaddr and a peer ID, the resolver tries every + registered transport that can dial the address, applies the required + upgrade layers (security, muxer), and returns the resulting + connection stack. + + Parameters + ---------- + registry: + The :class:`~libp2p.providers.ProviderRegistry` to consult. + + """ + + def __init__(self, registry: ProviderRegistry) -> None: + self._registry = registry + + async def resolve( + self, + maddr: Multiaddr, + peer_id: ID, + *, + is_initiator: bool = True, + ) -> ResolvedStack: + """ + Resolve a fully-upgraded connection to *peer_id* via *maddr*. + + The resolver iterates over transport providers that match + *maddr*. For each, it: + + 1. Dials the transport to get a raw connection. + 2. Checks capability flags — if the transport already provides + security and muxing, returns immediately. + 3. Otherwise, applies security then muxer upgrades using the + registered providers. + 4. If any step fails, moves to the next transport. + + Parameters + ---------- + maddr: + Target multiaddr. + peer_id: + Remote peer identity. + is_initiator: + Whether we initiated the connection (affects security + handshake direction). + + Returns + ------- + ResolvedStack + The fully-resolved connection stack. + + Raises + ------ + NoTransportError + If no transport can dial *maddr*. + AllPathsFailedError + If every transport path fails. + + """ + candidates = self._registry.get_transports_for(maddr) + if not candidates: + raise NoTransportError(f"No registered transport can dial {maddr}") + + failures: list[tuple[str, Exception]] = [] + + for tp in candidates: + try: + stack = await self._try_path(tp, maddr, peer_id, is_initiator) + logger.info( + "Resolver: connection established via %s", stack.describes() + ) + return stack + except Exception as exc: + logger.debug("Resolver: path %s failed: %s", tp.protocol_name, exc) + failures.append((tp.protocol_name, exc)) + + raise AllPathsFailedError(failures) + + async def _try_path( + self, + tp: TransportProvider, + maddr: Multiaddr, + peer_id: ID, + is_initiator: bool, + ) -> ResolvedStack: + """Attempt to build a full stack using one transport provider.""" + stack = ResolvedStack(transport_provider=tp) + + raw_conn = await tp.dial(maddr) + stack.raw_conn = raw_conn + + if tp.provides_security and tp.provides_muxing: + stack.skipped_security = True + stack.skipped_muxer = True + logger.debug( + "Transport %s provides security + muxing; skipping upgrades", + tp.protocol_name, + ) + return stack + + conn_for_muxer: Any = raw_conn + if tp.provides_security: + stack.skipped_security = True + conn_for_muxer = raw_conn + logger.debug( + "Transport %s provides security; skipping security upgrade", + tp.protocol_name, + ) + else: + sec_providers = self._registry.get_security_providers() + if not sec_providers: + raise SecurityUpgradeFailure("No security providers registered") + sec_ok = False + for sp in sec_providers: + try: + secure_conn = await sp.upgrade(raw_conn, is_initiator, peer_id) + stack.secure_conn = secure_conn + stack.security_provider = sp + conn_for_muxer = secure_conn + sec_ok = True + break + except Exception as exc: + logger.debug("Security provider %s failed: %s", sp.protocol_id, exc) + continue + if not sec_ok: + try: + await raw_conn.close() + except Exception: + pass + raise SecurityUpgradeFailure("All security providers failed") + + if tp.provides_muxing: + stack.skipped_muxer = True + logger.debug( + "Transport %s provides muxing; skipping muxer upgrade", + tp.protocol_name, + ) + else: + mux_providers = self._registry.get_muxer_providers() + if not mux_providers: + raise MuxerUpgradeFailure("No muxer providers registered") + + for mp in mux_providers: + after = get_after_connections(mp.muxer_class) + for iface in after: + if not isinstance(conn_for_muxer, iface): + logger.warning( + "Muxer %s declares @after_connection(%s) " + "but connection (%s) does not satisfy it", + mp.muxer_class.__name__, + iface.__name__, + type(conn_for_muxer).__name__, + ) + + mux_ok = False + for mp in mux_providers: + try: + muxed = await mp.upgrade(conn_for_muxer, peer_id) + stack.muxed_conn = muxed + stack.muxer_provider = mp + mux_ok = True + break + except Exception as exc: + logger.debug("Muxer provider %s failed: %s", mp.protocol_id, exc) + continue + if not mux_ok: + try: + await conn_for_muxer.close() + except Exception: + pass + raise MuxerUpgradeFailure("All muxer providers failed") + + return stack + + async def upgrade_inbound( + self, + raw_conn: IRawConnection, + *, + transport_has_security: bool = False, + transport_has_muxing: bool = False, + ) -> ResolvedStack: + """ + Upgrade an inbound (listener-accepted) raw connection. + + Unlike :meth:`resolve`, we already have the raw connection — we + just need to apply security and muxer layers. + """ + stack = ResolvedStack(raw_conn=raw_conn) + + if transport_has_security and transport_has_muxing: + stack.skipped_security = True + stack.skipped_muxer = True + return stack + + conn_for_muxer: Any = raw_conn + + # Security + if transport_has_security: + stack.skipped_security = True + conn_for_muxer = raw_conn + else: + sec_providers = self._registry.get_security_providers() + if sec_providers: + sp = sec_providers[0] + secure_conn = await sp.upgrade(raw_conn, is_initiator=False) + stack.secure_conn = secure_conn + stack.security_provider = sp + conn_for_muxer = secure_conn + else: + raise SecurityUpgradeFailure( + "No security providers registered for inbound" + ) + + if transport_has_muxing: + stack.skipped_muxer = True + else: + mux_providers = self._registry.get_muxer_providers() + if mux_providers: + peer_id = ( + stack.secure_conn.get_remote_peer() + if stack.secure_conn is not None + else ID(b"\x00") + ) + mp = mux_providers[0] + muxed = await mp.upgrade(conn_for_muxer, peer_id) + stack.muxed_conn = muxed + stack.muxer_provider = mp + else: + raise MuxerUpgradeFailure("No muxer providers registered for inbound") + + return stack diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 74533250c..b989816aa 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -38,6 +38,7 @@ ) from libp2p.network.auto_connector import AutoConnector from libp2p.network.config import ConnectionConfig, RetryConfig +from libp2p.transport.quic.config import QUICTransportConfig from libp2p.network.connection_gate import ConnectionGate from libp2p.network.connection_pruner import ConnectionPruner from libp2p.network.tag_store import TagInfo, TagStore @@ -57,9 +58,6 @@ OpenConnectionError, SecurityUpgradeFailure, ) -from libp2p.transport.quic.config import QUICTransportConfig -from libp2p.transport.quic.connection import QUICConnection -from libp2p.transport.quic.transport import QUICTransport from libp2p.transport.upgrader import ( TransportUpgrader, ) @@ -215,13 +213,10 @@ async def run(self) -> None: # Set background nursery BEFORE setting the event # This ensures transports have the nursery when they check - if isinstance(self.transport, QUICTransport): - self.transport.set_background_nursery(nursery) - self.transport.set_swarm(self) - elif hasattr(self.transport, "set_background_nursery"): - # WebSocket transport also needs background nursery - # for connection management + if hasattr(self.transport, "set_background_nursery"): self.transport.set_background_nursery(nursery) # type: ignore[attr-defined] + if hasattr(self.transport, "set_swarm"): + self.transport.set_swarm(self) # type: ignore[attr-defined] # Signal that the background nursery is available. self.event_background_nursery_created.set() @@ -670,11 +665,11 @@ async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetC pass raise SwarmException(f"Unexpected error dialing peer {peer_id}") from e - if isinstance(self.transport, QUICTransport) and isinstance( - raw_conn, IMuxedConn + if getattr(self.transport, 'provides_muxing', False) and getattr( + raw_conn, 'is_muxed', False ): logger.info( - "Skipping upgrade for QUIC, QUIC connections are already multiplexed" + "Skipping upgrade — transport already provides security + muxing" ) try: swarm_conn = await self.add_conn(raw_conn, direction="outbound") @@ -931,6 +926,7 @@ async def new_stream(self, peer_id: ID) -> INetStream: f"Failed to get a valid connection for peer {peer_id}" ) + if getattr(self.transport, 'provides_muxing', False) and connection is not None: net_stream = await self._open_stream_on_connection( connection, connections, peer_id ) @@ -1167,14 +1163,16 @@ async def conn_handler( pass return - # No need to upgrade QUIC Connection - if isinstance(self.transport, QUICTransport): + # No need to upgrade connections from transports with built-in muxing + if getattr(self.transport, 'provides_muxing', False): try: - quic_conn = cast(QUICConnection, read_write_closer) - await self.add_conn(quic_conn, direction="inbound") - peer_id = quic_conn.peer_id + # The connection is already muxed; add it directly. + muxed_conn = cast(IMuxedConn, read_write_closer) + await self.add_conn(muxed_conn, direction="inbound") + peer_id = muxed_conn.peer_id logger.debug( - f"successfully opened quic connection to peer {peer_id}" + "successfully opened native-muxed connection " + f"to peer {peer_id}" ) # NOTE: This is a intentional barrier to prevent from the # handler exiting and closing the connection. @@ -1560,6 +1558,13 @@ async def add_conn( logger.debug("Swarm::add_conn | starting muxed connection") self.manager.run_task(muxed_conn.start) await muxed_conn.event_started.wait() + # For connections that need an explicit handshake-completion step, + # wait until the connection reports itself as established. + if hasattr(muxed_conn, 'is_established') and hasattr( + muxed_conn, '_connected_event' + ): + if not muxed_conn.is_established: # type: ignore[attr-defined] + await muxed_conn._connected_event.wait() # type: ignore[attr-defined] logger.debug( f"Swarm::add_conn | event_started received for peer {muxed_conn.peer_id}" ) diff --git a/libp2p/providers.py b/libp2p/providers.py new file mode 100644 index 000000000..e008dcf21 --- /dev/null +++ b/libp2p/providers.py @@ -0,0 +1,281 @@ +""" +Connection-layer provider abstractions. + +A *connection provider* is any component that can transform one kind of +connection into another (e.g. ``IRawConnection → ISecureConn``) or that +can produce connections from scratch (e.g. a transport that dials a peer). + +The resolver (:mod:`libp2p.network.resolver`) uses this registry to +build a connection stack dynamically at dial/listen time. + +Terminology +----------- +* **Transport** — incoming IO; bottom touches the network, top exposes a + connection. Implements :class:`ProvidesTransport`. +* **Connection layer** — transitive IO; consumes one connection and + produces a higher-level one. Implements :class:`ProvidesConnection`. +* **Protocol** — outgoing IO; runs on a fully-upgraded connection. + +See Also +-------- +:mod:`libp2p.capabilities` — capability flags (``provides_security``, + ``provides_muxing``). +:mod:`libp2p.requirements` — ``@after_connection`` ordering metadata. + +""" + +from __future__ import annotations + +import logging +from typing import Any, Protocol, runtime_checkable + +from libp2p.abc import ( + IMuxedConn, + IRawConnection, + ISecureConn, + ISecureTransport, + ITransport, +) +from libp2p.custom_types import ( + TMuxerClass, + TMuxerOptions, + TProtocol, + TSecurityOptions, +) +from libp2p.peer.id import ID + +logger = logging.getLogger(__name__) + + +@runtime_checkable +class ProvidesTransport(Protocol): + """ + Structural protocol for components that can produce raw connections + from a network address. + + Every :class:`~libp2p.abc.ITransport` is automatically a + ``ProvidesTransport`` — the protocol simply formalises the concept so + the resolver can discover and iterate over transport providers + without hard-coding transport classes. + """ + + def can_dial(self, maddr: Any) -> bool: + """Return ``True`` if this provider can dial the given multiaddr.""" + ... + + async def dial(self, maddr: Any) -> IRawConnection: + """Dial the multiaddr and return a raw connection.""" + ... + + +@runtime_checkable +class ProvidesConnection(Protocol): + """ + Structural protocol for components that can *upgrade* an existing + connection to a higher-level one. + + Examples + -------- + * A security transport consumes ``IRawConnection`` and produces + ``ISecureConn`` → provides ``ISecureConn``. + * A muxer consumes ``ISecureConn`` and produces ``IMuxedConn`` + → provides ``IMuxedConn``. + + """ + + @property + def provides_interface(self) -> type: + """ + The connection interface this provider produces + (e.g. ``ISecureConn``, ``IMuxedConn``). + """ + ... + + @property + def requires_interface(self) -> type: + """ + The connection interface this provider consumes + (e.g. ``IRawConnection``, ``ISecureConn``). + """ + ... + + +class SecurityProvider: + """ + Wraps an :class:`~libp2p.abc.ISecureTransport` as a + :class:`ProvidesConnection`. + + Consumes ``IRawConnection``, produces ``ISecureConn``. + """ + + def __init__( + self, + protocol_id: TProtocol, + transport: ISecureTransport, + ) -> None: + self.protocol_id = protocol_id + self.transport = transport + + @property + def provides_interface(self) -> type: + return ISecureConn + + @property + def requires_interface(self) -> type: + return IRawConnection + + async def upgrade( + self, + conn: IRawConnection, + is_initiator: bool, + peer_id: ID | None = None, + ) -> ISecureConn: + if is_initiator: + if peer_id is None: + raise ValueError("peer_id required for outbound security upgrade") + return await self.transport.secure_outbound(conn, peer_id) + return await self.transport.secure_inbound(conn) + + def __repr__(self) -> str: + return f"SecurityProvider({self.protocol_id!r})" + + +class MuxerProvider: + """ + Wraps a muxer class as a :class:`ProvidesConnection`. + + Consumes ``ISecureConn``, produces ``IMuxedConn``. + """ + + def __init__( + self, + protocol_id: TProtocol, + muxer_class: TMuxerClass, + ) -> None: + self.protocol_id = protocol_id + self.muxer_class = muxer_class + + @property + def provides_interface(self) -> type: + return IMuxedConn + + @property + def requires_interface(self) -> type: + return ISecureConn + + async def upgrade( + self, + conn: ISecureConn, + peer_id: ID, + ) -> IMuxedConn: + return self.muxer_class(conn, peer_id) + + def __repr__(self) -> str: + return f"MuxerProvider({self.protocol_id!r})" + + +class TransportProvider: + """ + Wraps an :class:`~libp2p.abc.ITransport` as a + :class:`ProvidesTransport`. + + Optionally supports multiaddr matching so the resolver can select the + right transport for a dial target. + """ + + def __init__( + self, + protocol_name: str, + transport: ITransport, + *, + matcher: Any | None = None, + ) -> None: + self.protocol_name = protocol_name + self.transport = transport + self._matcher = matcher + + def can_dial(self, maddr: Any) -> bool: + if self._matcher is not None: + return bool(self._matcher(maddr)) + try: + protocols = [p.name for p in maddr.protocols()] + return self.protocol_name in protocols + except Exception: + return False + + async def dial(self, maddr: Any) -> IRawConnection: + return await self.transport.dial(maddr) + + @property + def provides_security(self) -> bool: + return getattr(self.transport, "provides_security", False) + + @property + def provides_muxing(self) -> bool: + return getattr(self.transport, "provides_muxing", False) + + def __repr__(self) -> str: + return f"TransportProvider({self.protocol_name!r})" + + +class ProviderRegistry: + """ + Central registry of transport and connection-layer providers. + + The :class:`~libp2p.network.resolver.ConnectionResolver` consults + this registry when building connection stacks. + """ + + def __init__(self) -> None: + self._transport_providers: list[TransportProvider] = [] + self._security_providers: list[SecurityProvider] = [] + self._muxer_providers: list[MuxerProvider] = [] + + def register_transport(self, provider: TransportProvider) -> None: + self._transport_providers.append(provider) + logger.debug("Registered transport provider: %s", provider) + + def register_security(self, provider: SecurityProvider) -> None: + self._security_providers.append(provider) + logger.debug("Registered security provider: %s", provider) + + def register_muxer(self, provider: MuxerProvider) -> None: + self._muxer_providers.append(provider) + logger.debug("Registered muxer provider: %s", provider) + + def register_security_options(self, opts: TSecurityOptions) -> None: + """Populate from the legacy ``TSecurityOptions`` mapping.""" + for proto, transport in opts.items(): + self.register_security(SecurityProvider(proto, transport)) + + def register_muxer_options(self, opts: TMuxerOptions) -> None: + """Populate from the legacy ``TMuxerOptions`` mapping.""" + for proto, muxer_cls in opts.items(): + self.register_muxer(MuxerProvider(proto, muxer_cls)) + + def get_transports(self) -> list[TransportProvider]: + return list(self._transport_providers) + + def get_transports_for(self, maddr: Any) -> list[TransportProvider]: + """Return providers that can dial *maddr*.""" + return [tp for tp in self._transport_providers if tp.can_dial(maddr)] + + def get_security_providers(self) -> list[SecurityProvider]: + return list(self._security_providers) + + def get_muxer_providers(self) -> list[MuxerProvider]: + return list(self._muxer_providers) + + def has_security(self) -> bool: + return len(self._security_providers) > 0 + + def has_muxer(self) -> bool: + return len(self._muxer_providers) > 0 + + def __repr__(self) -> str: + return ( + f"ProviderRegistry(" + f"transports={len(self._transport_providers)}, " + f"security={len(self._security_providers)}, " + f"muxers={len(self._muxer_providers)})" + ) diff --git a/libp2p/requirements.py b/libp2p/requirements.py new file mode 100644 index 000000000..78711799c --- /dev/null +++ b/libp2p/requirements.py @@ -0,0 +1,164 @@ +""" +Protocol handler requirement declarations. + +Decorators that let protocol handlers and connection layers express what +they need at runtime — without changing how existing code works. + +``@requires_connection`` — declares that a protocol handler needs a +connection satisfying one or more interfaces (e.g. ``ISecureConn``). + +``@after_connection`` — declares that a connection layer (e.g. a muxer) +should be stacked *after* another layer (e.g. security) in the upgrade +pipeline. + +``get_required_connections`` / ``get_after_connections`` — introspect the +metadata attached by those decorators. + +Examples +-------- +:: + + from libp2p.requirements import requires_connection, after_connection + from libp2p.abc import ISecureConn, IMuxedConn + + @requires_connection(ISecureConn) + async def my_protocol_handler(stream): + '''Only run me on a secured connection.''' + ... + + @after_connection(ISecureConn) + class Yamux(IMuxedConn): + '''Yamux should be stacked AFTER a security layer.''' + ... + +""" + +from __future__ import annotations + +from collections.abc import Sequence +import logging +from typing import Any, TypeVar + +logger = logging.getLogger(__name__) + +F = TypeVar("F") + + +def requires_connection(*interfaces: type) -> Any: + """ + Mark a protocol handler as requiring certain connection interfaces. + + Parameters + ---------- + *interfaces: + One or more ABC / Protocol types that the underlying connection + must satisfy (e.g. ``ISecureConn``, ``IMuxedConn``). + + Returns + ------- + decorator + Wraps the handler function, attaching ``_required_connections``. + + Example + ------- + :: + + @requires_connection(ISecureConn) + async def echo_handler(stream): + ... + + """ + + def decorator(fn: F) -> F: + fn._required_connections = interfaces + return fn + + return decorator + + +def get_required_connections(fn: Any) -> Sequence[type]: + """Return the connection interfaces required by *fn*, or ``()``.""" + return getattr(fn, "_required_connections", ()) + + +def after_connection(*interfaces: type) -> Any: + """ + Declare that a connection layer must be applied *after* certain + other layers are present in the stack. + + This is ordering metadata only — it does **not** mean the listed + interfaces *must* be present (use ``@requires_connection`` for that). + + Parameters + ---------- + *interfaces: + One or more ABC / Protocol types that must appear earlier in the + connection upgrade pipeline. + + Example + ------- + :: + + @after_connection(ISecureConn) + class Yamux(IMuxedConn): + ... + + """ + + def decorator(cls: F) -> F: + cls._after_connections = interfaces + return cls + + return decorator + + +def get_after_connections(cls: Any) -> Sequence[type]: + """Return the interfaces that *cls* must come after, or ``()``.""" + return getattr(cls, "_after_connections", ()) + + +def check_connection_requirements( + handler: Any, + connection: Any, + *, + raise_on_failure: bool = False, +) -> bool: + """ + Verify that *connection* satisfies the requirements declared on *handler*. + + Parameters + ---------- + handler: + A callable (possibly decorated with ``@requires_connection``). + connection: + The underlying connection object to check. + raise_on_failure: + If ``True``, raise ``ConnectionRequirementError`` instead of + returning ``False``. + + Returns + ------- + bool + ``True`` if all requirements are met (or none were declared). + + """ + required = get_required_connections(handler) + if not required: + return True + + for iface in required: + if not isinstance(connection, iface): + msg = ( + f"Handler {getattr(handler, '__name__', handler)} requires " + f"{iface.__name__}, but the connection {type(connection).__name__} " + f"does not satisfy it." + ) + if raise_on_failure: + raise ConnectionRequirementError(msg) + logger.warning(msg) + return False + return True + + +class ConnectionRequirementError(Exception): + """Raised when a protocol handler's connection requirements are not met.""" diff --git a/libp2p/security/noise/transport.py b/libp2p/security/noise/transport.py index d64e5e3d2..2f5066318 100644 --- a/libp2p/security/noise/transport.py +++ b/libp2p/security/noise/transport.py @@ -12,16 +12,19 @@ ISecureConn, ISecureTransport, ) +from libp2p.crypto.ed25519 import create_new_key_pair as create_ed25519_kp from libp2p.crypto.keys import ( KeyPair, PrivateKey, ) +from libp2p.crypto.x25519 import create_new_key_pair as create_x25519_kp from libp2p.custom_types import ( TProtocol, ) from libp2p.peer.id import ( ID, ) +from libp2p.providers import SecurityProvider from .early_data import EarlyDataHandler, EarlyDataManager from .patterns import ( @@ -161,3 +164,22 @@ def get_cached_static_key(self, peer_id: ID) -> bytes | None: def clear_static_key_cache(self) -> None: """Clear the static key cache.""" self._static_key_cache.clear() + + +def _create_noise_provider() -> "SecurityProvider": + """ + Entry-point factory for Noise security discovery. + + Returns a :class:`~libp2p.providers.SecurityProvider` wrapping a + fresh Noise transport instance. + + .. note:: + + Both an Ed25519 identity key and an X25519 static key are + generated; callers needing specific keys should construct + the provider manually. + """ + id_kp = create_ed25519_kp() + noise_kp = create_x25519_kp() + transport = Transport(id_kp, noise_privkey=noise_kp.private_key) + return SecurityProvider(PROTOCOL_ID, transport) diff --git a/libp2p/security/tls/transport.py b/libp2p/security/tls/transport.py index 54aa9af7d..4829d6b83 100644 --- a/libp2p/security/tls/transport.py +++ b/libp2p/security/tls/transport.py @@ -8,9 +8,11 @@ import libp2p from libp2p.abc import IRawConnection, ISecureConn, ISecureTransport +from libp2p.crypto.ed25519 import create_new_key_pair from libp2p.crypto.keys import KeyPair, PrivateKey from libp2p.custom_types import TProtocol from libp2p.peer.id import ID +from libp2p.providers import SecurityProvider from libp2p.security.secure_session import SecureSession from libp2p.security.tls.certificate import ( ALPN_PROTOCOL, @@ -578,3 +580,20 @@ def create_tls_transport( """ return TLSTransport(libp2p_keypair, early_data, muxers, identity_config) + + +def _create_tls_provider() -> "SecurityProvider": + """ + Entry-point factory for TLS security discovery. + + Returns a :class:`~libp2p.providers.SecurityProvider` wrapping a + fresh TLS transport instance. + + .. note:: + + An Ed25519 identity key is generated; callers needing a specific + key should construct the provider manually. + """ + kp = create_new_key_pair() + transport = TLSTransport(kp) + return SecurityProvider(PROTOCOL_ID, transport) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index b52ba4453..ec8de7a52 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -33,6 +33,8 @@ from libp2p.peer.id import ( ID, ) +from libp2p.providers import MuxerProvider +from libp2p.requirements import after_connection from libp2p.utils import ( decode_uvarint_from_stream, encode_uvarint, @@ -60,6 +62,7 @@ logger = logging.getLogger(__name__) +@after_connection(ISecureConn) class Mplex(IMuxedConn): """ reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go @@ -437,3 +440,13 @@ def get_connection_type(self) -> ConnectionType: Get connection type by delegating to secured_conn. """ return self.secured_conn.get_connection_type() + + +def _create_mplex_provider() -> "MuxerProvider": + """ + Entry-point factory for Mplex muxer discovery. + + Returns a :class:`~libp2p.providers.MuxerProvider` wrapping the + :class:`Mplex` class. + """ + return MuxerProvider(MPLEX_PROTOCOL_ID, Mplex) diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index 8acc1a9ea..50196b186 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -32,6 +32,7 @@ IMuxedStream, ISecureConn, ) +from libp2p.custom_types import TProtocol from libp2p.io.exceptions import ( ConnectionClosedError, IncompleteReadError, @@ -46,6 +47,8 @@ from libp2p.peer.id import ( ID, ) +from libp2p.providers import MuxerProvider +from libp2p.requirements import after_connection from libp2p.stream_muxer.exceptions import ( MuxedConnUnavailable, MuxedStreamEOF, @@ -416,6 +419,7 @@ def get_remote_address(self) -> tuple[str, int] | None: return None +@after_connection(ISecureConn) class Yamux(IMuxedConn): def __init__( self, @@ -1134,3 +1138,13 @@ async def _cleanup_on_error(self) -> None: self.on_close() except Exception as callback_error: logger.error(f"Error in on_close callback: {callback_error}") + + +def _create_yamux_provider() -> "MuxerProvider": + """ + Entry-point factory for Yamux muxer discovery. + + Returns a :class:`~libp2p.providers.MuxerProvider` wrapping the + :class:`Yamux` class. + """ + return MuxerProvider(TProtocol(PROTOCOL_ID), Yamux) diff --git a/libp2p/transport/quic/connection.py b/libp2p/transport/quic/connection.py index f0cd45369..04c529527 100644 --- a/libp2p/transport/quic/connection.py +++ b/libp2p/transport/quic/connection.py @@ -290,6 +290,18 @@ def is_peer_verified(self) -> bool: """Check if peer identity has been verified.""" return self._peer_verified + # -- Capability declarations (see libp2p/capabilities.py) ---------------- + + @property + def is_secure(self) -> bool: + """QUIC connections are always secured via TLS 1.3.""" + return True + + @property + def is_muxed(self) -> bool: + """QUIC connections natively support stream multiplexing.""" + return True + def multiaddr(self) -> multiaddr.Multiaddr: """Get the multiaddr for this connection.""" return self._maddr diff --git a/libp2p/transport/quic/listener.py b/libp2p/transport/quic/listener.py index bf25e4e26..c7ebb3ce2 100644 --- a/libp2p/transport/quic/listener.py +++ b/libp2p/transport/quic/listener.py @@ -19,8 +19,8 @@ from libp2p.abc import IListener from libp2p.custom_types import ( TProtocol, - TQUICConnHandlerFn, ) +from libp2p.transport.quic.types import TQUICConnHandlerFn from libp2p.transport.quic.security import ( LIBP2P_TLS_EXTENSION_OID, QUICTLSConfigManager, diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py index 0572fcfb9..96c891037 100644 --- a/libp2p/transport/quic/transport.py +++ b/libp2p/transport/quic/transport.py @@ -20,14 +20,17 @@ from libp2p.abc import ( ITransport, ) +from libp2p.crypto.ed25519 import create_new_key_pair from libp2p.crypto.keys import ( PrivateKey, ) -from libp2p.custom_types import TProtocol, TQUICConnHandlerFn +from libp2p.custom_types import TProtocol from libp2p.peer.id import ( ID, ) +from libp2p.providers import TransportProvider from libp2p.transport.quic.security import QUICTLSSecurityConfig +from libp2p.transport.quic.types import TQUICConnHandlerFn from libp2p.transport.quic.utils import ( create_client_config_from_base, create_server_config_from_base, @@ -117,6 +120,18 @@ def __init__( f"Initialized QUIC transport with security for peer {self._peer_id}" ) + # -- Capability declarations (see libp2p/capabilities.py) ---------------- + + @property + def provides_security(self) -> bool: + """QUIC has built-in TLS 1.3 — no separate security upgrade needed.""" + return True + + @property + def provides_muxing(self) -> bool: + """QUIC has built-in stream multiplexing — no separate muxer needed.""" + return True + def set_background_nursery(self, nursery: trio.Nursery) -> None: """Set the nursery to use for background tasks (called by swarm).""" self._background_nursery = nursery @@ -496,3 +511,21 @@ def get_listener_socket(self) -> trio.socket.SocketType | None: if listener.is_listening() and listener._socket: return listener._socket return None + + +def _create_quic_provider() -> "TransportProvider": + """ + Entry-point factory for QUIC transport discovery. + + Returns a :class:`~libp2p.providers.TransportProvider` wrapping a + default :class:`QUICTransport` instance. + + .. note:: + + QUIC requires a private key at construction time. This factory + generates a fresh Ed25519 identity; callers that need a specific + identity should construct the provider manually. + """ + key_pair = create_new_key_pair() + transport = QUICTransport(key_pair.private_key) + return TransportProvider("quic", transport) diff --git a/libp2p/transport/quic/types.py b/libp2p/transport/quic/types.py new file mode 100644 index 000000000..537562acc --- /dev/null +++ b/libp2p/transport/quic/types.py @@ -0,0 +1,19 @@ +""" +QUIC-specific type aliases. + +These live here (rather than in ``libp2p/custom_types.py``) so that the +core package does not need to import any QUIC internals. +""" + +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, cast + +from libp2p.transport.quic.stream import QUICStream + +if TYPE_CHECKING: + from libp2p.transport.quic.connection import QUICConnection +else: + QUICConnection = cast(type, object) + +TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]] +TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]] diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index e7851006e..d64775090 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -27,6 +27,7 @@ from libp2p.network.connection.raw_connection import ( RawConnection, ) +from libp2p.providers import TransportProvider from libp2p.transport.exceptions import ( OpenConnectionError, ) @@ -312,3 +313,13 @@ def create_listener(self, handler_function: THandler) -> TCPListener: def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr: return multiaddr_from_socket(socket) + + +def _create_tcp_provider() -> "TransportProvider": + """ + Entry-point factory for TCP transport discovery. + + Returns a :class:`~libp2p.providers.TransportProvider` wrapping a + default :class:`TCP` instance. + """ + return TransportProvider("tcp", TCP()) diff --git a/libp2p/transport/transport_registry.py b/libp2p/transport/transport_registry.py index 3cbef4c70..14caa0110 100644 --- a/libp2p/transport/transport_registry.py +++ b/libp2p/transport/transport_registry.py @@ -1,5 +1,9 @@ """ Transport registry for dynamic transport selection based on multiaddr protocols. + +The registry also exposes capability-aware queries that let callers discover +which registered transports bundle their own security or multiplexing, so the +upgrade pipeline can decide at runtime whether to skip certain steps. """ from collections.abc import Callable @@ -131,6 +135,82 @@ def get_supported_protocols(self) -> list[str]: """Get list of supported transport protocols.""" return list(self._transports.keys()) + @staticmethod + def _class_has_capability(transport_class: type, attr: str) -> bool: + """ + Check whether *transport_class* declares a boolean capability. + + Works correctly both for concrete attributes and ``@property`` + descriptors on the class — we inspect the descriptor to see if + the getter returns ``True`` on a best-effort basis. + """ + obj = getattr(transport_class, attr, None) + if obj is None: + return False + if isinstance(obj, bool): + return obj + if isinstance(obj, property) and obj.fget is not None: + try: + sentinel = object.__new__(transport_class) + return bool(obj.fget(sentinel)) + except Exception: + return True + return False + + def transport_provides_security(self, protocol: str) -> bool: + """ + Return ``True`` if the transport registered for *protocol* + declares built-in security (``provides_security`` property). + + Returns ``False`` for unknown protocols or transports that do not + declare the capability. + """ + transport_class = self.get_transport(protocol) + if transport_class is None: + return False + return self._class_has_capability(transport_class, "provides_security") + + def transport_provides_muxing(self, protocol: str) -> bool: + """ + Return ``True`` if the transport registered for *protocol* + declares built-in multiplexing (``provides_muxing`` property). + + Returns ``False`` for unknown protocols or transports that do not + declare the capability. + """ + transport_class = self.get_transport(protocol) + if transport_class is None: + return False + return self._class_has_capability(transport_class, "provides_muxing") + + def needs_security_upgrade(self, protocol: str) -> bool: + """ + Return ``True`` if the transport for *protocol* does **not** + provide built-in security and therefore requires the standard + security upgrade step. + """ + return not self.transport_provides_security(protocol) + + def needs_muxer_upgrade(self, protocol: str) -> bool: + """ + Return ``True`` if the transport for *protocol* does **not** + provide built-in multiplexing and therefore requires the standard + muxer upgrade step. + """ + return not self.transport_provides_muxing(protocol) + + def get_self_upgrading_protocols(self) -> list[str]: + """ + Return protocols whose transports provide *both* security and + muxing — i.e. transports that need no additional upgrades. + """ + return [ + proto + for proto in self._transports + if self.transport_provides_security(proto) + and self.transport_provides_muxing(proto) + ] + def create_transport( self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any ) -> ITransport | None: @@ -285,3 +365,21 @@ def get_supported_transport_protocols() -> list[str]: """Get list of supported transport protocols from the global registry.""" registry = get_transport_registry() return registry.get_supported_protocols() + + +def transport_needs_security(protocol: str) -> bool: + """ + Check whether the transport for *protocol* requires a security upgrade. + + Convenience wrapper around the global registry. + """ + return get_transport_registry().needs_security_upgrade(protocol) + + +def transport_needs_muxer(protocol: str) -> bool: + """ + Check whether the transport for *protocol* requires a muxer upgrade. + + Convenience wrapper around the global registry. + """ + return get_transport_registry().needs_muxer_upgrade(protocol) diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 8b2e41cb0..d221b7b84 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -1,3 +1,5 @@ +import logging + from libp2p.abc import ( IMuxedConn, IRawConnection, @@ -17,6 +19,7 @@ from libp2p.protocol_muxer.multiselect import ( DEFAULT_NEGOTIATE_TIMEOUT, ) +from libp2p.requirements import get_after_connections from libp2p.security.exceptions import ( HandshakeFailure, ) @@ -31,6 +34,8 @@ SecurityUpgradeFailure, ) +logger = logging.getLogger(__name__) + class TransportUpgrader: security_multistream: SecurityMultistream @@ -81,10 +86,43 @@ async def upgrade_security( ) from error async def upgrade_connection(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: - """Upgrade secured connection to a muxed connection.""" + """ + Upgrade secured connection to a muxed connection. + + Before negotiating the muxer, this method verifies that the + connection satisfies any ordering requirements declared by the + registered muxer classes (via ``@after_connection``). + """ + self._verify_muxer_ordering(conn) + try: return await self.muxer_multistream.new_conn(conn, peer_id) except (MultiselectError, MultiselectClientError) as error: raise MuxerUpgradeFailure( "failed to negotiate the multiplexer protocol" ) from error + + def _verify_muxer_ordering(self, conn: ISecureConn) -> None: + """ + Check that *conn* satisfies the ``@after_connection`` requirements + declared on every registered muxer class. + + If a muxer declares ``@after_connection(ISecureConn)`` the + connection handed to it must be an ``ISecureConn`` instance. + A mismatch is logged as a warning (non-fatal) so that existing + code keeps working while giving operators clear diagnostics. + """ + for protocol, muxer_class in self.muxer_multistream.transports.items(): + after = get_after_connections(muxer_class) + if not after: + continue + for iface in after: + if not isinstance(conn, iface): + logger.warning( + "Muxer %s (protocol %s) declares @after_connection(%s) " + "but the connection (%s) does not satisfy it", + muxer_class.__name__, + protocol, + iface.__name__, + type(conn).__name__, + ) diff --git a/pyproject.toml b/pyproject.toml index a12ceb021..76d4bc045 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,47 @@ classifiers = [ [project.urls] Homepage = "https://github.com/libp2p/py-libp2p" +# Optional dependency groups for modular packaging. +# Users can install only the transports / security / muxers they need: +# pip install libp2p[quic] +# pip install libp2p[noise] +# pip install libp2p[defaults] (everything) +[project.optional-dependencies] +quic = [ + "aioquic>=1.2.0", +] +noise = [ + "noiseprotocol>=0.3.0", +] +tls = [] +secio = [ + "pycryptodome>=3.9.2", +] +yamux = [] +mplex = [] +defaults = [ + "libp2p[quic,noise,tls,secio,yamux,mplex]", +] + +# Entry-point groups for plugin discovery. +# Third-party transport/security/muxer packages declare entry points in +# these groups so that libp2p.entrypoints.discover_and_register() can +# find and register them automatically. +# +# Built-in transports/security/muxers are listed here so they are +# discoverable through the same mechanism. +[project.entry-points."libp2p.transports"] +tcp = "libp2p.transport.tcp.tcp:_create_tcp_provider" +quic = "libp2p.transport.quic.transport:_create_quic_provider" + +[project.entry-points."libp2p.security"] +noise = "libp2p.security.noise.transport:_create_noise_provider" +tls = "libp2p.security.tls.transport:_create_tls_provider" + +[project.entry-points."libp2p.muxers"] +yamux = "libp2p.stream_muxer.yamux.yamux:_create_yamux_provider" +mplex = "libp2p.stream_muxer.mplex.mplex:_create_mplex_provider" + [project.scripts] chat-demo = "examples.chat.chat:main" echo-demo = "examples.echo.echo:main" diff --git a/tests/core/test_capabilities.py b/tests/core/test_capabilities.py new file mode 100644 index 000000000..ab3a83f0b --- /dev/null +++ b/tests/core/test_capabilities.py @@ -0,0 +1,238 @@ +""" +Tests for the capability protocol declarations. + +Validates that ``TransportCapabilities``, ``ConnectionCapabilities``, and +``NeedsSetup`` structural protocols correctly detect conforming / +non-conforming objects at runtime. +""" + +import pytest + +from libp2p.capabilities import ( + ConnectionCapabilities, + NeedsSetup, + TransportCapabilities, +) + + +class _FullCapabilityTransport: + """Transport that provides both security and muxing.""" + + @property + def provides_security(self) -> bool: + return True + + @property + def provides_muxing(self) -> bool: + return True + + +class _SecurityOnlyTransport: + """Transport that provides only security.""" + + @property + def provides_security(self) -> bool: + return True + + @property + def provides_muxing(self) -> bool: + return False + + +class _MuxingOnlyTransport: + """Transport that provides only muxing.""" + + @property + def provides_security(self) -> bool: + return False + + @property + def provides_muxing(self) -> bool: + return True + + +class _PlainTransport: + """Transport with no capability properties at all.""" + + pass + + +class _FullCapabilityConnection: + """Connection that is both secure and muxed.""" + + @property + def is_secure(self) -> bool: + return True + + @property + def is_muxed(self) -> bool: + return True + + +class _InsecureConnection: + """Connection that is muxed but not secure.""" + + @property + def is_secure(self) -> bool: + return False + + @property + def is_muxed(self) -> bool: + return True + + +class _PlainConnection: + """Connection with no capability properties.""" + + pass + + +class _SetupTransport: + """Transport that needs lifecycle hooks.""" + + def set_background_nursery(self, nursery: object) -> None: + self._nursery = nursery + + def set_swarm(self, swarm: object) -> None: + self._swarm = swarm + + +class _PartialSetupTransport: + """Transport that only implements set_swarm (not full NeedsSetup).""" + + def set_swarm(self, swarm: object) -> None: + self._swarm = swarm + + +class TestTransportCapabilities: + """Verify TransportCapabilities structural protocol detection.""" + + def test_full_capability_transport_is_instance(self): + t = _FullCapabilityTransport() + assert isinstance(t, TransportCapabilities) + + def test_security_only_transport_is_instance(self): + t = _SecurityOnlyTransport() + assert isinstance(t, TransportCapabilities) + + def test_muxing_only_transport_is_instance(self): + t = _MuxingOnlyTransport() + assert isinstance(t, TransportCapabilities) + + def test_plain_transport_is_not_instance(self): + t = _PlainTransport() + assert not isinstance(t, TransportCapabilities) + + def test_provides_security_value(self): + assert _FullCapabilityTransport().provides_security is True + assert _SecurityOnlyTransport().provides_security is True + assert _MuxingOnlyTransport().provides_security is False + + def test_provides_muxing_value(self): + assert _FullCapabilityTransport().provides_muxing is True + assert _SecurityOnlyTransport().provides_muxing is False + assert _MuxingOnlyTransport().provides_muxing is True + + def test_getattr_fallback_for_plain(self): + """The canonical usage pattern: getattr(t, 'provides_security', False).""" + t = _PlainTransport() + assert getattr(t, "provides_security", False) is False + assert getattr(t, "provides_muxing", False) is False + + +class TestConnectionCapabilities: + """Verify ConnectionCapabilities structural protocol detection.""" + + def test_full_capability_connection_is_instance(self): + c = _FullCapabilityConnection() + assert isinstance(c, ConnectionCapabilities) + + def test_insecure_connection_is_instance(self): + c = _InsecureConnection() + assert isinstance(c, ConnectionCapabilities) + + def test_plain_connection_is_not_instance(self): + c = _PlainConnection() + assert not isinstance(c, ConnectionCapabilities) + + def test_is_secure_value(self): + assert _FullCapabilityConnection().is_secure is True + assert _InsecureConnection().is_secure is False + + def test_is_muxed_value(self): + assert _FullCapabilityConnection().is_muxed is True + assert _InsecureConnection().is_muxed is True + + def test_getattr_fallback_for_plain(self): + c = _PlainConnection() + assert getattr(c, "is_secure", False) is False + assert getattr(c, "is_muxed", False) is False + + +class TestNeedsSetup: + """Verify NeedsSetup structural protocol detection.""" + + def test_setup_transport_is_instance(self): + t = _SetupTransport() + assert isinstance(t, NeedsSetup) + + def test_partial_setup_transport_is_not_instance(self): + t = _PartialSetupTransport() + assert not isinstance(t, NeedsSetup) + + def test_plain_transport_is_not_instance(self): + t = _PlainTransport() + assert not isinstance(t, NeedsSetup) + + def test_hasattr_detection(self): + """The canonical usage in swarm.py: hasattr(t, 'set_background_nursery').""" + s = _SetupTransport() + assert hasattr(s, "set_background_nursery") + assert hasattr(s, "set_swarm") + + p = _PlainTransport() + assert not hasattr(p, "set_background_nursery") + assert not hasattr(p, "set_swarm") + + def test_set_background_nursery_stores_value(self): + t = _SetupTransport() + sentinel = object() + t.set_background_nursery(sentinel) + assert t._nursery is sentinel + + def test_set_swarm_stores_value(self): + t = _SetupTransport() + sentinel = object() + t.set_swarm(sentinel) + assert t._swarm is sentinel + + +class TestQUICConformance: + """Verify that the real QUIC transport and connection satisfy the protocols.""" + + def test_quic_transport_satisfies_transport_capabilities(self): + try: + from libp2p.transport.quic.transport import QUICTransport + except ImportError: + pytest.skip("aioquic not installed") + + assert hasattr(QUICTransport, "provides_security") + assert hasattr(QUICTransport, "provides_muxing") + + def test_quic_connection_satisfies_connection_capabilities(self): + try: + from libp2p.transport.quic.connection import QUICConnection + except ImportError: + pytest.skip("aioquic not installed") + + assert hasattr(QUICConnection, "is_secure") + assert hasattr(QUICConnection, "is_muxed") + + def test_quic_transport_satisfies_needs_setup(self): + try: + from libp2p.transport.quic.transport import QUICTransport + except ImportError: + pytest.skip("aioquic not installed") + + assert hasattr(QUICTransport, "set_background_nursery") + assert hasattr(QUICTransport, "set_swarm") diff --git a/tests/core/test_entrypoints.py b/tests/core/test_entrypoints.py new file mode 100644 index 000000000..dc637ee00 --- /dev/null +++ b/tests/core/test_entrypoints.py @@ -0,0 +1,274 @@ +""" +Tests for entry-point discovery and conditional imports. + +Validates: +- ``discover_and_register()`` populates a registry from entry points +- ``discover_transports / discover_security / discover_muxers`` individual scanners +- Entry-point factory functions produce correct provider types +- Conditional import guards (``_HAS_QUIC``, ``_HAS_NOISE``, etc.) +- Built-in factory functions (``_create_tcp_provider``, etc.) +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from libp2p.custom_types import TProtocol +from libp2p.entrypoints import ( + EP_GROUP_MUXERS, + EP_GROUP_SECURITY, + EP_GROUP_TRANSPORTS, + discover_and_register, + discover_muxers, + discover_security, + discover_transports, +) +from libp2p.providers import ( + MuxerProvider, + ProviderRegistry, + SecurityProvider, + TransportProvider, +) + + +class TestEntryPointGroups: + def test_transport_group(self) -> None: + assert EP_GROUP_TRANSPORTS == "libp2p.transports" + + def test_security_group(self) -> None: + assert EP_GROUP_SECURITY == "libp2p.security" + + def test_muxer_group(self) -> None: + assert EP_GROUP_MUXERS == "libp2p.muxers" + + +class TestTcpFactory: + """_create_tcp_provider should return a TransportProvider.""" + + def test_creates_transport_provider(self) -> None: + from libp2p.transport.tcp.tcp import _create_tcp_provider + + tp = _create_tcp_provider() + assert isinstance(tp, TransportProvider) + assert tp.protocol_name == "tcp" + + +class TestYamuxFactory: + def test_creates_muxer_provider(self) -> None: + from libp2p.stream_muxer.yamux.yamux import _create_yamux_provider + + mp = _create_yamux_provider() + assert isinstance(mp, MuxerProvider) + assert "/yamux" in mp.protocol_id + + +class TestMplexFactory: + def test_creates_muxer_provider(self) -> None: + from libp2p.stream_muxer.mplex.mplex import _create_mplex_provider + + mp = _create_mplex_provider() + assert isinstance(mp, MuxerProvider) + assert "/mplex" in mp.protocol_id + + +class TestNoiseFactory: + def test_creates_security_provider(self) -> None: + from libp2p.security.noise.transport import _create_noise_provider + + sp = _create_noise_provider() + assert isinstance(sp, SecurityProvider) + assert "/noise" in sp.protocol_id + + +class TestTlsFactory: + def test_creates_security_provider(self) -> None: + from libp2p.security.tls.transport import _create_tls_provider + + sp = _create_tls_provider() + assert isinstance(sp, SecurityProvider) + assert "/tls" in sp.protocol_id + + +def _make_ep(name: str, factory: object) -> MagicMock: + """Create a mock entry point.""" + ep = MagicMock() + ep.name = name + ep.load.return_value = factory + return ep + + +class TestDiscoverAndRegister: + """Tests using mocked importlib.metadata.entry_points.""" + + def test_discovers_all_types(self) -> None: + """Simulates entry points returning one of each type.""" + tp = TransportProvider("fake-tcp", MagicMock()) + sp = SecurityProvider(TProtocol("/fake-noise"), MagicMock()) + mp = MuxerProvider(TProtocol("/fake-yamux"), MagicMock()) + + def mock_entry_points(group: str) -> list[MagicMock]: + if group == EP_GROUP_TRANSPORTS: + return [_make_ep("fake-tcp", lambda: tp)] + elif group == EP_GROUP_SECURITY: + return [_make_ep("fake-noise", lambda: sp)] + elif group == EP_GROUP_MUXERS: + return [_make_ep("fake-yamux", lambda: mp)] + return [] + + with patch( + "libp2p.entrypoints.entry_points", + side_effect=mock_entry_points, + create=True, + ): + with patch("libp2p.entrypoints._load_entry_points") as mock_load: + mock_load.side_effect = lambda group: [ + (ep.name, ep.load()) for ep in mock_entry_points(group) + ] + + reg = discover_and_register() + + assert len(reg.get_transports()) == 1 + assert len(reg.get_security_providers()) == 1 + assert len(reg.get_muxer_providers()) == 1 + + def test_uses_existing_registry(self) -> None: + """When given an existing registry, populates it.""" + reg = ProviderRegistry() + reg.register_transport(TransportProvider("existing", MagicMock())) + + with patch("libp2p.entrypoints._load_entry_points", return_value=[]): + result = discover_and_register(reg) + + assert result is reg + assert len(result.get_transports()) == 1 + + def test_creates_new_registry_when_none(self) -> None: + with patch("libp2p.entrypoints._load_entry_points", return_value=[]): + reg = discover_and_register() + assert isinstance(reg, ProviderRegistry) + + def test_skips_bad_entry_point(self) -> None: + """A broken entry point is skipped, not fatal.""" + + def bad_load(group: str) -> list[tuple[str, object]]: + if group == EP_GROUP_TRANSPORTS: + + def _bad() -> None: + raise RuntimeError("broken plugin") + + return [("broken", _bad)] + return [] + + with patch("libp2p.entrypoints._load_entry_points", side_effect=bad_load): + reg = discover_and_register() + + assert len(reg.get_transports()) == 0 + + def test_skips_wrong_type(self) -> None: + """An entry point that returns the wrong type is skipped.""" + + def wrong_type_load(group: str) -> list[tuple[str, object]]: + if group == EP_GROUP_TRANSPORTS: + return [("wrong", lambda: "not a provider")] + return [] + + with patch( + "libp2p.entrypoints._load_entry_points", side_effect=wrong_type_load + ): + reg = discover_and_register() + + assert len(reg.get_transports()) == 0 + + +class TestConditionalImportFlags: + """Verify that _HAS_* flags are defined in libp2p.__init__.""" + + def test_has_quic_flag_exists(self) -> None: + import libp2p + + assert hasattr(libp2p, "_HAS_QUIC") + assert isinstance(libp2p._HAS_QUIC, bool) + + def test_has_noise_flag_exists(self) -> None: + import libp2p + + assert hasattr(libp2p, "_HAS_NOISE") + assert isinstance(libp2p._HAS_NOISE, bool) + + def test_has_tls_flag_exists(self) -> None: + import libp2p + + assert hasattr(libp2p, "_HAS_TLS") + assert isinstance(libp2p._HAS_TLS, bool) + + def test_has_yamux_flag_exists(self) -> None: + import libp2p + + assert hasattr(libp2p, "_HAS_YAMUX") + assert isinstance(libp2p._HAS_YAMUX, bool) + + def test_has_mplex_flag_exists(self) -> None: + import libp2p + + assert hasattr(libp2p, "_HAS_MPLEX") + assert isinstance(libp2p._HAS_MPLEX, bool) + + def test_all_flags_true_in_full_install(self) -> None: + """In a full install, all extras should be available.""" + import libp2p + + assert libp2p._HAS_QUIC is True + assert libp2p._HAS_NOISE is True + assert libp2p._HAS_TLS is True + assert libp2p._HAS_YAMUX is True + assert libp2p._HAS_MPLEX is True + + +class TestIndividualDiscovery: + def test_discover_transports_empty(self) -> None: + with patch("libp2p.entrypoints._load_entry_points", return_value=[]): + result = discover_transports() + assert result == [] + + def test_discover_security_empty(self) -> None: + with patch("libp2p.entrypoints._load_entry_points", return_value=[]): + result = discover_security() + assert result == [] + + def test_discover_muxers_empty(self) -> None: + with patch("libp2p.entrypoints._load_entry_points", return_value=[]): + result = discover_muxers() + assert result == [] + + def test_discover_transports_with_factory(self) -> None: + tp = TransportProvider("mock", MagicMock()) + + with patch( + "libp2p.entrypoints._load_entry_points", + return_value=[("mock", lambda: tp)], + ): + result = discover_transports() + assert len(result) == 1 + assert result[0] is tp + + def test_discover_security_with_factory(self) -> None: + sp = SecurityProvider(TProtocol("/mock"), MagicMock()) + + with patch( + "libp2p.entrypoints._load_entry_points", + return_value=[("mock", lambda: sp)], + ): + result = discover_security() + assert len(result) == 1 + assert result[0] is sp + + def test_discover_muxers_with_factory(self) -> None: + mp = MuxerProvider(TProtocol("/mock"), MagicMock()) + + with patch( + "libp2p.entrypoints._load_entry_points", + return_value=[("mock", lambda: mp)], + ): + result = discover_muxers() + assert len(result) == 1 + assert result[0] is mp diff --git a/tests/core/test_providers.py b/tests/core/test_providers.py new file mode 100644 index 000000000..8ef9ea09a --- /dev/null +++ b/tests/core/test_providers.py @@ -0,0 +1,371 @@ +""" +Tests for the provider abstractions. + +Validates: +- ``ProvidesTransport`` / ``ProvidesConnection`` structural protocols +- ``SecurityProvider``, ``MuxerProvider``, ``TransportProvider`` wrappers +- ``ProviderRegistry`` registration, bulk-import, and query methods +""" + +from __future__ import annotations + +from collections import OrderedDict + +import pytest + +from libp2p.abc import ( + IMuxedConn, + IRawConnection, + ISecureConn, +) +from libp2p.custom_types import TMuxerOptions, TProtocol, TSecurityOptions +from libp2p.peer.id import ID +from libp2p.providers import ( + MuxerProvider, + ProviderRegistry, + ProvidesConnection, + ProvidesTransport, + SecurityProvider, + TransportProvider, +) + + +class _StubRawConn: + """Minimal stub implementing enough for SecurityProvider.upgrade.""" + + async def close(self) -> None: + pass + + +class _StubSecureConn: + """Minimal stub for ISecureConn-like objects.""" + + def get_remote_peer(self) -> ID: + return ID(b"\x01\x02") + + async def close(self) -> None: + pass + + +class _StubMuxedConn: + """Minimal stub for IMuxedConn-like objects.""" + + pass + + +class _StubSecureTransport: + """Minimal stub implementing ISecureTransport's relevant methods.""" + + async def secure_outbound(self, conn: object, peer_id: ID) -> _StubSecureConn: + return _StubSecureConn() + + async def secure_inbound(self, conn: object) -> _StubSecureConn: + return _StubSecureConn() + + +class _StubMuxerClass: + """Callable stub that simulates a muxer constructor.""" + + def __init__(self, conn: object, peer_id: ID) -> None: + self.conn = conn + self.peer_id = peer_id + + +class _StubTransport: + """Minimal ITransport-like stub.""" + + async def dial(self, maddr: object) -> _StubRawConn: + return _StubRawConn() + + async def create_listener(self, handler: object) -> None: + pass + + +class _CapableTransport(_StubTransport): + """Transport that provides security + muxing (like QUIC).""" + + @property + def provides_security(self) -> bool: + return True + + @property + def provides_muxing(self) -> bool: + return True + + +class _ConformingProvider: + """A class that structurally conforms to ProvidesTransport.""" + + def can_dial(self, maddr: object) -> bool: + return True + + async def dial(self, maddr: object) -> IRawConnection: ... + + +class _NonConformingObject: + """Object that does NOT conform to ProvidesTransport.""" + + pass + + +class _ConnectionLayerConforming: + """Conforms to ProvidesConnection.""" + + @property + def provides_interface(self) -> type: + return ISecureConn + + @property + def requires_interface(self) -> type: + return IRawConnection + + +class _ProtoStub: + """Simple object with a .name attribute.""" + + def __init__(self, name: str) -> None: + self.name = name + + +class _FakeMaddr: + """Fake multiaddr for testing transport matching.""" + + def __init__(self, proto_names: list[str]) -> None: + self._proto_names = proto_names + + def protocols(self) -> list[_ProtoStub]: + return [_ProtoStub(n) for n in self._proto_names] + + +class TestProvidesTransportProtocol: + """Tests for the ProvidesTransport structural protocol.""" + + def test_conforming_class_is_instance(self) -> None: + assert isinstance(_ConformingProvider(), ProvidesTransport) + + def test_transport_provider_is_instance(self) -> None: + tp = TransportProvider("tcp", _StubTransport()) + assert isinstance(tp, ProvidesTransport) + + def test_non_conforming_not_instance(self) -> None: + assert not isinstance(_NonConformingObject(), ProvidesTransport) + + def test_plain_object_not_instance(self) -> None: + assert not isinstance(42, ProvidesTransport) + + +class TestProvidesConnectionProtocol: + """Tests for the ProvidesConnection structural protocol.""" + + def test_conforming_class_is_instance(self) -> None: + assert isinstance(_ConnectionLayerConforming(), ProvidesConnection) + + def test_security_provider_is_instance(self) -> None: + sp = SecurityProvider(TProtocol("/noise"), _StubSecureTransport()) + assert isinstance(sp, ProvidesConnection) + + def test_muxer_provider_is_instance(self) -> None: + mp = MuxerProvider(TProtocol("/yamux/1.0.0"), _StubMuxerClass) + assert isinstance(mp, ProvidesConnection) + + def test_non_conforming_not_instance(self) -> None: + assert not isinstance(_NonConformingObject(), ProvidesConnection) + + +class TestSecurityProvider: + """Tests for the SecurityProvider wrapper.""" + + def test_provides_and_requires_interfaces(self) -> None: + sp = SecurityProvider(TProtocol("/noise"), _StubSecureTransport()) + assert sp.provides_interface is ISecureConn + assert sp.requires_interface is IRawConnection + + @pytest.mark.trio + async def test_upgrade_outbound(self) -> None: + sp = SecurityProvider(TProtocol("/noise"), _StubSecureTransport()) + result = await sp.upgrade( + _StubRawConn(), is_initiator=True, peer_id=ID(b"\x01") + ) + assert isinstance(result, _StubSecureConn) + + @pytest.mark.trio + async def test_upgrade_inbound(self) -> None: + sp = SecurityProvider(TProtocol("/noise"), _StubSecureTransport()) + result = await sp.upgrade(_StubRawConn(), is_initiator=False) + assert isinstance(result, _StubSecureConn) + + @pytest.mark.trio + async def test_upgrade_outbound_requires_peer_id(self) -> None: + sp = SecurityProvider(TProtocol("/noise"), _StubSecureTransport()) + with pytest.raises(ValueError, match="peer_id required"): + await sp.upgrade(_StubRawConn(), is_initiator=True) + + def test_repr(self) -> None: + sp = SecurityProvider(TProtocol("/noise"), _StubSecureTransport()) + assert "/noise" in repr(sp) + + +class TestMuxerProvider: + """Tests for the MuxerProvider wrapper.""" + + def test_provides_and_requires_interfaces(self) -> None: + mp = MuxerProvider(TProtocol("/yamux/1.0.0"), _StubMuxerClass) + assert mp.provides_interface is IMuxedConn + assert mp.requires_interface is ISecureConn + + @pytest.mark.trio + async def test_upgrade(self) -> None: + mp = MuxerProvider(TProtocol("/yamux/1.0.0"), _StubMuxerClass) + result = await mp.upgrade(_StubSecureConn(), ID(b"\x01")) + assert isinstance(result, _StubMuxerClass) + assert result.peer_id == ID(b"\x01") + + def test_repr(self) -> None: + mp = MuxerProvider(TProtocol("/yamux/1.0.0"), _StubMuxerClass) + assert "/yamux/1.0.0" in repr(mp) + + +class TestTransportProvider: + """Tests for the TransportProvider wrapper.""" + + def test_can_dial_with_protocol_name_match(self) -> None: + tp = TransportProvider("tcp", _StubTransport()) + maddr = _FakeMaddr(["ip4", "tcp"]) + assert tp.can_dial(maddr) + + def test_can_dial_no_match(self) -> None: + tp = TransportProvider("tcp", _StubTransport()) + maddr = _FakeMaddr(["ip4", "udp", "quic"]) + assert not tp.can_dial(maddr) + + def test_can_dial_with_custom_matcher(self) -> None: + tp = TransportProvider("custom", _StubTransport(), matcher=lambda _: True) + assert tp.can_dial(object()) + + def test_can_dial_custom_matcher_false(self) -> None: + tp = TransportProvider("custom", _StubTransport(), matcher=lambda _: False) + assert not tp.can_dial(object()) + + @pytest.mark.trio + async def test_dial(self) -> None: + tp = TransportProvider("tcp", _StubTransport()) + result = await tp.dial(object()) + assert isinstance(result, _StubRawConn) + + def test_provides_security_false_by_default(self) -> None: + tp = TransportProvider("tcp", _StubTransport()) + assert not tp.provides_security + + def test_provides_muxing_false_by_default(self) -> None: + tp = TransportProvider("tcp", _StubTransport()) + assert not tp.provides_muxing + + def test_provides_security_from_transport(self) -> None: + tp = TransportProvider("quic", _CapableTransport()) + assert tp.provides_security + + def test_provides_muxing_from_transport(self) -> None: + tp = TransportProvider("quic", _CapableTransport()) + assert tp.provides_muxing + + def test_repr(self) -> None: + tp = TransportProvider("tcp", _StubTransport()) + assert "tcp" in repr(tp) + + +class TestProviderRegistry: + """Tests for the ProviderRegistry.""" + + def test_empty_registry(self) -> None: + reg = ProviderRegistry() + assert reg.get_transports() == [] + assert reg.get_security_providers() == [] + assert reg.get_muxer_providers() == [] + assert not reg.has_security() + assert not reg.has_muxer() + + def test_register_transport(self) -> None: + reg = ProviderRegistry() + tp = TransportProvider("tcp", _StubTransport()) + reg.register_transport(tp) + assert len(reg.get_transports()) == 1 + assert reg.get_transports()[0] is tp + + def test_register_security(self) -> None: + reg = ProviderRegistry() + sp = SecurityProvider(TProtocol("/noise"), _StubSecureTransport()) + reg.register_security(sp) + assert len(reg.get_security_providers()) == 1 + assert reg.has_security() + + def test_register_muxer(self) -> None: + reg = ProviderRegistry() + mp = MuxerProvider(TProtocol("/yamux/1.0.0"), _StubMuxerClass) + reg.register_muxer(mp) + assert len(reg.get_muxer_providers()) == 1 + assert reg.has_muxer() + + def test_get_transports_for_matching(self) -> None: + reg = ProviderRegistry() + tcp_tp = TransportProvider("tcp", _StubTransport()) + quic_tp = TransportProvider("quic", _CapableTransport()) + reg.register_transport(tcp_tp) + reg.register_transport(quic_tp) + + tcp_maddr = _FakeMaddr(["ip4", "tcp"]) + matches = reg.get_transports_for(tcp_maddr) + assert len(matches) == 1 + assert matches[0] is tcp_tp + + def test_get_transports_for_no_match(self) -> None: + reg = ProviderRegistry() + tcp_tp = TransportProvider("tcp", _StubTransport()) + reg.register_transport(tcp_tp) + + quic_maddr = _FakeMaddr(["ip4", "udp", "quic"]) + assert reg.get_transports_for(quic_maddr) == [] + + def test_register_security_options_bulk(self) -> None: + reg = ProviderRegistry() + sec_opts: TSecurityOptions = OrderedDict( + { + TProtocol("/noise"): _StubSecureTransport(), + TProtocol("/tls/1.0.0"): _StubSecureTransport(), + } + ) + reg.register_security_options(sec_opts) + assert len(reg.get_security_providers()) == 2 + assert reg.get_security_providers()[0].protocol_id == TProtocol("/noise") + assert reg.get_security_providers()[1].protocol_id == TProtocol("/tls/1.0.0") + + def test_register_muxer_options_bulk(self) -> None: + reg = ProviderRegistry() + mux_opts: TMuxerOptions = OrderedDict( + { + TProtocol("/yamux/1.0.0"): _StubMuxerClass, + TProtocol("/mplex/6.7.0"): _StubMuxerClass, + } + ) + reg.register_muxer_options(mux_opts) + assert len(reg.get_muxer_providers()) == 2 + assert reg.get_muxer_providers()[0].protocol_id == TProtocol("/yamux/1.0.0") + + def test_repr(self) -> None: + reg = ProviderRegistry() + reg.register_transport(TransportProvider("tcp", _StubTransport())) + reg.register_security( + SecurityProvider(TProtocol("/noise"), _StubSecureTransport()) + ) + r = repr(reg) + assert "transports=1" in r + assert "security=1" in r + assert "muxers=0" in r + + def test_multiple_transports_same_protocol(self) -> None: + """Registering multiple transports with the same name is allowed.""" + reg = ProviderRegistry() + tp1 = TransportProvider("tcp", _StubTransport()) + tp2 = TransportProvider("tcp", _StubTransport()) + reg.register_transport(tp1) + reg.register_transport(tp2) + assert len(reg.get_transports()) == 2 diff --git a/tests/core/test_requirements.py b/tests/core/test_requirements.py new file mode 100644 index 000000000..8470fe75c --- /dev/null +++ b/tests/core/test_requirements.py @@ -0,0 +1,352 @@ +""" +Tests for the requirement decorators and runtime enforcement. + +Validates ``@requires_connection``, ``@after_connection``, introspection +helpers, and ``check_connection_requirements``. +""" + +import pytest + +from libp2p.abc import IMuxedConn, ISecureConn +from libp2p.requirements import ( + ConnectionRequirementError, + after_connection, + check_connection_requirements, + get_after_connections, + get_required_connections, + requires_connection, +) + + +class _FakeSecureConn: + """ + Minimal stub that *is* an ISecureConn stand-in for isinstance checks. + + We register it with ISecureConn at module level so ``isinstance`` works. + """ + + pass + + +class _FakePlainConn: + """A plain connection that satisfies no interfaces.""" + + pass + + +class TestRequiresConnection: + """Tests for the @requires_connection decorator.""" + + def test_attaches_metadata_single_interface(self): + @requires_connection(ISecureConn) + async def handler(stream): + pass + + assert hasattr(handler, "_required_connections") + assert handler._required_connections == (ISecureConn,) + + def test_attaches_metadata_multiple_interfaces(self): + @requires_connection(ISecureConn, IMuxedConn) + async def handler(stream): + pass + + assert handler._required_connections == (ISecureConn, IMuxedConn) + + def test_no_arguments(self): + @requires_connection() + async def handler(stream): + pass + + assert handler._required_connections == () + + def test_preserves_function_identity(self): + async def original(stream): + pass + + decorated = requires_connection(ISecureConn)(original) + assert decorated is original + + def test_preserves_function_name(self): + @requires_connection(ISecureConn) + async def my_handler(stream): + pass + + assert my_handler.__name__ == "my_handler" + + def test_works_on_sync_function(self): + @requires_connection(ISecureConn) + def sync_handler(stream): + pass + + assert sync_handler._required_connections == (ISecureConn,) + + def test_works_on_class(self): + @requires_connection(ISecureConn) + class MyHandler: + pass + + assert MyHandler._required_connections == (ISecureConn,) + + +class TestGetRequiredConnections: + """Tests for the get_required_connections introspection helper.""" + + def test_returns_interfaces_for_decorated(self): + @requires_connection(ISecureConn) + async def handler(stream): + pass + + result = get_required_connections(handler) + assert result == (ISecureConn,) + + def test_returns_empty_for_undecorated(self): + async def handler(stream): + pass + + result = get_required_connections(handler) + assert result == () + + def test_returns_empty_for_none(self): + result = get_required_connections(None) + assert result == () + + def test_returns_empty_for_arbitrary_object(self): + result = get_required_connections(42) + assert result == () + + +class TestAfterConnection: + """Tests for the @after_connection decorator.""" + + def test_attaches_metadata_single_interface(self): + @after_connection(ISecureConn) + class MyMuxer: + pass + + assert hasattr(MyMuxer, "_after_connections") + assert MyMuxer._after_connections == (ISecureConn,) + + def test_attaches_metadata_multiple_interfaces(self): + @after_connection(ISecureConn, IMuxedConn) + class MyMuxer: + pass + + assert MyMuxer._after_connections == (ISecureConn, IMuxedConn) + + def test_no_arguments(self): + @after_connection() + class MyMuxer: + pass + + assert MyMuxer._after_connections == () + + def test_preserves_class_identity(self): + class Original: + pass + + decorated = after_connection(ISecureConn)(Original) + assert decorated is Original + + def test_preserves_class_name(self): + @after_connection(ISecureConn) + class MyMuxer: + pass + + assert MyMuxer.__name__ == "MyMuxer" + + def test_works_on_function(self): + @after_connection(ISecureConn) + def setup_fn(): + pass + + assert setup_fn._after_connections == (ISecureConn,) + + +class TestGetAfterConnections: + """Tests for the get_after_connections introspection helper.""" + + def test_returns_interfaces_for_decorated(self): + @after_connection(ISecureConn) + class MyMuxer: + pass + + result = get_after_connections(MyMuxer) + assert result == (ISecureConn,) + + def test_returns_empty_for_undecorated(self): + class MyMuxer: + pass + + result = get_after_connections(MyMuxer) + assert result == () + + def test_returns_empty_for_none(self): + result = get_after_connections(None) + assert result == () + + +class TestRealMuxerMetadata: + """Verify that Mplex and Yamux carry @after_connection(ISecureConn) metadata.""" + + def test_mplex_has_after_connection(self): + from libp2p.stream_muxer.mplex.mplex import Mplex + + after = get_after_connections(Mplex) + assert ISecureConn in after + + def test_yamux_has_after_connection(self): + from libp2p.stream_muxer.yamux.yamux import Yamux + + after = get_after_connections(Yamux) + assert ISecureConn in after + + +class TestCheckConnectionRequirements: + """Tests for the runtime enforcement helper.""" + + def test_returns_true_when_no_requirements(self): + async def handler(stream): + pass + + result = check_connection_requirements(handler, _FakePlainConn()) + assert result is True + + def test_returns_true_when_requirement_satisfied(self): + """ + ISecureConn is an ABC — we can't easily make a fake satisfy it. + But we can test with a real QUIC connection class or by using + a protocol that's satisfied structurally. + """ + + @requires_connection() + async def handler(stream): + pass + + result = check_connection_requirements(handler, _FakePlainConn()) + assert result is True + + def test_returns_false_when_requirement_not_satisfied(self): + @requires_connection(ISecureConn) + async def handler(stream): + pass + + result = check_connection_requirements(handler, _FakePlainConn()) + assert result is False + + def test_raises_on_failure(self): + @requires_connection(ISecureConn) + async def handler(stream): + pass + + with pytest.raises(ConnectionRequirementError, match="requires ISecureConn"): + check_connection_requirements( + handler, _FakePlainConn(), raise_on_failure=True + ) + + def test_error_message_contains_handler_name(self): + @requires_connection(ISecureConn) + async def my_echo_handler(stream): + pass + + with pytest.raises(ConnectionRequirementError, match="my_echo_handler"): + check_connection_requirements( + my_echo_handler, _FakePlainConn(), raise_on_failure=True + ) + + def test_error_message_contains_connection_type(self): + @requires_connection(ISecureConn) + async def handler(stream): + pass + + with pytest.raises(ConnectionRequirementError, match="_FakePlainConn"): + check_connection_requirements( + handler, _FakePlainConn(), raise_on_failure=True + ) + + def test_multiple_requirements_all_fail(self): + @requires_connection(ISecureConn, IMuxedConn) + async def handler(stream): + pass + + result = check_connection_requirements(handler, _FakePlainConn()) + assert result is False + + def test_handler_with_no_decoration(self): + """An undecorated handler should always pass.""" + + async def handler(stream): + pass + + assert check_connection_requirements(handler, _FakePlainConn()) is True + assert check_connection_requirements(handler, object()) is True + + +class TestConnectionRequirementError: + """Verify the custom exception class.""" + + def test_is_exception(self): + assert issubclass(ConnectionRequirementError, Exception) + + def test_message_preserved(self): + err = ConnectionRequirementError("test message") + assert str(err) == "test message" + + def test_can_be_caught(self): + with pytest.raises(ConnectionRequirementError): + raise ConnectionRequirementError("oops") + + +class TestUpgraderMuxerOrdering: + """Verify that TransportUpgrader._verify_muxer_ordering works correctly.""" + + def test_verify_muxer_ordering_with_no_muxers(self): + """Empty muxer registry should not raise.""" + from libp2p.transport.upgrader import TransportUpgrader + + upgrader = TransportUpgrader({}, {}) + upgrader._verify_muxer_ordering(object()) + + def test_verify_muxer_ordering_with_undecorated_muxer(self): + """A muxer with no @after_connection should pass any connection.""" + from libp2p.transport.upgrader import TransportUpgrader + + class PlainMuxer: + pass + + upgrader = TransportUpgrader({}, {"/plain/1.0.0": PlainMuxer}) + upgrader._verify_muxer_ordering(object()) + + def test_verify_muxer_ordering_warns_on_mismatch(self): + """ + A muxer with @after_connection(ISecureConn) should warn when + given a plain connection. + """ + import logging + + from libp2p.transport.upgrader import TransportUpgrader + + @after_connection(ISecureConn) + class StrictMuxer: + pass + + upgrader = TransportUpgrader({}, {"/strict/1.0.0": StrictMuxer}) + + upgrader_logger = logging.getLogger("libp2p.transport.upgrader") + + captured: list[str] = [] + + class _CaptureHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + captured.append(record.getMessage()) + + handler = _CaptureHandler() + handler.setLevel(logging.WARNING) + upgrader_logger.addHandler(handler) + try: + upgrader._verify_muxer_ordering(_FakePlainConn()) + finally: + upgrader_logger.removeHandler(handler) + + assert len(captured) == 1 + assert "StrictMuxer" in captured[0] + assert "ISecureConn" in captured[0] diff --git a/tests/core/test_resolver.py b/tests/core/test_resolver.py new file mode 100644 index 000000000..046ef2f6a --- /dev/null +++ b/tests/core/test_resolver.py @@ -0,0 +1,446 @@ +""" +Tests for the pull-based connection resolver. + +Validates: +- ``ConnectionResolver.resolve()`` — outbound dial with transport selection, + security + muxer upgrade, capability-based skip, fallback on failure. +- ``ConnectionResolver.upgrade_inbound()`` — inbound upgrade path. +- ``ResolvedStack`` — properties and diagnostics. +- Resolution error classes. +""" + +from __future__ import annotations + +import pytest + +from libp2p.custom_types import TProtocol +from libp2p.network.resolver import ( + AllPathsFailedError, + ConnectionResolver, + NoTransportError, + ResolutionError, + ResolvedStack, +) +from libp2p.peer.id import ID +from libp2p.providers import ( + MuxerProvider, + ProviderRegistry, + SecurityProvider, + TransportProvider, +) +from libp2p.transport.exceptions import ( + MuxerUpgradeFailure, + SecurityUpgradeFailure, +) + + +class _StubRawConn: + async def close(self) -> None: + pass + + +class _StubSecureConn: + def get_remote_peer(self) -> ID: + return ID(b"\x01\x02") + + async def close(self) -> None: + pass + + +class _StubMuxedConn: + pass + + +class _StubSecureTransport: + async def secure_outbound(self, conn: object, peer_id: ID) -> _StubSecureConn: + return _StubSecureConn() + + async def secure_inbound(self, conn: object) -> _StubSecureConn: + return _StubSecureConn() + + +class _FailingSecureTransport: + async def secure_outbound(self, conn: object, peer_id: ID) -> _StubSecureConn: + raise ConnectionError("security handshake failed") + + async def secure_inbound(self, conn: object) -> _StubSecureConn: + raise ConnectionError("security handshake failed") + + +class _StubMuxerClass: + def __init__(self, conn: object, peer_id: ID) -> None: + self.conn = conn + self.peer_id = peer_id + + +class _StubTransport: + async def dial(self, maddr: object) -> _StubRawConn: + return _StubRawConn() + + async def create_listener(self, handler: object) -> None: + pass + + +class _FailingTransport: + async def dial(self, maddr: object) -> _StubRawConn: + raise ConnectionError("dial failed") + + +class _CapableTransport: + """Transport providing security + muxing (like QUIC).""" + + @property + def provides_security(self) -> bool: + return True + + @property + def provides_muxing(self) -> bool: + return True + + async def dial(self, maddr: object) -> _StubRawConn: + return _StubRawConn() + + +class _ProtoStub: + """Simple object with a .name attribute.""" + + def __init__(self, name: str) -> None: + self.name = name + + +class _FakeMaddr: + """Fake multiaddr for testing.""" + + def __init__(self, proto_names: list[str]) -> None: + self._proto_names = proto_names + + def protocols(self) -> list[_ProtoStub]: + return [_ProtoStub(n) for n in self._proto_names] + + +def _tcp_maddr() -> _FakeMaddr: + return _FakeMaddr(["ip4", "tcp"]) + + +def _quic_maddr() -> _FakeMaddr: + return _FakeMaddr(["ip4", "udp", "quic"]) + + +def _peer_id() -> ID: + return ID(b"\x01\x02\x03") + + +def _build_registry( + *, + transport: object | None = None, + transport_name: str = "tcp", + sec_transport: object | None = None, + muxer_class: object | None = None, + capable_transport: bool = False, +) -> ProviderRegistry: + """Build a ProviderRegistry with optional components registered.""" + reg = ProviderRegistry() + + if capable_transport: + tp = TransportProvider("quic", _CapableTransport()) + reg.register_transport(tp) + elif transport is not None: + tp = TransportProvider(transport_name, transport) + reg.register_transport(tp) + + if sec_transport is not None: + sp = SecurityProvider(TProtocol("/noise"), sec_transport) + reg.register_security(sp) + + if muxer_class is not None: + mp = MuxerProvider(TProtocol("/yamux/1.0.0"), muxer_class) + reg.register_muxer(mp) + + return reg + + +class TestResolvedStack: + """Tests for the ResolvedStack dataclass.""" + + def test_top_connection_muxed(self) -> None: + stack = ResolvedStack( + raw_conn=_StubRawConn(), + secure_conn=_StubSecureConn(), + muxed_conn=_StubMuxedConn(), + ) + assert stack.top_connection is stack.muxed_conn + + def test_top_connection_secure(self) -> None: + stack = ResolvedStack( + raw_conn=_StubRawConn(), + secure_conn=_StubSecureConn(), + ) + assert stack.top_connection is stack.secure_conn + + def test_top_connection_raw(self) -> None: + stack = ResolvedStack(raw_conn=_StubRawConn()) + assert stack.top_connection is stack.raw_conn + + def test_top_connection_empty(self) -> None: + stack = ResolvedStack() + assert stack.top_connection is None + + def test_describes_full_stack(self) -> None: + stack = ResolvedStack( + transport_provider=TransportProvider("tcp", _StubTransport()), + security_provider=SecurityProvider( + TProtocol("/noise"), _StubSecureTransport() + ), + muxer_provider=MuxerProvider(TProtocol("/yamux/1.0.0"), _StubMuxerClass), + ) + desc = stack.describes() + assert "tcp" in desc + assert "/noise" in desc + assert "/yamux/1.0.0" in desc + + def test_describes_builtin_security(self) -> None: + stack = ResolvedStack( + transport_provider=TransportProvider("quic", _CapableTransport()), + skipped_security=True, + skipped_muxer=True, + ) + desc = stack.describes() + assert "builtin" in desc + + def test_describes_empty(self) -> None: + stack = ResolvedStack() + assert stack.describes() == "(empty)" + + +class TestResolutionErrors: + """Tests for resolution error classes.""" + + def test_no_transport_error_is_resolution_error(self) -> None: + assert issubclass(NoTransportError, ResolutionError) + + def test_all_paths_failed_stores_failures(self) -> None: + failures = [("tcp", ConnectionError("fail1")), ("quic", TimeoutError("fail2"))] + err = AllPathsFailedError(failures) + assert err.failures == failures + assert "tcp" in str(err) + assert "quic" in str(err) + + def test_resolution_error_is_exception(self) -> None: + assert issubclass(ResolutionError, Exception) + + +class TestResolverHappyPath: + """Tests for successful resolution.""" + + @pytest.mark.trio + async def test_resolve_full_stack_tcp(self) -> None: + """TCP transport → security upgrade → muxer upgrade.""" + reg = _build_registry( + transport=_StubTransport(), + sec_transport=_StubSecureTransport(), + muxer_class=_StubMuxerClass, + ) + resolver = ConnectionResolver(reg) + stack = await resolver.resolve(_tcp_maddr(), _peer_id()) + + assert stack.raw_conn is not None + assert stack.secure_conn is not None + assert stack.muxed_conn is not None + assert not stack.skipped_security + assert not stack.skipped_muxer + assert stack.transport_provider is not None + assert stack.security_provider is not None + assert stack.muxer_provider is not None + + @pytest.mark.trio + async def test_resolve_self_upgrading_transport(self) -> None: + """QUIC-like transport skips security + muxer.""" + reg = _build_registry(capable_transport=True) + resolver = ConnectionResolver(reg) + stack = await resolver.resolve(_quic_maddr(), _peer_id()) + + assert stack.raw_conn is not None + assert stack.skipped_security + assert stack.skipped_muxer + assert stack.secure_conn is None + assert stack.muxed_conn is None + + @pytest.mark.trio + async def test_resolve_security_only_transport(self) -> None: + """Transport provides security but not muxing.""" + + class _SecOnlyTransport: + @property + def provides_security(self) -> bool: + return True + + @property + def provides_muxing(self) -> bool: + return False + + async def dial(self, maddr: object) -> _StubRawConn: + return _StubRawConn() + + reg = ProviderRegistry() + reg.register_transport(TransportProvider("tls-tcp", _SecOnlyTransport())) + reg.register_muxer(MuxerProvider(TProtocol("/yamux/1.0.0"), _StubMuxerClass)) + + resolver = ConnectionResolver(reg) + maddr = _FakeMaddr(["ip4", "tls-tcp"]) + stack = await resolver.resolve(maddr, _peer_id()) + + assert stack.skipped_security + assert not stack.skipped_muxer + assert stack.muxed_conn is not None + + +class TestResolverErrorPaths: + """Tests for resolution failure modes.""" + + @pytest.mark.trio + async def test_no_transport_for_maddr(self) -> None: + """No transport can dial the address.""" + reg = _build_registry( + transport=_StubTransport(), + transport_name="tcp", + ) + resolver = ConnectionResolver(reg) + quic_maddr = _quic_maddr() + with pytest.raises(NoTransportError, match="No registered transport"): + await resolver.resolve(quic_maddr, _peer_id()) + + @pytest.mark.trio + async def test_empty_registry_raises_no_transport(self) -> None: + reg = ProviderRegistry() + resolver = ConnectionResolver(reg) + with pytest.raises(NoTransportError): + await resolver.resolve(_tcp_maddr(), _peer_id()) + + @pytest.mark.trio + async def test_transport_dial_failure_all_paths_failed(self) -> None: + """Transport dial fails → AllPathsFailedError.""" + reg = _build_registry(transport=_FailingTransport()) + resolver = ConnectionResolver(reg) + with pytest.raises(AllPathsFailedError) as exc_info: + await resolver.resolve(_tcp_maddr(), _peer_id()) + assert len(exc_info.value.failures) == 1 + + @pytest.mark.trio + async def test_security_upgrade_failure(self) -> None: + """Security upgrade fails for all providers → AllPathsFailedError.""" + reg = _build_registry( + transport=_StubTransport(), + sec_transport=_FailingSecureTransport(), + muxer_class=_StubMuxerClass, + ) + resolver = ConnectionResolver(reg) + with pytest.raises(AllPathsFailedError): + await resolver.resolve(_tcp_maddr(), _peer_id()) + + @pytest.mark.trio + async def test_no_security_providers_registered(self) -> None: + """No security providers → SecurityUpgradeFailure wrapped in AllPaths.""" + reg = _build_registry( + transport=_StubTransport(), + muxer_class=_StubMuxerClass, + ) + resolver = ConnectionResolver(reg) + with pytest.raises(AllPathsFailedError): + await resolver.resolve(_tcp_maddr(), _peer_id()) + + @pytest.mark.trio + async def test_no_muxer_providers_registered(self) -> None: + """No muxer providers → MuxerUpgradeFailure wrapped in AllPaths.""" + reg = _build_registry( + transport=_StubTransport(), + sec_transport=_StubSecureTransport(), + ) + resolver = ConnectionResolver(reg) + with pytest.raises(AllPathsFailedError): + await resolver.resolve(_tcp_maddr(), _peer_id()) + + +class TestResolverFallback: + """Tests for multi-transport fallback behaviour.""" + + @pytest.mark.trio + async def test_fallback_to_second_transport(self) -> None: + """First transport fails, second succeeds.""" + reg = ProviderRegistry() + reg.register_transport(TransportProvider("tcp", _FailingTransport())) + reg.register_transport(TransportProvider("tcp", _StubTransport())) + reg.register_security( + SecurityProvider(TProtocol("/noise"), _StubSecureTransport()) + ) + reg.register_muxer(MuxerProvider(TProtocol("/yamux/1.0.0"), _StubMuxerClass)) + + resolver = ConnectionResolver(reg) + stack = await resolver.resolve(_tcp_maddr(), _peer_id()) + + assert stack.raw_conn is not None + assert stack.muxed_conn is not None + + +class TestResolverInbound: + """Tests for inbound (listener-accepted) connection upgrades.""" + + @pytest.mark.trio + async def test_inbound_full_upgrade(self) -> None: + """Inbound: security + muxer applied.""" + reg = _build_registry( + sec_transport=_StubSecureTransport(), + muxer_class=_StubMuxerClass, + ) + resolver = ConnectionResolver(reg) + raw = _StubRawConn() + stack = await resolver.upgrade_inbound(raw) + + assert stack.raw_conn is raw + assert stack.secure_conn is not None + assert stack.muxed_conn is not None + assert not stack.skipped_security + assert not stack.skipped_muxer + + @pytest.mark.trio + async def test_inbound_self_upgrading_transport(self) -> None: + """Inbound from a self-upgrading transport — skip both layers.""" + reg = ProviderRegistry() + resolver = ConnectionResolver(reg) + raw = _StubRawConn() + stack = await resolver.upgrade_inbound( + raw, + transport_has_security=True, + transport_has_muxing=True, + ) + assert stack.skipped_security + assert stack.skipped_muxer + + @pytest.mark.trio + async def test_inbound_no_security_providers(self) -> None: + """Inbound with no security providers → SecurityUpgradeFailure.""" + reg = ProviderRegistry() + resolver = ConnectionResolver(reg) + with pytest.raises(SecurityUpgradeFailure): + await resolver.upgrade_inbound(_StubRawConn()) + + @pytest.mark.trio + async def test_inbound_no_muxer_providers(self) -> None: + """Inbound with no muxer providers → MuxerUpgradeFailure.""" + reg = _build_registry(sec_transport=_StubSecureTransport()) + resolver = ConnectionResolver(reg) + with pytest.raises(MuxerUpgradeFailure): + await resolver.upgrade_inbound(_StubRawConn()) + + @pytest.mark.trio + async def test_inbound_security_skip_muxer_applied(self) -> None: + """Inbound: transport has security, muxer still applied.""" + reg = _build_registry(muxer_class=_StubMuxerClass) + resolver = ConnectionResolver(reg) + raw = _StubRawConn() + stack = await resolver.upgrade_inbound( + raw, + transport_has_security=True, + ) + assert stack.skipped_security + assert not stack.skipped_muxer + assert stack.muxed_conn is not None diff --git a/tests/core/transport/test_transport_registry.py b/tests/core/transport/test_transport_registry.py index 31b398736..d0552e773 100644 --- a/tests/core/transport/test_transport_registry.py +++ b/tests/core/transport/test_transport_registry.py @@ -354,3 +354,97 @@ def create_listener(self, handler_function: THandler) -> IListener: # Should be available in the other assert registry2.get_transport("persistent") == PersistentTransport + + +class TestTransportRegistryCapabilities: + """Tests for capability-aware query methods.""" + + def test_quic_provides_security(self): + """QUIC transport declares provides_security = True.""" + registry = TransportRegistry() + assert registry.transport_provides_security("quic") is True + assert registry.transport_provides_security("quic-v1") is True + + def test_quic_provides_muxing(self): + """QUIC transport declares provides_muxing = True.""" + registry = TransportRegistry() + assert registry.transport_provides_muxing("quic") is True + assert registry.transport_provides_muxing("quic-v1") is True + + def test_tcp_does_not_provide_security(self): + """TCP transport has no capability declarations.""" + registry = TransportRegistry() + assert registry.transport_provides_security("tcp") is False + + def test_tcp_does_not_provide_muxing(self): + registry = TransportRegistry() + assert registry.transport_provides_muxing("tcp") is False + + def test_unknown_protocol_returns_false(self): + registry = TransportRegistry() + assert registry.transport_provides_security("nonexistent") is False + assert registry.transport_provides_muxing("nonexistent") is False + + def test_needs_security_upgrade(self): + registry = TransportRegistry() + assert registry.needs_security_upgrade("tcp") is True + assert registry.needs_security_upgrade("quic") is False + + def test_needs_muxer_upgrade(self): + registry = TransportRegistry() + assert registry.needs_muxer_upgrade("tcp") is True + assert registry.needs_muxer_upgrade("quic") is False + + def test_get_self_upgrading_protocols(self): + registry = TransportRegistry() + self_upgrading = registry.get_self_upgrading_protocols() + assert "quic" in self_upgrading + assert "quic-v1" in self_upgrading + assert "tcp" not in self_upgrading + + def test_needs_upgrade_unknown_protocol(self): + registry = TransportRegistry() + assert registry.needs_security_upgrade("unknown") is True + assert registry.needs_muxer_upgrade("unknown") is True + + def test_custom_capable_transport(self): + """Registering a custom transport with capabilities should be detected.""" + registry = TransportRegistry() + + class CapableTransport(ITransport): + @property + def provides_security(self) -> bool: + return True + + @property + def provides_muxing(self) -> bool: + return False + + async def dial(self, maddr: Multiaddr) -> IRawConnection: + raise NotImplementedError + + def create_listener(self, handler_function: THandler) -> IListener: + raise NotImplementedError + + registry.register_transport("custom-secure", CapableTransport) + assert registry.transport_provides_security("custom-secure") is True + assert registry.transport_provides_muxing("custom-secure") is False + assert registry.needs_security_upgrade("custom-secure") is False + assert registry.needs_muxer_upgrade("custom-secure") is True + assert "custom-secure" not in registry.get_self_upgrading_protocols() + + +class TestModuleLevelCapabilityHelpers: + """Tests for the module-level convenience functions.""" + + def test_transport_needs_security(self): + from libp2p.transport.transport_registry import transport_needs_security + + assert transport_needs_security("tcp") is True + assert transport_needs_security("quic") is False + + def test_transport_needs_muxer(self): + from libp2p.transport.transport_registry import transport_needs_muxer + + assert transport_needs_muxer("tcp") is True + assert transport_needs_muxer("quic") is False