diff --git a/docs/examples.multi_transport.rst b/docs/examples.multi_transport.rst new file mode 100644 index 000000000..8d647d2ea --- /dev/null +++ b/docs/examples.multi_transport.rst @@ -0,0 +1,37 @@ +examples.multi\_transport package +================================= + +Migration Guide +--------------- +The `libp2p.transport.transport_registry` module has been removed, and `Swarm` no longer accepts the `transport=` keyword argument. + +To migrate to the new `TransportManager` architecture: +- Pass a list of transports to `new_swarm` or `new_host` using the `transports=[...]` keyword argument. +- If `transports` is omitted, the swarm will automatically detect and create transports based on the provided `listen_addrs`. + +Submodules +---------- + +examples.multi\_transport.client module +--------------------------------------- + +.. automodule:: examples.multi_transport.client + :members: + :show-inheritance: + :undoc-members: + +examples.multi\_transport.server module +--------------------------------------- + +.. automodule:: examples.multi_transport.server + :members: + :show-inheritance: + :undoc-members: + +Module contents +--------------- + +.. automodule:: examples.multi_transport + :members: + :show-inheritance: + :undoc-members: diff --git a/docs/examples.rst b/docs/examples.rst index 43243e9c6..eafff9d03 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -35,7 +35,8 @@ Examples examples.websocket examples.tls examples.tcp - examples.transport + examples.autotls examples.perf examples.path_handling + examples.multi_transport diff --git a/docs/examples.transport.rst b/docs/examples.transport.rst deleted file mode 100644 index aab3036e4..000000000 --- a/docs/examples.transport.rst +++ /dev/null @@ -1,21 +0,0 @@ -examples.transport package -========================== - -Submodules ----------- - -examples.transport.transport\_integration\_demo module ------------------------------------------------------- - -.. automodule:: examples.transport.transport_integration_demo - :members: - :show-inheritance: - :undoc-members: - -Module contents ---------------- - -.. automodule:: examples.transport - :members: - :show-inheritance: - :undoc-members: diff --git a/examples/autotls_browser/main.py b/examples/autotls_browser/main.py index e4da5430f..c04a35a40 100644 --- a/examples/autotls_browser/main.py +++ b/examples/autotls_browser/main.py @@ -129,7 +129,7 @@ async def start_server(self) -> None: peer_id=peer_id, peerstore=peer_store, upgrader=upgrader, - transport=transport, + transports=[transport], ) self.host = BasicHost(swarm) diff --git a/examples/multi_transport/__init__.py b/examples/multi_transport/__init__.py new file mode 100644 index 000000000..0ab3c8f78 --- /dev/null +++ b/examples/multi_transport/__init__.py @@ -0,0 +1 @@ +"""Multi-transport example package.""" diff --git a/examples/multi_transport/client.py b/examples/multi_transport/client.py new file mode 100644 index 000000000..8bfd19687 --- /dev/null +++ b/examples/multi_transport/client.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +""" +Multi-Transport Echo Client — py-libp2p + +Connects to the multi-transport echo server using whichever transport +is encoded in the supplied multiaddress. + +Usage: + python client.py -d /ip4/127.0.0.1/tcp/4001/p2p/ + python client.py -d /ip4/127.0.0.1/tcp/4002/ws/p2p/ + python client.py -d /ip4/127.0.0.1/udp/4003/quic/p2p/ +""" + +import argparse +import logging +from pathlib import Path +import sys + +# Ensure local libp2p is used +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import multiaddr +import trio + +from libp2p import new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.peer.peerinfo import info_from_p2p_addr + +logging.basicConfig(level=logging.WARNING) +logging.getLogger("libp2p").setLevel(logging.WARNING) + +ECHO_PROTOCOL = TProtocol("/echo/1.0.0") + + +def _detect_transport(maddr_str: str) -> str: + """Return a human-readable transport name from a multiaddr string.""" + if "/quic" in maddr_str: + return "QUIC" + if "/ws" in maddr_str or "/wss" in maddr_str: + return "WebSocket" + return "TCP" + + +async def run_client( + destination: str, message: bytes = b"hello from py-libp2p!\n" +) -> None: + """ + Connect to *destination* (a full /p2p/… multiaddr), send *message*, and + print the echoed reply. + + The client inspects the destination multiaddr to enable the right transport + in the Swarm's TransportManager before dialing. + """ + transport_name = _detect_transport(destination) + print(f"=== Multi-Transport Echo Client ({transport_name}) ===\n") + + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + + # Enable the transport that matches the destination address. + # new_host() must know which transports to register at construction time — + # the TransportManager is fixed after the Swarm is built. + enable_quic = transport_name == "QUIC" + enable_websocket = transport_name == "WebSocket" + + host = new_host( + key_pair=create_new_key_pair(), + enable_quic=enable_quic, + enable_websocket=enable_websocket, + ) + + # Client doesn't listen — pass an empty list. + async with host.run(listen_addrs=[]): + print(f"My peer ID : {host.get_id().to_string()}") + print(f"Connecting : {destination}\n") + + await host.connect(info) + print(f"Connected ✓ (transport: {transport_name})") + + stream = await host.new_stream(info.peer_id, [ECHO_PROTOCOL]) + await stream.write(message) + # Read exactly the number of bytes we sent to avoid deadlocks + response = await stream.read(len(message)) + await stream.close() + + print(f"Sent : {message!r}") + print(f"Got : {response!r}") + + if response == message: + print("\n✅ Echo verified — round-trip successful!") + else: + print("\n❌ Echo mismatch!") + await trio.sleep(30) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Multi-transport echo client (auto-selects TCP / WebSocket / QUIC)." + ) + parser.add_argument( + "-d", + "--destination", + required=True, + type=str, + help=( + "Full multiaddr of the server including /p2p/, e.g.: " + "/ip4/127.0.0.1/tcp/4001/p2p/16Uiu2..." + ), + ) + parser.add_argument( + "-m", + "--message", + type=str, + default="hello from py-libp2p!", + help="Message to echo (default: 'hello from py-libp2p!')", + ) + args = parser.parse_args() + try: + trio.run(run_client, args.destination, args.message.encode()) + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/examples/multi_transport/server.py b/examples/multi_transport/server.py new file mode 100644 index 000000000..12800dd96 --- /dev/null +++ b/examples/multi_transport/server.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +""" +Multi-Transport Echo Server — py-libp2p + +Demonstrates listening on TCP, WebSocket, and QUIC simultaneously, +mirroring go-libp2p's multi-transport architecture. + +Usage: + # Start server (auto-detects free ports on all transports): + python server.py + + # Start server on specific port: + python server.py --port 4001 + + # Start client (copy one of the multiaddrs printed by the server): + python client.py -d /ip4/127.0.0.1/tcp/4001/p2p/ + python client.py -d /ip4/127.0.0.1/tcp/4001/ws/p2p/ + python client.py -d /ip4/127.0.0.1/udp/4001/quic/p2p/ +""" + +import argparse +import logging +from pathlib import Path +import sys + +# Ensure local libp2p is used +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import multiaddr +import trio + +from libp2p import new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.network.stream.net_stream import INetStream +from libp2p.utils.address_validation import find_free_port + +# Configure minimal logging +logging.basicConfig(level=logging.WARNING) +logging.getLogger("libp2p").setLevel(logging.WARNING) + +ECHO_PROTOCOL = TProtocol("/echo/1.0.0") + + +async def _echo_handler(stream: INetStream) -> None: + """Echo handler: read up to 1024 bytes and write it back.""" + try: + peer_id = stream.muxed_conn.peer_id + # Read a chunk rather than wait for EOF, preventing deadlocks + # if the client keeps it open + data = await stream.read(1024) + print(f" [{peer_id!s:.20}...] echoing {len(data)} bytes") + await stream.write(data) + await stream.close() + except Exception as exc: + print(f" Handler error: {exc}") + try: + await stream.reset() + except Exception: + pass + + +async def run_server(port: int = 0) -> None: + """ + Listen simultaneously on TCP, WebSocket, and QUIC using the EXACT SAME PORT. + This demonstrates py-libp2p's connection multiplexing (cmux) capabilities. + """ + if port == 0: + port = find_free_port() + + # Build listen addresses for all three transports on the SAME port + listen_addrs = [ + multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}"), # plain TCP + multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}/ws"), # WebSocket + multiaddr.Multiaddr(f"/ip4/0.0.0.0/udp/{port}/quic"), # QUIC + ] + + host = new_host(key_pair=create_new_key_pair(), listen_addrs=listen_addrs) + host.set_stream_handler(ECHO_PROTOCOL, _echo_handler) + + print("=== Multi-Transport CMUX Echo Server ===\n") + print(f"Using shared port: {port}\n") + print("Listening on:") + + async with host.run(listen_addrs=listen_addrs): + # Wait until the host has bound and resolved all ports + for _ in range(50): + if host.get_addrs(): + break + await trio.sleep(0.05) + + peer_id = host.get_id().to_string() + for addr in host.get_addrs(): + print(f" {addr}") + + print( + "\nConnect using any of the following (replace 0.0.0.0 with your IP):\n\n" + f" TCP: python client.py -d " + f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}\n" + f" WebSocket: python client.py -d " + f"/ip4/127.0.0.1/tcp/{port}/ws/p2p/{peer_id}\n" + f" QUIC: python client.py -d " + f"/ip4/127.0.0.1/udp/{port}/quic/p2p/{peer_id}" + ) + print("\nWaiting for connections… (Ctrl-C to stop)\n") + + await trio.sleep_forever() + + +def main() -> None: + parser = argparse.ArgumentParser( + description="CMUX echo server (TCP + WebSocket + QUIC on SAME port)." + ) + parser.add_argument( + "--port", type=int, default=0, help="Listen port for all transports (0 = auto)" + ) + args = parser.parse_args() + try: + trio.run(run_server, args.port) + except KeyboardInterrupt: + print("\nServer stopped.") + + +if __name__ == "__main__": + main() diff --git a/examples/transport/__init__.py b/examples/transport/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/transport/transport_integration_demo.py b/examples/transport/transport_integration_demo.py deleted file mode 100644 index 58403c6e0..000000000 --- a/examples/transport/transport_integration_demo.py +++ /dev/null @@ -1,208 +0,0 @@ -#!/usr/bin/env python3 -""" -Demo script showing the new transport integration capabilities in py-libp2p. - -This script demonstrates: -1. How to use the transport registry -2. How to create transports dynamically based on multiaddrs -3. How to register custom transports -4. How the new system automatically selects the right transport - -Usage: - python examples/transport/transport_integration_demo.py -""" - -import logging - -import multiaddr -import trio - -from libp2p.transport import ( - create_transport, - create_transport_for_multiaddr, - get_supported_transport_protocols, - get_transport_registry, - register_transport, -) -from libp2p.transport.tcp.tcp import TCP -from libp2p.transport.upgrader import TransportUpgrader - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def demo_transport_registry(): - """Demonstrate the transport registry functionality.""" - print("🔧 Transport Registry Demo") - print("=" * 50) - - # Get the global registry - registry = get_transport_registry() - - # Show supported protocols - supported = get_supported_transport_protocols() - print(f"Supported transport protocols: {supported}") - - # Show registered transports - print("\nRegistered transports:") - for protocol in supported: - transport_class = registry.get_transport(protocol) - class_name = transport_class.__name__ if transport_class else "None" - print(f" {protocol}: {class_name}") - - print() - - -def demo_transport_factory(): - """Demonstrate the transport factory functions.""" - print("🏭 Transport Factory Demo") - print("=" * 50) - - # Create a dummy upgrader for WebSocket transport - upgrader = TransportUpgrader({}, {}) - - # Create transports using the factory function - try: - tcp_transport = create_transport("tcp") - print(f"✅ Created TCP transport: {type(tcp_transport).__name__}") - - ws_transport = create_transport("ws", upgrader) - print(f"✅ Created WebSocket transport: {type(ws_transport).__name__}") - - except Exception as e: - print(f"❌ Error creating transport: {e}") - - print() - - -def demo_multiaddr_transport_selection(): - """Demonstrate automatic transport selection based on multiaddrs.""" - print("🎯 Multiaddr Transport Selection Demo") - print("=" * 50) - - # Create a dummy upgrader - upgrader = TransportUpgrader({}, {}) - - # Test different multiaddr types - test_addrs = [ - "/ip4/127.0.0.1/tcp/8080", - "/ip4/127.0.0.1/tcp/8080/ws", - "/ip6/::1/tcp/8080/ws", - "/dns4/example.com/tcp/443/ws", - ] - - for addr_str in test_addrs: - try: - maddr = multiaddr.Multiaddr(addr_str) - transport = create_transport_for_multiaddr(maddr, upgrader) - - if transport: - print(f"✅ {addr_str} -> {type(transport).__name__}") - else: - print(f"❌ {addr_str} -> No transport found") - - except Exception as e: - print(f"❌ {addr_str} -> Error: {e}") - - print() - - -def demo_custom_transport_registration(): - """Demonstrate how to register custom transports.""" - print("🔧 Custom Transport Registration Demo") - print("=" * 50) - - # Show current supported protocols - print(f"Before registration: {get_supported_transport_protocols()}") - - # Register a custom transport (using TCP as an example) - class CustomTCPTransport(TCP): - """Custom TCP transport for demonstration.""" - - def __init__(self): - super().__init__() - self.custom_flag = True - - # Register the custom transport - register_transport("custom_tcp", CustomTCPTransport) - - # Show updated supported protocols - print(f"After registration: {get_supported_transport_protocols()}") - - # Test creating the custom transport - try: - custom_transport = create_transport("custom_tcp") - print(f"✅ Created custom transport: {type(custom_transport).__name__}") - # Check if it has the custom flag (type-safe way) - if hasattr(custom_transport, "custom_flag"): - flag_value = getattr(custom_transport, "custom_flag", "Not found") - print(f" Custom flag: {flag_value}") - else: - print(" Custom flag: Not found") - except Exception as e: - print(f"❌ Error creating custom transport: {e}") - - print() - - -def demo_integration_with_libp2p(): - """Demonstrate how the new system integrates with libp2p.""" - print("🚀 Libp2p Integration Demo") - print("=" * 50) - - print("The new transport system integrates seamlessly with libp2p:") - print() - print("1. ✅ Automatic transport selection based on multiaddr") - print("2. ✅ Support for WebSocket (/ws) protocol") - print("3. ✅ Fallback to TCP for backward compatibility") - print("4. ✅ Easy registration of new transport protocols") - print("5. ✅ No changes needed to existing libp2p code") - print() - - print("Example usage in libp2p:") - print(" # This will automatically use WebSocket transport") - print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080/ws'])") - print() - print(" # This will automatically use TCP transport") - print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080'])") - print() - - print() - - -async def main(): - """Run all demos.""" - print("🎉 Py-libp2p Transport Integration Demo") - print("=" * 60) - print() - - # Run all demos - demo_transport_registry() - demo_transport_factory() - demo_multiaddr_transport_selection() - demo_custom_transport_registration() - demo_integration_with_libp2p() - - print("🎯 Summary of New Features:") - print("=" * 40) - print("✅ Transport Registry: Central registry for all transport implementations") - print("✅ Dynamic Transport Selection: Automatic selection based on multiaddr") - print("✅ WebSocket Support: Full /ws protocol support") - print("✅ Extensible Architecture: Easy to add new transport protocols") - print("✅ Backward Compatibility: Existing TCP code continues to work") - print("✅ Factory Functions: Simple API for creating transports") - print() - print("🚀 The transport system is now ready for production use!") - - -if __name__ == "__main__": - try: - trio.run(main) - except KeyboardInterrupt: - print("\n👋 Demo interrupted by user") - except Exception as e: - print(f"\n❌ Demo failed with error: {e}") - import traceback - - traceback.print_exc() diff --git a/examples/websocket/proxy_websocket_demo.py b/examples/websocket/proxy_websocket_demo.py index 8e484852a..2a5b2bab5 100644 --- a/examples/websocket/proxy_websocket_demo.py +++ b/examples/websocket/proxy_websocket_demo.py @@ -85,15 +85,9 @@ def create_websocket_host_with_proxy(proxy_url=None, proxy_auth=None): sec_opt={PLAINTEXT_PROTOCOL_ID: InsecureTransport(key_pair)}, muxer_opt=create_yamux_muxer_option(), listen_addrs=[Multiaddr("/ip4/0.0.0.0/tcp/0/ws")], + transports=[transport], ) - # Replace the default transport with our configured one - from libp2p.network.swarm import Swarm - - swarm = host.get_network() - if isinstance(swarm, Swarm): - swarm.transport = transport - return host diff --git a/examples/websocket/test_tcp_echo.py b/examples/websocket/test_tcp_echo.py index 20728bf62..1b57473de 100644 --- a/examples/websocket/test_tcp_echo.py +++ b/examples/websocket/test_tcp_echo.py @@ -66,7 +66,7 @@ def create_tcp_host(): transport = TCP() # Create swarm and host - swarm = Swarm(peer_id, peer_store, upgrader, transport) + swarm = Swarm(peer_id, peer_store, upgrader, [transport]) host = BasicHost(swarm) return host diff --git a/examples/websocket/test_websocket_transport.py b/examples/websocket/test_websocket_transport.py deleted file mode 100644 index 09e7457ca..000000000 --- a/examples/websocket/test_websocket_transport.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple test script to verify WebSocket transport functionality. -""" - -import logging -from pathlib import Path -import sys - -# Add the libp2p directory to the path so we can import it -sys.path.insert(0, str(Path(__file__).parent)) - -import pytest -import multiaddr -import trio - -from libp2p.transport import create_transport, create_transport_for_multiaddr -from libp2p.transport.upgrader import TransportUpgrader - -# Set up logging -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - - -@pytest.mark.trio -async def test_websocket_transport(): - """Test basic WebSocket transport functionality.""" - print("🧪 Testing WebSocket Transport Functionality") - print("=" * 50) - - # Create a dummy upgrader - upgrader = TransportUpgrader({}, {}) - - # Test creating WebSocket transport - try: - ws_transport = create_transport("ws", upgrader) - print(f"✅ WebSocket transport created: {type(ws_transport).__name__}") - - # Test creating transport from multiaddr - ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") - ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader) - print( - f"✅ WebSocket transport from multiaddr: " - f"{type(ws_transport_from_maddr).__name__}" - ) - - # Test creating listener - handler_called = False - - async def test_handler(conn): - nonlocal handler_called - handler_called = True - print(f"✅ Connection handler called with: {type(conn).__name__}") - await conn.close() - - listener = ws_transport.create_listener(test_handler) - print(f"✅ WebSocket listener created: {type(listener).__name__}") - - # Test that the transport can be used - print( - f"✅ WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}" - ) - print( - f"✅ WebSocket transport supports listening: " - f"{hasattr(ws_transport, 'create_listener')}" - ) - - print("\n🎯 WebSocket Transport Test Results:") - print("✅ Transport creation: PASS") - print("✅ Multiaddr parsing: PASS") - print("✅ Listener creation: PASS") - print("✅ Interface compliance: PASS") - - except Exception as e: - print(f"❌ WebSocket transport test failed: {e}") - import traceback - - traceback.print_exc() - return False - - return True - - -@pytest.mark.trio -async def test_transport_registry(): - """Test the transport registry functionality.""" - print("\n🔧 Testing Transport Registry") - print("=" * 30) - - from libp2p.transport import ( - get_supported_transport_protocols, - get_transport_registry, - ) - - registry = get_transport_registry() - supported = get_supported_transport_protocols() - - print(f"Supported protocols: {supported}") - - # Test getting transports - for protocol in supported: - transport_class = registry.get_transport(protocol) - class_name = transport_class.__name__ if transport_class else "None" - print(f" {protocol}: {class_name}") - - # Test creating transports through registry - upgrader = TransportUpgrader({}, {}) - - for protocol in supported: - try: - transport = registry.create_transport(protocol, upgrader) - if transport: - print(f"✅ {protocol}: Created successfully") - else: - print(f"❌ {protocol}: Failed to create") - except Exception as e: - print(f"❌ {protocol}: Error - {e}") - - -async def main(): - """Run all tests.""" - print("🚀 WebSocket Transport Integration Test Suite") - print("=" * 60) - print() - - # Run tests - success = await test_websocket_transport() - await test_transport_registry() - - print("\n" + "=" * 60) - if success: - print("🎉 All tests passed! WebSocket transport is working correctly.") - else: - print("❌ Some tests failed. Check the output above for details.") - - print("\n🚀 WebSocket transport is ready for use in py-libp2p!") - - -if __name__ == "__main__": - try: - trio.run(main) - except KeyboardInterrupt: - print("\n👋 Test interrupted by user") - except Exception as e: - print(f"\n❌ Test failed with error: {e}") - import traceback - - traceback.print_exc() diff --git a/examples/websocket/websocket_comprehensive_demo.py b/examples/websocket/websocket_comprehensive_demo.py index f94fcdd78..f620300c1 100644 --- a/examples/websocket/websocket_comprehensive_demo.py +++ b/examples/websocket/websocket_comprehensive_demo.py @@ -229,15 +229,9 @@ def create_websocket_host( listen_addrs=listen_addrs, tls_server_config=server_context, tls_client_config=client_context, + transports=[transport], ) - # Replace the default transport with our configured one - from libp2p.network.swarm import Swarm - - swarm = host.get_network() - if isinstance(swarm, Swarm): - swarm.transport = transport - return host diff --git a/examples/websocket/websocket_demo.py b/examples/websocket/websocket_demo.py index 5a467233a..0f8118d93 100644 --- a/examples/websocket/websocket_demo.py +++ b/examples/websocket/websocket_demo.py @@ -98,7 +98,7 @@ def create_websocket_host(listen_addrs=None, use_plaintext=False): transport = WebsocketTransport(upgrader) # Create swarm and host - swarm = Swarm(peer_id, peer_store, upgrader, transport) + swarm = Swarm(peer_id, peer_store, upgrader, [transport]) host = BasicHost(swarm) return host diff --git a/examples/websocket_mvp/client.py b/examples/websocket_mvp/client.py index ac4c24726..644861823 100644 --- a/examples/websocket_mvp/client.py +++ b/examples/websocket_mvp/client.py @@ -130,15 +130,9 @@ def create_host(self): sec_opt={PLAINTEXT_PROTOCOL_ID: InsecureTransport(key_pair)}, muxer_opt=create_yamux_muxer_option(), listen_addrs=[], # Client doesn't need to listen + transports=[transport], ) - # Replace the default transport with our configured one - from libp2p.network.swarm import Swarm - - swarm = host.get_network() - if isinstance(swarm, Swarm): - swarm.transport = transport - return host async def send_echo(self, peer_info, message: str) -> str: diff --git a/examples/websocket_mvp/server.py b/examples/websocket_mvp/server.py index b0b012996..6dbc024dd 100644 --- a/examples/websocket_mvp/server.py +++ b/examples/websocket_mvp/server.py @@ -74,15 +74,9 @@ def create_host(self): sec_opt={PLAINTEXT_PROTOCOL_ID: InsecureTransport(key_pair)}, muxer_opt=create_yamux_muxer_option(), listen_addrs=listen_addrs, + transports=[transport], ) - # Replace the default transport with our configured one - from libp2p.network.swarm import Swarm - - swarm = host.get_network() - if isinstance(swarm, Swarm): - swarm.transport = transport - return host async def echo_handler(self, stream): diff --git a/libp2p/__init__.py b/libp2p/__init__.py index a010ccde8..542325fdf 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -6,7 +6,10 @@ from pathlib import Path import ssl from libp2p.transport.quic.utils import is_quic_multiaddr -from typing import Any +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from libp2p.transport.cmux import PortDemultiplexer from cryptography.hazmat.primitives.asymmetric import ed25519 from cryptography.hazmat.primitives import serialization @@ -110,10 +113,7 @@ from libp2p.transport.upgrader import ( TransportUpgrader, ) -from libp2p.transport.transport_registry import ( - create_transport_for_multiaddr, - get_supported_transport_protocols, -) + import libp2p.utils from libp2p.utils.logging import ( setup_logging, @@ -279,6 +279,148 @@ def get_default_muxer_options() -> TMuxerOptions: else: # YAMUX is default return create_yamux_muxer_option() +def _build_transports_for_swarm( + key_pair: KeyPair, + listen_addrs: Sequence[multiaddr.Multiaddr] | None, + transports: Sequence[ITransport] | None, + enable_quic: bool, + enable_webrtc: bool, + enable_tcp: bool, + enable_websocket: bool, + enable_autotls: bool, + upgrader: TransportUpgrader, + quic_config: QUICTransportConfig | None, + tls_client_config: ssl.SSLContext | None, + tls_server_config: ssl.SSLContext | None, + # Pass QUICTransport class from module scope so monkeypatching in tests works. + quic_class: type | None = None, +) -> list[ITransport]: + """ + Build the ordered list of transports for the Swarm's TransportManager. + + Priority: + 1. Explicit ``transports`` list — used as-is (highest priority). + 2. ``listen_addrs`` inspection — auto-detects which transports are needed + by inspecting **every** address (not just the first one). + 3. ``enable_*`` flags — coarse-grained control when no addresses given. + 4. Default fallback: TCP only. + + :param key_pair: The host's key pair (needed by QUIC for TLS). + :param listen_addrs: The multiaddrs the host will listen on. + :param transports: Explicit transport list, or ``None`` to auto-build. + :param enable_quic: Whether to create a QUIC transport when auto-building. + :param enable_webrtc: Whether to create a WebRTC transport when auto-building. + :param enable_tcp: Whether to include a TCP transport when auto-building. + :param enable_websocket: Whether to include a WebSocket transport. + :param enable_autotls: Whether to enable AutoTLS in QUIC/WebSocket transports. + :param upgrader: The upgrader passed to WebSocket transport at construction. + :param quic_config: Optional QUIC transport configuration. + :param tls_client_config: TLS client context for WebSocket. + :param tls_server_config: TLS server context for WebSocket. + :returns: Ordered list of :class:`~libp2p.abc.ITransport` instances. + """ + # Highest priority: user-supplied explicit list. + if transports is not None: + return list(transports) + + # Use the provided class (patchable by tests) or fall back to module-level import. + _QUICTransport = quic_class if quic_class is not None else QUICTransport + + result: list[ITransport] = [] + + if listen_addrs: + seen_classes: set[type] = set() + for addr in listen_addrs: + protocols = [p.name for p in addr.protocols()] + transport_obj: ITransport | None = None + + if "tcp" in protocols and "ws" not in protocols and "wss" not in protocols: + transport_obj = TCP() + elif "quic" in protocols or "quic-v1" in protocols: + if key_pair is None: + logger.warning("QUIC transport requires key_pair (private_key)") + continue + transport_obj = _QUICTransport( + key_pair.private_key, + config=quic_config, + enable_autotls=enable_autotls, + ) + elif "ws" in protocols or "wss" in protocols: + from libp2p.transport.websocket.transport import WebsocketTransport + + transport_obj = WebsocketTransport( + upgrader, + tls_client_config=tls_client_config, + tls_server_config=tls_server_config, + ) + elif "webrtc-direct" in protocols: + from libp2p.transport.webrtc.transport import WebRTCDirectTransport + + if key_pair is None: + logger.warning("WebRTC transport requires key_pair (private_key)") + continue + transport_obj = WebRTCDirectTransport(private_key=key_pair.private_key) + elif "webrtc" in protocols: + from libp2p.transport.webrtc.private_transport import WebRTCPrivateTransport + + if key_pair is None: + logger.warning("WebRTC transport requires key_pair (private_key)") + continue + transport_obj = WebRTCPrivateTransport(private_key=key_pair.private_key) + + if transport_obj is None: + continue + + cls = type(transport_obj) + if cls not in seen_classes: + seen_classes.add(cls) + result.append(transport_obj) + + # If enable_quic=True is requested but no QUIC was detected in listen_addrs, + # replace the result with only the QUIC transport (mirrors original new_swarm() + # behavior where the transport was replaced rather than appended). + if enable_quic and not any(type(t).__name__ == "QUICTransport" for t in result): + result = [ + _QUICTransport( + key_pair.private_key, + config=quic_config, + enable_autotls=enable_autotls, + ) + ] + + # Fall through to flags if nothing was auto-detected. + if not result: + if enable_quic: + result.append( + _QUICTransport( + key_pair.private_key, + config=quic_config, + enable_autotls=enable_autotls, + ) + ) + if enable_websocket: + from libp2p.transport.websocket.transport import WebsocketTransport + + result.append( + WebsocketTransport( + upgrader, + tls_client_config=tls_client_config, + tls_server_config=tls_server_config, + ) + ) + if enable_webrtc: + from libp2p.transport.webrtc.transport import WebRTCDirectTransport + + if key_pair is None: + logger.warning("WebRTC transport requires key_pair (private_key)") + else: + result.append(WebRTCDirectTransport(private_key=key_pair.private_key)) + if enable_tcp or not result: + result.append(TCP()) + + return result + + def new_swarm( key_pair: KeyPair | None = None, muxer_opt: TMuxerOptions | None = None, @@ -286,33 +428,77 @@ def new_swarm( peerstore_opt: IPeerStore | None = None, muxer_preference: Literal["YAMUX", "MPLEX"] | None = None, listen_addrs: Sequence[multiaddr.Multiaddr] | None = None, + # NEW: explicit transport list — highest priority + transports: Sequence[ITransport] | None = None, + # Backward-compat flags enable_quic: bool = False, enable_webrtc: bool = False, enable_autotls: bool = False, + # NEW: convenience flags for auto-building transports + enable_tcp: bool = True, + enable_websocket: bool = False, retry_config: RetryConfig | None = None, connection_config: ConnectionConfig | QUICTransportConfig | None = None, tls_client_config: ssl.SSLContext | None = None, tls_server_config: ssl.SSLContext | None = None, resource_manager: ResourceManager | None = None, - psk: str | None = None + psk: str | None = None, ) -> INetworkService: - logger.debug(f"new_swarm: enable_quic={enable_quic}, listen_addrs={listen_addrs}") """ - Create a swarm instance based on the parameters. + Create a swarm instance with multi-transport support. + + The swarm can listen on and dial over multiple transports simultaneously + (TCP, WebSocket, QUIC), mirroring go-libp2p's architecture. + + Transport selection priority (highest to lowest): + + 1. **Explicit ``transports`` list** — used as-is; all other transport + parameters are ignored. + 2. **``listen_addrs`` inspection** — each address is inspected to determine + which transports are needed (TCP, WebSocket, QUIC). All detected + transport types are created and registered. + 3. **``enable_*`` flags** — coarse-grained control when no addresses are + provided (``enable_quic``, ``enable_websocket``, ``enable_tcp``). + 4. **Default fallback** — TCP only. :param key_pair: optional choice of the ``KeyPair`` :param muxer_opt: optional choice of stream muxer :param sec_opt: optional choice of security upgrade :param peerstore_opt: optional peerstore - :param muxer_preference: optional explicit muxer preference - :param listen_addrs: optional list of multiaddrs to listen on - :param enable_quic: enable quic for transport - :param enable_autotls: enable autotls for security - :param quic_transport_opt: options for transport - :param resource_manager: optional resource manager for connection/stream limits - :type resource_manager: :class:`libp2p.rcmgr.ResourceManager` or None - :param psk: optional pre-shared key for PSK encryption in transport - :return: return a default swarm instance + :param muxer_preference: optional explicit muxer preference (``"YAMUX"`` or + ``"MPLEX"``) + :param listen_addrs: optional list of multiaddrs to listen on. **All** + addresses are inspected to determine which transports to create. + :param transports: explicit list of transport instances to register. When + provided, all ``enable_*`` flags and ``listen_addrs``-based detection + are bypassed. + :param enable_quic: include a QUIC transport when auto-building (deprecated; + prefer passing ``listen_addrs`` with QUIC addresses or ``transports``). + :param enable_autotls: enable AutoTLS for QUIC / WebSocket transports. + :param enable_tcp: include a TCP transport when auto-building (default True). + :param enable_websocket: include a WebSocket transport when auto-building. + :param retry_config: optional connection retry configuration. + :param connection_config: optional connection configuration. + :param tls_client_config: TLS client context for WebSocket transport. + :param tls_server_config: TLS server context for WebSocket transport. + :param resource_manager: optional resource manager for connection/stream limits. + :param psk: optional pre-shared key for PSK encryption. + :return: a Swarm instance implementing INetworkService. + + Examples:: + + # TCP only (default) + swarm = new_swarm() + + # TCP + WebSocket + QUIC auto-detected from listen_addrs + swarm = new_swarm(listen_addrs=[ + Multiaddr("/ip4/0.0.0.0/tcp/4001"), + Multiaddr("/ip4/0.0.0.0/tcp/4002/ws"), + Multiaddr("/ip4/0.0.0.0/udp/4003/quic-v1"), + ]) + + # Explicit transport list + swarm = new_swarm(transports=[TCP(), QUICTransport(kp.private_key)]) Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer due to its improved performance and features. @@ -329,65 +515,8 @@ def new_swarm( id_opt = generate_peer_id_from(key_pair) - transport: TCP | QUICTransport | ITransport quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None - if listen_addrs is None: - if enable_quic: - transport = QUICTransport( - key_pair.private_key, - config=quic_transport_opt, - enable_autotls=enable_autotls, - ) - else: - transport = TCP() - else: - # Use transport registry to select the appropriate transport - from libp2p.transport.transport_registry import create_transport_for_multiaddr - - # Create a temporary upgrader for transport selection - # We'll create the real upgrader later with the proper configuration - temp_upgrader = TransportUpgrader( - secure_transports_by_protocol={}, - muxer_transports_by_protocol={} - ) - - addr = listen_addrs[0] - logger.debug(f"new_swarm: Creating transport for address: {addr}") - transport_maybe = create_transport_for_multiaddr( - addr, - temp_upgrader, - private_key=key_pair.private_key, - config=quic_transport_opt, - enable_autotls=enable_autotls, - tls_client_config=tls_client_config, - tls_server_config=tls_server_config - ) - - if transport_maybe is None: - raise ValueError(f"Unsupported transport for listen_addrs: {listen_addrs}") - - transport = transport_maybe - logger.debug(f"new_swarm: Created transport: {type(transport)}") - - # If enable_quic is True but we didn't get a QUIC transport, force QUIC - if enable_quic and not isinstance(transport, QUICTransport): - logger.debug(f"new_swarm: Forcing QUIC transport (enable_quic=True but got {type(transport)})") - transport = QUICTransport( - key_pair.private_key, - config=quic_transport_opt, - enable_autotls=enable_autotls, - ) - - # If enable_webrtc is True, force WebRTC Direct transport - if enable_webrtc: - from libp2p.transport.webrtc.transport import WebRTCDirectTransport - - logger.debug("new_swarm: Creating WebRTC Direct transport") - transport = WebRTCDirectTransport(private_key=key_pair.private_key) - - logger.debug(f"new_swarm: Final transport type: {type(transport)}") - # Generate X25519 keypair for Noise noise_key_pair = create_new_x25519_key_pair() @@ -434,19 +563,88 @@ def new_swarm( muxer_transports_by_protocol=muxer_transports_by_protocol, ) + # Build the transport list using the helper. + # Pass QUICTransport as a module-level reference so tests can monkeypatch it. + transport_list = _build_transports_for_swarm( + key_pair=key_pair, + listen_addrs=listen_addrs, + transports=transports, + enable_quic=enable_quic, + enable_webrtc=enable_webrtc, + enable_tcp=enable_tcp, + enable_websocket=enable_websocket, + enable_autotls=enable_autotls, + upgrader=upgrader, + quic_config=quic_transport_opt, + tls_client_config=tls_client_config, + tls_server_config=tls_server_config, + quic_class=QUICTransport, # module-level ref; monkeypatching libp2p.QUICTransport works + ) + logger.debug( + "new_swarm: using transports: %s", + [type(t).__name__ for t in transport_list], + ) peerstore = peerstore_opt or PeerStore() # Store our key pair in peerstore peerstore.add_key_pair(id_opt, key_pair) + # ---- Detect shared-port TCP+WS addresses and build PortDemultiplexer ---- + # Mirrors go-libp2p: sharedTCP *tcpreuse.PortDemultiplexer is passed at transport + # construction time. When the listen_addrs include both a plain TCP and a + # WebSocket address on the SAME port, we create one PortDemultiplexer so they share + # the OS socket (EADDRINUSE prevention). + port_demux = None + port_demuxers: dict[tuple[str, int], PortDemultiplexer] = {} + if listen_addrs: + from libp2p.transport.cmux import PortDemultiplexer as _ConnMgr + + # Map (host, port) -> list of protocol-stacks that share that port. + port_protos: dict[tuple[str, int], list[set[str]]] = {} + for addr in listen_addrs: + try: + protos = {p.name for p in addr.protocols()} + if "tcp" not in protos: + continue # only TCP-based addrs can share + host_val = ( + addr.value_for_protocol("ip4") + or addr.value_for_protocol("ip6") + ) + port_val = addr.value_for_protocol("tcp") + if host_val is None or port_val is None: + continue + key = (str(host_val), int(port_val)) + port_protos.setdefault(key, []).append(protos) + except Exception: + continue + + # A PortDemultiplexer is needed when a port has BOTH plain-TCP and WS/WSS addrs. + for (host, port), proto_sets in port_protos.items(): + has_plain_tcp = any( + "ws" not in ps and "wss" not in ps for ps in proto_sets + ) + has_ws = any("ws" in ps or "wss" in ps for ps in proto_sets) + if has_plain_tcp and has_ws: + port_demuxers[(host, port)] = _ConnMgr(host, port) + logger.debug( + "new_swarm: created PortDemultiplexer for shared port %s:%d", + host, + port, + ) + + from libp2p.transport.manager import TransportManager + + transport_manager = TransportManager(port_demuxers=port_demuxers) + swarm = Swarm( id_opt, peerstore, upgrader, - transport, + transports=transport_list, # NEW: list instead of single transport retry_config=retry_config, connection_config=connection_config, - psk=psk + psk=psk, + transport_manager=transport_manager, ) # Set resource manager if provided @@ -490,21 +688,37 @@ def new_host( bootstrap_dns_max_retries: int = 3, connection_config: ConnectionConfig | None = None, announce_addrs: Sequence[multiaddr.Multiaddr] | None = None, + # NEW: explicit transport list — highest priority + transports: Sequence[ITransport] | None = None, + # NEW: convenience flags + enable_tcp: bool = True, + enable_websocket: bool = False ) -> IHost: """ Create a new libp2p host based on the given parameters. + The host can listen on and dial over multiple transports simultaneously + (TCP, WebSocket, QUIC), mirroring go-libp2p's architecture. + + Transport selection priority (highest to lowest): + + 1. ``transports`` — explicit list, used as-is. + 2. ``listen_addrs`` inspection — all addresses are inspected. + 3. ``enable_*`` flags. + 4. Default: TCP only. + :param key_pair: optional choice of the ``KeyPair`` :param muxer_opt: optional choice of stream muxer :param sec_opt: optional choice of security upgrade :param peerstore_opt: optional peerstore :param disc_opt: optional discovery :param muxer_preference: optional explicit muxer preference - :param listen_addrs: optional list of multiaddrs to listen on + :param listen_addrs: optional list of multiaddrs to listen on. **All** + addresses are inspected to determine which transports to create. :param enable_mDNS: whether to enable mDNS discovery :param bootstrap: optional list of bootstrap peer addresses as strings - :param enable_quic: optinal choice to use QUIC for transport - :param enable_autotls: optinal choice to use AutoTLS for security + :param enable_quic: optional choice to use QUIC for transport + :param enable_autotls: optional choice to use AutoTLS for security :param quic_transport_opt: optional configuration for quic transport :param tls_client_config: optional TLS client configuration for WebSocket transport :param tls_server_config: optional TLS server configuration for WebSocket transport @@ -516,11 +730,16 @@ def new_host( :param bootstrap_dns_max_retries: max DNS resolution retries with backoff :param connection_config: optional connection configuration for connection manager :param announce_addrs: if set, these replace listen addrs in get_addrs() + :param transports: explicit list of transport instances to register. When + provided, all ``enable_*`` flags and ``listen_addrs``-based detection + are bypassed. + :param enable_tcp: include a TCP transport when auto-building (default True). + :param enable_websocket: include a WebSocket transport when auto-building. :return: return a host instance """ if not enable_quic and quic_transport_opt is not None: - logger.warning(f"QUIC config provided but QUIC not enabled, ignoring QUIC config") + logger.warning("QUIC config provided but QUIC not enabled, ignoring QUIC config") # Enable automatic protection by default: if no resource manager is supplied, # create a default instance so connections/streams are guarded out of the box. @@ -555,7 +774,11 @@ def new_host( tls_client_config=tls_client_config, tls_server_config=tls_server_config, resource_manager=resource_manager, - psk=psk + psk=psk, + # NEW: forward multi-transport params + transports=transports, + enable_tcp=enable_tcp, + enable_websocket=enable_websocket, ) if disc_opt is not None: diff --git a/libp2p/abc.py b/libp2p/abc.py index dd22992ff..5876cc0c0 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -3043,6 +3043,70 @@ def create_listener(self, handler_function: THandler) -> IListener: """ + @abstractmethod + def can_dial(self, maddr: Multiaddr) -> bool: + """ + Return True if this transport can dial the given multiaddr. + + The TransportManager calls this method before attempting a dial + to route the connection to the correct transport. + + Parameters + ---------- + maddr : Multiaddr + The multiaddress to check. + + Returns + ------- + bool + True if this transport can dial maddr, False otherwise. + + Examples + -------- + - TCP returns True for ``/ip4/127.0.0.1/tcp/4001`` + - WebSocket returns True for ``/ip4/127.0.0.1/tcp/8080/ws`` + - QUIC returns True for ``/ip4/127.0.0.1/udp/4001/quic-v1`` + + """ + + @abstractmethod + def can_listen(self, maddr: Multiaddr) -> bool: + """ + Return True if this transport can listen on the given multiaddr. + + Often identical to :meth:`can_dial` but may differ — e.g. a + relay transport can dial outbound but cannot accept inbound + connections. + + Parameters + ---------- + maddr : Multiaddr + The multiaddress to check. + + Returns + ------- + bool + True if this transport can listen on maddr, False otherwise. + + """ + + @abstractmethod + def protocols(self) -> list[str]: + """ + Return the list of multiaddr protocol names this transport handles. + + Used by :class:`~libp2p.transport.manager.TransportManager` as a + fast pre-filter: if the multiaddr contains none of the listed + protocol names, ``can_dial`` / ``can_listen`` are not called. + + Returns + ------- + list[str] + Protocol name strings, e.g. ``["tcp"]``, ``["ws", "wss"]``, + or ``["quic", "quic-v1"]``. + + """ + # -------------------------- pubsub abc.py -------------------------- diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 85920a0b3..4fe33768c 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -324,32 +324,36 @@ def _detect_negotiate_timeout_from_transport(self) -> float | None: """ Detect negotiate timeout from transport configuration. - Checks if the network uses a QUIC transport and returns its - NEGOTIATE_TIMEOUT config value for coordination. + Iterates all registered transports in the swarm's + :attr:`~libp2p.network.swarm.Swarm.transport_manager` and returns the + ``NEGOTIATE_TIMEOUT`` value from the first transport that exposes it + (currently QUIC). - :return: Negotiate timeout from transport config, or None if not available + :return: Negotiate timeout in seconds from transport config, or None. """ try: - # Check if network has a transport attribute (Swarm pattern) - # Type ignore: transport exists on Swarm but not in INetworkService - if hasattr(self._network, "transport"): - transport = getattr(self._network, "transport", None) # type: ignore - # Check if it's a QUIC transport - if ( - transport is not None - and hasattr(transport, "_config") - and hasattr(transport._config, "NEGOTIATE_TIMEOUT") - ): - timeout = getattr(transport._config, "NEGOTIATE_TIMEOUT", None) # type: ignore - if timeout is not None: - logger.debug( - f"Detected negotiate timeout {timeout}s " - "from QUIC transport config" - ) - return float(timeout) + # Prefer the new multi-transport API (transport_manager). + if hasattr(self._network, "transport_manager"): + tm = getattr(self._network, "transport_manager", None) + if tm is not None: + for transport in tm.get_transports(): + if hasattr(transport, "_config") and hasattr( + transport._config, "NEGOTIATE_TIMEOUT" + ): + timeout = getattr( + transport._config, "NEGOTIATE_TIMEOUT", None + ) + if timeout is not None: + logger.debug( + "Detected negotiate timeout %ss from %s config", + timeout, + type(transport).__name__, + ) + return float(timeout) + except Exception as e: - # Silently fail - this is optional coordination - logger.debug(f"Could not detect negotiate timeout from transport: {e}") + # Silently fail — this is optional coordination. + logger.debug("Could not detect negotiate timeout from transport: %s", e) return None diff --git a/libp2p/io/peekable_stream.py b/libp2p/io/peekable_stream.py new file mode 100644 index 000000000..ab6acd66a --- /dev/null +++ b/libp2p/io/peekable_stream.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from collections.abc import Callable +import socket + +import trio + + +class PeekableStream(trio.abc.Stream): + """ + Wraps a :class:`trio.abc.Stream` and allows peeking/buffering of the first + few bytes. + + When `receive_some` is called, it returns buffered data before reading from the + underlying stream. This is useful for connection multiplexing (cmux) where you + need to read bytes to determine the protocol without permanently consuming them. + """ + + stream: trio.abc.Stream + buffer: bytearray + close_callback: Callable[[], None] | None + + def __init__( + self, + stream: trio.abc.Stream, + initial_buffer: bytes = b"", + close_callback: Callable[[], None] | None = None, + ) -> None: + self.stream = stream + self.buffer = bytearray(initial_buffer) + self.close_callback = close_callback + + @property + def socket(self) -> socket.socket | None: + """ + Pass-through to underlying socket for address retrieval. + + This property is required by :class:`~libp2p.io.trio.TrioTCPStream` to retrieve + the remote peer's IP address. + """ + if hasattr(self.stream, "socket"): + return getattr(self.stream, "socket") + return None + + async def receive_some(self, max_bytes: int | None = None) -> bytes: + if self.buffer: + if max_bytes is None: + max_bytes = len(self.buffer) + data = bytes(self.buffer[:max_bytes]) + self.buffer = self.buffer[max_bytes:] + return data + return await self.stream.receive_some(max_bytes) + + async def send_all(self, data: bytes | memoryview) -> None: + await self.stream.send_all(data) + + async def wait_send_all_might_not_block(self) -> None: + await self.stream.wait_send_all_might_not_block() + + async def aclose(self) -> None: + try: + await self.stream.aclose() + finally: + if self.close_callback is not None: + self.close_callback() diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index deb808ad2..f28bc0dab 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -57,8 +57,8 @@ OpenConnectionError, SecurityUpgradeFailure, ) +from libp2p.transport.manager import TransportManager from libp2p.transport.quic.config import QUICTransportConfig -from libp2p.transport.quic.connection import QUICConnection from libp2p.transport.upgrader import ( TransportUpgrader, ) @@ -81,6 +81,9 @@ logger = logging.getLogger(__name__) +_HAPPY_EYEBALLS_DELAY = 0.250 +_MAX_PARALLEL_DIALS = 8 + def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn: async def stream_handler(stream: INetStream) -> None: @@ -93,7 +96,7 @@ class Swarm(Service, INetworkService): self_id: ID peerstore: IPeerStore upgrader: TransportUpgrader - transport: ITransport + transport_manager: TransportManager connections: dict[ID, list[INetConn]] listeners: dict[str, IListener] common_stream_handler: StreamHandlerFn @@ -123,17 +126,48 @@ def __init__( peer_id: ID, peerstore: IPeerStore, upgrader: TransportUpgrader, - transport: ITransport, + transports: list[ITransport] | None = None, retry_config: RetryConfig | None = None, connection_config: ConnectionConfig | QUICTransportConfig | None = None, psk: str | None = None, + *, + # Optional pre-built TransportManager (e.g. with a PortDemultiplexer attached + # for shared-port TCP+WS demultiplexing). When supplied, it is used as-is + # and transports are appended to it; when omitted a fresh one is created. + transport_manager: TransportManager | None = None, + **kwargs: Any, ): + if kwargs.pop("transport", None): + raise TypeError( + "Swarm() no longer accepts 'transport='. Use transports=[...] instead." + ) + if kwargs: + keys = list(kwargs.keys()) + raise TypeError( + f"Swarm.__init__() got unexpected keyword arguments: {keys}" + ) + self.self_id = peer_id self.peerstore = peerstore self.upgrader = upgrader - self.transport = transport self.psk = psk + # Use the pre-built TransportManager when provided (e.g. from new_swarm() + # which wires in a PortDemultiplexer for shared-port TCP+WS). Otherwise create + # a fresh one (preserves backward compatibility for direct Swarm() callers). + self.transport_manager = ( + transport_manager if transport_manager is not None else TransportManager() + ) + + # Backward-compat: callers that still pass a single ITransport + # positionally (e.g. Swarm(peer_id, ps, upgrader, tcp_transport) or + # Swarm(peer_id, ps, upgrader, Mock())) will land in `transports`. + # Detect this by checking whether `transports` is actually a list. + if isinstance(transports, list): + self.transport_manager.add_transports(transports) + elif transports is not None: + self.transport_manager.add_transport(transports) + # Enhanced: Initialize retry and connection configuration self.retry_config = retry_config or RetryConfig() self.connection_config = connection_config or ConnectionConfig() @@ -212,15 +246,12 @@ async def run(self) -> None: # internal nurseries and no longer use this one. self.background_nursery = nursery - # Set background nursery BEFORE setting the event - # This ensures transports have the nursery when they check - if hasattr(self.transport, "set_swarm"): - self.transport.set_background_nursery(nursery) # type: ignore[attr-defined] - self.transport.set_swarm(self) # type: ignore[attr-defined] - elif hasattr(self.transport, "set_background_nursery"): - # WebSocket transport also needs background nursery - # for connection management - self.transport.set_background_nursery(nursery) # type: ignore[attr-defined] + # Wire the background nursery and swarm reference to ALL + # registered transports that need them (QUIC, WebSocket, etc.). + # This replaces the old isinstance(self.transport, QUICTransport) + # special-cases — the TransportManager delegates generically. + self.transport_manager.set_background_nursery(nursery) + self.transport_manager.set_swarm(self) # Signal that the background nursery is available. self.event_background_nursery_created.set() @@ -531,24 +562,38 @@ async def dial_peer(self, peer_id: ID) -> list[INetConn]: connections = [] exceptions: list[SwarmException] = [] - # Try all allowed addresses with retry logic - for multiaddr in allowed_addrs: - try: - connection = await self._dial_with_retry(multiaddr, peer_id) - connections.append(connection) - - # Limit number of connections per peer - if len(connections) >= self.connection_config.max_connections_per_peer: - break + # Try allowed addresses using Happy Eyeballs algorithm + with trio.CancelScope() as cancel_scope: + async with trio.open_nursery() as nursery: + for multiaddr in allowed_addrs[:_MAX_PARALLEL_DIALS]: + failed_event = trio.Event() - except SwarmException as e: - exceptions.append(e) - logger.debug( - "encountered swarm exception when trying to connect to %s, " - "trying next address...", - multiaddr, - exc_info=e, - ) + async def dial_task( + addr: Multiaddr = multiaddr, ev: trio.Event = failed_event + ) -> None: + try: + connection = await self._dial_with_retry(addr, peer_id) + connections.append(connection) + # Limit number of connections per peer + if ( + len(connections) + >= self.connection_config.max_connections_per_peer + ): + cancel_scope.cancel() + except SwarmException as e: + exceptions.append(e) + logger.debug( + "encountered exception when trying to connect to %s", + addr, + exc_info=e, + ) + ev.set() + + nursery.start_soon(dial_task) + + # Start next dial immediately if this one fails, or after 250ms + with trio.move_on_after(_HAPPY_EYEBALLS_DELAY): + await failed_event.wait() if not connections: # Tried all addresses, raising exception. @@ -618,13 +663,29 @@ def _calculate_backoff_delay(self, attempt: int) -> float: async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetConn: """ - Enhanced: Single attempt to dial an address (extracted from original dial_addr). + Single attempt to dial an address. + + Routes the dial to the correct transport via :attr:`transport_manager` + rather than using a fixed ``self.transport``. Transports that return + a pre-multiplexed connection (e.g. QUIC) are detected generically via + the :class:`~libp2p.abc.IMuxedConn` interface and skip the security + + muxer upgrade pipeline. :param addr: the address we want to connect with :param peer_id: the peer we want to connect to :raises SwarmException: raised when an error occurs :return: network connection """ + # For the dial to be successful, there needs to be a registered transport + # that can dial the provided `maddr` + transport = self.transport_manager.transport_for_dialing(addr) + if transport is None: + raise SwarmException( + f"No registered transport can dial {addr}. " + f"Registered transports: " + f"{[type(t).__name__ for t in self.transport_manager.get_transports()]}" + ) + # Optional pre-upgrade admission on outbound using endpoint from multiaddr pre_scope = None if self._resource_manager is not None: @@ -639,14 +700,22 @@ async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetC raise pre_scope = None - # Dial peer (connection to peer does not yet exist) - # Transport dials peer (gets back a raw conn) + # Dial peer via the selected transport (returns a raw connection). raw_conn = None try: - addr = Multiaddr(f"{addr}/p2p/{peer_id}") - raw_conn = await self.transport.dial(addr) + # Ensure the multiaddr has the target peer ID, but don't append if + # already present + try: + existing_p2p = addr.value_for_protocol("p2p") + except Exception: + existing_p2p = None - # Enable PNET if psk is provvided + if not existing_p2p: + addr = Multiaddr(f"{addr}/p2p/{peer_id}") + + raw_conn = await transport.dial(addr) + + # Enable PNET if psk is provided if self.psk is not None: raw_conn = new_protected_conn(raw_conn, self.psk) except OpenConnectionError as error: @@ -669,12 +738,14 @@ async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetC pass raise SwarmException(f"Unexpected error dialing peer {peer_id}") from e - if getattr(self.transport, "provides_native_muxing", False) and isinstance( - raw_conn, IMuxedConn - ): + # Detect pre-multiplexed connections generically via the IMuxedConn + # interface instead of isinstance(transport, QUICTransport). + # This works for QUIC today and any future transport with built-in + # multiplexing (e.g. WebTransport). + if isinstance(raw_conn, IMuxedConn): logger.info( - "Skipping upgrade for native-mux transport " - "(connection already multiplexed)" + "Skipping upgrade: connection is already multiplexed (transport=%s)", + type(transport).__name__, ) try: swarm_conn = await self.add_conn(raw_conn, direction="outbound") @@ -693,6 +764,8 @@ async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetC raise logger.debug("dialed peer %s over base transport", peer_id) + if not isinstance(raw_conn, IRawConnection): + raise TypeError("Expected an IRawConnection to upgrade") try: swarm_conn = await self.upgrade_outbound_raw_conn( raw_conn, peer_id, pre_scope @@ -957,10 +1030,7 @@ async def _open_stream_on_connection( ) -> INetStream: """Try to open a stream on *connection*, falling back to alternatives.""" try: - if ( - getattr(self.transport, "provides_native_muxing", False) - and connection is not None - ): + if connection is not None: conn = cast("SwarmConn", connection) stream = await conn.new_stream() else: @@ -1116,16 +1186,22 @@ async def listen(self, *multiaddrs: Multiaddr) -> bool: :param multiaddrs: one or many multiaddrs to start listening on :return: true if at least one success - For each multiaddr + For each multiaddr: - - Check if a listener for multiaddr exists already - - If listener already exists, continue - - Otherwise: + - Route to the transport that can handle the address via + :attr:`transport_manager`. + - Check if a listener for this multiaddr already exists. + - Create a listener on the matched transport and start it. + - Map multiaddr string to the listener for future reference. - - Capture multiaddr in conn handler - - Have conn handler delegate to stream handler - - Call listener listen with the multiaddr - - Map multiaddr to listener + When a :class:`~libp2p.transport.cmux.PortDemultiplexer` is attached to the + :attr:`transport_manager`, all TCP-based transports (TCP and WebSocket) + register :class:`~libp2p.transport.cmux.DemultiplexedListener` objects + instead of opening their own sockets. After every address has been + processed, a single ``port_demux.listen()`` call binds the shared socket + and starts the 3-byte demultiplexing loop — mirroring the go-libp2p + pattern where the physical listener is created once by ``PortDemultiplexer`` and + each transport only receives a virtual channel. """ logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}") # Wait until the background nursery is available so that transports @@ -1134,100 +1210,165 @@ async def listen(self, *multiaddrs: Multiaddr) -> bool: logger.debug("Starting to listen") await self.event_background_nursery_created.wait() - success_count = 0 - for maddr in multiaddrs: - logger.debug(f"Swarm.listen processing multiaddr: {maddr}") + # ── 1. Start PortDemultiplexer FIRST so the OS socket is bound ────────── + port_demuxers = getattr(self.transport_manager, "_port_demuxers", {}) + if not port_demuxers: + port_demux = getattr(self.transport_manager, "_port_demux", None) + if port_demux: + port_demuxers = {(port_demux.host, port_demux.port): port_demux} + + for (host, port), port_demux in port_demuxers.items(): + tcp_maddr = next( + ( + m + for m in multiaddrs + if "tcp" in {p.name for p in m.protocols()} + and "ws" not in {p.name for p in m.protocols()} + and "wss" not in {p.name for p in m.protocols()} + and str(m.value_for_protocol("tcp")) == str(port) + ), + None, + ) + if tcp_maddr is not None: + try: + port_demux.background_nursery = self.background_nursery + await port_demux.listen(tcp_maddr) + except Exception as exc: + logger.error( + "PortDemultiplexer.listen failed for %s:%s: %s", host, port, exc + ) + return False + + # ── 2. Start all listeners in parallel ────────────────────────────────── + results: list[tuple[Multiaddr, bool]] = [] + results_lock = trio.Lock() + + async def _start_one(maddr: Multiaddr) -> None: if str(maddr) in self.listeners: - logger.debug(f"Swarm.listen: listener already exists for {maddr}") - success_count += 1 - continue + async with results_lock: + results.append((maddr, True)) + return + + transport = self.transport_manager.transport_for_listening(maddr) + if transport is None: + logger.warning( + "Swarm.listen: no transport for %s (registered: %s). Skipping.", + maddr, + [type(t).__name__ for t in self.transport_manager.get_transports()], + ) + async with results_lock: + results.append((maddr, False)) + return async def conn_handler( - read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr + read_write_closer: ReadWriteCloser, _maddr: Multiaddr = maddr ) -> None: - # Enforce connection gate on inbound connections - # Build multiaddr from remote address tuple - logger.debug( - f"[conn_handler] Handling inbound connection on listener {maddr}" - ) - remote_maddr = self._build_remote_multiaddr(read_write_closer) - logger.debug(f"[conn_handler] Built remote_maddr: {remote_maddr}") - - if remote_maddr is not None: - if not await self.connection_gate.is_allowed(remote_maddr): - logger.debug( - "Inbound connection from %s denied by connection gate", - remote_maddr, - ) - try: - await read_write_closer.close() - except Exception: - pass - return - - # No need to upgrade native-mux connections (QUIC, WebRTC) - if getattr(self.transport, "provides_native_muxing", False): - try: - muxed_conn = cast(IMuxedConn, read_write_closer) - await self.add_conn(muxed_conn, direction="inbound") - peer_id = muxed_conn.peer_id - logger.debug( - "successfully opened native-mux connection to peer %s", - peer_id, - ) - # NOTE: This is a intentional barrier to prevent from the - # handler exiting and closing the connection. - await self.manager.wait_finished() - except Exception: - await read_write_closer.close() - return - - # For non-QUIC connections, wrap in try/except to ensure cleanup - raw_conn = None - try: - raw_conn = RawConnection(read_write_closer, False) - await self.upgrade_inbound_raw_conn(raw_conn, maddr) - # NOTE: This is a intentional barrier to prevent from the handler - # exiting and closing the connection. - await self.manager.wait_finished() - except Exception as e: - logger.debug(f"Error handling incoming connection: {e}") - # Ensure the underlying connection is closed on any error - try: - if raw_conn is not None: - await raw_conn.close() - else: - # If raw_conn wasn't created, - # close the underlying connection - await read_write_closer.close() - except Exception: - pass - # Re-raise to let the listener handle it appropriately - # (swallow the exception to prevent propagation) + await self._handle_inbound_connection(read_write_closer, _maddr) try: - # Success - logger.debug(f"Swarm.listen: creating listener for {maddr}") - listener = self.transport.create_listener(conn_handler) - logger.debug(f"Swarm.listen: listener created for {maddr}") + listener = self.transport_manager.add_listen_addr(maddr, conn_handler) + if listener is None: + async with results_lock: + results.append((maddr, False)) + return self.listeners[str(maddr)] = listener + if self.background_nursery is None: raise SwarmException("swarm instance hasn't been run") - logger.debug(f"Swarm.listen: calling listener.listen for {maddr}") - await listener.listen(maddr) - logger.debug(f"Swarm.listen: listener.listen completed for {maddr}") - # Call notifiers since event occurred + setattr(listener, "background_nursery", self.background_nursery) + await listener.listen(maddr) await self.notify_listen(maddr) - - success_count += 1 logger.debug("successfully started listening on: %s", maddr) - except OSError: - # Failed. Continue looping. - logger.debug("fail to listen on: %s", maddr) + async with results_lock: + results.append((maddr, True)) + except (OSError, OpenConnectionError, SwarmException) as exc: + logger.debug("fail to listen on %s: %s", maddr, exc) + self.listeners.pop(str(maddr), None) + async with results_lock: + results.append((maddr, False)) + + async with trio.open_nursery() as nursery: + for maddr in multiaddrs: + nursery.start_soon(_start_one, maddr) + + return any(ok for _, ok in results) + + async def _handle_inbound_connection( + self, read_write_closer: ReadWriteCloser, maddr: Multiaddr + ) -> None: + """ + Unified inbound connection handler for all transports. - # Return true if at least one address succeeded - return success_count > 0 + Replaces the inline ``conn_handler`` closures that previously had + separate code paths for QUIC vs. non-QUIC connections. Transport + detection is now done via the :class:`~libp2p.abc.IMuxedConn` + interface rather than an ``isinstance(self.transport, QUICTransport)`` + class check, so any future transport with built-in multiplexing + (e.g. WebTransport) will be handled automatically. + + :param read_write_closer: The raw stream from the listener. + :param maddr: The multiaddr of the listener that accepted this connection. + """ + logger.debug( + "[_handle_inbound_connection] Handling inbound connection on listener %s", + maddr, + ) + + # Enforce connection gate on inbound connections. + remote_maddr = self._build_remote_multiaddr(read_write_closer) + logger.debug( + "[_handle_inbound_connection] Built remote_maddr: %s", remote_maddr + ) + + if remote_maddr is not None: + if not await self.connection_gate.is_allowed(remote_maddr): + logger.debug( + "Inbound connection from %s denied by connection gate", + remote_maddr, + ) + try: + await read_write_closer.close() + except Exception: + pass + return + + # If the incoming connection is already fully multiplexed (e.g. QUIC, + # WebTransport), skip the security + muxer upgrade entirely. + # Detection is via the IMuxedConn interface, not a class check. + if isinstance(read_write_closer, IMuxedConn): + try: + muxed_conn = cast(IMuxedConn, read_write_closer) + await self.add_conn(muxed_conn, direction="inbound") + peer_id = getattr(muxed_conn, "peer_id", None) + logger.debug( + "successfully opened pre-multiplexed inbound connection (peer=%s)", + peer_id, + ) + # Intentional barrier: keep handler alive so the connection + # stays open for the duration of the swarm's lifetime. + await self.manager.wait_finished() + except Exception: + await read_write_closer.close() + return + + # Standard upgrade path (TCP, WebSocket): wrap in RawConnection then + # run the security + muxer upgrade pipeline. + raw_conn = None + try: + raw_conn = RawConnection(read_write_closer, False) + await self.upgrade_inbound_raw_conn(raw_conn, maddr) + # Intentional barrier: keep handler alive. + await self.manager.wait_finished() + except Exception as e: + logger.debug("Error handling incoming connection: %s", e) + try: + if raw_conn is not None: + await raw_conn.close() + else: + await read_write_closer.close() + except Exception: + pass async def upgrade_inbound_raw_conn( self, raw_conn: IRawConnection, maddr: Multiaddr @@ -1315,12 +1456,17 @@ async def _cleanup_inbound_upgrade() -> None: with trio.fail_after(inbound_timeout): try: secured_conn = await self.upgrader.upgrade_security(raw_conn, False) - except SecurityUpgradeFailure as error: - logger.error("failed to upgrade security for peer at %s", maddr) + except SecurityUpgradeFailure as exc: + logger.error( + "failed to upgrade security for peer at %s: %s", + maddr, + exc, + exc_info=True, + ) await _cleanup_inbound_upgrade() raise SwarmException( f"failed to upgrade security for peer at {maddr}" - ) from error + ) from exc peer_id = secured_conn.get_remote_peer() try: @@ -1434,13 +1580,8 @@ async def close(self) -> None: ) self.listeners.clear() - # Close the transport if it exists and has a close method - if hasattr(self, "transport") and self.transport is not None: - # Check if transport has close method before calling it - if hasattr(self.transport, "close"): - await self.transport.close() # type: ignore - # Ignoring the type above since `transport` may not have a close method - # and we have already checked it with hasattr + # Close all transports + await self.transport_manager.close_all() logger.debug("swarm successfully closed") @@ -1565,10 +1706,13 @@ async def add_conn( # Verify connection is fully established before proceeding. # For QUIC connections, wait for the connected event. # For other muxers (like Yamux/Mplex), check the is_established property. - # For QUIC connections, also verify connection is established - if isinstance(muxed_conn, QUICConnection): - if not muxed_conn.is_established: - await muxed_conn._connected_event.wait() + # For some muxers (e.g. QUIC), wait for the connected event. + # For others (like Yamux/Mplex), check the is_established property. + if hasattr(muxed_conn, "_connected_event") and hasattr( + muxed_conn, "is_established" + ): + if not getattr(muxed_conn, "is_established"): + await getattr(muxed_conn, "_connected_event").wait() elif not muxed_conn.is_established: logger.warning( f"Swarm::add_conn | muxer event_started set but " diff --git a/libp2p/relay/circuit_v2/transport.py b/libp2p/relay/circuit_v2/transport.py index 7fab98d1d..37b8a3324 100644 --- a/libp2p/relay/circuit_v2/transport.py +++ b/libp2p/relay/circuit_v2/transport.py @@ -200,6 +200,18 @@ def __init__( if config.enable_dht_discovery: self.dht = KadDHT(host, DHTMode.CLIENT) + def can_dial(self, maddr: multiaddr.Multiaddr) -> bool: + """Return True if this transport can dial the given multiaddr.""" + return any(p.code == P_P2P_CIRCUIT for p in maddr.protocols()) + + def can_listen(self, maddr: multiaddr.Multiaddr) -> bool: + """Return True if this transport can listen on the given multiaddr.""" + return any(p.code == P_P2P_CIRCUIT for p in maddr.protocols()) + + def protocols(self) -> list[str]: + """Return the list of protocol names handled by this transport.""" + return ["p2p-circuit"] + async def dial( # type: ignore[override] self, maddr: multiaddr.Multiaddr, diff --git a/libp2p/transport/__init__.py b/libp2p/transport/__init__.py index 4a54b7353..d83163224 100644 --- a/libp2p/transport/__init__.py +++ b/libp2p/transport/__init__.py @@ -1,58 +1,21 @@ from typing import Any - -from .tcp.tcp import TCP -from .websocket.transport import WebsocketTransport -from .transport_registry import ( - TransportRegistry, - create_transport_for_multiaddr, - get_transport_registry, - register_transport, - get_supported_transport_protocols, -) +from .manager import TransportManager from .upgrader import TransportUpgrader +from .cmux import PortDemultiplexer, DemultiplexedConnType, DemultiplexedListener, identify_conn_type from libp2p.abc import ITransport +from .tcp.tcp import TCP +from .websocket.transport import WebsocketTransport -def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any) -> ITransport: - """ - Convenience function to create a transport instance. - :param protocol: The transport protocol ("tcp", "ws", "wss", or custom) - :param upgrader: Optional transport upgrader (required for WebSocket) - :param kwargs: Additional arguments for transport construction (e.g., tls_client_config, tls_server_config) - :return: Transport instance - """ - # First check if it's a built-in protocol - if protocol in ["ws", "wss"]: - if upgrader is None: - raise ValueError(f"WebSocket transport requires an upgrader") - from libp2p.transport.websocket.transport import WebsocketConfig, WebsocketTransport - config = WebsocketConfig( - tls_client_config=kwargs.get("tls_client_config"), - tls_server_config=kwargs.get("tls_server_config"), - handshake_timeout=kwargs.get("handshake_timeout", 15.0) - ) - return WebsocketTransport(upgrader, config) - elif protocol == "tcp": - return TCP() - else: - # Check if it's a custom registered transport - registry = get_transport_registry() - transport_class = registry.get_transport(protocol) - if transport_class: - transport = registry.create_transport(protocol, upgrader, **kwargs) - if transport is None: - raise ValueError(f"Failed to create transport for protocol: {protocol}") - return transport - else: - raise ValueError(f"Unsupported transport protocol: {protocol}") __all__ = [ + # Transports "TCP", + "TransportManager", "WebsocketTransport", - "TransportRegistry", - "create_transport_for_multiaddr", - "create_transport", - "get_transport_registry", - "register_transport", - "get_supported_transport_protocols", + # Port sharing / cmux + "PortDemultiplexer", + "DemultiplexedConnType", + "DemultiplexedListener", + "identify_conn_type", ] diff --git a/libp2p/transport/cmux.py b/libp2p/transport/cmux.py new file mode 100644 index 000000000..da12b77fb --- /dev/null +++ b/libp2p/transport/cmux.py @@ -0,0 +1,518 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator, Awaitable, Callable +from enum import IntEnum +import logging +from typing import TYPE_CHECKING, Any + +from multiaddr import Multiaddr +import trio + +from libp2p.abc import IListener +from libp2p.io.peekable_stream import PeekableStream +from libp2p.transport.exceptions import OpenConnectionError + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +# Matches go-libp2p's tcpreuse.acceptQueueSize = 64. +# Limits how many classified connections can wait before being consumed by their +# per-type listener. If the queue is full the connection is dropped after +# ACCEPT_TIMEOUT seconds (mirrors go-libp2p's acceptTimeout = 30s). +ACCEPT_QUEUE_SIZE = 64 +ACCEPT_TIMEOUT = 30.0 + +# go-libp2p peeks exactly 3 bytes to classify connections. +_PEEK_SIZE = 3 +# Timeout for reading the classification bytes from a new connection. +_IDENTIFY_TIMEOUT = 5.0 + + +class DemultiplexedConnType(IntEnum): + """ + Mirrors ``tcpreuse.DemultiplexedConnType`` from go-libp2p. + + Used to route an incoming TCP connection to the correct transport listener + after peeking at the first 3 bytes of the stream. + """ + + UNKNOWN = 0 + #: Raw libp2p TCP connections start with the multistream-select header + #: ``\\x13/multistream/1.0.0\\n``. The first 3 bytes are ``\\x13/m``. + MULTISTREAM_SELECT = 1 + #: WebSocket (plain) connections start with an HTTP upgrade request. + #: Matches GET / HEAD / POST / PUT / DELETE / CONNECT / OPTIONS / + #: TRACE / PATCH and also PRI (HTTP/2 cleartext preface). + HTTP = 2 + #: WebSocket-Secure (WSS) connections are wrapped in TLS. + #: Matches TLS 1.0–1.3 ``ClientHello`` record headers. + TLS = 3 + + +# --------------------------------------------------------------------------- +# 3-byte classifier – mirrors go-libp2p identifyConnType / IsMultistreamSelect +# --------------------------------------------------------------------------- + +_HTTP_PREFIXES: frozenset[bytes] = frozenset( + [ + b"GET", + b"HEA", + b"POS", + b"PUT", + b"DEL", + b"CON", + b"OPT", + b"TRA", + b"PAT", + # HTTP/2 cleartext preface "PRI * HTTP/2.0\r\n…" + b"PRI", + ] +) + +_TLS_PREFIXES: frozenset[bytes] = frozenset( + [ + b"\x16\x03\x01", # TLS 1.0 / 1.2 ClientHello record + b"\x16\x03\x02", # TLS 1.1 + b"\x16\x03\x03", # TLS 1.3 + ] +) + + +def identify_conn_type(prefix: bytes) -> DemultiplexedConnType: + """ + Classify a connection from its first 3 bytes. + + Directly mirrors ``tcpreuse.identifyConnType`` / ``IsMultistreamSelect`` / + ``IsHTTP`` / ``IsTLS`` from go-libp2p. + + :param prefix: Exactly 3 bytes read from the start of the stream. + :returns: The detected :class:`DemultiplexedConnType`. + """ + if len(prefix) < 3: + return DemultiplexedConnType.UNKNOWN + if prefix == b"\x13/m": + return DemultiplexedConnType.MULTISTREAM_SELECT + if prefix in _TLS_PREFIXES: + return DemultiplexedConnType.TLS + if prefix in _HTTP_PREFIXES: + return DemultiplexedConnType.HTTP + return DemultiplexedConnType.UNKNOWN + + +# --------------------------------------------------------------------------- +# DemultiplexedListener – per-connection-type listener (channel consumer) +# --------------------------------------------------------------------------- + + +class DemultiplexedListener(IListener): + """ + A listener that receives pre-classified connections from :class:`PortDemultiplexer`. + + Each instance holds one end of a :func:`trio.open_memory_channel` and + exposes an ``accept()`` async iterator consumed by the transport that + registered for this connection type. + + When a *conn_handler* is provided, :meth:`listen` spawns a background Trio + task that drains the channel and calls the handler for every incoming + connection. This is the Trio equivalent of the goroutine that reads from a + buffered ``chan *connWithScope`` in go-libp2p's ``demultiplexedListener``. + + Mirrors ``tcpreuse.demultiplexedListener`` from go-libp2p. + """ + + def __init__( + self, + conn_type: DemultiplexedConnType, + recv_channel: trio.MemoryReceiveChannel[PeekableStream], + listen_maddr: Multiaddr, + close_callback: Callable[[DemultiplexedConnType], None], + conn_handler: Callable[[trio.abc.Stream], Awaitable[None]] | None = None, + ) -> None: + self.conn_type = conn_type + self._recv = recv_channel + self._listen_maddr = listen_maddr + self._close_callback = close_callback + self._closed = False + # Optional handler called per-connection by the drain task. + self._conn_handler = conn_handler + self._nursery: trio.Nursery | None = None + self._cancel_scope: trio.CancelScope | None = None + + async def connections(self) -> AsyncIterator[PeekableStream]: + """Async-iterate over pre-classified connections.""" + async with self._recv: + async for conn in self._recv: + yield conn + + def get_addrs(self) -> tuple[Multiaddr, ...]: + return (self._listen_maddr,) if self._listen_maddr else () + + async def listen(self, maddr: Multiaddr) -> None: + """ + Start the connection-drain loop (no-op when no handler is registered). + + When a *conn_handler* was supplied, this spawns a Trio system task that + continuously reads :class:`PeekableStream` objects from the channel and + calls the handler for each one. The actual TCP socket is created by + :class:`PortDemultiplexer` — this method only starts the consumer side. + """ + if self._conn_handler is None: + # No handler: caller will consume via .connections() directly. + return + + handler = self._conn_handler # capture for closure + + async def _drain() -> None: + try: + with trio.CancelScope(**{}) as cancel_scope: + self._cancel_scope = cancel_scope + async with trio.open_nursery() as nursery: + self._nursery = nursery + async with self._recv: + async for stream in self._recv: + nursery.start_soon(self._run_handler, handler, stream) + except (trio.Cancelled, KeyboardInterrupt): + raise + except Exception: + pass + + nursery = getattr(self, "background_nursery", None) + if nursery is not None: + nursery.start_soon(_drain) + else: + raise RuntimeError( + "DemultiplexedListener.listen requires a background_nursery to be set." + ) + + async def _run_handler( + self, + handler: Callable[[trio.abc.Stream], Awaitable[None]], + stream: PeekableStream, + ) -> None: + try: + await handler(stream) + except Exception as exc: + logger.debug("DemultiplexedListener: handler error: %s", exc) + + async def close(self) -> None: + """Close this demultiplexed listener and remove it from PortDemultiplexer.""" + if self._closed: + return + self._closed = True + self._recv.close() + if self._cancel_scope is not None: + self._cancel_scope.cancel() + if callable(self._close_callback): + self._close_callback(self.conn_type) + + +# --------------------------------------------------------------------------- +# PortDemultiplexer – shared TCP listener + demultiplexer +# --------------------------------------------------------------------------- + + +class PortDemultiplexer(IListener): + """ + Enables sharing a single TCP port between multiple transports. + + Each transport calls :meth:`demultiplexed_listen` with its expected + :class:`DemultiplexedConnType`. Internally, one OS-level TCP socket is + opened and a background task classifies every new connection by peeking at + its first 3 bytes, then routes it into the matching + :class:`DemultiplexedListener`'s channel. + + Mirrors ``tcpreuse.PortDemultiplexer`` from go-libp2p (``p2p/transport/tcpreuse``). + + Usage:: + + port_demux = PortDemultiplexer("0.0.0.0", 4001) + + # TCP transport registers for raw libp2p connections + tcp_listener = port_demux.demultiplexed_listen( + maddr, DemultiplexedConnType.MULTISTREAM_SELECT + ) + + # WebSocket transport registers for HTTP-upgrade connections + ws_listener = port_demux.demultiplexed_listen( + maddr, DemultiplexedConnType.HTTP + ) + + await port_demux.listen(maddr) # binds the socket and starts routing + + """ + + host: str + port: int + listen_maddr: Multiaddr | None + + def __init__(self, host: str, port: int) -> None: + self.host = host + self.port = port + self.listen_maddr = None + + # Map from conn type → (send_channel, recv_channel) pair. + # Populated by demultiplexed_listen() before listen() is called. + self._send_channels: dict[ + DemultiplexedConnType, trio.MemorySendChannel[PeekableStream] + ] = {} + self._listeners: dict[DemultiplexedConnType, DemultiplexedListener] = {} + + self._nursery: trio.Nursery | None = None + self._started: trio.Event = trio.Event() + self._stopped: trio.Event = trio.Event() + self._closed: bool = False + self._start_error: BaseException | None = None + + def has_listener(self, conn_type: DemultiplexedConnType) -> bool: + """Return True if a listener is registered for this conn_type.""" + return conn_type in self._send_channels + + def get_listener( + self, conn_type: DemultiplexedConnType + ) -> DemultiplexedListener | None: + """Return the listener registered for this conn_type, or None.""" + return self._listeners.get(conn_type) + + # ------------------------------------------------------------------ + # Public API – mirrors go-libp2p PortDemultiplexer.DemultiplexedListen() + # ------------------------------------------------------------------ + + def demultiplexed_listen( + self, + maddr: Multiaddr, + conn_type: DemultiplexedConnType, + conn_handler: Callable[[trio.abc.Stream], Awaitable[None]] | None = None, + ) -> DemultiplexedListener: + """ + Register a listener for *conn_type* connections on this shared port. + + Must be called **before** :meth:`listen`. Raises :exc:`ValueError` + if a listener for the same *conn_type* is already registered (mirrors + ``tcpreuse.ErrListenerExists``). + + :param maddr: The multiaddress that will be bound (used for + :meth:`DemultiplexedListener.get_addrs`). + :param conn_type: The connection type this listener should receive. + :param conn_handler: Optional async callable called per connection. + When provided, :meth:`DemultiplexedListener.listen` starts a + background drain task that calls this handler for each classified + connection. + :returns: A :class:`DemultiplexedListener` whose + :meth:`~DemultiplexedListener.connections` yields classified + connections. + :raises ValueError: If *conn_type* already has a registered listener. + """ + if conn_type == DemultiplexedConnType.UNKNOWN: + raise ValueError("Cannot register a listener for UNKNOWN conn type") + if conn_type in self._send_channels: + raise ValueError( + f"Listener already exists for conn type {conn_type!r}" + f" on {self.host}:{self.port}" + ) + + # Update TCP port component of the listen_maddr if it differs from self.port. + # This handles cases where demultiplexed_listen is called after listen() + # bound 0. + from multiaddr import Multiaddr + + parts = str(maddr).split("/") + for i, part in enumerate(parts): + if part == "tcp" and i + 1 < len(parts): + parts[i + 1] = str(self.port) + real_maddr = Multiaddr("/".join(parts)) + + send_ch, recv_ch = trio.open_memory_channel[PeekableStream](ACCEPT_QUEUE_SIZE) + self._send_channels[conn_type] = send_ch + + def _remove(ct: DemultiplexedConnType) -> None: + self._send_channels.pop(ct, None) + self._listeners.pop(ct, None) + + dl = DemultiplexedListener( + conn_type=conn_type, + recv_channel=recv_ch, + listen_maddr=real_maddr, + close_callback=_remove, + conn_handler=conn_handler, + ) + self._listeners[conn_type] = dl + logger.debug( + "PortDemultiplexer: registered %s listener on %s:%d", + conn_type.name, + self.host, + self.port, + ) + return dl + + # ------------------------------------------------------------------ + # IListener.listen – bind the OS socket and start the routing task + # ------------------------------------------------------------------ + + async def listen(self, maddr: Multiaddr) -> None: + """ + Bind to the port and start the demultiplexing loop. + + Subsequent calls are no-ops if the socket is already bound (mirrors + go-libp2p where each transport calls ``DemultiplexedListen`` and the + underlying socket is only created once per address). + """ + if self._started.is_set(): + return + + async def _serve( + task_status: Any = trio.TASK_STATUS_IGNORED, + ) -> None: + logger.debug("PortDemultiplexer serving on %s:%d", self.host, self.port) + try: + await trio.serve_tcp( + self._classify_and_route, + self.port, + host=self.host, + task_status=task_status, + ) + except Exception as exc: + logger.error("PortDemultiplexer serve error: %s", exc) + raise + + async def _run_server() -> None: + try: + async with trio.open_nursery() as nursery: + self._nursery = nursery + try: + listeners = await nursery.start(_serve) + if listeners: + sock = listeners[0].socket + new_port = sock.getsockname()[1] + self.port = new_port + for dl in self._listeners.values(): + if dl._listen_maddr is not None: + # Update the tcp component with the real port + from multiaddr import Multiaddr + + parts = str(dl._listen_maddr).split("/") + for i, part in enumerate(parts): + if part == "tcp" and i + 1 < len(parts): + parts[i + 1] = str(new_port) + dl._listen_maddr = Multiaddr("/".join(parts)) + except BaseException as err: + self._start_error = err + finally: + self._started.set() + finally: + self._stopped.set() + self._nursery = None + + nursery = getattr(self, "background_nursery", None) + if nursery is not None: + nursery.start_soon(_run_server) + else: + trio.lowlevel.spawn_system_task(_run_server) + await self._started.wait() + + if self._start_error is not None: + raise OpenConnectionError( + f"PortDemultiplexer failed to listen on {maddr}: {self._start_error}" + ) + + # ------------------------------------------------------------------ + # Internal – classify each new connection and route it + # ------------------------------------------------------------------ + + async def _classify_and_route(self, stream: trio.SocketStream) -> None: + """ + Peek at the first :data:`_PEEK_SIZE` bytes, classify, and dispatch. + + Mirrors ``multiplexedListener.run()`` + ``identifyConnType()`` from + go-libp2p's ``tcpreuse`` package. + """ + try: + peekable = PeekableStream(stream) + + # Read exactly 3 bytes for classification (matches go-libp2p). + try: + with trio.fail_after(_IDENTIFY_TIMEOUT): + prefix = await peekable.receive_some(_PEEK_SIZE) + except trio.TooSlowError: + logger.debug( + "PortDemultiplexer: timed out reading classification bytes;" + " closing connection" + ) + await stream.aclose() + return + + # Prepend the peeked bytes back into the buffer so handlers see a + # complete stream (mirrors sampledconn.PeekBytes in go-libp2p). + peekable.buffer = bytearray(prefix) + peekable.buffer + + conn_type = identify_conn_type(prefix) + logger.debug( + "PortDemultiplexer: classified connection as %s", conn_type.name + ) + + send_ch = self._send_channels.get(conn_type) + if send_ch is None: + logger.debug( + "PortDemultiplexer: no registered listener " + "for %s; closing connection", + conn_type.name, + ) + await stream.aclose() + return + + # Create an event to keep the trio.serve_tcp handler alive + # because trio automatically closes the stream when the handler returns. + stream_closed = trio.Event() + peekable.close_callback = stream_closed.set + + # Try to deliver with a bounded timeout to avoid blocking the + # accept loop (mirrors go-libp2p's acceptTimeout = 30s). + try: + with trio.fail_after(ACCEPT_TIMEOUT): + await send_ch.send(peekable) + except (trio.TooSlowError, trio.ClosedResourceError): + logger.debug( + "PortDemultiplexer: accept queue full or listener closed for %s; " + "dropping connection", + conn_type.name, + ) + await stream.aclose() + return + + # Wait for the consumer to finish using and close the stream + await stream_closed.wait() + + except Exception: + try: + await stream.aclose() + except Exception: + pass + + # ------------------------------------------------------------------ + # IListener helpers + # ------------------------------------------------------------------ + + def get_addrs(self) -> tuple[Multiaddr, ...]: + """Return the multiaddresses this listener is bound to.""" + if self.listen_maddr is not None: + return (self.listen_maddr,) + return () + + async def close(self) -> None: + """Close all registered listeners and the underlying TCP socket.""" + if self._closed: + return + self._closed = True + + # Close all per-type send channels so their DemultiplexedListeners see EOF. + for send_ch in list(self._send_channels.values()): + send_ch.close() + self._send_channels.clear() + + # Cancel the background nursery (stops trio.serve_tcp). + if self._nursery is not None: + self._nursery.cancel_scope.cancel() + + if self._started.is_set(): + await self._stopped.wait() diff --git a/libp2p/transport/manager.py b/libp2p/transport/manager.py new file mode 100644 index 000000000..56db24215 --- /dev/null +++ b/libp2p/transport/manager.py @@ -0,0 +1,443 @@ +""" +TransportManager — routes dial/listen operations to the correct transport. + +Modelled after go-libp2p's transport manager in +``go-libp2p/p2p/net/swarm/swarm_transport.go``. + +Usage:: + + from libp2p.transport.manager import TransportManager + from libp2p.transport.tcp.tcp import TCP + from libp2p.transport.quic.transport import QUICTransport + from multiaddr import Multiaddr + + mgr = TransportManager() + mgr.add_transport(TCP()) + mgr.add_transport(QUICTransport(private_key)) + + transport = mgr.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/tcp/4001")) + # -> TCP instance + + transport = mgr.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1")) + # -> QUICTransport instance + +Port sharing (TCP + WebSocket on the same port):: + + from libp2p.transport.cmux import PortDemultiplexer, DemultiplexedConnType + + port_demux = PortDemultiplexer("0.0.0.0", 4001) + mgr = TransportManager(port_demux=port_demux) + + mgr.add_transport(TCP()) + mgr.add_transport(WebsocketTransport(...)) + + # TransportManager calls add_listen_addr() per multiaddr. + # add_listen_addr() detects TCP-based addrs and delegates to PortDemultiplexer so + # both transports share the same OS socket. + +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from multiaddr import Multiaddr +import trio + +from libp2p.abc import IListener, ITransport +from libp2p.custom_types import THandler + +if TYPE_CHECKING: + import trio + + from libp2p.transport.cmux import PortDemultiplexer + +logger = logging.getLogger(__name__) + + +class TransportManager: + """ + Maintains an ordered list of :class:`~libp2p.abc.ITransport` instances and + provides routing helpers used by the :class:`~libp2p.network.swarm.Swarm`. + + Transports are checked in the order they were added. For dialing, the + first transport whose :meth:`~libp2p.abc.ITransport.can_dial` returns + ``True`` is selected. For listening, the same logic applies via + :meth:`~libp2p.abc.ITransport.can_listen`. + + This is the Python equivalent of go-libp2p's ``Swarm.TransportForDialing`` + / ``Swarm.TransportForListening`` pair. + + :param port_demux: Optional :class:`~libp2p.transport.cmux.PortDemultiplexer` for + sharing a single TCP port between multiple transports (TCP + WS). + Mirrors the ``sharedTCP *tcpreuse.PortDemultiplexer`` parameter passed to + ``NewTCPTransport`` / ``websocket.New`` in go-libp2p. + """ + + def __init__( + self, + port_demux: PortDemultiplexer | None = None, + port_demuxers: dict[tuple[str, int], PortDemultiplexer] | None = None, + ) -> None: + self._transports: list[ITransport] = [] + # Shared PortDemultiplexer for TCP port reuse (optional). + if port_demuxers is not None: + self._port_demuxers = port_demuxers + elif port_demux is not None: + self._port_demuxers = {(port_demux.host, port_demux.port): port_demux} + else: + self._port_demuxers = {} + + self._port_demux: PortDemultiplexer | None = ( + next(iter(self._port_demuxers.values()), None) + if self._port_demuxers + else None + ) + # Tracks per-(host, port) listeners for non-PortDemultiplexer paths. + self._listeners: dict[tuple[str, int], IListener] = {} + + # ── Registration ────────────────────────────────────────────────────────── + + def add_transport(self, transport: ITransport) -> None: + """ + Append a transport to the routing list. + + :param transport: A transport instance implementing + :class:`~libp2p.abc.ITransport`. + """ + self._transports.append(transport) + + # Re-sort by listen_order() to match go-libp2p's ListenOrder priority. + # Transports without listen_order() default to priority 0 (highest). + self._transports.sort(key=lambda t: getattr(t, "listen_order", lambda: 0)()) + + logger.debug( + "TransportManager: registered %s (protocols=%s, listen_order=%d)", + type(transport).__name__, + getattr(transport, "protocols", lambda: "")(), + getattr(transport, "listen_order", lambda: 0)(), + ) + + def add_transports(self, transports: list[ITransport]) -> None: + """ + Convenience helper to register multiple transports at once. + + :param transports: List of transport instances. + """ + for t in transports: + self.add_transport(t) + + # ── Routing ─────────────────────────────────────────────────────────────── + + def transport_for_dialing(self, maddr: Multiaddr) -> ITransport | None: + """ + Return the first registered transport that can dial *maddr*, or ``None``. + + The manager first performs a cheap pre-filter using each transport's + :meth:`~libp2p.abc.ITransport.protocols` list (set intersection), and + only calls :meth:`~libp2p.abc.ITransport.can_dial` when there is at + least one protocol name overlap. This avoids unnecessary work when + many transports are registered. + + This is the Python equivalent of go-libp2p's + ``Swarm.TransportForDialing()``. + + :param maddr: The multiaddress to route. + :returns: The matching transport, or ``None`` if no transport can handle + the address. + """ + proto_names = {p.name for p in maddr.protocols()} + + for transport in self._transports: + # Fast pre-filter: skip if no protocol name overlap at all. + _protocols = getattr(transport, "protocols", None) + if _protocols is not None and not proto_names.intersection( + set(_protocols()) + ): + continue + _can_dial = getattr(transport, "can_dial", None) + if _can_dial is not None and _can_dial(maddr): + logger.debug( + "TransportManager.transport_for_dialing: %s => %s", + maddr, + type(transport).__name__, + ) + return transport + + logger.warning( + "TransportManager.transport_for_dialing: no transport found for %s " + "(registered: %s)", + maddr, + [type(t).__name__ for t in self._transports], + ) + return None + + def transport_for_listening(self, maddr: Multiaddr) -> ITransport | None: + """ + Return the first registered transport that can listen on *maddr*, or + ``None``. + + Uses the same two-step pre-filter logic as :meth:`transport_for_dialing`. + + This is the Python equivalent of go-libp2p's + ``Swarm.TransportForListening()``. + + :param maddr: The multiaddress to route. + :returns: The matching transport, or ``None`` if no transport can handle + the address. + """ + proto_names = {p.name for p in maddr.protocols()} + + for transport in self._transports: + _protocols = getattr(transport, "protocols", None) + if _protocols is not None and not proto_names.intersection( + set(_protocols()) + ): + continue + _can_listen = getattr(transport, "can_listen", None) + if _can_listen is not None and _can_listen(maddr): + logger.debug( + "TransportManager.transport_for_listening: %s => %s", + maddr, + type(transport).__name__, + ) + return transport + + logger.warning( + "TransportManager.transport_for_listening: no transport found for %s " + "(registered: %s)", + maddr, + [type(t).__name__ for t in self._transports], + ) + return None + + # ── Introspection ───────────────────────────────────────────────────────── + + def get_transports(self) -> list[ITransport]: + """ + Return a shallow copy of the registered transports list. + + :returns: List of registered :class:`~libp2p.abc.ITransport` instances. + """ + return list(self._transports) + + def has_transport_for(self, maddr: Multiaddr) -> bool: + """ + Return ``True`` if any registered transport can dial *maddr*. + + :param maddr: The multiaddress to check. + :returns: ``True`` if a matching transport exists, ``False`` otherwise. + """ + return self.transport_for_dialing(maddr) is not None + + # ── Listening ───────────────────────────────────────────────────────────── + + def add_listen_addr( + self, maddr: Multiaddr, conn_handler: THandler + ) -> IListener | None: + """ + Create and return a listener for *maddr*. + + Mirrors ``Swarm.AddListenAddr()`` from go-libp2p. + + When a :class:`~libp2p.transport.cmux.PortDemultiplexer` was supplied at + construction **and** the address is TCP-based, the manager calls + :meth:`~libp2p.transport.cmux.PortDemultiplexer.demultiplexed_listen` on the + appropriate :class:`~libp2p.transport.cmux.DemultiplexedConnType` so + that TCP and WebSocket transports can share the same OS socket. + + For non-TCP transports (e.g. QUIC) the transport's own + :meth:`~libp2p.abc.ITransport.create_listener` is used directly. + + :param maddr: The multiaddress to listen on. + :param conn_handler: Handler called for every accepted connection. + :returns: An :class:`~libp2p.abc.IListener`, or ``None`` if no + transport supports *maddr*. + """ + transport = self.transport_for_listening(maddr) + if transport is None: + return None + + protocols = [p.name for p in maddr.protocols()] + + if "tcp" in protocols: + host_val = None + if "ip4" in protocols: + host_val = maddr.value_for_protocol("ip4") + elif "ip6" in protocols: + host_val = maddr.value_for_protocol("ip6") + port_val = maddr.value_for_protocol("tcp") + if host_val and port_val: + key = (str(host_val), int(port_val)) + if hasattr(self, "_port_demuxers") and key in self._port_demuxers: + return self._add_listen_addr_shared( + maddr, + protocols, + transport, + conn_handler, + self._port_demuxers[key], + ) + + # ---- Non-TCP or no PortDemultiplexer: let the transport own its listener ---- + return transport.create_listener(conn_handler) + + def _add_listen_addr_shared( + self, + maddr: Multiaddr, + protocols: list[str], + transport: ITransport, + conn_handler: THandler, + port_demux: PortDemultiplexer, + ) -> IListener | None: + """ + Wire a TCP-based transport into the shared + :class:`~libp2p.transport.cmux.PortDemultiplexer`. + + Determines the correct :class:`~libp2p.transport.cmux.DemultiplexedConnType` + for *transport*, registers a + :class:`~libp2p.transport.cmux.DemultiplexedListener`, and wires + *conn_handler* into the listener so that + :meth:`~libp2p.transport.cmux.DemultiplexedListener.listen` starts + the drain task automatically. + + :returns: The :class:`~libp2p.transport.cmux.DemultiplexedListener`, + or ``None`` on failure. + """ + from libp2p.transport.cmux import DemultiplexedConnType + + # Determine connection type for this transport. + if "ws" in protocols or "wss" in protocols: + is_secure = "wss" in protocols + if is_secure: + conn_type = DemultiplexedConnType.TLS + else: + conn_type = DemultiplexedConnType.HTTP + + from trio_websocket import ( # type: ignore + wrap_server_stream, + ) + + from libp2p.transport.websocket.connection import P2PWebSocketConnection + + async def ws_wrapped_handler(stream: trio.abc.Stream) -> None: + try: + async with trio.open_nursery() as ws_nursery: + # Max message size defaults to 32MB to match WebsocketTransport + request = await wrap_server_stream( + ws_nursery, stream, max_message_size=32 * 1024 * 1024 + ) + ws = await request.accept() + conn = P2PWebSocketConnection( + ws, + is_secure=is_secure, + max_buffered_amount=4 * 1024 * 1024, + ) + await conn_handler(conn) + except Exception as exc: + logger.error( + "WS upgrade failed on shared port: %s, cause: %s", + exc, + getattr(exc, "__cause__", None), + exc_info=True, + ) + + if port_demux.has_listener(conn_type): + logger.debug( + "PortDemultiplexer already has a listener for %s; reusing", + conn_type.name, + ) + return port_demux.get_listener(conn_type) + + return port_demux.demultiplexed_listen( + maddr, conn_type, conn_handler=ws_wrapped_handler + ) + + else: + conn_type = DemultiplexedConnType.MULTISTREAM_SELECT + + from libp2p.io.trio import TrioTCPStream + + async def tcp_wrapped_handler(stream: trio.abc.Stream) -> None: + try: + tcp_stream = TrioTCPStream(stream) # type: ignore[arg-type] + await conn_handler(tcp_stream) + except Exception as exc: + logger.debug("TCP handler failed on shared port: %s", exc) + + # Check whether this conn_type is already registered — return the + # existing DemultiplexedListener so the swarm can call listen() on it + # and notify listeners without error. + if port_demux.has_listener(conn_type): + logger.debug( + "PortDemultiplexer already has a listener for %s; reusing", + conn_type.name, + ) + return port_demux.get_listener(conn_type) + + # Register and wire the handler in one step. + return port_demux.demultiplexed_listen( + maddr, conn_type, conn_handler=tcp_wrapped_handler + ) + + # Backwards-compatible alias. + def listen_on(self, maddr: Multiaddr, conn_handler: THandler) -> IListener | None: + """Deprecated alias for :meth:`add_listen_addr`.""" + return self.add_listen_addr(maddr, conn_handler) + + # ── Lifecycle helpers (called by Swarm) ─────────────────────────────────── + + def set_background_nursery(self, nursery: trio.Nursery) -> None: + """ + Pass the Swarm's background nursery to all transports that need one. + + Called once by :meth:`~libp2p.network.swarm.Swarm.run` as soon as the + background nursery is ready. Delegates to every transport that exposes + a ``set_background_nursery`` method (currently QUIC and WebSocket). + + :param nursery: The trio nursery to share with transports. + """ + for transport in self._transports: + if hasattr(transport, "set_background_nursery"): + transport.set_background_nursery(nursery) # type: ignore[attr-defined] + logger.debug( + "TransportManager: set background nursery on %s", + type(transport).__name__, + ) + + def set_swarm(self, swarm: object) -> None: + """ + Pass a reference to the Swarm to all transports that need it. + + Called once by :meth:`~libp2p.network.swarm.Swarm.run`. Needed by + :class:`~libp2p.transport.quic.transport.QUICTransport` so it can + call :meth:`~libp2p.network.swarm.Swarm.add_conn` for inbound + QUIC connections. + + :param swarm: The :class:`~libp2p.network.swarm.Swarm` instance. + """ + for transport in self._transports: + if hasattr(transport, "set_swarm"): + transport.set_swarm(swarm) # type: ignore[attr-defined] + logger.debug( + "TransportManager: set swarm reference on %s", + type(transport).__name__, + ) + + async def close_all(self) -> None: + """ + Close all registered transports concurrently. + + Called by :meth:`~libp2p.network.swarm.Swarm.close` during teardown. + """ + import trio + + async with trio.open_nursery() as nursery: + for transport in self._transports: + if hasattr(transport, "close"): + nursery.start_soon(transport.close) # type: ignore[attr-defined] + + logger.debug( + "TransportManager: closed all transports (%d total)", + len(self._transports), + ) diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py index 68e2ad543..fc31f8ee4 100644 --- a/libp2p/transport/quic/transport.py +++ b/libp2p/transport/quic/transport.py @@ -411,7 +411,7 @@ def can_dial(self, maddr: multiaddr.Multiaddr) -> bool: """ return is_quic_multiaddr(maddr) - def protocols(self) -> list[TProtocol]: + def protocols(self) -> list[str]: """ Get supported protocol identifiers. @@ -419,11 +419,25 @@ def protocols(self) -> list[TProtocol]: List of supported protocol strings """ - protocols = [QUIC_V1_PROTOCOL] + protocols: list[str] = [str(QUIC_V1_PROTOCOL)] if self._config.enable_draft29: - protocols.append(QUIC_DRAFT29_PROTOCOL) + protocols.append(str(QUIC_DRAFT29_PROTOCOL)) return protocols + def can_listen(self, maddr: multiaddr.Multiaddr) -> bool: + """ + Get supported protocol identifiers. + Return True if this QUIC transport can listen on the given multiaddr. + + Args: + maddr: Multiaddr to check. + + Returns: + True if the multiaddr contains a QUIC protocol component. + + """ + return is_quic_multiaddr(maddr) + def listen_order(self) -> int: """ Get the listen order priority for this transport. diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index e7851006e..5d054c505 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -4,6 +4,7 @@ Sequence, ) import logging +import typing from multiaddr import Multiaddr from multiaddr.exceptions import ProtocolLookupError @@ -132,7 +133,7 @@ async def handler(stream: trio.SocketStream) -> None: ) try: started_listeners = await nursery.start( - serve_tcp, + typing.cast(typing.Any, serve_tcp), handler, tcp_port, host_str, @@ -309,6 +310,39 @@ def create_listener(self, handler_function: THandler) -> TCPListener: """ return TCPListener(handler_function) + def can_dial(self, maddr: Multiaddr) -> bool: + """ + Return True if this TCP transport can dial the given multiaddr. + + Accepts pure TCP addresses (/ip4/.../tcp/... or /ip6/.../tcp/...) but + rejects WebSocket addresses (/ws, /wss) even though they use TCP underneath, + so the TransportManager routes those to WebsocketTransport instead. + + :param maddr: The multiaddress to check. + :return: True if this transport handles the multiaddr. + """ + names = {p.name for p in maddr.protocols()} + return "tcp" in names and not names.intersection( + {"ws", "wss", "quic", "quic-v1"} + ) + + def can_listen(self, maddr: Multiaddr) -> bool: + """ + Return True if this TCP transport can listen on the given multiaddr. + + :param maddr: The multiaddress to check. + :return: True if this transport can listen on the multiaddr. + """ + return self.can_dial(maddr) + + def protocols(self) -> list[str]: + """ + Return the list of multiaddr protocol names handled by TCP transport. + + :return: ["tcp"] + """ + return ["tcp"] + def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr: return multiaddr_from_socket(socket) diff --git a/libp2p/transport/transport_registry.py b/libp2p/transport/transport_registry.py deleted file mode 100644 index edac9d4eb..000000000 --- a/libp2p/transport/transport_registry.py +++ /dev/null @@ -1,354 +0,0 @@ -""" -Transport registry for dynamic transport selection based on multiaddr protocols. -""" - -from collections.abc import Callable -import logging -from typing import Any - -from multiaddr import Multiaddr -from multiaddr.protocols import Protocol - -from libp2p.abc import ITransport -from libp2p.transport.tcp.tcp import TCP -from libp2p.transport.upgrader import TransportUpgrader -from libp2p.transport.websocket.multiaddr_utils import ( - is_valid_websocket_multiaddr, -) - - -# Import QUIC utilities here to avoid circular imports -def _get_quic_transport() -> Any: - from libp2p.transport.quic.transport import QUICTransport - - return QUICTransport - - -def _get_quic_validation() -> Callable[[Multiaddr], bool]: - from libp2p.transport.quic.utils import is_quic_multiaddr - - return is_quic_multiaddr - - -# Import WebsocketTransport here to avoid circular imports -def _get_websocket_transport() -> Any: - from libp2p.transport.websocket.transport import WebsocketTransport - - return WebsocketTransport - - -def _get_webrtc_direct_transport() -> Any: - from libp2p.transport.webrtc.transport import WebRTCDirectTransport - - return WebRTCDirectTransport - - -def _get_webrtc_private_transport() -> Any: - from libp2p.transport.webrtc.private_transport import WebRTCPrivateTransport - - return WebRTCPrivateTransport - - -logger = logging.getLogger(__name__) - - -def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool: - """ - Validate that a multiaddr has a valid TCP structure. - - :param maddr: The multiaddr to validate - :return: True if valid TCP structure, False otherwise - """ - try: - # TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080 - # or /ip6/::1/tcp/8080 - protocols: list[Protocol] = list(maddr.protocols()) - - # Must have at least 2 protocols: network (ip4/ip6) + tcp - if len(protocols) < 2: - return False - - # First protocol should be a network protocol (ip4, ip6, dns4, dns6) - if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]: - return False - - # Second protocol should be tcp - if protocols[1].name != "tcp": - return False - - # Should not have any protocols after tcp (unless it's a valid - # continuation like p2p) - # For now, we'll be strict and only allow network + tcp - if len(protocols) > 2: - # Check if the additional protocols are valid continuations - valid_continuations = ["p2p"] # Add more as needed - for i in range(2, len(protocols)): - if protocols[i].name not in valid_continuations: - return False - - return True - - except Exception: - return False - - -class TransportRegistry: - """ - Registry for mapping multiaddr protocols to transport implementations. - """ - - def __init__(self) -> None: - self._transports: dict[str, type[ITransport]] = {} - self._register_default_transports() - - def _register_default_transports(self) -> None: - """Register the default transport implementations.""" - # Register TCP transport for /tcp protocol - self.register_transport("tcp", TCP) - - # Register WebSocket transport for /ws and /wss protocols - WebsocketTransport = _get_websocket_transport() - self.register_transport("ws", WebsocketTransport) - self.register_transport("wss", WebsocketTransport) - - # Register QUIC transport for /quic and /quic-v1 protocols - QUICTransport = _get_quic_transport() - self.register_transport("quic", QUICTransport) - self.register_transport("quic-v1", QUICTransport) - - # Register WebRTC transports only when aiortc is actually installed. - # The scaffolding modules themselves do not import aiortc (aiortc is - # loaded lazily inside the bridge), so we probe for it explicitly. - import importlib.util as _importlib_util - - if _importlib_util.find_spec("aiortc") is not None: - try: - WebRTCDirectTransport = _get_webrtc_direct_transport() - self.register_transport("webrtc-direct", WebRTCDirectTransport) - WebRTCPrivateTransport = _get_webrtc_private_transport() - self.register_transport("webrtc", WebRTCPrivateTransport) - except ImportError as e: - logger.debug("aiortc present but WebRTC transport import failed: %s", e) - else: - logger.debug( - "aiortc not installed; skipping /webrtc and /webrtc-direct " - "transport registration (install libp2p[webrtc] to enable)" - ) - - def register_transport( - self, protocol: str, transport_class: type[ITransport] - ) -> None: - """ - Register a transport class for a specific protocol. - - :param protocol: The protocol identifier (e.g., "tcp", "ws") - :param transport_class: The transport class to register - """ - self._transports[protocol] = transport_class - logger.debug( - f"Registered transport {transport_class.__name__} for protocol {protocol}" - ) - - def get_transport(self, protocol: str) -> type[ITransport] | None: - """ - Get the transport class for a specific protocol. - - :param protocol: The protocol identifier - :return: The transport class or None if not found - """ - return self._transports.get(protocol) - - def get_supported_protocols(self) -> list[str]: - """Get list of supported transport protocols.""" - return list(self._transports.keys()) - - def create_transport( - self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any - ) -> ITransport | None: - """ - Create a transport instance for a specific protocol. - - :param protocol: The protocol identifier - :param upgrader: The transport upgrader instance (required for WebSocket) - :param kwargs: Additional arguments for transport construction - :return: Transport instance or None if protocol not supported or creation fails - """ - transport_class = self.get_transport(protocol) - if transport_class is None: - return None - - try: - if protocol in ["ws", "wss"]: - # WebSocket transport requires upgrader - if upgrader is None: - logger.warning( - f"WebSocket transport '{protocol}' requires upgrader" - ) - return None - # Use explicit WebsocketTransport to avoid type issues - WebsocketTransport = _get_websocket_transport() - return WebsocketTransport( - upgrader, - tls_client_config=kwargs.get("tls_client_config"), - tls_server_config=kwargs.get("tls_server_config"), - handshake_timeout=kwargs.get("handshake_timeout", 15.0), - ) - elif protocol in ["quic", "quic-v1"]: - # QUIC transport requires private_key - private_key = kwargs.get("private_key") - if private_key is None: - logger.warning(f"QUIC transport '{protocol}' requires private_key") - return None - # Use explicit QUICTransport to avoid type issues - QUICTransport = _get_quic_transport() - from libp2p.transport.quic.config import QUICTransportConfig - - # Get or create config - config = kwargs.get("config") - if config is None: - config = QUICTransportConfig() - elif not isinstance(config, QUICTransportConfig): - # If config is not QUICTransportConfig, create new one - config = QUICTransportConfig() - - # Allow negotiation config to be passed via kwargs for coordination - if "negotiation_semaphore_limit" in kwargs: - config.NEGOTIATION_SEMAPHORE_LIMIT = kwargs[ - "negotiation_semaphore_limit" - ] - if "negotiate_timeout" in kwargs: - config.NEGOTIATE_TIMEOUT = kwargs["negotiate_timeout"] - - enable_autotls = kwargs.get("enable_autotls", False) - return QUICTransport( - private_key, config=config, enable_autotls=enable_autotls - ) - elif protocol in ["webrtc-direct", "webrtc"]: - # WebRTC transports require a private key for the local peer - # identity used in the Noise XX handshake. The transport - # classes are loaded lazily; mypy can't see the concrete - # signature here, so we cast the call. - private_key = kwargs.get("private_key") - if private_key is None: - logger.warning( - "WebRTC transport '%s' requires private_key", protocol - ) - return None - config = kwargs.get("config") - if protocol == "webrtc-direct": - return transport_class( # type: ignore[call-arg] - private_key=private_key, config=config - ) - # private-to-private also accepts an optional host - host = kwargs.get("host") - return transport_class( # type: ignore[call-arg] - private_key=private_key, host=host, config=config - ) - else: - # TCP transport doesn't require upgrader - return transport_class() - except Exception as e: - logger.error(f"Failed to create transport for protocol {protocol}: {e}") - return None - - -# Global transport registry instance (lazy initialization) -_global_registry: TransportRegistry | None = None - - -def get_transport_registry() -> TransportRegistry: - """Get the global transport registry instance.""" - global _global_registry - if _global_registry is None: - _global_registry = TransportRegistry() - return _global_registry - - -def register_transport(protocol: str, transport_class: type[ITransport]) -> None: - """Register a transport class in the global registry.""" - registry = get_transport_registry() - registry.register_transport(protocol, transport_class) - - -def create_transport_for_multiaddr( - maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any -) -> ITransport | None: - """ - Create the appropriate transport for a given multiaddr. - - :param maddr: The multiaddr to create transport for - :param upgrader: The transport upgrader instance - :param kwargs: Additional arguments for transport construction - (e.g., private_key for QUIC) - :return: Transport instance or None if no suitable transport found - """ - try: - # Get all protocols in the multiaddr - protocols = [proto.name for proto in maddr.protocols()] - - # Check for supported transport protocols in order of preference - # We need to validate that the multiaddr structure is valid for our transports - if "webrtc-direct" in protocols or "webrtc" in protocols: - # WebRTC Direct: /ip4//udp//webrtc-direct/... - # WebRTC (relayed): /p2p-circuit/webrtc/... - # Both are only routable when the corresponding transport is - # registered (which only happens when aiortc is installed). - registry = get_transport_registry() - proto = "webrtc-direct" if "webrtc-direct" in protocols else "webrtc" - if proto in registry.get_supported_protocols(): - return registry.create_transport(proto, upgrader, **kwargs) - logger.warning( - "Multiaddr requires the WebRTC transport (%s) but it is not " - "registered. Install libp2p[webrtc] to enable WebRTC support.", - proto, - ) - return None - elif "quic" in protocols or "quic-v1" in protocols: - # For QUIC, we need a valid structure like: - # /ip4/127.0.0.1/udp/4001/quic - # /ip4/127.0.0.1/udp/4001/quic-v1 - is_quic_multiaddr = _get_quic_validation() - if is_quic_multiaddr(maddr): - # Determine QUIC version - registry = get_transport_registry() - if "quic-v1" in protocols: - return registry.create_transport("quic-v1", upgrader, **kwargs) - else: - return registry.create_transport("quic", upgrader, **kwargs) - elif "ws" in protocols or "wss" in protocols or "tls" in protocols: - # For WebSocket, we need a valid structure like: - # /ip4/127.0.0.1/tcp/8080/ws (insecure) - # /ip4/127.0.0.1/tcp/8080/wss (secure) - # /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS) - # /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI) - if is_valid_websocket_multiaddr(maddr): - # Determine if this is a secure WebSocket connection - registry = get_transport_registry() - if "wss" in protocols or "tls" in protocols: - return registry.create_transport("wss", upgrader, **kwargs) - else: - return registry.create_transport("ws", upgrader, **kwargs) - elif "tcp" in protocols: - # For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080 - # Check if the multiaddr has proper TCP structure - if _is_valid_tcp_multiaddr(maddr): - registry = get_transport_registry() - return registry.create_transport("tcp", upgrader) - - # If no supported transport protocol found or structure is invalid, return None - logger.warning( - f"No supported transport protocol found or invalid structure in " - f"multiaddr: {maddr}" - ) - return None - - except Exception as e: - # Handle any errors gracefully (e.g., invalid multiaddr) - logger.warning(f"Error processing multiaddr {maddr}: {e}") - return None - - -def get_supported_transport_protocols() -> list[str]: - """Get list of supported transport protocols from the global registry.""" - registry = get_transport_registry() - return registry.get_supported_protocols() diff --git a/libp2p/transport/webrtc/private_transport.py b/libp2p/transport/webrtc/private_transport.py index 617bea26f..7e4bf441b 100644 --- a/libp2p/transport/webrtc/private_transport.py +++ b/libp2p/transport/webrtc/private_transport.py @@ -91,6 +91,15 @@ async def _ensure_bridge(self) -> AsyncioBridge: await self._bridge.start() return self._bridge + def can_dial(self, maddr: Multiaddr) -> bool: + return is_webrtc_multiaddr(maddr) + + def can_listen(self, maddr: Multiaddr) -> bool: + return is_webrtc_multiaddr(maddr) + + def protocols(self) -> list[str]: + return ["webrtc"] + async def dial(self, maddr: Multiaddr) -> WebRTCConnection: """ Dial a remote peer over WebRTC via a relay. diff --git a/libp2p/transport/webrtc/transport.py b/libp2p/transport/webrtc/transport.py index 50f867fb0..3ce4b3294 100644 --- a/libp2p/transport/webrtc/transport.py +++ b/libp2p/transport/webrtc/transport.py @@ -90,6 +90,15 @@ async def _ensure_bridge(self) -> AsyncioBridge: await self._bridge.start() return self._bridge + def can_dial(self, maddr: Multiaddr) -> bool: + return is_webrtc_direct_multiaddr(maddr) + + def can_listen(self, maddr: Multiaddr) -> bool: + return is_webrtc_direct_multiaddr(maddr) + + def protocols(self) -> list[str]: + return ["webrtc-direct"] + async def dial(self, maddr: Multiaddr) -> WebRTCConnection: """ Dial a remote peer over WebRTC Direct. diff --git a/libp2p/transport/websocket/transport.py b/libp2p/transport/websocket/transport.py index c9a2a5c24..46169e532 100644 --- a/libp2p/transport/websocket/transport.py +++ b/libp2p/transport/websocket/transport.py @@ -490,14 +490,42 @@ def set_background_nursery(self, nursery: trio.Nursery) -> None: ): nursery.start_soon(self._initialize_autotls, self._peer_id) - async def can_dial(self, maddr: Multiaddr) -> bool: - """Check if we can dial the given multiaddr.""" + def can_dial(self, maddr: Multiaddr) -> bool: + """ + Return True if this WebSocket transport can dial the given multiaddr. + + Checks whether the multiaddr matches the WebSocket multiaddr pattern + (e.g. ``/ip4/.../tcp/.../ws`` or ``/ip4/.../tcp/.../wss``). + + The check is purely protocol-level (no I/O); it changed from async + to sync to satisfy the :class:`~libp2p.abc.ITransport` interface. + + :param maddr: The multiaddress to check. + :return: True if this transport handles the multiaddr. + """ try: parse_websocket_multiaddr(maddr) - return True # If parsing succeeds, it's a valid WebSocket multiaddr + return True except (ValueError, KeyError): return False + def can_listen(self, maddr: Multiaddr) -> bool: + """ + Return True if this WebSocket transport can listen on the given multiaddr. + + :param maddr: The multiaddress to check. + :return: True if this transport can listen on the multiaddr. + """ + return self.can_dial(maddr) + + def protocols(self) -> list[str]: + """ + Return the list of multiaddr protocol names handled by WebSocket transport. + + :return: ["ws", "wss"] + """ + return ["ws", "wss"] + async def _initialize_autotls(self, peer_id: ID | None = None) -> None: """Initialize AutoTLS if configured.""" if self._autotls_initialized: @@ -847,7 +875,7 @@ async def dial(self, maddr: Multiaddr) -> RawConnection: async def _dial_resolved(self, maddr: Multiaddr) -> RawConnection: """Dial using a multiaddr that has an IP (no DNS).""" - if not await self.can_dial(maddr): + if not self.can_dial(maddr): raise OpenConnectionError(f"Cannot dial {maddr}") try: diff --git a/newsfragments/1359.breaking.rst b/newsfragments/1359.breaking.rst new file mode 100644 index 000000000..d40f8835f --- /dev/null +++ b/newsfragments/1359.breaking.rst @@ -0,0 +1 @@ +Removed the ``libp2p.transport.transport_registry`` module and the legacy ``transport=`` keyword argument from ``Swarm.__init__``. Third-party transports should now be registered using ``TransportManager`` or by passing ``transports=[...]`` to ``Swarm()``. diff --git a/newsfragments/1359.feature.rst b/newsfragments/1359.feature.rst new file mode 100644 index 000000000..4fdd046f2 --- /dev/null +++ b/newsfragments/1359.feature.rst @@ -0,0 +1,5 @@ +Add multi-transport support to ``Swarm``, ``new_swarm``, and ``new_host``. +A node can now listen and dial over TCP, WebSocket, and QUIC simultaneously, +matching go-libp2p's ``TransportManager`` architecture. Pass an explicit +``transports=[...]`` list to ``new_swarm`` / ``new_host``, or let transports +be auto-detected from ``listen_addrs``. diff --git a/pyproject.toml b/pyproject.toml index 8fce236e9..97997528d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -322,6 +322,7 @@ project_excludes = [ "**/.project-template/**", "**/docs/conf.py", "**/*pb2.py", + "**/*pb2_grpc.py", "**/*.pyi", ".venv/**", "./tests/interop/nim_libp2p", diff --git a/tests/core/identity/identify/test_identify_integration.py b/tests/core/identity/identify/test_identify_integration.py index e4ebcba77..280ff3870 100644 --- a/tests/core/identity/identify/test_identify_integration.py +++ b/tests/core/identity/identify/test_identify_integration.py @@ -239,3 +239,57 @@ async def test_identify_message_equivalence_real_network(security_protocol): assert result_varint.protocol_version == result_raw.protocol_version assert result_varint.public_key == result_raw.public_key assert result_varint.listen_addrs == result_raw.listen_addrs + + +@pytest.mark.trio +async def test_identify_multi_transport_host_addresses(security_protocol): + """Test that a multi-transport host advertises all its addrs and they're learned.""" + from multiaddr import Multiaddr + + from libp2p import new_host + from libp2p.peer.peerinfo import info_from_p2p_addr + + host_a = new_host( + enable_tcp=True, + enable_websocket=True, + ) + host_b = new_host(enable_tcp=True, enable_websocket=True) + + from libp2p.tools.anyio_service import background_trio_service + + async with ( + background_trio_service(host_a.get_network()), + background_trio_service(host_b.get_network()), + ): + await host_a.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) + await host_a.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0/ws")) + await host_b.get_network().listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) + + # host_b dials host_a using one of its addresses + host_a.set_stream_handler(ID, identify_handler_for(host_a)) + + host_a_addrs = host_a.get_addrs() + assert len(host_a_addrs) == 2, "host_a should have 2 listen addresses" + + # We dial using the first address + maddr = host_a_addrs[0].encapsulate( + Multiaddr(f"/p2p/{host_a.get_id().to_base58()}") + ) + info = info_from_p2p_addr(maddr) + + # Connect + await host_b.connect(info) + + # Make identify request + stream = await host_b.new_stream(host_a.get_id(), (ID,)) + response = await stream.read(8192) + await stream.close() + + # Parse response + result = parse_identify_response(response) + + # Verify response contains all addresses + for addr in host_a_addrs: + assert _multiaddr_to_bytes(addr) in result.listen_addrs, ( + f"Address {addr} not advertised by host_a" + ) diff --git a/tests/core/network/test_enhanced_swarm.py b/tests/core/network/test_enhanced_swarm.py index 69a7c2a70..fa4b18063 100644 --- a/tests/core/network/test_enhanced_swarm.py +++ b/tests/core/network/test_enhanced_swarm.py @@ -92,7 +92,7 @@ async def test_enhanced_swarm_constructor(): transport = Mock() # Test with default config - swarm = Swarm(peer_id, peerstore, upgrader, transport) + swarm = Swarm(peer_id, peerstore, upgrader, [transport]) assert swarm.retry_config.max_retries == 3 assert swarm.connection_config.max_connections_per_peer == 3 assert isinstance(swarm.connections, dict) @@ -101,7 +101,7 @@ async def test_enhanced_swarm_constructor(): custom_retry = RetryConfig(max_retries=5, initial_delay=0.5) custom_conn = ConnectionConfig(max_connections_per_peer=5) - swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn) + swarm = Swarm(peer_id, peerstore, upgrader, [transport], custom_retry, custom_conn) assert swarm.retry_config.max_retries == 5 assert swarm.retry_config.initial_delay == 0.5 assert swarm.connection_config.max_connections_per_peer == 5 @@ -119,7 +119,7 @@ async def test_swarm_backoff_calculation(): initial_delay=0.1, max_delay=1.0, backoff_multiplier=2.0, jitter_factor=0.1 ) - swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config) + swarm = Swarm(peer_id, peerstore, upgrader, [transport], retry_config) # Test backoff calculation delay1 = swarm._calculate_backoff_delay(0) @@ -154,7 +154,7 @@ async def test_swarm_retry_logic(): max_delay=0.1, ) - swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config) + swarm = Swarm(peer_id, peerstore, upgrader, [transport], retry_config) # Mock the single attempt method to fail twice then succeed attempt_count = [0] @@ -186,7 +186,7 @@ async def test_swarm_load_balancing_strategies(): upgrader = Mock() transport = Mock() - swarm = Swarm(peer_id, peerstore, upgrader, transport) + swarm = Swarm(peer_id, peerstore, upgrader, [transport]) # Create mock connections with different stream counts conn1 = MockConnection(peer_id) @@ -235,7 +235,7 @@ async def test_swarm_multiple_connections_api(): upgrader = Mock() transport = Mock() - swarm = Swarm(peer_id, peerstore, upgrader, transport) + swarm = Swarm(peer_id, peerstore, upgrader, [transport]) # Test empty connections assert swarm.get_connections() == [] @@ -311,7 +311,7 @@ async def test_swarm_backward_compatibility(): upgrader = Mock() transport = Mock() - swarm = Swarm(peer_id, peerstore, upgrader, transport) + swarm = Swarm(peer_id, peerstore, upgrader, [transport]) # Add connections conn1 = MockConnection(peer_id) diff --git a/tests/core/network/test_stream_semaphore.py b/tests/core/network/test_stream_semaphore.py index 1dde5ee62..331e5c147 100644 --- a/tests/core/network/test_stream_semaphore.py +++ b/tests/core/network/test_stream_semaphore.py @@ -28,7 +28,7 @@ def _make_swarm( upgrader = Mock() transport = Mock() - swarm = Swarm(peer_id, peerstore, upgrader, transport) + swarm = Swarm(peer_id, peerstore, upgrader, [transport]) rm = ResourceManager(limits=ResourceLimits(max_streams=max_streams)) swarm.set_resource_manager(rm, enable_stream_semaphore=enable_semaphore) return swarm diff --git a/tests/core/network/test_swarm.py b/tests/core/network/test_swarm.py index 8c7a78ea1..39b0bf867 100644 --- a/tests/core/network/test_swarm.py +++ b/tests/core/network/test_swarm.py @@ -47,6 +47,31 @@ def __init__(self, private_key, config=None, enable_autotls=False): self.enable_autotls = enable_autotls +def test_swarm_legacy_keyword_raises_typeerror(): + from unittest.mock import Mock + + with pytest.raises(TypeError, match="no longer accepts 'transport='"): + Swarm(Mock(), Mock(), Mock(), transport=Mock()) + + +def test_swarm_positional_backward_compatibility(): + from unittest.mock import Mock + + peer_id = Mock() + peerstore = Mock() + upgrader = Mock() + + # Passing a single transport positionally should be wrapped in a list internally + # and added to the transport manager. + transport = Mock() + swarm1 = Swarm(peer_id, peerstore, upgrader, transport) + assert len(swarm1.transport_manager.get_transports()) == 1 + + # Passing a list of transports positionally + swarm2 = Swarm(peer_id, peerstore, upgrader, [transport]) + assert len(swarm2.transport_manager.get_transports()) == 1 + + @pytest.mark.trio async def test_swarm_dial_peer(security_protocol): async with SwarmFactory.create_batch_and_listen( @@ -257,14 +282,14 @@ def clear(): def test_new_swarm_defaults_to_tcp(): swarm = new_swarm() assert isinstance(swarm, Swarm) - assert isinstance(swarm.transport, TCP) + assert isinstance(swarm.transport_manager.get_transports()[0], TCP) def test_new_swarm_tcp_multiaddr_supported(): addr = Multiaddr("/ip4/127.0.0.1/tcp/9999") swarm = new_swarm(listen_addrs=[addr]) assert isinstance(swarm, Swarm) - assert isinstance(swarm.transport, TCP) + assert isinstance(swarm.transport_manager.get_transports()[0], TCP) def test_new_swarm_quic_multiaddr_supported(): @@ -273,15 +298,13 @@ def test_new_swarm_quic_multiaddr_supported(): addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic") swarm = new_swarm(listen_addrs=[addr]) assert isinstance(swarm, Swarm) - assert isinstance(swarm.transport, QUICTransport) + assert isinstance(swarm.transport_manager.get_transports()[0], QUICTransport) def test_new_swarm_quic_paths_propagate_enable_autotls(monkeypatch): import libp2p as libp2p_module - from libp2p.transport import transport_registry key_pair = generate_new_ed25519_identity() - original_quic_transport = libp2p_module.QUICTransport # Path 1: direct QUIC creation when listen_addrs is None and enable_quic=True. monkeypatch.setattr(libp2p_module, "QUICTransport", _FakeQUICTransport) @@ -291,47 +314,22 @@ def test_new_swarm_quic_paths_propagate_enable_autotls(monkeypatch): enable_autotls=True, ) assert isinstance(swarm_direct, Swarm) - assert isinstance(swarm_direct.transport, _FakeQUICTransport) - assert swarm_direct.transport.enable_autotls is True - - # Path 2: registry-based creation should receive enable_autotls in kwargs. - registry_calls = [] - - def fake_create_transport_for_multiaddr(maddr, upgrader, **kwargs): - registry_calls.append(kwargs) - return _FakeQUICTransport( - kwargs["private_key"], - config=kwargs.get("config"), - enable_autotls=kwargs.get("enable_autotls", False), - ) + transport1 = swarm_direct.transport_manager.get_transports()[0] + assert isinstance(transport1, _FakeQUICTransport) + assert transport1.enable_autotls is True - monkeypatch.setattr( - transport_registry, - "create_transport_for_multiaddr", - fake_create_transport_for_multiaddr, - ) + # Path 2: list_addrs based creation should receive enable_autotls in kwargs. swarm_registry = new_swarm( key_pair=key_pair, listen_addrs=[Multiaddr("/ip4/127.0.0.1/udp/9999/quic")], enable_autotls=True, ) - assert registry_calls[0]["enable_autotls"] is True assert isinstance(swarm_registry, Swarm) - assert isinstance(swarm_registry.transport, _FakeQUICTransport) - assert swarm_registry.transport.enable_autotls is True - - # Path 3: forced-QUIC fallback when enable_quic=True but registry gives non-QUIC. - monkeypatch.setattr(libp2p_module, "QUICTransport", original_quic_transport) + transport2 = swarm_registry.transport_manager.get_transports()[0] + assert isinstance(transport2, _FakeQUICTransport) + assert transport2.enable_autotls is True - def fake_create_transport_for_multiaddr_non_quic(maddr, upgrader, **kwargs): - return TCP() - - monkeypatch.setattr( - transport_registry, - "create_transport_for_multiaddr", - fake_create_transport_for_multiaddr_non_quic, - ) - monkeypatch.setattr(libp2p_module, "QUICTransport", _FakeQUICTransport) + # Path 3: forced-QUIC fallback when enable_quic=True but no QUIC addr provided. swarm_forced = new_swarm( key_pair=key_pair, listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/9999")], @@ -339,8 +337,9 @@ def fake_create_transport_for_multiaddr_non_quic(maddr, upgrader, **kwargs): enable_autotls=True, ) assert isinstance(swarm_forced, Swarm) - assert isinstance(swarm_forced.transport, _FakeQUICTransport) - assert swarm_forced.transport.enable_autotls is True + transport3 = swarm_forced.transport_manager.get_transports()[0] + assert isinstance(transport3, _FakeQUICTransport) + assert transport3.enable_autotls is True def test_new_swarm_defaults_to_ed25519(): diff --git a/tests/core/test_libp2p/test_libp2p.py b/tests/core/test_libp2p/test_libp2p.py index 9bcea6127..8c88a6556 100644 --- a/tests/core/test_libp2p/test_libp2p.py +++ b/tests/core/test_libp2p/test_libp2p.py @@ -1,5 +1,4 @@ import pytest -import multiaddr from libp2p.custom_types import ( TProtocol, @@ -304,6 +303,8 @@ async def test_host_connect(security_protocol): assert len(hosts[0].get_peerstore().peer_ids()) == 2 assert hosts[1].get_id() in hosts[0].get_peerstore().peer_ids() - ma_node_b = multiaddr.Multiaddr("/p2p/%s" % hosts[1].get_id().pretty()) - for addr in hosts[0].get_peerstore().addrs(hosts[1].get_id()): - assert addr.encapsulate(ma_node_b) in hosts[1].get_addrs() + # Ensure host 0 learned all of host 1's advertised addresses + host1_advertised_addrs = hosts[1].get_addrs() + host0_known_addrs = hosts[0].get_peerstore().addrs(hosts[1].get_id()) + for addr in host1_advertised_addrs: + assert addr in host0_known_addrs diff --git a/tests/core/tools/anyio_service/test_anyio_based_service.py b/tests/core/tools/anyio_service/test_anyio_based_service.py index 4025200bc..b6a5ec514 100644 --- a/tests/core/tools/anyio_service/test_anyio_based_service.py +++ b/tests/core/tools/anyio_service/test_anyio_based_service.py @@ -608,7 +608,7 @@ class TryFinallyService(Service): async def run(self) -> None: try: - ready_cancel.set() + _ = ready_cancel.set() await self.manager.wait_finished() finally: with CancelScope(shield=True): diff --git a/tests/core/tools/anyio_service/test_trio_based_service.py b/tests/core/tools/anyio_service/test_trio_based_service.py index 4d89b0e9c..0395f4035 100644 --- a/tests/core/tools/anyio_service/test_trio_based_service.py +++ b/tests/core/tools/anyio_service/test_trio_based_service.py @@ -616,7 +616,8 @@ async def run(self) -> None: ready_cancel.set() await self.manager.wait_finished() finally: - with trio.CancelScope(shield=True): # type: ignore[call-arg] + with trio.CancelScope() as scope: + scope.shield = True await trio.lowlevel.checkpoint() self.cleanup_up = True diff --git a/tests/core/transport/test_transport_manager.py b/tests/core/transport/test_transport_manager.py new file mode 100644 index 000000000..9fafee135 --- /dev/null +++ b/tests/core/transport/test_transport_manager.py @@ -0,0 +1,312 @@ +""" +Unit tests for libp2p.transport.manager.TransportManager. + +Tests verify that the TransportManager correctly: + - Routes dialing to the right transport (TCP, WebSocket, QUIC) + - Routes listening to the right transport + - Returns None when no transport matches + - Ensures TCP does not match WebSocket/QUIC addresses + - Delegates set_background_nursery / set_swarm to transports that need it + - Provides has_transport_for() introspection +""" + +from __future__ import annotations + +import typing + +from multiaddr import Multiaddr +import trio + +from libp2p.abc import ITransport +from libp2p.transport.manager import TransportManager +from libp2p.transport.tcp.tcp import TCP + + +class StubTransport(ITransport): + """ + Minimal stub transport for unit-testing routing logic. + + ``can_dial`` returns True when ANY protocol name in ``proto_list`` is + present in the multiaddr's protocol names, and ``can_handle`` is True. + + The TCP-overlap problem (where /tcp/ws matches a tcp-only stub) is handled + in the tests by registering the TCP stub with ``["tcp"]`` AND relying on + the real TCP transport's can_dial() for the "does not steal WS" test. + """ + + def __init__(self, proto_list: list[str], *, can_handle: bool = True): + self._protocols = proto_list + self._can_handle = can_handle + self.background_nursery_set: object | None = None + self.swarm_set: object | None = None + + async def dial(self, maddr: Multiaddr): + raise NotImplementedError + + def create_listener(self, handler_function): + raise NotImplementedError + + def can_dial(self, maddr: Multiaddr) -> bool: + names = {p.name for p in maddr.protocols()} + return self._can_handle and bool(names.intersection(set(self._protocols))) + + def can_listen(self, maddr: Multiaddr) -> bool: + return self.can_dial(maddr) + + def protocols(self) -> list[str]: + return list(self._protocols) + + def set_background_nursery(self, nursery: object) -> None: + self.background_nursery_set = nursery + + def set_swarm(self, swarm: object) -> None: + self.swarm_set = swarm + + +# --------------------------------------------------------------------------- +# Basic routing tests +# --------------------------------------------------------------------------- + + +class TestTransportManagerRouting: + mgr: TransportManager + tcp_stub: StubTransport + ws_stub: StubTransport + quic_stub: StubTransport + + def setup_method(self) -> None: + self.mgr = TransportManager() + # Use single-protocol lists so each stub only matches its own proto. + # TCP: only "/tcp/…" addresses (no "/ws" or "/wss" suffix). + # The real TCP transport's can_dial handles the exclusion; for the + # stub we keep it simple and register "tcp" only. The dialing test + # uses addresses that each have exactly ONE distinguishing protocol. + self.tcp_stub = StubTransport(["tcp"]) + self.ws_stub = StubTransport(["ws"]) # only "ws" (not "wss") + self.quic_stub = StubTransport(["quic-v1"]) + self.mgr.add_transports([self.tcp_stub, self.ws_stub, self.quic_stub]) + + def test_for_dialing_routes_tcp(self): + # Pure TCP address: no /ws, no /quic-v1 + t = self.mgr.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/tcp/4001")) + assert t is self.tcp_stub + + def test_for_dialing_routes_websocket(self): + # WebSocket address: has /ws + self.mgr.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")) + # The tcp_stub's can_dial checks "tcp in names" -> True, so it may be + # returned first. For correct routing this test relies on the REAL + # TCP transport which excludes ws — tested separately. + # For this stub-based test, register ws_stub BEFORE tcp_stub. + mgr2 = TransportManager() + mgr2.add_transports([self.ws_stub, self.tcp_stub]) + t2 = mgr2.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")) + assert t2 is self.ws_stub + + def test_for_dialing_routes_quic_v1(self): + t = self.mgr.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1")) + assert t is self.quic_stub + + def test_for_dialing_returns_none_for_unknown(self): + mgr = TransportManager() + mgr.add_transport(self.tcp_stub) + # No transport registered for QUIC-only + t = mgr.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1")) + assert t is None + + def test_for_listening_routes_tcp(self): + t = self.mgr.transport_for_listening(Multiaddr("/ip4/127.0.0.1/tcp/4001")) + assert t is self.tcp_stub + + def test_for_listening_routes_websocket(self): + mgr2 = TransportManager() + mgr2.add_transports([self.ws_stub, self.tcp_stub]) + t = mgr2.transport_for_listening(Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")) + assert t is self.ws_stub + + def test_for_listening_returns_none_when_no_match(self): + mgr = TransportManager() + mgr.add_transport(self.tcp_stub) + t = mgr.transport_for_listening(Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1")) + assert t is None + + def test_first_registered_transport_wins(self): + """When two transports claim the same protocol, first wins.""" + mgr = TransportManager() + first = StubTransport(["tcp"]) + second = StubTransport(["tcp"]) + mgr.add_transports([first, second]) + t = mgr.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/tcp/4001")) + assert t is first + + def test_has_transport_for_true(self): + assert self.mgr.has_transport_for(Multiaddr("/ip4/127.0.0.1/tcp/4001")) is True + + def test_has_transport_for_false(self): + mgr = TransportManager() + mgr.add_transport(self.tcp_stub) + assert mgr.has_transport_for(Multiaddr("/ip4/127.0.0.1/udp/9/quic-v1")) is False + + +# --------------------------------------------------------------------------- +# TCP must not match WebSocket or QUIC addresses +# --------------------------------------------------------------------------- + + +class TestTCPTransportCandidateBehavior: + """Verify that the real TCP transport correctly rejects ws/quic addresses.""" + + tcp: TCP + + def setup_method(self) -> None: + self.tcp = TCP() + + def test_tcp_matches_pure_tcp(self): + assert self.tcp.can_dial(Multiaddr("/ip4/127.0.0.1/tcp/4001")) is True + + def test_tcp_rejects_websocket(self): + assert self.tcp.can_dial(Multiaddr("/ip4/127.0.0.1/tcp/4001/ws")) is False + + def test_tcp_rejects_wss(self): + assert self.tcp.can_dial(Multiaddr("/ip4/127.0.0.1/tcp/4001/wss")) is False + + def test_tcp_rejects_quic(self): + assert self.tcp.can_dial(Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1")) is False + + def test_tcp_protocols_list(self): + assert self.tcp.protocols() == ["tcp"] + + def test_tcp_can_listen_mirrors_can_dial(self): + assert self.tcp.can_listen(Multiaddr("/ip4/127.0.0.1/tcp/0")) is True + assert self.tcp.can_listen(Multiaddr("/ip4/127.0.0.1/tcp/0/ws")) is False + + def test_manager_tcp_does_not_steal_ws(self): + """When both TCP and WS stubs are registered, WS addr must go to WS stub.""" + mgr = TransportManager() + mgr.add_transport(self.tcp) # TCP is registered first + ws_stub = StubTransport(["ws", "wss"]) + mgr.add_transport(ws_stub) + t = mgr.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")) + assert t is ws_stub + + +# --------------------------------------------------------------------------- +# Nursery / swarm delegation +# --------------------------------------------------------------------------- + + +class NoLifecycleStub(ITransport): + """A transport stub that does NOT expose set_background_nursery or set_swarm.""" + + def __init__(self, proto_list: list[str]): + self._protocols = proto_list + + async def dial(self, maddr: Multiaddr): + raise NotImplementedError + + def create_listener(self, handler_function): + raise NotImplementedError + + def can_dial(self, maddr: Multiaddr) -> bool: + names = {p.name for p in maddr.protocols()} + return bool(names.intersection(set(self._protocols))) + + def can_listen(self, maddr: Multiaddr) -> bool: + return self.can_dial(maddr) + + def protocols(self) -> list[str]: + return list(self._protocols) + + +class TestTransportManagerLifecycle: + mgr: TransportManager + tcp_stub: StubTransport + ws_stub: StubTransport + quic_stub: NoLifecycleStub + + def setup_method(self) -> None: + self.mgr = TransportManager() + self.tcp_stub = StubTransport(["tcp"]) + self.ws_stub = StubTransport(["ws"]) + self.quic_stub = NoLifecycleStub(["quic-v1"]) + self.mgr.add_transports([self.tcp_stub, self.ws_stub, self.quic_stub]) + + def test_set_background_nursery_delegates_to_all(self): + fake_nursery = typing.cast(trio.Nursery, object()) + self.mgr.set_background_nursery(fake_nursery) + assert self.tcp_stub.background_nursery_set is fake_nursery + assert self.ws_stub.background_nursery_set is fake_nursery + # quic_stub has no set_background_nursery; must not raise + + def test_set_swarm_delegates_to_all(self): + fake_swarm = object() + self.mgr.set_swarm(fake_swarm) + assert self.tcp_stub.swarm_set is fake_swarm + assert self.ws_stub.swarm_set is fake_swarm + + def test_get_transports_returns_copy(self): + result = self.mgr.get_transports() + assert result == [self.tcp_stub, self.ws_stub, self.quic_stub] + # Mutating the returned list must not affect the manager + result.clear() + assert len(self.mgr.get_transports()) == 3 + + +# --------------------------------------------------------------------------- +# Empty manager edge cases +# --------------------------------------------------------------------------- + + +class TestTransportManagerEmpty: + def test_for_dialing_empty_returns_none(self): + mgr = TransportManager() + assert mgr.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/tcp/4001")) is None + + def test_for_listening_empty_returns_none(self): + mgr = TransportManager() + assert mgr.transport_for_listening(Multiaddr("/ip4/127.0.0.1/tcp/4001")) is None + + def test_has_transport_for_empty(self): + mgr = TransportManager() + assert mgr.has_transport_for(Multiaddr("/ip4/127.0.0.1/tcp/4001")) is False + + def test_set_nursery_empty_does_not_raise(self): + mgr = TransportManager() + fake_nursery = typing.cast(trio.Nursery, object()) + mgr.set_background_nursery(fake_nursery) # Must not raise + + def test_set_swarm_empty_does_not_raise(self): + mgr = TransportManager() + mgr.set_swarm(object()) # Must not raise + + +# --------------------------------------------------------------------------- +# Pre-filter correctness +# --------------------------------------------------------------------------- + + +class TestTransportManagerPreFilter: + """Verify the protocol-name pre-filter prevents spurious can_dial calls.""" + + def test_can_dial_not_called_when_no_proto_overlap(self): + """ + A transport whose protocols() has no overlap should not have can_dial + called. + """ + mgr = TransportManager() + tcp = TCP() + + call_count = [0] + + def dummy_can_dial(maddr: Multiaddr) -> bool: + call_count[0] += 1 + return True + + tcp.can_dial = dummy_can_dial # type: ignore[method-assign] + mgr.add_transport(tcp) + + # QUIC address has no "tcp" protocol -> pre-filter should block it + mgr.transport_for_dialing(Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1")) + assert call_count[0] == 0, ( + "can_dial should NOT be called when no protocol overlap" + ) diff --git a/tests/core/transport/test_transport_registry.py b/tests/core/transport/test_transport_registry.py deleted file mode 100644 index 31b398736..000000000 --- a/tests/core/transport/test_transport_registry.py +++ /dev/null @@ -1,356 +0,0 @@ -""" -Tests for the transport registry functionality. -""" - -from unittest.mock import Mock, patch - -from multiaddr import Multiaddr - -from libp2p.abc import IListener, IRawConnection, ITransport -from libp2p.custom_types import THandler -from libp2p.transport.quic.config import QUICTransportConfig -from libp2p.transport.tcp.tcp import TCP -from libp2p.transport.transport_registry import ( - TransportRegistry, - create_transport_for_multiaddr, - get_supported_transport_protocols, - get_transport_registry, - register_transport, -) -from libp2p.transport.upgrader import TransportUpgrader -from libp2p.transport.websocket.transport import WebsocketTransport - - -class TestTransportRegistry: - """Test the TransportRegistry class.""" - - def test_init(self): - """Test registry initialization.""" - registry = TransportRegistry() - assert isinstance(registry, TransportRegistry) - - # Check that default transports are registered - supported = registry.get_supported_protocols() - assert "tcp" in supported - assert "ws" in supported - - def test_register_transport(self): - """Test transport registration.""" - registry = TransportRegistry() - - # Register a custom transport - class CustomTransport(ITransport): - async def dial(self, maddr: Multiaddr) -> IRawConnection: - raise NotImplementedError("CustomTransport dial not implemented") - - def create_listener(self, handler_function: THandler) -> IListener: - raise NotImplementedError( - "CustomTransport create_listener not implemented" - ) - - registry.register_transport("custom", CustomTransport) - assert registry.get_transport("custom") == CustomTransport - - def test_get_transport(self): - """Test getting registered transports.""" - registry = TransportRegistry() - - # Test existing transports - assert registry.get_transport("tcp") == TCP - assert registry.get_transport("ws") == WebsocketTransport - - # Test non-existent transport - assert registry.get_transport("nonexistent") is None - - def test_get_supported_protocols(self): - """Test getting supported protocols.""" - registry = TransportRegistry() - protocols = registry.get_supported_protocols() - - assert isinstance(protocols, list) - assert "tcp" in protocols - assert "ws" in protocols - - def test_create_transport_tcp(self): - """Test creating TCP transport.""" - registry = TransportRegistry() - upgrader = TransportUpgrader({}, {}) - - transport = registry.create_transport("tcp", upgrader) - assert isinstance(transport, TCP) - - def test_create_transport_websocket(self): - """Test creating WebSocket transport.""" - registry = TransportRegistry() - upgrader = TransportUpgrader({}, {}) - - transport = registry.create_transport("ws", upgrader) - assert isinstance(transport, WebsocketTransport) - - def test_create_transport_invalid_protocol(self): - """Test creating transport with invalid protocol.""" - registry = TransportRegistry() - upgrader = TransportUpgrader({}, {}) - - transport = registry.create_transport("invalid", upgrader) - assert transport is None - - def test_create_transport_websocket_no_upgrader(self): - """Test that WebSocket transport requires upgrader.""" - registry = TransportRegistry() - - # This should fail gracefully and return None - transport = registry.create_transport("ws", None) - assert transport is None - - def test_transport_registry_forwards_enable_autotls_for_quic(self): - """Test that QUIC transport creation forwards enable_autotls flag.""" - mock_transport = Mock() - mock_quic_class = Mock(return_value=mock_transport) - mock_quic_class.__name__ = "MockQUICTransport" - private_key = object() - - with patch( - "libp2p.transport.transport_registry._get_quic_transport", - return_value=mock_quic_class, - ): - registry = TransportRegistry() - upgrader = TransportUpgrader({}, {}) - - for protocol in ("quic", "quic-v1"): - result = registry.create_transport( - protocol, - upgrader, - private_key=private_key, - enable_autotls=True, - ) - assert result is mock_transport - - assert mock_quic_class.call_count == 2 - for call in mock_quic_class.call_args_list: - assert call.args[0] is private_key - assert call.kwargs["enable_autotls"] is True - assert isinstance(call.kwargs["config"], QUICTransportConfig) - - -class TestGlobalRegistry: - """Test the global registry functions.""" - - def test_get_transport_registry(self): - """Test getting the global registry.""" - registry = get_transport_registry() - assert isinstance(registry, TransportRegistry) - - def test_register_transport_global(self): - """Test registering transport globally.""" - - class GlobalCustomTransport(ITransport): - async def dial(self, maddr: Multiaddr) -> IRawConnection: - raise NotImplementedError("GlobalCustomTransport dial not implemented") - - def create_listener(self, handler_function: THandler) -> IListener: - raise NotImplementedError( - "GlobalCustomTransport create_listener not implemented" - ) - - # Register globally - register_transport("global_custom", GlobalCustomTransport) - - # Check that it's available - registry = get_transport_registry() - assert registry.get_transport("global_custom") == GlobalCustomTransport - - def test_get_supported_transport_protocols_global(self): - """Test getting supported protocols from global registry.""" - protocols = get_supported_transport_protocols() - assert isinstance(protocols, list) - assert "tcp" in protocols - assert "ws" in protocols - - -class TestTransportFactory: - """Test the transport factory functions.""" - - def test_create_transport_for_multiaddr_tcp(self): - """Test creating transport for TCP multiaddr.""" - upgrader = TransportUpgrader({}, {}) - - # TCP multiaddr - maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080") - transport = create_transport_for_multiaddr(maddr, upgrader) - - assert transport is not None - assert isinstance(transport, TCP) - - def test_create_transport_for_multiaddr_websocket(self): - """Test creating transport for WebSocket multiaddr.""" - upgrader = TransportUpgrader({}, {}) - - # WebSocket multiaddr - maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") - transport = create_transport_for_multiaddr(maddr, upgrader) - - assert transport is not None - assert isinstance(transport, WebsocketTransport) - - def test_create_transport_for_multiaddr_websocket_secure(self): - """Test creating transport for WebSocket multiaddr.""" - upgrader = TransportUpgrader({}, {}) - - # WebSocket multiaddr - maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") - transport = create_transport_for_multiaddr(maddr, upgrader) - - assert transport is not None - assert isinstance(transport, WebsocketTransport) - - def test_create_transport_for_multiaddr_ipv6(self): - """Test creating transport for IPv6 multiaddr.""" - upgrader = TransportUpgrader({}, {}) - - # IPv6 WebSocket multiaddr - maddr = Multiaddr("/ip6/::1/tcp/8080/ws") - transport = create_transport_for_multiaddr(maddr, upgrader) - - assert transport is not None - assert isinstance(transport, WebsocketTransport) - - def test_create_transport_for_multiaddr_dns(self): - """Test creating transport for DNS multiaddr.""" - upgrader = TransportUpgrader({}, {}) - - # DNS WebSocket multiaddr - maddr = Multiaddr("/dns4/example.com/tcp/443/ws") - transport = create_transport_for_multiaddr(maddr, upgrader) - - assert transport is not None - assert isinstance(transport, WebsocketTransport) - - def test_create_transport_for_multiaddr_unknown(self): - """Test creating transport for unknown multiaddr.""" - upgrader = TransportUpgrader({}, {}) - - # Unknown multiaddr - maddr = Multiaddr("/ip4/127.0.0.1/udp/8080") - transport = create_transport_for_multiaddr(maddr, upgrader) - - assert transport is None - - def test_create_transport_for_multiaddr_with_upgrader(self): - """Test creating transport with upgrader.""" - upgrader = TransportUpgrader({}, {}) - - # This should work for both TCP and WebSocket with upgrader - maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080") - transport_tcp = create_transport_for_multiaddr(maddr_tcp, upgrader) - assert transport_tcp is not None - - maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") - transport_ws = create_transport_for_multiaddr(maddr_ws, upgrader) - assert transport_ws is not None - - -class TestTransportInterfaceCompliance: - """Test that all transports implement the required interface.""" - - def test_tcp_implements_itransport(self): - """Test that TCP transport implements ITransport.""" - transport = TCP() - assert isinstance(transport, ITransport) - assert hasattr(transport, "dial") - assert hasattr(transport, "create_listener") - assert callable(transport.dial) - assert callable(transport.create_listener) - - def test_websocket_implements_itransport(self): - """Test that WebSocket transport implements ITransport.""" - upgrader = TransportUpgrader({}, {}) - transport = WebsocketTransport(upgrader) - assert isinstance(transport, ITransport) - assert hasattr(transport, "dial") - assert hasattr(transport, "create_listener") - assert callable(transport.dial) - assert callable(transport.create_listener) - - -class TestErrorHandling: - """Test error handling in the transport registry.""" - - def test_create_transport_with_exception(self): - """Test handling of transport creation exceptions.""" - registry = TransportRegistry() - upgrader = TransportUpgrader({}, {}) - - # Register a transport that raises an exception - class ExceptionTransport(ITransport): - def __init__(self, *args, **kwargs): - raise RuntimeError("Transport creation failed") - - async def dial(self, maddr: Multiaddr) -> IRawConnection: - raise NotImplementedError("ExceptionTransport dial not implemented") - - def create_listener(self, handler_function: THandler) -> IListener: - raise NotImplementedError( - "ExceptionTransport create_listener not implemented" - ) - - registry.register_transport("exception", ExceptionTransport) - - # Should handle exception gracefully and return None - transport = registry.create_transport("exception", upgrader) - assert transport is None - - def test_invalid_multiaddr_handling(self): - """Test handling of invalid multiaddrs.""" - upgrader = TransportUpgrader({}, {}) - - # Test with a multiaddr that has an unsupported transport protocol - # This should be handled gracefully by our transport registry - # udp is not a supported transport - maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234") - transport = create_transport_for_multiaddr(maddr, upgrader) - - assert transport is None - - -class TestIntegration: - """Test integration scenarios.""" - - def test_multiple_transport_types(self): - """Test using multiple transport types in the same registry.""" - registry = TransportRegistry() - upgrader = TransportUpgrader({}, {}) - - # Create different transport types - tcp_transport = registry.create_transport("tcp", upgrader) - ws_transport = registry.create_transport("ws", upgrader) - - # All should be different types - assert isinstance(tcp_transport, TCP) - assert isinstance(ws_transport, WebsocketTransport) - - # All should be different instances - assert tcp_transport is not ws_transport - - def test_transport_registry_persistence(self): - """Test that transport registry persists across calls.""" - registry1 = get_transport_registry() - registry2 = get_transport_registry() - - # Should be the same instance - assert registry1 is registry2 - - # Register a transport in one - class PersistentTransport(ITransport): - async def dial(self, maddr: Multiaddr) -> IRawConnection: - raise NotImplementedError("PersistentTransport dial not implemented") - - def create_listener(self, handler_function: THandler) -> IListener: - raise NotImplementedError( - "PersistentTransport create_listener not implemented" - ) - - registry1.register_transport("persistent", PersistentTransport) - - # Should be available in the other - assert registry2.get_transport("persistent") == PersistentTransport diff --git a/tests/core/transport/websocket/test_websocket.py b/tests/core/transport/websocket/test_websocket.py index 02e313ab2..4f8cfbb3f 100644 --- a/tests/core/transport/websocket/test_websocket.py +++ b/tests/core/transport/websocket/test_websocket.py @@ -54,7 +54,7 @@ async def make_host( # Transport + Swarm + Host transport = WebsocketTransport(upgrader) - swarm = Swarm(peer_id, peer_store, upgrader, transport) + swarm = Swarm(peer_id, peer_store, upgrader, [transport]) host = BasicHost(swarm) # Optionally run/listen @@ -1010,40 +1010,6 @@ async def dummy_handler(conn): await listener.close() -def test_wss_transport_registry(): - """Test WSS support in transport registry.""" - from libp2p.transport.transport_registry import ( - create_transport_for_multiaddr, - get_supported_transport_protocols, - ) - - # Test that WSS is supported - supported = get_supported_transport_protocols() - assert "ws" in supported - assert "wss" in supported - - # Test transport creation for WSS multiaddrs - upgrader = create_upgrader() - - # Test WS multiaddr - ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") - ws_transport = create_transport_for_multiaddr(ws_maddr, upgrader) - assert ws_transport is not None - assert isinstance(ws_transport, WebsocketTransport) - - # Test WSS multiaddr - wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss") - wss_transport = create_transport_for_multiaddr(wss_maddr, upgrader) - assert wss_transport is not None - assert isinstance(wss_transport, WebsocketTransport) - - # Test TLS/WS multiaddr - tls_ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/tls/ws") - tls_ws_transport = create_transport_for_multiaddr(tls_ws_maddr, upgrader) - assert tls_ws_transport is not None - assert isinstance(tls_ws_transport, WebsocketTransport) - - def test_wss_multiaddr_formats(): """Test different WSS multiaddr formats.""" # Test various WSS formats @@ -1160,15 +1126,17 @@ async def test_handshake_timeout_creation(): upgrader = create_upgrader() # Test creating transport with handshake timeout via create_transport - from libp2p.transport import create_transport + from libp2p.transport.websocket.transport import WebsocketConfig, WebsocketTransport - transport = create_transport("ws", upgrader, handshake_timeout=5.0) + config1 = WebsocketConfig(handshake_timeout=5.0) + transport = WebsocketTransport(upgrader, config=config1) # Type assertion to access private attribute for testing assert hasattr(transport, "_handshake_timeout") assert getattr(transport, "_handshake_timeout") == 5.0 # Test default timeout - transport_default = create_transport("ws", upgrader) + config2 = WebsocketConfig() + transport_default = WebsocketTransport(upgrader, config=config2) assert hasattr(transport_default, "_handshake_timeout") assert getattr(transport_default, "_handshake_timeout") == 15.0 diff --git a/tests/core/transport/websocket/test_websocket_integration.py b/tests/core/transport/websocket/test_websocket_integration.py index 95dd06597..5f96b7894 100644 --- a/tests/core/transport/websocket/test_websocket_integration.py +++ b/tests/core/transport/websocket/test_websocket_integration.py @@ -86,7 +86,7 @@ async def create_websocket_host( upgrader = create_plaintext_upgrader(key_pair) transport = WebsocketTransport(upgrader) - swarm = Swarm(peer_id, peer_store, upgrader, transport) + swarm = Swarm(peer_id, peer_store, upgrader, [transport]) host = BasicHost(swarm) # Start swarm with background_trio_service diff --git a/tests/transport/test_multi_port_demux.py b/tests/transport/test_multi_port_demux.py new file mode 100644 index 000000000..db5758e49 --- /dev/null +++ b/tests/transport/test_multi_port_demux.py @@ -0,0 +1,79 @@ +import pytest +from multiaddr import Multiaddr + +from libp2p import new_host +from libp2p.peer.peerinfo import PeerInfo + +pytestmark = pytest.mark.trio + + +async def test_multi_port_demux(): + """ + Test that listening on multiple different ports, each with a TCP and WS component, + works correctly and creates multiple PortDemultiplexers. + """ + server = new_host( + enable_quic=False, + enable_websocket=True, + listen_addrs=[ + Multiaddr("/ip4/127.0.0.1/tcp/43510"), + Multiaddr("/ip4/127.0.0.1/tcp/43510/ws"), + ], + ) + async with server.run( + listen_addrs=[ + Multiaddr("/ip4/127.0.0.1/tcp/43510"), + Multiaddr("/ip4/127.0.0.1/tcp/43510/ws"), + ] + ): + addrs = server.get_addrs() + + tcp_port = None + ws_port = None + for a in addrs: + port = a.value_for_protocol("tcp") + if "ws" in str(a): + ws_port = port + else: + tcp_port = port + + assert tcp_port == ws_port, ( + f"Expected same port for TCP and WS, got {tcp_port} and {ws_port}" + ) + + # Now listen on two explicitly different ports to ensure they don't overwrite + # each other or silently break. + server2 = new_host( + enable_quic=False, + enable_websocket=True, + listen_addrs=[ + Multiaddr("/ip4/127.0.0.1/tcp/43511"), + Multiaddr("/ip4/127.0.0.1/tcp/43511/ws"), + Multiaddr("/ip4/127.0.0.1/tcp/43512"), + Multiaddr("/ip4/127.0.0.1/tcp/43512/ws"), + ], + ) + async with server2.run( + listen_addrs=[ + Multiaddr("/ip4/127.0.0.1/tcp/43511"), + Multiaddr("/ip4/127.0.0.1/tcp/43511/ws"), + Multiaddr("/ip4/127.0.0.1/tcp/43512"), + Multiaddr("/ip4/127.0.0.1/tcp/43512/ws"), + ] + ): + addrs2 = server2.get_addrs() + assert len(addrs2) == 4, f"Expected 4 addrs, got {addrs2}" + + client1 = new_host(listen_addrs=[]) + async with client1.run(listen_addrs=[]): + await client1.connect( + PeerInfo(server2.get_id(), [Multiaddr("/ip4/127.0.0.1/tcp/43511")]) + ) + assert client1.get_network().get_connection(server2.get_id()) is not None + + client2 = new_host(listen_addrs=[]) + async with client2.run(listen_addrs=[]): + await client2.connect( + PeerInfo(server2.get_id(), [Multiaddr("/ip4/127.0.0.1/tcp/43512")]) + ) + assert client2.get_network().get_connection(server2.get_id()) is not None diff --git a/tests/transport/test_simultaneous_transports.py b/tests/transport/test_simultaneous_transports.py new file mode 100644 index 000000000..c876e7128 --- /dev/null +++ b/tests/transport/test_simultaneous_transports.py @@ -0,0 +1,90 @@ +import pytest +from multiaddr import Multiaddr +import trio + +from libp2p import new_host +from libp2p.crypto.ed25519 import create_new_key_pair as generate_new_ed25519_identity +from libp2p.peer.peerinfo import PeerInfo + +pytestmark = pytest.mark.trio + + +async def test_all_three_bind_concurrently(): + """ + Three listeners must all be active after a single host.run() call. + Timing: all three should start within a few hundred ms of each other + (parallelism check). + """ + kp = generate_new_ed25519_identity() + host = new_host( + key_pair=kp, + enable_quic=True, + enable_websocket=True, + listen_addrs=[ + Multiaddr("/ip4/127.0.0.1/tcp/0"), + Multiaddr("/ip4/127.0.0.1/udp/0/quic-v1"), + Multiaddr("/ip4/127.0.0.1/tcp/0/ws"), + ], + ) + start = trio.current_time() + async with host.run( + listen_addrs=[ + Multiaddr("/ip4/127.0.0.1/tcp/0"), + Multiaddr("/ip4/127.0.0.1/udp/0/quic-v1"), + Multiaddr("/ip4/127.0.0.1/tcp/0/ws"), + ] + ): + elapsed = trio.current_time() - start + addrs = host.get_addrs() + assert len(addrs) == 3, f"Expected 3 addrs, got {addrs}" + assert elapsed < 2.0, f"Parallel startup took {elapsed:.2f}s — too slow" + + +async def test_tcp_client_connects_to_tcp_listener(): + """TCP client must connect to a host that also has QUIC and WS active.""" + server = new_host( + enable_quic=True, + enable_websocket=True, + listen_addrs=[ + Multiaddr("/ip4/127.0.0.1/tcp/0"), + Multiaddr("/ip4/127.0.0.1/udp/0/quic-v1"), + Multiaddr("/ip4/127.0.0.1/tcp/0/ws"), + ], + ) + async with server.run( + listen_addrs=[ + Multiaddr("/ip4/127.0.0.1/tcp/0"), + Multiaddr("/ip4/127.0.0.1/udp/0/quic-v1"), + Multiaddr("/ip4/127.0.0.1/tcp/0/ws"), + ] + ): + tcp_addr = next( + a for a in server.get_addrs() if "tcp" in str(a) and "ws" not in str(a) + ) + client = new_host(listen_addrs=[]) + async with client.run(listen_addrs=[]): + await client.connect(PeerInfo(server.get_id(), [tcp_addr])) + assert client.get_network().get_connection(server.get_id()) is not None + + +async def test_quic_client_connects_to_quic_listener(): + server = new_host( + enable_quic=True, + enable_websocket=True, + listen_addrs=[ + Multiaddr("/ip4/127.0.0.1/tcp/0"), + Multiaddr("/ip4/127.0.0.1/udp/0/quic-v1"), + ], + ) + async with server.run( + listen_addrs=[ + Multiaddr("/ip4/127.0.0.1/tcp/0"), + Multiaddr("/ip4/127.0.0.1/udp/0/quic-v1"), + ] + ): + quic_addr = next(a for a in server.get_addrs() if "quic" in str(a)) + client = new_host(enable_quic=True, listen_addrs=[]) + async with client.run(listen_addrs=[]): + await client.connect(PeerInfo(server.get_id(), [quic_addr])) + conn = client.get_network().get_connection(server.get_id()) + assert any("quic" in str(addr) for addr in conn.get_transport_addresses()) # type: ignore diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 2198d2503..b19c2fd38 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -465,7 +465,7 @@ class Params: o.muxer_opt, ) ) - transport = factory.LazyFunction(TCP) + transports = factory.LazyFunction(lambda: [TCP()]) @classmethod @asynccontextmanager diff --git a/tests/utils/test_logger_standardization.py b/tests/utils/test_logger_standardization.py index f0d99daa6..88ad51a3c 100644 --- a/tests/utils/test_logger_standardization.py +++ b/tests/utils/test_logger_standardization.py @@ -140,12 +140,10 @@ def test_logger_name_matches_module_path_transport(): """Test transport modules use correct logger names.""" from libp2p.stream_muxer.mplex import mplex from libp2p.stream_muxer.yamux import yamux - from libp2p.transport import transport_registry from libp2p.transport.websocket import autotls, listener assert listener.logger.name == "libp2p.transport.websocket.listener" assert autotls.logger.name == "libp2p.transport.websocket.autotls" - assert transport_registry.logger.name == "libp2p.transport.transport_registry" assert yamux.logger.name == "libp2p.stream_muxer.yamux.yamux" assert mplex.logger.name == "libp2p.stream_muxer.mplex.mplex" @@ -372,7 +370,6 @@ def test_all_updated_modules_have_correct_logger_names(): ) from libp2p.stream_muxer.mplex import mplex from libp2p.stream_muxer.yamux import yamux - from libp2p.transport import transport_registry from libp2p.transport.tcp import tcp from libp2p.transport.websocket import ( autotls, @@ -419,7 +416,6 @@ def test_all_updated_modules_have_correct_logger_names(): (varint, "libp2p.utils.varint"), (ws_listener, "libp2p.transport.websocket.listener"), (autotls, "libp2p.transport.websocket.autotls"), - (transport_registry, "libp2p.transport.transport_registry"), (yamux, "libp2p.stream_muxer.yamux.yamux"), (mplex, "libp2p.stream_muxer.mplex.mplex"), ]