diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 93d5da165..70c671549 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -126,5 +126,7 @@ jobs: if [[ "${{ matrix.toxenv }}" == "wheel" ]]; then python -m tox run -e windows-wheel else - python -m tox run -e py${{ matrix.python-version }}-${{ matrix.toxenv }} + tox_python="py${{ matrix.python-version }}" + tox_python="${tox_python//./}" + python -m tox run -e "${tox_python}-${{ matrix.toxenv }}" fi diff --git a/docs/examples.request_response.rst b/docs/examples.request_response.rst new file mode 100644 index 000000000..f05074c93 --- /dev/null +++ b/docs/examples.request_response.rst @@ -0,0 +1,25 @@ +Request/Response Demo +===================== + +This example demonstrates the high-level libp2p request/response helper using a +single JSON request and response over a dedicated protocol stream. + +.. code-block:: console + + $ request-response-demo + Listener ready, listening on: + ... + +Copy the printed command into another terminal, for example: + +.. code-block:: console + + $ request-response-demo -d /ip4/127.0.0.1/tcp/8000/p2p/ --message hello + Sent: hello + Received: {'message': 'hello', 'echo': 'HELLO', 'peer': ''} + +The full source code for this example is below: + +.. literalinclude:: ../examples/request_response/request_response_demo.py + :language: python + :linenos: diff --git a/docs/examples.rst b/docs/examples.rst index 09f0edc59..17c912c00 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -11,6 +11,7 @@ Examples examples.echo examples.echo_quic examples.ping + examples.request_response examples.interop examples.pubsub examples.bitswap diff --git a/docs/libp2p.request_response.rst b/docs/libp2p.request_response.rst new file mode 100644 index 000000000..391cad211 --- /dev/null +++ b/docs/libp2p.request_response.rst @@ -0,0 +1,37 @@ +libp2p.request_response package +=============================== + +Submodules +---------- + +libp2p.request_response.api module +---------------------------------- + +.. automodule:: libp2p.request_response.api + :members: + :undoc-members: + :show-inheritance: + +libp2p.request_response.codec module +------------------------------------ + +.. automodule:: libp2p.request_response.codec + :members: + :undoc-members: + :show-inheritance: + +libp2p.request_response.exceptions module +----------------------------------------- + +.. automodule:: libp2p.request_response.exceptions + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: libp2p.request_response + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/libp2p.rst b/docs/libp2p.rst index fb2ab82b0..035fb27e0 100644 --- a/docs/libp2p.rst +++ b/docs/libp2p.rst @@ -20,6 +20,7 @@ Subpackages libp2p.perf libp2p.protocol_muxer libp2p.pubsub + libp2p.request_response libp2p.rcmgr libp2p.records libp2p.relay diff --git a/examples/request_response/__init__.py b/examples/request_response/__init__.py new file mode 100644 index 000000000..0bdbc7813 --- /dev/null +++ b/examples/request_response/__init__.py @@ -0,0 +1 @@ +"""Examples for the request_response helper.""" diff --git a/examples/request_response/request_response_demo.py b/examples/request_response/request_response_demo.py new file mode 100644 index 000000000..811e892c8 --- /dev/null +++ b/examples/request_response/request_response_demo.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import argparse +import logging +import random +import secrets + +import multiaddr +import trio + +from libp2p import new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.request_response import JSONCodec, RequestResponse +from libp2p.utils.address_validation import ( + find_free_port, + get_available_interfaces, + get_optimal_binding_address, +) + +logging.basicConfig(level=logging.WARNING) +logging.getLogger("multiaddr").setLevel(logging.WARNING) +logging.getLogger("libp2p").setLevel(logging.WARNING) + +PROTOCOL_ID = TProtocol("/example/request-response/1.0.0") + + +async def run( + port: int, + destination: str | None, + message: str, + seed: int | None = None, +) -> None: + if port <= 0: + port = find_free_port() + listen_addr = get_available_interfaces(port) + + if seed is not None: + random.seed(seed) + secret_number = random.getrandbits(32 * 8) + secret = secret_number.to_bytes(length=32, byteorder="big") + else: + secret = secrets.token_bytes(32) + + host = new_host(key_pair=create_new_key_pair(secret)) + rr = RequestResponse(host) + codec = JSONCodec() + + async def handler(request: dict[str, str], context) -> dict[str, str]: + return { + "message": request["message"], + "echo": request["message"].upper(), + "peer": str(context.peer_id), + } + + async with host.run(listen_addrs=listen_addr), trio.open_nursery() as nursery: + nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) + print(f"I am {host.get_id().to_string()}") + + if not destination: + rr.set_handler(PROTOCOL_ID, handler=handler, codec=codec) + peer_id = host.get_id().to_string() + print("Listener ready, listening on:\n") + for addr in listen_addr: + print(f"{addr}/p2p/{peer_id}") + + optimal_addr = get_optimal_binding_address(port) + optimal_addr_with_peer = f"{optimal_addr}/p2p/{peer_id}" + print( + "\nRun this from the same folder in another console:\n\n" + f"request-response-demo -d {optimal_addr_with_peer} --message hello\n" + ) + print("Waiting for incoming requests...") + await trio.sleep_forever() + + destination_str = destination + if destination_str is None: + raise ValueError("destination is required in dialer mode") + maddr = multiaddr.Multiaddr(destination_str) + info = info_from_p2p_addr(maddr) + await host.connect(info) + response = await rr.send_request( + peer_id=info.peer_id, + protocol_ids=[PROTOCOL_ID], + request={"message": message}, + codec=codec, + ) + print(f"Sent: {message}") + print(f"Received: {response}") + + +def main() -> None: + description = """ + Demonstrates the request/response helper with a single JSON request and response. + Run once without -d to start a listener, then run again with -d to send a request. + """ + example_maddr = ( + "/ip4/[HOST_IP]/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" + ) + parser = argparse.ArgumentParser(description=description) + parser.add_argument("-p", "--port", default=0, type=int, help="source port") + parser.add_argument( + "-d", + "--destination", + type=str, + help=f"destination multiaddr string, e.g. {example_maddr}", + ) + parser.add_argument( + "--message", + type=str, + default="hello", + help="JSON message payload to send", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + help="seed the RNG to make peer IDs reproducible", + ) + args = parser.parse_args() + try: + trio.run(run, args.port, args.destination, args.message, args.seed) + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py index 1cf7114b3..6e5b7a982 100644 --- a/libp2p/io/msgio.py +++ b/libp2p/io/msgio.py @@ -1,34 +1,17 @@ """ ``msgio`` is an implementation of `https://github.com/libp2p/go-msgio`. -from that repo: "a simple package to r/w length-delimited slices." - -NOTE: currently missing the capability to indicate lengths by "varint" method. +From that repo: "a simple package to r/w length-delimited slices." """ -from abc import ( - abstractmethod, -) -from typing import ( - Literal, -) - -from libp2p.io.abc import ( - MsgReadWriteCloser, - Reader, - ReadWriteCloser, -) -from libp2p.io.utils import ( - read_exactly, -) -from libp2p.utils import ( - decode_uvarint_from_stream, - encode_varint_prefixed, -) - -from .exceptions import ( - MessageTooLarge, -) +from abc import abstractmethod +from typing import Literal + +from libp2p.io.abc import MsgReadWriteCloser, Reader, ReadWriteCloser +from libp2p.io.utils import read_exactly +from libp2p.utils import decode_uvarint_from_stream, encode_varint_prefixed + +from .exceptions import MessageTooLarge BYTE_ORDER: Literal["big", "little"] = "big" @@ -87,6 +70,22 @@ def encode_msg(self, msg: bytes) -> bytes: class VarIntLengthMsgReadWriter(BaseMsgReadWriter): max_msg_size: int + def __init__( + self, + read_write_closer: ReadWriteCloser, + max_msg_size: int | None = None, + ) -> None: + super().__init__(read_write_closer) + if max_msg_size is None: + if not hasattr(self, "max_msg_size"): + raise TypeError("max_msg_size is required") + effective_max_msg_size = self.max_msg_size + else: + effective_max_msg_size = max_msg_size + if effective_max_msg_size <= 0: + raise ValueError("max_msg_size must be greater than 0") + self.max_msg_size = effective_max_msg_size + async def next_msg_len(self) -> int: msg_len = await decode_uvarint_from_stream(self.read_write_closer) if msg_len > self.max_msg_size: diff --git a/libp2p/request_response/__init__.py b/libp2p/request_response/__init__.py new file mode 100644 index 000000000..269f20441 --- /dev/null +++ b/libp2p/request_response/__init__.py @@ -0,0 +1,31 @@ +from .api import RequestContext, RequestResponse, RequestResponseConfig +from .codec import BytesCodec, JSONCodec, RequestResponseCodec +from .exceptions import ( + MessageTooLargeError, + ProtocolNotSupportedError, + RequestDecodeError, + RequestEncodeError, + RequestResponseError, + RequestTimeoutError, + RequestTransportError, + ResponseDecodeError, + ResponseEncodeError, +) + +__all__ = [ + "BytesCodec", + "JSONCodec", + "MessageTooLargeError", + "ProtocolNotSupportedError", + "RequestContext", + "RequestDecodeError", + "RequestEncodeError", + "RequestResponse", + "RequestResponseCodec", + "RequestResponseConfig", + "RequestResponseError", + "RequestTimeoutError", + "RequestTransportError", + "ResponseDecodeError", + "ResponseEncodeError", +] diff --git a/libp2p/request_response/api.py b/libp2p/request_response/api.py new file mode 100644 index 000000000..4481f398c --- /dev/null +++ b/libp2p/request_response/api.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass +import logging +from typing import Generic, TypeVar, cast + +import trio + +from libp2p.abc import IHost, INetStream +from libp2p.custom_types import TProtocol +from libp2p.host.exceptions import StreamFailure +from libp2p.io.exceptions import IncompleteReadError, MessageTooLarge +from libp2p.io.msgio import VarIntLengthMsgReadWriter +from libp2p.network.stream.exceptions import StreamEOF, StreamError, StreamReset +from libp2p.peer.id import ID +from libp2p.protocol_muxer.exceptions import ( + MultiselectClientError as MultiselectClientError, + ProtocolNotSupportedError as MultiselectProtocolNotSupportedError, +) + +from .codec import RequestResponseCodec +from .exceptions import ( + MessageTooLargeError, + ProtocolNotSupportedError, + RequestDecodeError, + RequestEncodeError, + RequestResponseError, + RequestTimeoutError, + RequestTransportError, + ResponseDecodeError, + ResponseEncodeError, +) + +ReqT = TypeVar("ReqT") +RespT = TypeVar("RespT") + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, slots=True) +class RequestContext: + peer_id: ID + protocol_id: TProtocol + + +@dataclass(frozen=True, slots=True) +class RequestResponseConfig: + timeout: float = 10.0 + max_request_size: int = 1_048_576 + max_response_size: int = 10_485_760 + max_concurrent_inbound: int = 128 + + def __post_init__(self) -> None: + if self.timeout <= 0: + raise ValueError("timeout must be greater than 0") + if self.max_request_size <= 0: + raise ValueError("max_request_size must be greater than 0") + if self.max_response_size <= 0: + raise ValueError("max_response_size must be greater than 0") + if self.max_concurrent_inbound <= 0: + raise ValueError("max_concurrent_inbound must be greater than 0") + + +RequestHandler = Callable[[ReqT, RequestContext], Awaitable[RespT]] + + +@dataclass(slots=True) +class _HandlerRegistration(Generic[ReqT, RespT]): + handler: RequestHandler[ReqT, RespT] + codec: RequestResponseCodec[ReqT, RespT] + config: RequestResponseConfig + limiter: trio.Semaphore + + +class RequestResponse: + """Safe, one-shot request/response helper on top of libp2p streams.""" + + def __init__(self, host: IHost) -> None: + self.host = host + self._handlers: dict[TProtocol, _HandlerRegistration[object, object]] = {} + + def set_handler( + self, + protocol_id: TProtocol, + handler: RequestHandler[ReqT, RespT], + codec: RequestResponseCodec[ReqT, RespT], + config: RequestResponseConfig | None = None, + ) -> None: + effective_config = config or RequestResponseConfig() + registration = _HandlerRegistration( + handler=cast(RequestHandler[object, object], handler), + codec=cast(RequestResponseCodec[object, object], codec), + config=effective_config, + limiter=trio.Semaphore(effective_config.max_concurrent_inbound), + ) + self._handlers[protocol_id] = registration + self.host.set_stream_handler( + protocol_id, self._build_stream_handler(protocol_id) + ) + + def remove_handler(self, protocol_id: TProtocol) -> None: + self._handlers.pop(protocol_id, None) + self.host.remove_stream_handler(protocol_id) + + async def send_request( + self, + peer_id: ID, + protocol_ids: Sequence[TProtocol], + request: ReqT, + codec: RequestResponseCodec[ReqT, RespT], + config: RequestResponseConfig | None = None, + ) -> RespT: + if not protocol_ids: + raise ValueError("protocol_ids must not be empty") + + effective_config = config or RequestResponseConfig() + request_payload = self._encode_request(codec, request) + self._check_size( + request_payload, + effective_config.max_request_size, + "request payload exceeds configured maximum", + ) + + stream: INetStream | None = None + try: + with trio.fail_after(effective_config.timeout): + stream = await self.host.new_stream(peer_id, protocol_ids) + await self._write_message( + stream, request_payload, effective_config.max_request_size + ) + response_payload = await self._read_message( + stream, effective_config.max_response_size + ) + response = self._decode_response(codec, response_payload) + except trio.TooSlowError as error: + await self._safe_reset(stream) + raise RequestTimeoutError( + f"request timed out after {effective_config.timeout} seconds" + ) from error + except MessageTooLargeError: + await self._safe_reset(stream) + raise + except StreamFailure as error: + raise self._map_stream_failure(error) from error + except RequestResponseError: + await self._safe_reset(stream) + raise + except (IncompleteReadError, StreamEOF, StreamError, StreamReset) as error: + await self._safe_reset(stream) + raise RequestTransportError( + "request/response exchange failed while reading or writing the stream" + ) from error + except Exception as error: + await self._safe_reset(stream) + raise RequestTransportError( + "request/response exchange failed due to an unexpected transport error" + ) from error + else: + await self._safe_close(stream) + return response + + def _build_stream_handler( + self, protocol_id: TProtocol + ) -> Callable[[INetStream], Awaitable[None]]: + async def _stream_handler(stream: INetStream) -> None: + registration = self._handlers.get(protocol_id) + if registration is None: + await self._safe_reset(stream) + return + + try: + registration.limiter.acquire_nowait() + except trio.WouldBlock: + logger.warning( + "request_response inbound limit reached for protocol %s", + protocol_id, + ) + await self._safe_reset(stream) + return + + try: + await self._handle_inbound(protocol_id, registration, stream) + finally: + registration.limiter.release() + + return _stream_handler + + async def _handle_inbound( + self, + protocol_id: TProtocol, + registration: _HandlerRegistration[object, object], + stream: INetStream, + ) -> None: + context = RequestContext( + peer_id=stream.muxed_conn.peer_id, + protocol_id=protocol_id, + ) + try: + with trio.fail_after(registration.config.timeout): + request_payload = await self._read_message( + stream, registration.config.max_request_size + ) + request = self._decode_inbound_request( + registration.codec, request_payload + ) + response = await registration.handler(request, context) + response_payload = self._encode_inbound_response( + registration.codec, response + ) + self._check_size( + response_payload, + registration.config.max_response_size, + "response payload exceeds configured maximum", + ) + await self._write_message( + stream, response_payload, registration.config.max_response_size + ) + except trio.TooSlowError: + logger.warning( + "request_response handler timed out for protocol %s from peer %s", + protocol_id, + context.peer_id, + ) + await self._safe_reset(stream) + except RequestResponseError as error: + logger.warning( + "request_response handler rejected protocol %s from peer %s: %s", + protocol_id, + context.peer_id, + error, + ) + await self._safe_reset(stream) + except (IncompleteReadError, StreamEOF, StreamError, StreamReset) as error: + logger.warning( + "request_response stream error for protocol %s from peer %s: %s", + protocol_id, + context.peer_id, + error, + ) + await self._safe_reset(stream) + except Exception: + logger.exception( + ( + "request_response unexpected handler failure for protocol %s " + "from peer %s" + ), + protocol_id, + context.peer_id, + ) + await self._safe_reset(stream) + else: + await self._safe_close(stream) + + async def _read_message(self, stream: INetStream, max_msg_size: int) -> bytes: + reader = VarIntLengthMsgReadWriter(stream, max_msg_size=max_msg_size) + try: + return await reader.read_msg() + except MessageTooLarge as error: + raise MessageTooLargeError(str(error)) from error + + async def _write_message( + self, stream: INetStream, payload: bytes, max_msg_size: int + ) -> None: + writer = VarIntLengthMsgReadWriter(stream, max_msg_size=max_msg_size) + try: + await writer.write_msg(payload) + except MessageTooLarge as error: + raise MessageTooLargeError(str(error)) from error + + def _check_size(self, payload: bytes, limit: int, message: str) -> None: + if len(payload) > limit: + raise MessageTooLargeError(message) + + def _encode_request( + self, codec: RequestResponseCodec[ReqT, RespT], request: ReqT + ) -> bytes: + try: + payload = codec.encode_request(request) + return self._ensure_bytes(payload, "request payload must encode to bytes") + except Exception as error: + raise RequestEncodeError("failed to encode request payload") from error + + def _decode_response( + self, codec: RequestResponseCodec[ReqT, RespT], payload: bytes + ) -> RespT: + try: + return codec.decode_response(payload) + except Exception as error: + raise ResponseDecodeError("failed to decode response payload") from error + + def _decode_inbound_request( + self, codec: RequestResponseCodec[object, object], payload: bytes + ) -> object: + try: + return codec.decode_request(payload) + except Exception as error: + raise RequestDecodeError("failed to decode request payload") from error + + def _encode_inbound_response( + self, codec: RequestResponseCodec[object, object], response: object + ) -> bytes: + try: + payload = codec.encode_response(response) + return self._ensure_bytes(payload, "response payload must encode to bytes") + except Exception as error: + raise ResponseEncodeError("failed to encode response payload") from error + + def _ensure_bytes(self, payload: bytes, message: str) -> bytes: + if not isinstance(payload, (bytes, bytearray, memoryview)): + raise TypeError(message) + return bytes(payload) + + def _map_stream_failure(self, error: StreamFailure) -> RequestResponseError: + cause = error.__cause__ + if isinstance(cause, MultiselectProtocolNotSupportedError): + return ProtocolNotSupportedError(str(cause)) + if isinstance( + cause, MultiselectClientError + ) and "protocol not supported" in str(cause): + return ProtocolNotSupportedError(str(cause)) + return RequestTransportError(str(error)) + + async def _safe_close(self, stream: INetStream | None) -> None: + if stream is None: + return + try: + await stream.close() + except Exception: + logger.debug( + "failed to close request_response stream cleanly", exc_info=True + ) + + async def _safe_reset(self, stream: INetStream | None) -> None: + if stream is None: + return + try: + await stream.reset() + except Exception: + logger.debug( + "failed to reset request_response stream cleanly", exc_info=True + ) diff --git a/libp2p/request_response/codec.py b/libp2p/request_response/codec.py new file mode 100644 index 000000000..cc8afddaf --- /dev/null +++ b/libp2p/request_response/codec.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import json +from typing import Any, Protocol, TypeVar, runtime_checkable + +ReqT = TypeVar("ReqT") +RespT = TypeVar("RespT") + + +@runtime_checkable +class RequestResponseCodec(Protocol[ReqT, RespT]): + def encode_request(self, request: ReqT) -> bytes: ... + + def decode_request(self, payload: bytes) -> ReqT: ... + + def encode_response(self, response: RespT) -> bytes: ... + + def decode_response(self, payload: bytes) -> RespT: ... + + +class BytesCodec(RequestResponseCodec[bytes, bytes]): + """Pass-through codec for raw bytes payloads.""" + + def encode_request(self, request: bytes) -> bytes: + return bytes(request) + + def decode_request(self, payload: bytes) -> bytes: + return bytes(payload) + + def encode_response(self, response: bytes) -> bytes: + return bytes(response) + + def decode_response(self, payload: bytes) -> bytes: + return bytes(payload) + + +class JSONCodec(RequestResponseCodec[Any, Any]): + """Codec for JSON-serializable Python values.""" + + def encode_request(self, request: Any) -> bytes: + return json.dumps(request).encode("utf-8") + + def decode_request(self, payload: bytes) -> Any: + return json.loads(payload.decode("utf-8")) + + def encode_response(self, response: Any) -> bytes: + return json.dumps(response).encode("utf-8") + + def decode_response(self, payload: bytes) -> Any: + return json.loads(payload.decode("utf-8")) diff --git a/libp2p/request_response/exceptions.py b/libp2p/request_response/exceptions.py new file mode 100644 index 000000000..7ea52f0d8 --- /dev/null +++ b/libp2p/request_response/exceptions.py @@ -0,0 +1,37 @@ +from libp2p.exceptions import BaseLibp2pError + + +class RequestResponseError(BaseLibp2pError): + """Base error for the request/response helper.""" + + +class RequestTimeoutError(RequestResponseError): + """Raised when a request/response exchange exceeds the configured timeout.""" + + +class ProtocolNotSupportedError(RequestResponseError): + """Raised when the remote peer supports none of the requested protocols.""" + + +class RequestEncodeError(RequestResponseError): + """Raised when request serialization fails.""" + + +class RequestDecodeError(RequestResponseError): + """Raised when request deserialization fails on the server side.""" + + +class ResponseEncodeError(RequestResponseError): + """Raised when response serialization fails on the server side.""" + + +class ResponseDecodeError(RequestResponseError): + """Raised when response deserialization fails on the client side.""" + + +class MessageTooLargeError(RequestResponseError): + """Raised when a framed request or response exceeds configured limits.""" + + +class RequestTransportError(RequestResponseError): + """Raised when transport- or stream-level failures interrupt a request.""" diff --git a/newsfragments/1287.feature.rst b/newsfragments/1287.feature.rst new file mode 100644 index 000000000..2faec2181 --- /dev/null +++ b/newsfragments/1287.feature.rst @@ -0,0 +1 @@ +Added a new ``libp2p.request_response`` helper that provides safe one-shot request/response exchanges with framed messages, bounded payload sizes, default timeouts, pluggable codecs, and a generic demo. diff --git a/pyproject.toml b/pyproject.toml index cd7aa5c68..5cf2797f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ autotls-demo = "examples.autotls.autotls:main" identify-demo = "examples.identify.identify:main" identify-push-demo = "examples.identify_push.identify_push_demo:run_main" identify-push-listener-dialer-demo = "examples.identify_push.identify_push_listener_dialer:main" +request-response-demo = "examples.request_response.request_response_demo:main" pubsub-demo = "examples.pubsub.pubsub:main" floodsub-demo = "examples.pubsub.floodsub:main" mdns-demo = "examples.mDNS.mDNS:main" diff --git a/tests/core/io/test_msgio.py b/tests/core/io/test_msgio.py new file mode 100644 index 000000000..c52bf6761 --- /dev/null +++ b/tests/core/io/test_msgio.py @@ -0,0 +1,72 @@ +import pytest + +from libp2p.io.abc import ReadWriteCloser +from libp2p.io.exceptions import MessageTooLarge +from libp2p.io.msgio import VarIntLengthMsgReadWriter +from libp2p.utils import encode_varint_prefixed + + +class MemoryReadWriteCloser(ReadWriteCloser): + def __init__(self, initial_bytes: bytes = b"") -> None: + self._buffer = bytearray(initial_bytes) + self.written = bytearray() + self.closed = False + + async def read(self, n: int | None = None) -> bytes: + if n is None: + n = len(self._buffer) + chunk = bytes(self._buffer[:n]) + del self._buffer[:n] + return chunk + + async def write(self, data: bytes) -> None: + self.written.extend(data) + + async def close(self) -> None: + self.closed = True + + def get_remote_address(self) -> tuple[str, int] | None: + return None + + +@pytest.mark.trio +async def test_varint_msgio_round_trip() -> None: + rw = MemoryReadWriteCloser(encode_varint_prefixed(b"hello")) + msgio = VarIntLengthMsgReadWriter(rw, max_msg_size=16) + + assert await msgio.read_msg() == b"hello" + + +@pytest.mark.trio +async def test_varint_msgio_write_prefixes_and_limits() -> None: + rw = MemoryReadWriteCloser() + msgio = VarIntLengthMsgReadWriter(rw, max_msg_size=8) + + await msgio.write_msg(b"hello") + + assert bytes(rw.written) == encode_varint_prefixed(b"hello") + + +@pytest.mark.trio +async def test_varint_msgio_rejects_oversized_reads() -> None: + rw = MemoryReadWriteCloser(encode_varint_prefixed(b"toolong")) + msgio = VarIntLengthMsgReadWriter(rw, max_msg_size=4) + + with pytest.raises(MessageTooLarge): + await msgio.read_msg() + + +@pytest.mark.trio +async def test_varint_msgio_rejects_oversized_writes() -> None: + rw = MemoryReadWriteCloser() + msgio = VarIntLengthMsgReadWriter(rw, max_msg_size=4) + + with pytest.raises(MessageTooLarge): + await msgio.write_msg(b"toolong") + + +def test_varint_msgio_requires_positive_limit() -> None: + rw = MemoryReadWriteCloser() + + with pytest.raises(ValueError, match="max_msg_size"): + VarIntLengthMsgReadWriter(rw, max_msg_size=0) diff --git a/tests/core/request_response/test_api.py b/tests/core/request_response/test_api.py new file mode 100644 index 000000000..efeeebdf5 --- /dev/null +++ b/tests/core/request_response/test_api.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from types import SimpleNamespace + +import pytest +import trio + +from libp2p.custom_types import TProtocol +from libp2p.request_response import ( + BytesCodec, + JSONCodec, + MessageTooLargeError, + ProtocolNotSupportedError, + RequestContext, + RequestResponse, + RequestResponseConfig, + RequestTimeoutError, + RequestTransportError, + ResponseDecodeError, +) +from libp2p.tools.utils import connect +from libp2p.utils import encode_varint_prefixed +from tests.utils.factories import HostFactory + +PROTO_V1 = TProtocol("/example/request-response/1.0.0") +PROTO_V2 = TProtocol("/example/request-response/2.0.0") + + +class DummyNetStream: + def __init__( + self, + initial_bytes: bytes = b"", + *, + peer_id: str = "peer-1", + read_delay: float = 0.0, + ) -> None: + self._buffer = bytearray(initial_bytes) + self._read_delay = read_delay + self.written = bytearray() + self.closed = False + self.reset_called = False + self.muxed_conn = SimpleNamespace(peer_id=peer_id) + + async def read(self, n: int | None = None) -> bytes: + if self._read_delay: + await trio.sleep(self._read_delay) + if n is None: + n = len(self._buffer) + chunk = bytes(self._buffer[:n]) + del self._buffer[:n] + return chunk + + async def write(self, data: bytes) -> None: + self.written.extend(data) + + async def close(self) -> None: + self.closed = True + + async def reset(self) -> None: + self.reset_called = True + + +StreamHandler = Callable[[DummyNetStream], Awaitable[None]] + + +class DummyHost: + def __init__(self, stream: DummyNetStream | Exception | None = None) -> None: + self.stream = stream + self.handlers: dict[TProtocol, StreamHandler] = {} + self.new_stream_calls: list[tuple[object, tuple[TProtocol, ...]]] = [] + + async def new_stream( + self, peer_id: object, protocol_ids: list[TProtocol] | tuple[TProtocol, ...] + ): + self.new_stream_calls.append((peer_id, tuple(protocol_ids))) + if isinstance(self.stream, Exception): + raise self.stream + if self.stream is None: + raise RuntimeError("stream not configured") + return self.stream + + def set_stream_handler( + self, protocol_id: TProtocol, stream_handler: StreamHandler + ) -> None: + self.handlers[protocol_id] = stream_handler + + def remove_stream_handler(self, protocol_id: TProtocol) -> None: + self.handlers.pop(protocol_id, None) + + +@pytest.mark.trio +async def test_send_request_rejects_oversized_request_before_opening_stream() -> None: + host = DummyHost(DummyNetStream()) + rr = RequestResponse(host) # type: ignore[arg-type] + + with pytest.raises(MessageTooLargeError): + await rr.send_request( + peer_id="peer", # type: ignore[arg-type] + protocol_ids=[PROTO_V1], + request=b"12345", + codec=BytesCodec(), + config=RequestResponseConfig(max_request_size=4), + ) + + assert host.new_stream_calls == [] + + +@pytest.mark.trio +async def test_send_request_raises_decode_error_for_malformed_response() -> None: + stream = DummyNetStream(encode_varint_prefixed(b"not-json")) + host = DummyHost(stream) + rr = RequestResponse(host) # type: ignore[arg-type] + + with pytest.raises(ResponseDecodeError): + await rr.send_request( + peer_id="peer", # type: ignore[arg-type] + protocol_ids=[PROTO_V1], + request={"msg": "hello"}, + codec=JSONCodec(), + ) + + assert stream.reset_called is True + + +@pytest.mark.trio +async def test_inbound_malformed_request_resets_stream_and_skips_handler() -> None: + stream = DummyNetStream(encode_varint_prefixed(b"not-json")) + host = DummyHost() + rr = RequestResponse(host) # type: ignore[arg-type] + handler_called = False + + async def handler(request: dict[str, str], context: object) -> dict[str, str]: + nonlocal handler_called + handler_called = True + return request + + rr.set_handler(PROTO_V1, handler=handler, codec=JSONCodec()) + + stream_handler = host.handlers[PROTO_V1] + await stream_handler(stream) + + assert handler_called is False + assert stream.reset_called is True + assert stream.closed is False + + +@pytest.mark.trio +async def test_inbound_oversized_response_resets_before_write() -> None: + stream = DummyNetStream(encode_varint_prefixed(b"ping")) + host = DummyHost() + rr = RequestResponse(host) # type: ignore[arg-type] + + async def handler(request: bytes, context: object) -> bytes: + return b"toolong" + + rr.set_handler( + PROTO_V1, + handler=handler, + codec=BytesCodec(), + config=RequestResponseConfig(max_response_size=4), + ) + + stream_handler = host.handlers[PROTO_V1] + await stream_handler(stream) + + assert bytes(stream.written) == b"" + assert stream.reset_called is True + + +@pytest.mark.trio +async def test_inbound_handler_exception_resets_stream() -> None: + stream = DummyNetStream(encode_varint_prefixed(b"ping")) + host = DummyHost() + rr = RequestResponse(host) # type: ignore[arg-type] + + async def handler(request: bytes, context: object) -> bytes: + raise RuntimeError("boom") + + rr.set_handler(PROTO_V1, handler=handler, codec=BytesCodec()) + + stream_handler = host.handlers[PROTO_V1] + await stream_handler(stream) + + assert stream.reset_called is True + + +@pytest.mark.trio +async def test_send_request_timeout_resets_stream() -> None: + stream = DummyNetStream(read_delay=0.2) + host = DummyHost(stream) + rr = RequestResponse(host) # type: ignore[arg-type] + + with pytest.raises(RequestTimeoutError): + await rr.send_request( + peer_id="peer", # type: ignore[arg-type] + protocol_ids=[PROTO_V1], + request=b"ping", + codec=BytesCodec(), + config=RequestResponseConfig(timeout=0.05), + ) + + assert stream.reset_called is True + + +@pytest.mark.trio +async def test_send_request_clean_close_before_response_raises_transport_error() -> ( + None +): + stream = DummyNetStream() + host = DummyHost(stream) + rr = RequestResponse(host) # type: ignore[arg-type] + + with pytest.raises(RequestTransportError): + await rr.send_request( + peer_id="peer", # type: ignore[arg-type] + protocol_ids=[PROTO_V1], + request=b"ping", + codec=BytesCodec(), + ) + + assert stream.reset_called is True + + +def test_remove_handler_unregisters_protocol() -> None: + host = DummyHost() + rr = RequestResponse(host) # type: ignore[arg-type] + + async def handler(request: bytes, context: object) -> bytes: + return request + + rr.set_handler(PROTO_V1, handler=handler, codec=BytesCodec()) + assert PROTO_V1 in host.handlers + + rr.remove_handler(PROTO_V1) + assert PROTO_V1 not in host.handlers + + +@pytest.mark.trio +async def test_inbound_concurrency_limit_resets_extra_streams() -> None: + host = DummyHost() + rr = RequestResponse(host) # type: ignore[arg-type] + release = trio.Event() + started = trio.Event() + + async def slow_handler(request: bytes, context: object) -> bytes: + started.set() + await release.wait() + return request + + rr.set_handler( + PROTO_V1, + handler=slow_handler, + codec=BytesCodec(), + config=RequestResponseConfig(max_concurrent_inbound=1), + ) + + stream_handler = host.handlers[PROTO_V1] + stream_one = DummyNetStream(encode_varint_prefixed(b"first")) + stream_two = DummyNetStream(encode_varint_prefixed(b"second")) + + async with trio.open_nursery() as nursery: + nursery.start_soon(stream_handler, stream_one) + await started.wait() + nursery.start_soon(stream_handler, stream_two) + await trio.sleep(0.05) + assert stream_two.reset_called is True + release.set() + + +@pytest.mark.trio +async def test_request_response_round_trip_integration(security_protocol) -> None: + async with HostFactory.create_batch_and_listen( + 2, security_protocol=security_protocol + ) as hosts: + rr_client = RequestResponse(hosts[0]) + rr_server = RequestResponse(hosts[1]) + + async def handler( + request: dict[str, str], context: RequestContext + ) -> dict[str, str]: + return {"echo": request["msg"], "peer": str(context.peer_id)} + + rr_server.set_handler(PROTO_V1, handler=handler, codec=JSONCodec()) + await connect(hosts[0], hosts[1]) + + response = await rr_client.send_request( + peer_id=hosts[1].get_id(), + protocol_ids=[PROTO_V1], + request={"msg": "hello"}, + codec=JSONCodec(), + ) + + assert response["echo"] == "hello" + + +@pytest.mark.trio +async def test_request_response_protocol_preference_integration( + security_protocol, +) -> None: + async with HostFactory.create_batch_and_listen( + 2, security_protocol=security_protocol + ) as hosts: + rr_client = RequestResponse(hosts[0]) + rr_server = RequestResponse(hosts[1]) + + async def handler(request: bytes, context: object) -> bytes: + return b"pong" + + rr_server.set_handler(PROTO_V1, handler=handler, codec=BytesCodec()) + await connect(hosts[0], hosts[1]) + + response = await rr_client.send_request( + peer_id=hosts[1].get_id(), + protocol_ids=[PROTO_V2, PROTO_V1], + request=b"ping", + codec=BytesCodec(), + ) + + assert response == b"pong" + + +@pytest.mark.trio +async def test_request_response_removed_handler_integration(security_protocol) -> None: + async with HostFactory.create_batch_and_listen( + 2, security_protocol=security_protocol + ) as hosts: + rr_client = RequestResponse(hosts[0]) + rr_server = RequestResponse(hosts[1]) + + async def handler(request: bytes, context): + return request + + rr_server.set_handler(PROTO_V1, handler=handler, codec=BytesCodec()) + rr_server.remove_handler(PROTO_V1) + await connect(hosts[0], hosts[1]) + + with pytest.raises(ProtocolNotSupportedError): + await rr_client.send_request( + peer_id=hosts[1].get_id(), + protocol_ids=[PROTO_V1], + request=b"ping", + codec=BytesCodec(), + ) diff --git a/tests/core/request_response/test_codec.py b/tests/core/request_response/test_codec.py new file mode 100644 index 000000000..c446ab989 --- /dev/null +++ b/tests/core/request_response/test_codec.py @@ -0,0 +1,21 @@ +from libp2p.request_response import BytesCodec, JSONCodec + + +def test_bytes_codec_round_trip() -> None: + codec = BytesCodec() + payload = b"hello" + + assert codec.encode_request(payload) == payload + assert codec.decode_request(payload) == payload + assert codec.encode_response(payload) == payload + assert codec.decode_response(payload) == payload + + +def test_json_codec_round_trip() -> None: + codec = JSONCodec() + payload = {"msg": "hello", "count": 3} + + encoded = codec.encode_request(payload) + + assert codec.decode_request(encoded) == payload + assert codec.decode_response(codec.encode_response(payload)) == payload