diff --git a/libp2p/relay/circuit_v2/discovery.py b/libp2p/relay/circuit_v2/discovery.py index f62fb9ff1..95a9fc180 100644 --- a/libp2p/relay/circuit_v2/discovery.py +++ b/libp2p/relay/circuit_v2/discovery.py @@ -39,6 +39,11 @@ ) from .pb.circuit_pb2 import ( HopMessage, + Peer, +) +from .pb_framing import ( + read_circuit_v2_pb, + write_circuit_v2_pb, ) from .protocol import ( PROTOCOL_ID, @@ -398,17 +403,19 @@ async def make_reservation(self, peer_id: ID) -> bool: # Prepare signed envelope envelope_bytes, _ = env_to_send_in_RPC(self.host) # Create and send reservation request + rpeer = Peer() + rpeer.id = self.host.get_id().to_bytes() request = HopMessage( - type=HopMessage.RESERVE, - peer=self.host.get_id().to_bytes(), + type=HopMessage.Type.RESERVE, senderRecord=envelope_bytes, ) + request.peer.CopyFrom(rpeer) with trio.fail_after(self.stream_timeout): - await stream.write(request.SerializeToString()) + await write_circuit_v2_pb(stream, request.SerializeToString()) # Wait for response - response_bytes = await stream.read(1024) + response_bytes = await read_circuit_v2_pb(stream) if not response_bytes: logger.error("No response received from relay %s", peer_id) return False @@ -428,12 +435,10 @@ async def make_reservation(self, peer_id: ID) -> bool: await stream.close() return False - # Check if reservation was successful - if response.type == HopMessage.STATUS and response.HasField( + if response.type == HopMessage.Type.STATUS and response.HasField( "status" ): - # Access status code directly from protobuf object - status_code = getattr(response.status, "code", StatusCode.OK) + status_code = StatusCode(response.status) if status_code == StatusCode.OK: # Update relay info with reservation details @@ -453,11 +458,9 @@ async def make_reservation(self, peer_id: ID) -> bool: ) return True - # Reservation failed error_message = "Unknown error" if response.HasField("status"): - # Access message directly from protobuf object - error_message = getattr(response.status, "message", "") + error_message = StatusCode(response.status).name logger.warning( "Reservation request rejected by relay %s: %s", diff --git a/libp2p/relay/circuit_v2/pb/__init__.py b/libp2p/relay/circuit_v2/pb/__init__.py index b4c96d734..80d1763d3 100644 --- a/libp2p/relay/circuit_v2/pb/__init__.py +++ b/libp2p/relay/circuit_v2/pb/__init__.py @@ -13,9 +13,18 @@ from .circuit_pb2 import ( HopMessage, Limit, + Peer, Reservation, Status, StopMessage, ) -__all__ = ["HopMessage", "Limit", "Reservation", "Status", "StopMessage", "HolePunch"] +__all__ = [ + "HopMessage", + "Limit", + "Peer", + "Reservation", + "Status", + "StopMessage", + "HolePunch", +] diff --git a/libp2p/relay/circuit_v2/pb/circuit.proto b/libp2p/relay/circuit_v2/pb/circuit.proto index 4ffc9413d..258ec780a 100644 --- a/libp2p/relay/circuit_v2/pb/circuit.proto +++ b/libp2p/relay/circuit_v2/pb/circuit.proto @@ -2,7 +2,9 @@ syntax = "proto3"; package circuit.pb.v2; -// Circuit v2 message types +// Aligned with libp2p relay/circuit-v2 and rust-libp2p 0.52 +// (protocols/relay/src/generated/message.proto), proto3 + optional for presence. + message HopMessage { enum Type { RESERVE = 0; @@ -10,12 +12,12 @@ message HopMessage { STATUS = 2; } - Type type = 1; - bytes peer = 2; - Reservation reservation = 3; - Limit limit = 4; - Status status = 5; - optional bytes senderRecord = 6; // Envelope(PeerRecord) + optional Type type = 1; + optional Peer peer = 2; + optional Reservation reservation = 3; + optional Limit limit = 4; + optional Status status = 5; + optional bytes senderRecord = 6; } message StopMessage { @@ -24,34 +26,39 @@ message StopMessage { STATUS = 1; } - Type type = 1; - bytes peer = 2; - Status status = 3; - optional bytes senderRecord = 4; // Envelope(PeerRecord) encoded + optional Type type = 1; + optional Peer peer = 2; + optional Limit limit = 3; + optional Status status = 4; + optional bytes senderRecord = 5; +} + +message Peer { + optional bytes id = 1; + repeated bytes addrs = 2; } message Reservation { - bytes voucher = 1; - bytes signature = 2; - int64 expire = 3; + optional uint64 expire = 1; + repeated bytes addrs = 2; + optional bytes voucher = 3; + // py-libp2p relay↔relay signing extension; ignored by rust-libp2p + optional bytes signature = 4; } message Limit { - int64 duration = 1; - int64 data = 2; + optional uint32 duration = 1; + optional uint64 data = 2; } -message Status { - enum Code { - OK = 0; - RESERVATION_REFUSED = 100; - RESOURCE_LIMIT_EXCEEDED = 101; - PERMISSION_DENIED = 102; - CONNECTION_FAILED = 200; - DIAL_REFUSED = 201; - STOP_FAILED = 300; - MALFORMED_MESSAGE = 400; - } - Code code = 1; - string message = 2; +enum Status { + UNUSED = 0; + OK = 100; + RESERVATION_REFUSED = 200; + RESOURCE_LIMIT_EXCEEDED = 201; + PERMISSION_DENIED = 202; + CONNECTION_FAILED = 203; + NO_RESERVATION = 204; + MALFORMED_MESSAGE = 400; + UNEXPECTED_MESSAGE = 401; } diff --git a/libp2p/relay/circuit_v2/pb/circuit_pb2.py b/libp2p/relay/circuit_v2/pb/circuit_pb2.py index ebd95e3ba..4767ed67f 100644 --- a/libp2p/relay/circuit_v2/pb/circuit_pb2.py +++ b/libp2p/relay/circuit_v2/pb/circuit_pb2.py @@ -2,7 +2,7 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: libp2p/relay/circuit_v2/pb/circuit.proto -# Protobuf Python Version: 6.32.1 +# Protobuf Python Version: 7.34.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -11,8 +11,8 @@ from google.protobuf.internal import builder as _builder _runtime_version.ValidateProtobufRuntimeVersion( _runtime_version.Domain.PUBLIC, - 6, - 32, + 7, + 34, 1, '', 'libp2p/relay/circuit_v2/pb/circuit.proto' @@ -24,27 +24,27 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(libp2p/relay/circuit_v2/pb/circuit.proto\x12\rcircuit.pb.v2\"\x9f\x02\n\nHopMessage\x12,\n\x04type\x18\x01 \x01(\x0e\x32\x1e.circuit.pb.v2.HopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12/\n\x0breservation\x18\x03 \x01(\x0b\x32\x1a.circuit.pb.v2.Reservation\x12#\n\x05limit\x18\x04 \x01(\x0b\x32\x14.circuit.pb.v2.Limit\x12%\n\x06status\x18\x05 \x01(\x0b\x32\x15.circuit.pb.v2.Status\x12\x19\n\x0csenderRecord\x18\x06 \x01(\x0cH\x00\x88\x01\x01\",\n\x04Type\x12\x0b\n\x07RESERVE\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\n\n\x06STATUS\x10\x02\x42\x0f\n\r_senderRecord\"\xbe\x01\n\x0bStopMessage\x12-\n\x04type\x18\x01 \x01(\x0e\x32\x1f.circuit.pb.v2.StopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12%\n\x06status\x18\x03 \x01(\x0b\x32\x15.circuit.pb.v2.Status\x12\x19\n\x0csenderRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\"\x1f\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\n\n\x06STATUS\x10\x01\x42\x0f\n\r_senderRecord\"A\n\x0bReservation\x12\x0f\n\x07voucher\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c\x12\x0e\n\x06\x65xpire\x18\x03 \x01(\x03\"\'\n\x05Limit\x12\x10\n\x08\x64uration\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x03\"\xf6\x01\n\x06Status\x12(\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1a.circuit.pb.v2.Status.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xb0\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\x17\n\x13RESERVATION_REFUSED\x10\x64\x12\x1b\n\x17RESOURCE_LIMIT_EXCEEDED\x10\x65\x12\x15\n\x11PERMISSION_DENIED\x10\x66\x12\x16\n\x11\x43ONNECTION_FAILED\x10\xc8\x01\x12\x11\n\x0c\x44IAL_REFUSED\x10\xc9\x01\x12\x10\n\x0bSTOP_FAILED\x10\xac\x02\x12\x16\n\x11MALFORMED_MESSAGE\x10\x90\x03\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(libp2p/relay/circuit_v2/pb/circuit.proto\x12\rcircuit.pb.v2\"\x84\x03\n\nHopMessage\x12\x31\n\x04type\x18\x01 \x01(\x0e\x32\x1e.circuit.pb.v2.HopMessage.TypeH\x00\x88\x01\x01\x12&\n\x04peer\x18\x02 \x01(\x0b\x32\x13.circuit.pb.v2.PeerH\x01\x88\x01\x01\x12\x34\n\x0breservation\x18\x03 \x01(\x0b\x32\x1a.circuit.pb.v2.ReservationH\x02\x88\x01\x01\x12(\n\x05limit\x18\x04 \x01(\x0b\x32\x14.circuit.pb.v2.LimitH\x03\x88\x01\x01\x12*\n\x06status\x18\x05 \x01(\x0e\x32\x15.circuit.pb.v2.StatusH\x04\x88\x01\x01\x12\x19\n\x0csenderRecord\x18\x06 \x01(\x0cH\x05\x88\x01\x01\",\n\x04Type\x12\x0b\n\x07RESERVE\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\n\n\x06STATUS\x10\x02\x42\x07\n\x05_typeB\x07\n\x05_peerB\x0e\n\x0c_reservationB\x08\n\x06_limitB\t\n\x07_statusB\x0f\n\r_senderRecord\"\xb3\x02\n\x0bStopMessage\x12\x32\n\x04type\x18\x01 \x01(\x0e\x32\x1f.circuit.pb.v2.StopMessage.TypeH\x00\x88\x01\x01\x12&\n\x04peer\x18\x02 \x01(\x0b\x32\x13.circuit.pb.v2.PeerH\x01\x88\x01\x01\x12(\n\x05limit\x18\x03 \x01(\x0b\x32\x14.circuit.pb.v2.LimitH\x02\x88\x01\x01\x12*\n\x06status\x18\x04 \x01(\x0e\x32\x15.circuit.pb.v2.StatusH\x03\x88\x01\x01\x12\x19\n\x0csenderRecord\x18\x05 \x01(\x0cH\x04\x88\x01\x01\"\x1f\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\n\n\x06STATUS\x10\x01\x42\x07\n\x05_typeB\x07\n\x05_peerB\x08\n\x06_limitB\t\n\x07_statusB\x0f\n\r_senderRecord\"-\n\x04Peer\x12\x0f\n\x02id\x18\x01 \x01(\x0cH\x00\x88\x01\x01\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x42\x05\n\x03_id\"\x84\x01\n\x0bReservation\x12\x13\n\x06\x65xpire\x18\x01 \x01(\x04H\x00\x88\x01\x01\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12\x14\n\x07voucher\x18\x03 \x01(\x0cH\x01\x88\x01\x01\x12\x16\n\tsignature\x18\x04 \x01(\x0cH\x02\x88\x01\x01\x42\t\n\x07_expireB\n\n\x08_voucherB\x0c\n\n_signature\"G\n\x05Limit\x12\x15\n\x08\x64uration\x18\x01 \x01(\rH\x00\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x02 \x01(\x04H\x01\x88\x01\x01\x42\x0b\n\t_durationB\x07\n\x05_data*\xca\x01\n\x06Status\x12\n\n\x06UNUSED\x10\x00\x12\x06\n\x02OK\x10\x64\x12\x18\n\x13RESERVATION_REFUSED\x10\xc8\x01\x12\x1c\n\x17RESOURCE_LIMIT_EXCEEDED\x10\xc9\x01\x12\x16\n\x11PERMISSION_DENIED\x10\xca\x01\x12\x16\n\x11\x43ONNECTION_FAILED\x10\xcb\x01\x12\x13\n\x0eNO_RESERVATION\x10\xcc\x01\x12\x16\n\x11MALFORMED_MESSAGE\x10\x90\x03\x12\x17\n\x12UNEXPECTED_MESSAGE\x10\x91\x03\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.relay.circuit_v2.pb.circuit_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None + _globals['_STATUS']._serialized_start=1016 + _globals['_STATUS']._serialized_end=1218 _globals['_HOPMESSAGE']._serialized_start=60 - _globals['_HOPMESSAGE']._serialized_end=347 - _globals['_HOPMESSAGE_TYPE']._serialized_start=286 - _globals['_HOPMESSAGE_TYPE']._serialized_end=330 - _globals['_STOPMESSAGE']._serialized_start=350 - _globals['_STOPMESSAGE']._serialized_end=540 - _globals['_STOPMESSAGE_TYPE']._serialized_start=492 - _globals['_STOPMESSAGE_TYPE']._serialized_end=523 - _globals['_RESERVATION']._serialized_start=542 - _globals['_RESERVATION']._serialized_end=607 - _globals['_LIMIT']._serialized_start=609 - _globals['_LIMIT']._serialized_end=648 - _globals['_STATUS']._serialized_start=651 - _globals['_STATUS']._serialized_end=897 - _globals['_STATUS_CODE']._serialized_start=721 - _globals['_STATUS_CODE']._serialized_end=897 + _globals['_HOPMESSAGE']._serialized_end=448 + _globals['_HOPMESSAGE_TYPE']._serialized_start=332 + _globals['_HOPMESSAGE_TYPE']._serialized_end=376 + _globals['_STOPMESSAGE']._serialized_start=451 + _globals['_STOPMESSAGE']._serialized_end=758 + _globals['_STOPMESSAGE_TYPE']._serialized_start=671 + _globals['_STOPMESSAGE_TYPE']._serialized_end=702 + _globals['_PEER']._serialized_start=760 + _globals['_PEER']._serialized_end=805 + _globals['_RESERVATION']._serialized_start=808 + _globals['_RESERVATION']._serialized_end=940 + _globals['_LIMIT']._serialized_start=942 + _globals['_LIMIT']._serialized_end=1013 # @@protoc_insertion_point(module_scope) diff --git a/libp2p/relay/circuit_v2/pb/circuit_pb2.pyi b/libp2p/relay/circuit_v2/pb/circuit_pb2.pyi index abe0cd47b..eca657ef6 100644 --- a/libp2p/relay/circuit_v2/pb/circuit_pb2.pyi +++ b/libp2p/relay/circuit_v2/pb/circuit_pb2.pyi @@ -1,10 +1,14 @@ """ -@generated by mypy-protobuf. Do not edit manually! +Types for ``circuit_pb2`` (aligned with ``circuit.proto`` / generated runtime). + +Edit when the ``.proto`` changes; keep in sync with mypy-protobuf patterns. isort:skip_file """ import builtins +import collections.abc import google.protobuf.descriptor +import google.protobuf.internal.containers import google.protobuf.internal.enum_type_wrapper import google.protobuf.message import sys @@ -17,9 +21,102 @@ else: DESCRIPTOR: google.protobuf.descriptor.FileDescriptor +class _Status: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _StatusEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_Status.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + UNUSED: _Status.ValueType # 0 + OK: _Status.ValueType # 100 + RESERVATION_REFUSED: _Status.ValueType # 200 + RESOURCE_LIMIT_EXCEEDED: _Status.ValueType # 201 + PERMISSION_DENIED: _Status.ValueType # 202 + CONNECTION_FAILED: _Status.ValueType # 203 + NO_RESERVATION: _Status.ValueType # 204 + MALFORMED_MESSAGE: _Status.ValueType # 400 + UNEXPECTED_MESSAGE: _Status.ValueType # 401 + +class Status(_Status, metaclass=_StatusEnumTypeWrapper): ... +UNUSED: Status.ValueType # 0 +OK: Status.ValueType # 100 +RESERVATION_REFUSED: Status.ValueType # 200 +RESOURCE_LIMIT_EXCEEDED: Status.ValueType # 201 +PERMISSION_DENIED: Status.ValueType # 202 +CONNECTION_FAILED: Status.ValueType # 203 +NO_RESERVATION: Status.ValueType # 204 +MALFORMED_MESSAGE: Status.ValueType # 400 +UNEXPECTED_MESSAGE: Status.ValueType # 401 +global___Status = Status + +@typing.final +class Peer(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ID_FIELD_NUMBER: builtins.int + ADDRS_FIELD_NUMBER: builtins.int + id: builtins.bytes + @property + def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + def __init__( + self, + *, + id: builtins.bytes | None = ..., + addrs: collections.abc.Iterable[builtins.bytes] | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["id", b"id"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "id", b"id"]) -> None: ... + +global___Peer = Peer + +@typing.final +class Reservation(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + EXPIRE_FIELD_NUMBER: builtins.int + ADDRS_FIELD_NUMBER: builtins.int + VOUCHER_FIELD_NUMBER: builtins.int + SIGNATURE_FIELD_NUMBER: builtins.int + expire: builtins.int + voucher: builtins.bytes + signature: builtins.bytes + @property + def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + def __init__( + self, + *, + expire: builtins.int | None = ..., + addrs: collections.abc.Iterable[builtins.bytes] | None = ..., + voucher: builtins.bytes | None = ..., + signature: builtins.bytes | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_expire", b"_expire", "_signature", b"_signature", "_voucher", b"_voucher", "expire", b"expire", "signature", b"signature", "voucher", b"voucher"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "expire", b"expire", "signature", b"signature", "voucher", b"voucher"]) -> None: ... + +global___Reservation = Reservation + +@typing.final +class Limit(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DURATION_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + duration: builtins.int + data: builtins.int + def __init__( + self, + *, + duration: builtins.int | None = ..., + data: builtins.int | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_data", b"_data", "_duration", b"_duration", "data", b"data", "duration", b"duration"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["data", b"data", "duration", b"duration"]) -> None: ... + +global___Limit = Limit + @typing.final class HopMessage(google.protobuf.message.Message): - """Circuit v2 message types""" + """Circuit v2 hop stream messages.""" DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -44,28 +141,26 @@ class HopMessage(google.protobuf.message.Message): LIMIT_FIELD_NUMBER: builtins.int STATUS_FIELD_NUMBER: builtins.int SENDERRECORD_FIELD_NUMBER: builtins.int - type: global___HopMessage.Type.ValueType - peer: builtins.bytes + type: HopMessage.Type.ValueType + peer: global___Peer senderRecord: builtins.bytes - """Envelope(PeerRecord)""" @property def reservation(self) -> global___Reservation: ... @property def limit(self) -> global___Limit: ... - @property - def status(self) -> global___Status: ... + status: global___Status.ValueType def __init__( self, *, - type: global___HopMessage.Type.ValueType = ..., - peer: builtins.bytes = ..., + type: HopMessage.Type.ValueType | None = ..., + peer: global___Peer | None = ..., reservation: global___Reservation | None = ..., limit: global___Limit | None = ..., - status: global___Status | None = ..., + status: builtins.int | global___Status.ValueType | None = ..., senderRecord: builtins.bytes | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["_senderRecord", b"_senderRecord", "limit", b"limit", "reservation", b"reservation", "senderRecord", b"senderRecord", "status", b"status"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_senderRecord", b"_senderRecord", "limit", b"limit", "peer", b"peer", "reservation", b"reservation", "senderRecord", b"senderRecord", "status", b"status", "type", b"type"]) -> None: ... + def HasField(self, field_name: typing.Literal["_peer", b"_peer", "_reservation", b"_reservation", "_senderRecord", b"_senderRecord", "limit", b"limit", "peer", b"peer", "reservation", b"reservation", "senderRecord", b"senderRecord", "status", b"status", "type", b"type"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_peer", b"_peer", "_reservation", b"_reservation", "_senderRecord", b"_senderRecord", "limit", b"limit", "peer", b"peer", "reservation", b"reservation", "senderRecord", b"senderRecord", "status", b"status", "type", b"type"]) -> None: ... def WhichOneof(self, oneof_group: typing.Literal["_senderRecord", b"_senderRecord"]) -> typing.Literal["senderRecord"] | None: ... global___HopMessage = HopMessage @@ -89,106 +184,26 @@ class StopMessage(google.protobuf.message.Message): TYPE_FIELD_NUMBER: builtins.int PEER_FIELD_NUMBER: builtins.int + LIMIT_FIELD_NUMBER: builtins.int STATUS_FIELD_NUMBER: builtins.int SENDERRECORD_FIELD_NUMBER: builtins.int - type: global___StopMessage.Type.ValueType - peer: builtins.bytes + type: StopMessage.Type.ValueType + peer: global___Peer senderRecord: builtins.bytes - """Envelope(PeerRecord) encoded""" @property - def status(self) -> global___Status: ... + def limit(self) -> global___Limit: ... + status: global___Status.ValueType def __init__( self, *, - type: global___StopMessage.Type.ValueType = ..., - peer: builtins.bytes = ..., - status: global___Status | None = ..., + type: StopMessage.Type.ValueType | None = ..., + peer: global___Peer | None = ..., + limit: global___Limit | None = ..., + status: builtins.int | global___Status.ValueType | None = ..., senderRecord: builtins.bytes | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["_senderRecord", b"_senderRecord", "senderRecord", b"senderRecord", "status", b"status"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_senderRecord", b"_senderRecord", "peer", b"peer", "senderRecord", b"senderRecord", "status", b"status", "type", b"type"]) -> None: ... + def HasField(self, field_name: typing.Literal["_limit", b"_limit", "_peer", b"_peer", "_senderRecord", b"_senderRecord", "limit", b"limit", "peer", b"peer", "senderRecord", b"senderRecord", "status", b"status", "type", b"type"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_limit", b"_limit", "_peer", b"_peer", "_senderRecord", b"_senderRecord", "limit", b"limit", "peer", b"peer", "senderRecord", b"senderRecord", "status", b"status", "type", b"type"]) -> None: ... def WhichOneof(self, oneof_group: typing.Literal["_senderRecord", b"_senderRecord"]) -> typing.Literal["senderRecord"] | None: ... global___StopMessage = StopMessage - -@typing.final -class Reservation(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - VOUCHER_FIELD_NUMBER: builtins.int - SIGNATURE_FIELD_NUMBER: builtins.int - EXPIRE_FIELD_NUMBER: builtins.int - voucher: builtins.bytes - signature: builtins.bytes - expire: builtins.int - def __init__( - self, - *, - voucher: builtins.bytes = ..., - signature: builtins.bytes = ..., - expire: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["expire", b"expire", "signature", b"signature", "voucher", b"voucher"]) -> None: ... - -global___Reservation = Reservation - -@typing.final -class Limit(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - DURATION_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int - duration: builtins.int - data: builtins.int - def __init__( - self, - *, - duration: builtins.int = ..., - data: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["data", b"data", "duration", b"duration"]) -> None: ... - -global___Limit = Limit - -@typing.final -class Status(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class _Code: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _CodeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Status._Code.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - OK: Status._Code.ValueType # 0 - RESERVATION_REFUSED: Status._Code.ValueType # 100 - RESOURCE_LIMIT_EXCEEDED: Status._Code.ValueType # 101 - PERMISSION_DENIED: Status._Code.ValueType # 102 - CONNECTION_FAILED: Status._Code.ValueType # 200 - DIAL_REFUSED: Status._Code.ValueType # 201 - STOP_FAILED: Status._Code.ValueType # 300 - MALFORMED_MESSAGE: Status._Code.ValueType # 400 - - class Code(_Code, metaclass=_CodeEnumTypeWrapper): ... - OK: Status.Code.ValueType # 0 - RESERVATION_REFUSED: Status.Code.ValueType # 100 - RESOURCE_LIMIT_EXCEEDED: Status.Code.ValueType # 101 - PERMISSION_DENIED: Status.Code.ValueType # 102 - CONNECTION_FAILED: Status.Code.ValueType # 200 - DIAL_REFUSED: Status.Code.ValueType # 201 - STOP_FAILED: Status.Code.ValueType # 300 - MALFORMED_MESSAGE: Status.Code.ValueType # 400 - - CODE_FIELD_NUMBER: builtins.int - MESSAGE_FIELD_NUMBER: builtins.int - code: global___Status.Code.ValueType - message: builtins.str - def __init__( - self, - *, - code: global___Status.Code.ValueType = ..., - message: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["code", b"code", "message", b"message"]) -> None: ... - -global___Status = Status diff --git a/libp2p/relay/circuit_v2/pb_framing.py b/libp2p/relay/circuit_v2/pb_framing.py new file mode 100644 index 000000000..37ddb5746 --- /dev/null +++ b/libp2p/relay/circuit_v2/pb_framing.py @@ -0,0 +1,38 @@ +""" +Length-prefixed protobuf framing for circuit relay v2. + +rust-libp2p uses ``quick_protobuf_codec::Codec`` (unsigned varint length + payload) +on hop/stop streams; raw ``SerializeToString()`` bytes are not compatible. +See ``protocols/relay`` in rust-libp2p (``MAX_MESSAGE_SIZE``). +""" + +from libp2p.io.abc import ( + Reader, + Writer, +) +from libp2p.utils.varint import ( + encode_varint_prefixed, + read_varint_prefixed_bytes, +) + +# protocols/relay/src/protocol.rs +MAX_CIRCUIT_V2_FRAME_PAYLOAD = 4096 + + +async def write_circuit_v2_pb(stream: Writer, payload: bytes) -> None: + if len(payload) > MAX_CIRCUIT_V2_FRAME_PAYLOAD: + raise ValueError( + f"circuit v2 protobuf frame ({len(payload)} bytes) exceeds max " + f"{MAX_CIRCUIT_V2_FRAME_PAYLOAD}" + ) + await stream.write(encode_varint_prefixed(payload)) + + +async def read_circuit_v2_pb(stream: Reader) -> bytes: + data = await read_varint_prefixed_bytes(stream) + if len(data) > MAX_CIRCUIT_V2_FRAME_PAYLOAD: + raise ValueError( + f"circuit v2 protobuf frame ({len(data)} bytes) exceeds max " + f"{MAX_CIRCUIT_V2_FRAME_PAYLOAD}" + ) + return data diff --git a/libp2p/relay/circuit_v2/protocol.py b/libp2p/relay/circuit_v2/protocol.py index ca7cb6da3..d4301bc31 100644 --- a/libp2p/relay/circuit_v2/protocol.py +++ b/libp2p/relay/circuit_v2/protocol.py @@ -61,12 +61,15 @@ from .pb.circuit_pb2 import ( HopMessage, Limit, - Status as PbStatus, + Peer, StopMessage, ) +from .pb_framing import ( + read_circuit_v2_pb, + write_circuit_v2_pb, +) from .protocol_buffer import ( StatusCode, - create_status, ) from .resources import ( RelayLimits, @@ -76,8 +79,18 @@ logger = logging.getLogger(__name__) -PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0") -STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0/stop") + +def _hop_reserve_peer_id(msg: HopMessage, stream_remote: ID) -> ID: + """Destination / reserving peer from HopMessage, else the muxed stream remote.""" + if msg.HasField("peer") and msg.peer.HasField("id") and msg.peer.id: + return ID(msg.peer.id) + return stream_remote + + +# Must match other libp2p implementations (e.g. rust-libp2p HOP_PROTOCOL_NAME / +# STOP_PROTOCOL_NAME). +PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/0.2.0/hop") +STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/0.2.0/stop") # Default limits for relay resources @@ -315,52 +328,44 @@ async def _handle_hop_stream(self, stream: INetStream) -> None: # First, handle the read timeout gracefully try: with trio.fail_after(STREAM_READ_TIMEOUT * 2): - msg_bytes = await stream.read(1024) + msg_bytes = await read_circuit_v2_pb(stream) if not msg_bytes: logger.error(f"Empty read from stream from {remote_id}") - pb_status = PbStatus() - pb_status.code = PbStatus.Code.MALFORMED_MESSAGE - pb_status.message = "Empty message received" - signed_envelope, _ = env_to_send_in_RPC(self.host) + signed_envelope_bytes, _ = env_to_send_in_RPC(self.host) response = HopMessage( - type=HopMessage.STATUS, - status=pb_status, - senderRecord=signed_envelope, + type=HopMessage.Type.STATUS, + status=int(StatusCode.MALFORMED_MESSAGE), + senderRecord=signed_envelope_bytes, + ) + await write_circuit_v2_pb( + stream, response.SerializeToString() ) - await stream.write(response.SerializeToString()) await trio.sleep( 0.5 ) # Longer wait to ensure message is sent continue except trio.TooSlowError: logger.error(f"Timeout reading from hop stream from {remote_id}") - # Create a proto Status directly - pb_status = PbStatus() - pb_status.code = PbStatus.Code.CONNECTION_FAILED - pb_status.message = "Stream read timeout" - signed_envelope, _ = env_to_send_in_RPC(self.host) + signed_envelope_bytes, _ = env_to_send_in_RPC(self.host) response = HopMessage( - type=HopMessage.STATUS, - status=pb_status, - senderRecord=signed_envelope, + type=HopMessage.Type.STATUS, + status=int(StatusCode.CONNECTION_FAILED), + senderRecord=signed_envelope_bytes, ) - await stream.write(response.SerializeToString()) + await write_circuit_v2_pb(stream, response.SerializeToString()) await trio.sleep(0.5) break except Exception as e: print(f"Error reading from hop stream from {remote_id}: {str(e)}") - pb_status = PbStatus() - pb_status.code = PbStatus.Code.MALFORMED_MESSAGE - pb_status.message = f"Read error: {str(e)}" - signed_envelope, _ = env_to_send_in_RPC(self.host) + signed_envelope_bytes, _ = env_to_send_in_RPC(self.host) response = HopMessage( - type=HopMessage.STATUS, - status=pb_status, - senderRecord=signed_envelope, + type=HopMessage.Type.STATUS, + status=int(StatusCode.MALFORMED_MESSAGE), + senderRecord=signed_envelope_bytes, ) - await stream.write(response.SerializeToString()) + await write_circuit_v2_pb(stream, response.SerializeToString()) await trio.sleep(0.5) # Longer wait to ensure the message is sent break # Parse the message @@ -370,16 +375,13 @@ async def _handle_hop_stream(self, stream: INetStream) -> None: except Exception as e: logger.error(f"Error parsing hop message from {remote_id}: {e}") - pb_status = PbStatus() - pb_status.code = PbStatus.Code.MALFORMED_MESSAGE - pb_status.message = f"Parse error: {str(e)}" - signed_envelope, _ = env_to_send_in_RPC(self.host) + signed_envelope_bytes, _ = env_to_send_in_RPC(self.host) response = HopMessage( - type=HopMessage.STATUS, - status=pb_status, - senderRecord=signed_envelope, + type=HopMessage.Type.STATUS, + status=int(StatusCode.MALFORMED_MESSAGE), + senderRecord=signed_envelope_bytes, ) - await stream.write(response.SerializeToString()) + await write_circuit_v2_pb(stream, response.SerializeToString()) await trio.sleep(0.5) continue if hop_msg.HasField("senderRecord"): @@ -391,9 +393,9 @@ async def _handle_hop_stream(self, stream: INetStream) -> None: return # Process based on message type - if hop_msg.type == HopMessage.RESERVE: - await self._handle_reserve(stream, hop_msg) - elif hop_msg.type == HopMessage.CONNECT: + if hop_msg.type == HopMessage.Type.RESERVE: + await self._handle_reserve(stream, hop_msg, remote_peer_id) + elif hop_msg.type == HopMessage.Type.CONNECT: await self._handle_connect(stream, hop_msg) else: logger.error( @@ -438,7 +440,7 @@ async def _handle_stop_stream(self, stream: INetStream) -> None: try: # Read the incoming message with timeout with trio.fail_after(STREAM_READ_TIMEOUT): - msg_bytes = await stream.read(1024) + msg_bytes = await read_circuit_v2_pb(stream) stop_msg = StopMessage() stop_msg.ParseFromString(msg_bytes) @@ -448,7 +450,7 @@ async def _handle_stop_stream(self, stream: INetStream) -> None: await self._close_stream(stream) return - if stop_msg.type != StopMessage.CONNECT: + if stop_msg.type != StopMessage.Type.CONNECT: # Use direct attribute access to create status object for error response relay_envelope_bytes, _ = env_to_send_in_RPC(self.host) relay_envelope = unmarshal_envelope(relay_envelope_bytes) @@ -461,8 +463,20 @@ async def _handle_stop_stream(self, stream: INetStream) -> None: await self._close_stream(stream) return + if not stop_msg.HasField("peer") or not stop_msg.peer.HasField("id"): + relay_envelope_bytes, _ = env_to_send_in_RPC(self.host) + relay_envelope = unmarshal_envelope(relay_envelope_bytes) + await self._send_stop_status( + stream, + StatusCode.MALFORMED_MESSAGE, + "CONNECT missing peer", + relay_envelope, + ) + await self._close_stream(stream) + return + # Get the source peer's SPR to send to destination - src_peer_id = ID(stop_msg.peer) + src_peer_id = ID(stop_msg.peer.id) src_peer_envelope = self.host.get_peerstore().get_peer_record(src_peer_id) # Get the destination peer's SPR to send to source @@ -476,7 +490,7 @@ async def _handle_stop_stream(self, stream: INetStream) -> None: src_peer_envelope, ) - await self.handle_incoming_connection(stream, stop_msg.peer) + await self.handle_incoming_connection(stream, stop_msg.peer.id) except trio.TooSlowError: logger.error("Timeout reading from stop stream") relay_envelope_bytes, _ = env_to_send_in_RPC(self.host) @@ -544,12 +558,14 @@ async def handle_incoming_connection( await stream.close() raise ConnectionError(f"Failed to handle incoming connection: {str(e)}") - async def _handle_reserve(self, stream: INetStream, msg: HopMessage) -> None: + async def _handle_reserve( + self, stream: INetStream, msg: HopMessage, stream_remote: ID + ) -> None: """Handle a reservation request.""" - peer_id = None - signed_envelope = None + peer_id: ID | None = None + signed_envelope: Envelope | None = None try: - peer_id = ID(msg.peer) + peer_id = _hop_reserve_peer_id(msg, stream_remote) logger.debug("Handling reservation request from peer %s", peer_id) signed_envelope_bytes, _ = env_to_send_in_RPC(self.host) signed_envelope = unmarshal_envelope(signed_envelope_bytes) @@ -559,46 +575,32 @@ async def _handle_reserve(self, stream: INetStream, msg: HopMessage) -> None: logger.debug("Peer %s already has a reservation — refreshing", peer_id) self.resource_manager.refresh_reservation(peer_id) status_code = StatusCode.OK - status_msg_text = "Reservation refreshed" else: # Check if we can accept more reservations if not self.resource_manager.can_accept_reservation(peer_id): logger.debug("Reservation limit exceeded for peer %s", peer_id) - # Send status message with STATUS type - status = create_status( - code=StatusCode.RESOURCE_LIMIT_EXCEEDED, - message="Reservation limit exceeded", - ) - status_msg = HopMessage( - type=HopMessage.STATUS, - status=status, + type=HopMessage.Type.STATUS, + status=int(StatusCode.RESOURCE_LIMIT_EXCEEDED), senderRecord=signed_envelope.marshal_envelope(), ) - await stream.write(status_msg.SerializeToString()) + await write_circuit_v2_pb(stream, status_msg.SerializeToString()) return # Accept reservation logger.debug("Accepting new reservation from peer %s", peer_id) self.resource_manager.reserve(peer_id) status_code = StatusCode.OK - status_msg_text = "Reservation accepted" # Get the reservation object to access its voucher and sign it reservation_obj = self.resource_manager._reservations.get(peer_id) if not reservation_obj: raise ValueError(f"Failed to create reservation for peer {peer_id}") - # Create the protobuf reservation with voucher and signature - pb_reservation = reservation_obj.to_proto() - # Get the peer's addresses from the peerstore if available addrs: list[bytes] = [] try: - # Try to get peer addresses from the host's peerstore - # Most host implementations have a peerstore attribute peer_addrs = self.host.get_peerstore().addrs(peer_id) - # Convert addresses to bytes for the protocol buffer addrs = [addr.to_bytes() for addr in peer_addrs] logger.debug( "Including %d addresses for peer %s in reservation response", @@ -606,39 +608,33 @@ async def _handle_reserve(self, stream: INetStream, msg: HopMessage) -> None: peer_id, ) except AttributeError: - # Host does not have peerstore or peerstore doesn't have addrs method logger.debug("Host peerstore not available for address lookup") except Exception as e: logger.warning("Error getting peer addresses: %s", str(e)) - # Add addresses to the reservation object reservation_obj.addrs = addrs # type: ignore - # Note: pb_reservation.addrs not available in current protobuf definition - # pb_reservation.addrs.extend(addrs) + pb_reservation = reservation_obj.to_proto() # Send reservation success response with trio.fail_after(self.write_timeout): - status = create_status(code=status_code, message=status_msg_text) - + lim_dur = min(int(self.limits.duration), 0xFFFFFFFF) response = HopMessage( - type=HopMessage.STATUS, - status=status, - reservation=pb_reservation, + type=HopMessage.Type.STATUS, + status=int(status_code), limit=Limit( - duration=self.limits.duration, + duration=lim_dur, data=self.limits.data, ), senderRecord=signed_envelope.marshal_envelope(), ) + response.reservation.CopyFrom(pb_reservation) - # Log the response message details for debugging logger.debug( "Sending reservation response: type=%s status=%s", response.type, - getattr(response.status, "code", "unknown"), + response.status, ) - await stream.write(response.SerializeToString()) - # Add a small wait to ensure the message is fully sent + await write_circuit_v2_pb(stream, response.SerializeToString()) await trio.sleep(0.1) logger.debug("Reservation response sent successfully") @@ -647,7 +643,6 @@ async def _handle_reserve(self, stream: INetStream, msg: HopMessage) -> None: logger.error("Error handling reservation request: %s", str(e)) if cast(INetStreamWithExtras, stream).is_open(): try: - # Send error response await self._send_status( stream, StatusCode.CONNECTION_FAILED, @@ -666,19 +661,27 @@ async def _handle_reserve(self, stream: INetStream, msg: HopMessage) -> None: async def _handle_connect(self, stream: INetStream, msg: HopMessage) -> None: """Handle a connect request.""" - peer_id = ID(msg.peer) + relay_envelope_bytes, _ = env_to_send_in_RPC(self.host) + relay_envelope = unmarshal_envelope(relay_envelope_bytes) + if not msg.HasField("peer") or not msg.peer.HasField("id") or not msg.peer.id: + await self._send_status( + stream, + StatusCode.MALFORMED_MESSAGE, + "CONNECT missing destination peer", + relay_envelope, + ) + await stream.reset() + return + + peer_id = ID(msg.peer.id) source_addr = stream.muxed_conn.peer_id logger.debug("Handling CONNECT request for peer %s", peer_id) dst_stream: INetStream | None = None logger.debug("Handling connect request to peer %s", peer_id) - # Verify reservation if provided + # Verify reservation if provided (voucher is for the destination peer) if msg.HasField("reservation"): - if not self.resource_manager.verify_reservation( - source_addr, msg.reservation - ): - relay_envelope_bytes, _ = env_to_send_in_RPC(self.host) - relay_envelope = unmarshal_envelope(relay_envelope_bytes) + if not self.resource_manager.verify_reservation(peer_id, msg.reservation): await self._send_status( stream, StatusCode.PERMISSION_DENIED, @@ -689,9 +692,7 @@ async def _handle_connect(self, stream: INetStream, msg: HopMessage) -> None: return # Check resource limits - if not self.resource_manager.can_accept_connection(peer_id=source_addr): - relay_envelope_bytes, _ = env_to_send_in_RPC(self.host) - relay_envelope = unmarshal_envelope(relay_envelope_bytes) + if not self.resource_manager.can_accept_connection(peer_id): await self._send_status( stream, StatusCode.RESOURCE_LIMIT_EXCEEDED, @@ -702,10 +703,6 @@ async def _handle_connect(self, stream: INetStream, msg: HopMessage) -> None: return try: - # Store the source stream with properly typed None - self._active_relays[source_addr] = (stream, None) - logger.debug("Stored source stream for peer %s", source_addr) - # Try to connect to the destination with timeout with trio.fail_after(STREAM_READ_TIMEOUT): logger.debug("Attempting to connect to destination %s", peer_id) @@ -723,16 +720,18 @@ async def _handle_connect(self, stream: INetStream, msg: HopMessage) -> None: logger.debug("Connected to destination peer %s", peer_id) # Send STOP CONNECT message + src_peer = Peer() + src_peer.id = source_addr.to_bytes() stop_msg = StopMessage( - type=StopMessage.CONNECT, - peer=source_addr.to_bytes(), + type=StopMessage.Type.CONNECT, senderRecord=relay_envelope_bytes, ) + stop_msg.peer.CopyFrom(src_peer) - await dst_stream.write(stop_msg.SerializeToString()) + await write_circuit_v2_pb(dst_stream, stop_msg.SerializeToString()) # Wait for response from destination - resp_bytes = await dst_stream.read(1024) + resp_bytes = await read_circuit_v2_pb(dst_stream) resp = StopMessage() resp.ParseFromString(resp_bytes) @@ -745,15 +744,12 @@ async def _handle_connect(self, stream: INetStream, msg: HopMessage) -> None: await self._close_stream(stream) return - # Handle status attributes from the response if resp.HasField("status"): - # Get code and message attributes with defaults - status_code = getattr(resp.status, "code", StatusCode.OK) - # Get message with default - status_msg = getattr(resp.status, "message", "Unknown error") + status_code = StatusCode(resp.status) + status_msg = status_code.name else: status_code = StatusCode.OK - status_msg = "No status provided" + status_msg = status_code.name if status_code != StatusCode.OK: logger.warning( @@ -796,7 +792,7 @@ async def _handle_connect(self, stream: INetStream, msg: HopMessage) -> None: # Start relaying data async with trio.open_nursery() as nursery: - nursery.start_soon(self._relay_data, stream, dst_stream, source_addr) + nursery.start_soon(self._relay_data, stream, dst_stream, peer_id) nursery.start_soon(self._relay_data, dst_stream, stream, peer_id) except (trio.TooSlowError, ConnectionError) as e: @@ -926,21 +922,13 @@ async def _send_status( message: str, envelope: Envelope | None = None, ) -> None: - """Send a status message.""" + """Send status; ``status`` is the top-level enum, not a submessage.""" try: logger.debug("Sending status message with code %s: %s", code, message) with trio.fail_after(STREAM_WRITE_TIMEOUT): - # Create a proto Status directly - pb_status = PbStatus() - pb_status.code = cast( - Any, int(code) - ) # Cast to Any to avoid type errors - pb_status.message = message - - # Send destination records to source in case of HOP status OK message status_msg = HopMessage( - type=HopMessage.STATUS, - status=pb_status, + type=HopMessage.Type.STATUS, + status=int(code), ) if envelope is not None: status_msg.senderRecord = envelope.marshal_envelope() @@ -948,7 +936,7 @@ async def _send_status( msg_bytes = status_msg.SerializeToString() logger.debug("Status message serialized (%d bytes)", len(msg_bytes)) - await stream.write(msg_bytes) + await write_circuit_v2_pb(stream, msg_bytes) logger.debug("Status message sent successfully") except trio.TooSlowError: logger.error( @@ -968,20 +956,13 @@ async def _send_stop_status( try: logger.debug("Sending stop status message with code %s: %s", code, message) with trio.fail_after(STREAM_WRITE_TIMEOUT): - # Create a proto Status directly - pb_status = PbStatus() - pb_status.code = cast( - Any, int(code) - ) # Cast to Any to avoid type errors - pb_status.message = message - status_msg = StopMessage( - type=StopMessage.STATUS, - status=pb_status, + type=StopMessage.Type.STATUS, + status=int(code), ) if senderRecord is not None: status_msg.senderRecord = senderRecord.marshal_envelope() - await stream.write(status_msg.SerializeToString()) + await write_circuit_v2_pb(stream, status_msg.SerializeToString()) except Exception as e: logger.error("Error sending stop status message: %s", str(e)) diff --git a/libp2p/relay/circuit_v2/protocol_buffer.py b/libp2p/relay/circuit_v2/protocol_buffer.py index 509cea1c6..6e2f171a8 100644 --- a/libp2p/relay/circuit_v2/protocol_buffer.py +++ b/libp2p/relay/circuit_v2/protocol_buffer.py @@ -1,55 +1,25 @@ """ -Protocol buffer wrapper classes for Circuit Relay v2. +Protocol buffer helpers for Circuit Relay v2. -This module provides wrapper classes for protocol buffer generated objects -to make them easier to work with in type-checked code. +``Status`` is a top-level protobuf enum (rust-libp2p / spec), not a submessage. """ -from enum import ( - IntEnum, -) -from typing import ( - cast, -) +from enum import IntEnum -from .pb.circuit_pb2 import Status as PbStatus - -# Define Status codes as an Enum for better type safety and organization class StatusCode(IntEnum): - OK = 0 - RESERVATION_REFUSED = 100 - RESOURCE_LIMIT_EXCEEDED = 101 - PERMISSION_DENIED = 102 - CONNECTION_FAILED = 200 - DIAL_REFUSED = 201 - STOP_FAILED = 300 + """Same numeric values as ``circuit_pb2.Status`` (spec / rust-libp2p).""" + + UNUSED = 0 + OK = 100 + RESERVATION_REFUSED = 200 + RESOURCE_LIMIT_EXCEEDED = 201 + PERMISSION_DENIED = 202 + CONNECTION_FAILED = 203 + NO_RESERVATION = 204 MALFORMED_MESSAGE = 400 + UNEXPECTED_MESSAGE = 401 -def create_status(code: int = StatusCode.OK, message: str = "") -> PbStatus: - """ - Create a protocol buffer Status object. - - Parameters - ---------- - code : int - The status code. Can be a StatusCode enum value or an integer. - message : str - The status message - - Returns - ------- - PbStatus - The protocol buffer Status object - - """ - pb_obj = PbStatus() - - # Convert the status code (int or StatusCode enum) to the protobuf enum value type. - # The code field expects PbStatus.Code.ValueType (a NewType wrapper around int). - # At runtime, protobuf accepts int directly, but type checker requires ValueType. - pb_obj.code = cast(PbStatus.Code.ValueType, int(code)) # type: ignore[assignment,attr-defined] - pb_obj.message = message - - return pb_obj +def pb_status_value(code: int | StatusCode) -> int: + return int(code) diff --git a/libp2p/relay/circuit_v2/resources.py b/libp2p/relay/circuit_v2/resources.py index b509f2f49..7ccb0ef65 100644 --- a/libp2p/relay/circuit_v2/resources.py +++ b/libp2p/relay/circuit_v2/resources.py @@ -219,11 +219,13 @@ def to_proto(self) -> PbReservation: self.peer_id, ) - return PbReservation( + pb = PbReservation( expire=int(self.expires_at), voucher=self.voucher, signature=signature, ) + pb.addrs.extend(getattr(self, "addrs", []) or []) + return pb def get_data_to_sign(self) -> bytes: """ @@ -368,12 +370,23 @@ def verify_reservation(self, peer_id: ID, proto_res: PbReservation) -> bool: ) return False - # Signature verification is required for security if not proto_res.signature: + if self.host is None: + logger.debug( + "No signature on reservation proto and no relay host; rejecting " + "peer %s", + peer_id, + ) + return False + # Other implementations (e.g. rust-libp2p) may omit the separate + # signature field; voucher+expire matching our reservation is enough + # at this relay when we could have verified a signature if present. logger.debug( - "No signature provided, rejecting reservation for peer %s", peer_id + "Reservation for peer %s has empty signature field; " + "accepting voucher+expire match (interop)", + peer_id, ) - return False + return True if self.host is None: logger.warning( @@ -383,7 +396,6 @@ def verify_reservation(self, peer_id: ID, proto_res: PbReservation) -> bool: ) return False - # Verify the signature using the relay's public key (not the client's) data_to_sign = self._get_data_to_sign(proto_res.voucher, proto_res.expire) return self._verify_signature_with_relay_key(data_to_sign, proto_res.signature) diff --git a/libp2p/relay/circuit_v2/transport.py b/libp2p/relay/circuit_v2/transport.py index 0307e4055..0e286b1c0 100644 --- a/libp2p/relay/circuit_v2/transport.py +++ b/libp2p/relay/circuit_v2/transport.py @@ -55,9 +55,14 @@ from .exceptions import RelayConnectionError from .pb.circuit_pb2 import ( HopMessage, + Peer, Reservation, StopMessage, ) +from .pb_framing import ( + read_circuit_v2_pb, + write_circuit_v2_pb, +) from .performance_tracker import ( RelayPerformanceTracker, ) @@ -385,24 +390,36 @@ async def dial_peer_info( logger.warning( "Failed to make reservation with relay %s", relay_peer_id ) + # rust-libp2p relay (v0.52) finishes each HOP substream after one + # exchange. Use a new stream for CONNECT so the relay does not drop + # the connection before we read the STATUS response. + await relay_stream.close() + relay_stream = await self.host.new_stream(relay_peer_id, [PROTOCOL_ID]) + if not relay_stream: + raise ConnectionError( + f"Could not open hop stream for CONNECT to relay " + f"{relay_peer_id}" + ) # Create signed peer record to send with the HOP message envelope_bytes, _ = env_to_send_in_RPC(self.host) - # Send HOP CONNECT message + dest_peer = Peer() + dest_peer.id = dest_info.peer_id.to_bytes() connect_msg = HopMessage( - type=HopMessage.CONNECT, - peer=dest_info.peer_id.to_bytes(), + type=HopMessage.Type.CONNECT, senderRecord=envelope_bytes, ) + connect_msg.peer.CopyFrom(dest_peer) - reservation_proof = self._reservation_proofs.get(relay_peer_id) - if reservation_proof and reservation_proof.expire > int(time.time()): - connect_msg.reservation.CopyFrom(reservation_proof) - await relay_stream.write(connect_msg.SerializeToString()) + # Do not attach ``_reservation_proofs[relay]`` here: that voucher is for + # *this* host's reservation with the relay. HOP CONNECT must carry the + # *destination* peer's reservation when present (obtained out-of-band); + # wrong voucher fails ``verify_reservation(dest, ...)`` on the relay. + await write_circuit_v2_pb(relay_stream, connect_msg.SerializeToString()) # Read response with timeout with trio.fail_after(STREAM_READ_TIMEOUT): - resp_bytes = await relay_stream.read(1024) + resp_bytes = await read_circuit_v2_pb(relay_stream) resp = HopMessage() resp.ParseFromString(resp_bytes) @@ -416,9 +433,12 @@ async def dial_peer_info( # Don't fail the connection - the senderRecord is optional # and the relay might not have the destination's signed peer record - # Access status attributes directly - status_code = getattr(resp.status, "code", StatusCode.OK) - status_msg = getattr(resp.status, "message", "Unknown error") + if resp.HasField("status"): + status_code = StatusCode(resp.status) + status_msg = status_code.name + else: + status_code = StatusCode.OK + status_msg = status_code.name if status_code != StatusCode.OK: raise RelayConnectionError( @@ -580,18 +600,22 @@ async def _dial_via_circuit_addr( raise ConnectionError(f"Could not open stream to relay {relay_peer_id}") try: - hop_msg = HopMessage( - type=HopMessage.CONNECT, - peer=peer_info.peer_id.to_bytes(), - ) - await relay_stream.write(hop_msg.SerializeToString()) + dest_peer = Peer() + dest_peer.id = peer_info.peer_id.to_bytes() + hop_msg = HopMessage(type=HopMessage.Type.CONNECT) + hop_msg.peer.CopyFrom(dest_peer) + await write_circuit_v2_pb(relay_stream, hop_msg.SerializeToString()) - resp_bytes = await relay_stream.read() + resp_bytes = await read_circuit_v2_pb(relay_stream) resp = HopMessage() resp.ParseFromString(resp_bytes) - status_code = getattr(resp.status, "code", StatusCode.OK) - status_msg = getattr(resp.status, "message", "Unknown error") + if resp.HasField("status"): + status_code = StatusCode(resp.status) + status_msg = status_code.name + else: + status_code = StatusCode.OK + status_msg = status_code.name if status_code != StatusCode.OK: await relay_stream.close() @@ -798,14 +822,16 @@ async def _make_reservation( # Create signed envelope for the reservation request to relay envelope_bytes, _ = env_to_send_in_RPC(self.host) # Send reservation request + rpeer = Peer() + rpeer.id = self.host.get_id().to_bytes() reserve_msg = HopMessage( - type=HopMessage.RESERVE, - peer=self.host.get_id().to_bytes(), + type=HopMessage.Type.RESERVE, senderRecord=envelope_bytes, ) + reserve_msg.peer.CopyFrom(rpeer) try: - await stream.write(reserve_msg.SerializeToString()) + await write_circuit_v2_pb(stream, reserve_msg.SerializeToString()) logger.debug("Successfully sent reservation request") except Exception as e: logger.error("Failed to send reservation request: %s", str(e)) @@ -814,7 +840,7 @@ async def _make_reservation( # Read response with timeout with trio.fail_after(STREAM_READ_TIMEOUT): try: - resp_bytes = await stream.read(1024) + resp_bytes = await read_circuit_v2_pb(stream) logger.debug( "Received reservation response: %d bytes", len(resp_bytes) ) @@ -834,9 +860,12 @@ async def _make_reservation( ) # Don't fail the reservation - the senderRecord is optional - # Access status attributes directly - status_code = getattr(resp.status, "code", StatusCode.OK) - status_msg = getattr(resp.status, "message", "Unknown error") + if resp.HasField("status"): + status_code = StatusCode(resp.status) + status_msg = status_code.name + else: + status_code = StatusCode.OK + status_msg = status_code.name expires = getattr(resp.reservation, "expire", 0) logger.debug( @@ -852,11 +881,10 @@ async def _make_reservation( return False self._reservations[relay_peer_id] = expires - self._reservation_proofs[relay_peer_id] = Reservation( - expire=expires, - voucher=getattr(resp.reservation, "voucher", b""), - signature=getattr(resp.reservation, "signature", b""), - ) + proof = Reservation() + if resp.HasField("reservation"): + proof.CopyFrom(resp.reservation) + self._reservation_proofs[relay_peer_id] = proof ttl = max(0, expires - int(time.time())) logger.info("Reserved peer %s (ttl=%ss)", relay_peer_id, ttl) @@ -1014,16 +1042,19 @@ async def handle_incoming_connection( try: # Read STOP message - msg_bytes = await stream.read() + msg_bytes = await read_circuit_v2_pb(stream) stop_msg = StopMessage() stop_msg.ParseFromString(msg_bytes) - if stop_msg.type != StopMessage.CONNECT: + if stop_msg.type != StopMessage.Type.CONNECT: raise ConnectionError("Invalid STOP message type") + if not stop_msg.HasField("peer") or not stop_msg.peer.HasField("id"): + raise ConnectionError("Invalid STOP message peer") + # Create raw connection for relayed connection # Construct circuit multiaddr: /p2p/{relay}/p2p-circuit/p2p/{source} - peer_id = ID(stop_msg.peer) + peer_id = ID(stop_msg.peer.id) relay_peer_id = self.host.get_id() circuit_ma = multiaddr.Multiaddr( f"/p2p/{relay_peer_id.to_base58()}/p2p-circuit/p2p/{peer_id.to_base58()}" diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py index 0572fcfb9..49bef8a1b 100644 --- a/libp2p/transport/quic/transport.py +++ b/libp2p/transport/quic/transport.py @@ -13,7 +13,6 @@ from aioquic.quic.connection import ( QuicConnection as NativeQUICConnection, ) -from aioquic.quic.logger import QuicLogger import multiaddr import trio @@ -269,12 +268,16 @@ async def dial( # Get appropriate QUIC client configuration config_key = TProtocol(f"{quic_version}_client") logger.debug("config_key", config_key, self._quic_configs.keys()) - config = self._quic_configs.get(config_key) - if not config: + template = self._quic_configs.get(config_key) + if not template: raise QUICDialError(f"Unsupported QUIC version: {quic_version}") + # Per-dial copy: the cached template must not be mutated (is_client, + # quic_logger) or concurrent/overlapping dials share one QuicConfiguration + # and aioquic raises "QuicLoggerTrace does not belong to QuicLogger". + config = copy.copy(template) config.is_client = True - config.quic_logger = QuicLogger() + config.quic_logger = None # Ensure client certificate is properly set for mutual authentication if not config.certificate or not config.private_key: diff --git a/newsfragments/1304.bugfix.rst b/newsfragments/1304.bugfix.rst new file mode 100644 index 000000000..00e26ce15 --- /dev/null +++ b/newsfragments/1304.bugfix.rst @@ -0,0 +1 @@ +Circuit relay v2: open a new HOP stream to the relay for CONNECT after a client RESERVE, matching relays that close the substream after one exchange. QUIC: dial with a per-dial copy of the client ``QuicConfiguration`` and without sharing a ``QuicLogger`` on the cached template, fixing intermittent handshake and aioquic logger errors. diff --git a/pyproject.toml b/pyproject.toml index a12ceb021..3287b7ce5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,8 @@ dependencies = [ "multiaddr==0.0.11", "mypy-protobuf>=3.0.0", "noiseprotocol>=0.3.0", - "protobuf>=4.25.0,<7.0.0", + # Match generated *_pb2.py (e.g. circuit_v2 uses runtime_version + 7.34.1 validation). + "protobuf>=7.34.1,<8.0.0", "pycryptodome>=3.9.2", "py-multibase>=2.0.0", "py-multihash>=3.0.0", diff --git a/tests/core/relay/test_circuit_v2_discovery.py b/tests/core/relay/test_circuit_v2_discovery.py index 4d90208fd..de141c322 100644 --- a/tests/core/relay/test_circuit_v2_discovery.py +++ b/tests/core/relay/test_circuit_v2_discovery.py @@ -10,6 +10,10 @@ RelayDiscovery, ) from libp2p.relay.circuit_v2.pb import circuit_pb2 as proto +from libp2p.relay.circuit_v2.pb_framing import ( + read_circuit_v2_pb, + write_circuit_v2_pb, +) from libp2p.relay.circuit_v2.protocol import ( PROTOCOL_ID, STOP_PROTOCOL_ID, @@ -17,9 +21,6 @@ from libp2p.tools.anyio_service import ( background_trio_service, ) -from libp2p.tools.constants import ( - MAX_READ_LEN, -) from libp2p.tools.utils import ( connect, ) @@ -43,7 +44,7 @@ async def simple_stream_handler(stream): logger.info("Simple stream handler invoked") try: # Read the request - request_data = await stream.read(MAX_READ_LEN) + request_data = await read_circuit_v2_pb(stream) if not request_data: logger.error("Empty request received") return @@ -54,14 +55,11 @@ async def simple_stream_handler(stream): logger.info("Received request: type=%s", request.type) # Only handle RESERVE requests - if request.type == proto.HopMessage.RESERVE: + if request.type == proto.HopMessage.Type.RESERVE: # Create a valid response response = proto.HopMessage( - type=proto.HopMessage.STATUS, - status=proto.Status( - code=proto.Status.OK, - message="Test reservation accepted", - ), + type=proto.HopMessage.Type.STATUS, + status=proto.Status.OK, reservation=proto.Reservation( expire=int(time.time()) + 3600, # 1 hour from now voucher=b"test-voucher", @@ -75,7 +73,7 @@ async def simple_stream_handler(stream): # Send the response logger.info("Sending response") - await stream.write(response.SerializeToString()) + await write_circuit_v2_pb(stream, response.SerializeToString()) logger.info("Response sent") except Exception as e: logger.error("Error in simple stream handler: %s", str(e)) diff --git a/tests/core/relay/test_circuit_v2_protocol.py b/tests/core/relay/test_circuit_v2_protocol.py index 8d3878310..ba51e2373 100644 --- a/tests/core/relay/test_circuit_v2_protocol.py +++ b/tests/core/relay/test_circuit_v2_protocol.py @@ -28,6 +28,10 @@ ) from libp2p.relay.circuit_v2.pb import circuit_pb2 as proto from libp2p.relay.circuit_v2.pb.circuit_pb2 import Reservation as PbReservation +from libp2p.relay.circuit_v2.pb_framing import ( + read_circuit_v2_pb, + write_circuit_v2_pb, +) from libp2p.relay.circuit_v2.protocol import ( DEFAULT_RELAY_LIMITS, PROTOCOL_ID, @@ -42,12 +46,10 @@ from libp2p.tools.anyio_service import ( background_trio_service, ) -from libp2p.tools.constants import ( - MAX_READ_LEN, -) from libp2p.tools.utils import ( connect, ) +from libp2p.utils.varint import decode_varint_with_size from tests.utils.factories import ( HostFactory, ) @@ -143,7 +145,7 @@ async def assert_stream_response( # Try to read response logger.debug("Attempt %d: Reading response from stream", attempt + 1) - response_bytes = await stream.read(MAX_READ_LEN) + response_bytes = await read_circuit_v2_pb(stream) # Check if we got any data if not response_bytes: @@ -167,7 +169,7 @@ async def assert_stream_response( "Attempt %d: Received HOP response: type=%s, status=%s", attempt + 1, response.type, - response.status.code + int(response.status) if response.HasField("status") else "No status", ) @@ -175,12 +177,10 @@ async def assert_stream_response( all_responses.append( { "type": response.type, - "status": response.status.code - if response.HasField("status") - else None, - "message": response.status.message + "status": int(response.status) if response.HasField("status") else None, + "message": None, } ) @@ -188,7 +188,7 @@ async def assert_stream_response( if ( expected_status is not None and response.HasField("status") - and response.status.code == expected_status + and int(response.status) == expected_status ): if response.type != expected_type: logger.warning( @@ -216,15 +216,15 @@ async def assert_stream_response( # Check status code if present if response.HasField("status"): - if response.status.code != expected_status: + if int(response.status) != expected_status: logger.warning( "Wrong status code: expected %s, got %s", expected_status, - response.status.code, + int(response.status), ) last_error = ( f"Wrong status code: expected {expected_status}, " - f"got {response.status.code}" + f"got {int(response.status)}" ) if attempt < retries - 1: # Not the last attempt continue @@ -257,8 +257,8 @@ async def assert_stream_response( status_code = None status_message = None if has_status: - status_code = stop_msg.status.code - status_message = stop_msg.status.message + status_code = int(stop_msg.status) + status_message = None response_dict: dict[str, Any] = { "stop_type": stop_msg.type, # Keep original type @@ -473,7 +473,8 @@ async def test_circuit_v2_voucher_verification_complete(): ) logger.info("Reservation for wrong peer correctly rejected") - # Test with missing signature (should fail) + # Omitted signature: spec allows peers to send only voucher+expire; we still + # accept when they match the in-memory reservation (interop, e.g. rust-libp2p). no_sig_reservation = proto.Reservation( expire=pb_reservation.expire, voucher=pb_reservation.voucher, @@ -483,10 +484,10 @@ async def test_circuit_v2_voucher_verification_complete(): is_valid_no_sig = resource_manager.verify_reservation( client_peer_id, no_sig_reservation ) - assert is_valid_no_sig is False, ( - "Reservation without signature should fail verification" + assert is_valid_no_sig is True, ( + "Reservation without signature should verify when voucher+expire match" ) - logger.info("Reservation without signature correctly rejected") + logger.info("Reservation without signature accepted (interop)") # Test with expired reservation expired_reservation = resource_manager._reservations[client_peer_id] @@ -542,8 +543,10 @@ async def test_handle_reserve_returns_signed_reservation_payload(): stream = AsyncMock() stream.write = AsyncMock() - reserve_msg = proto.HopMessage(type=proto.HopMessage.RESERVE) - reserve_msg.peer = client_peer_id.to_bytes() + reserve_msg = proto.HopMessage(type=proto.HopMessage.Type.RESERVE) + _p = proto.Peer() + _p.id = client_peer_id.to_bytes() + reserve_msg.peer.CopyFrom(_p) fake_envelope = Mock() fake_envelope.marshal_envelope.return_value = b"signed-relay-record" @@ -558,17 +561,18 @@ async def test_handle_reserve_returns_signed_reservation_payload(): return_value=fake_envelope, ), ): - await protocol._handle_reserve(stream, reserve_msg) + await protocol._handle_reserve(stream, reserve_msg, client_peer_id) assert stream.write.await_count == 1 await_args = stream.write.await_args assert await_args is not None - response_bytes = await_args.args[0] + framed = await_args.args[0] + plen, off = decode_varint_with_size(framed) response = proto.HopMessage() - response.ParseFromString(response_bytes) + response.ParseFromString(framed[off : off + plen]) - assert response.type == proto.HopMessage.STATUS - assert response.status.code == proto.Status.OK + assert response.type == proto.HopMessage.Type.STATUS + assert response.status == proto.Status.OK assert response.reservation.voucher != b"" assert response.reservation.signature != b"" @@ -588,13 +592,13 @@ async def mock_reserve_handler(stream): # Read the request logger.info("Mock handler received stream request") try: - request_data = await stream.read(MAX_READ_LEN) + request_data = await read_circuit_v2_pb(stream) request = proto.HopMessage() request.ParseFromString(request_data) logger.info("Mock handler parsed request: type=%s", request.type) # Only handle RESERVE requests - if request.type == proto.HopMessage.RESERVE: + if request.type == proto.HopMessage.Type.RESERVE: # Check if the request contains signed peer records (SPR validation) if request.HasField("senderRecord"): logger.info("Request contains senderRecord, validating SPR") @@ -610,11 +614,8 @@ async def mock_reserve_handler(stream): # Create a valid response response = proto.HopMessage( - type=proto.HopMessage.RESERVE, - status=proto.Status( - code=proto.Status.OK, - message="Reservation accepted", - ), + type=proto.HopMessage.Type.STATUS, + status=proto.Status.OK, reservation=proto.Reservation( expire=int(time.time()) + 3600, # 1 hour from now voucher=b"test-voucher", @@ -628,7 +629,7 @@ async def mock_reserve_handler(stream): # Send the response logger.info("Mock handler sending response") - await stream.write(response.SerializeToString()) + await write_circuit_v2_pb(stream, response.SerializeToString()) logger.info("Mock handler sent response") # Keep stream open for client to read response @@ -671,14 +672,16 @@ async def mock_reserve_handler(stream): # Get the peer record, handle None case client_host_envelope, _ = env_to_send_in_RPC(client_host) request = proto.HopMessage( - type=proto.HopMessage.RESERVE, - peer=client_host.get_id().to_bytes(), + type=proto.HopMessage.Type.RESERVE, # Client sends its signed-peer records in reservation request senderRecord=client_host_envelope, ) + _cp = proto.Peer() + _cp.id = client_host.get_id().to_bytes() + request.peer.CopyFrom(_cp) logger.info("Sending reservation request") - await stream.write(request.SerializeToString()) + await write_circuit_v2_pb(stream, request.SerializeToString()) logger.info("Reservation request sent") # Wait to ensure the request is processed @@ -686,7 +689,7 @@ async def mock_reserve_handler(stream): # Read response directly logger.info("Reading response directly") - response_bytes = await stream.read(MAX_READ_LEN) + response_bytes = await read_circuit_v2_pb(stream) assert response_bytes, "No response received" # Parse response @@ -694,12 +697,12 @@ async def mock_reserve_handler(stream): response.ParseFromString(response_bytes) # Verify response - assert response.type == proto.HopMessage.RESERVE, ( + assert response.type == proto.HopMessage.Type.STATUS, ( f"Wrong response type: {response.type}" ) assert response.HasField("status"), "No status field" - assert response.status.code == proto.Status.OK, ( - f"Wrong status code: {response.status.code}" + assert response.status == proto.Status.OK, ( + f"Wrong status code: {response.status}" ) # Verify reservation details @@ -737,121 +740,95 @@ async def test_circuit_v2_reservation_limit(): # Custom handler that responds based on reservation limits async def mock_reserve_handler(stream): - # Read the request logger.info("Mock handler received stream request") try: - request_data = await stream.read(MAX_READ_LEN) + request_data = await read_circuit_v2_pb(stream) request = proto.HopMessage() request.ParseFromString(request_data) logger.info("Mock handler parsed request: type=%s", request.type) - # Only handle RESERVE requests - if request.type == proto.HopMessage.RESERVE: - # Extract peer ID from request - peer_id = ID(request.peer) - logger.info( - "Mock handler received reservation request from %s", peer_id - ) - # Check if reservation request has senderRecord - if request.HasField("senderRecord"): - try: - # Validate the SPR using the real validation function - if maybe_consume_signed_record( - request, relay_host, peer_id - ): - logger.info( - "Reservation request from %s contain valid records", - peer_id, - ) - else: - logger.warning("Invalid senderRecord from %s", peer_id) - response = proto.HopMessage( - type=proto.HopMessage.RESERVE, - status=proto.Status( - code=proto.Status.PERMISSION_DENIED, - message="Invalid senderRecord", - ), - ) - await stream.write(response.SerializeToString()) - return - except Exception as e: - logger.warning( - "SPR validation error for %s: %s", peer_id, e - ) + if request.type != proto.HopMessage.Type.RESERVE: + return + + if not request.HasField("peer") or not request.peer.id: + return + peer_id = ID(request.peer.id) + logger.info( + "Mock handler received reservation request from %s", peer_id + ) + + if request.HasField("senderRecord"): + try: + if not maybe_consume_signed_record( + request, relay_host, peer_id + ): + logger.warning("Invalid senderRecord from %s", peer_id) response = proto.HopMessage( - type=proto.HopMessage.RESERVE, - status=proto.Status( - code=proto.Status.PERMISSION_DENIED, - message=f"SPR validation error: {e}", - ), + type=proto.HopMessage.Type.STATUS, + status=proto.Status.PERMISSION_DENIED, + ) + await write_circuit_v2_pb( + stream, response.SerializeToString() ) - await stream.write(response.SerializeToString()) return - else: - logger.warning( - "Reservation request from %s is missing senderRecord", + logger.info( + "Reservation request from %s contain valid records", peer_id, ) + except Exception as e: + logger.warning("SPR validation error for %s: %s", peer_id, e) response = proto.HopMessage( - type=proto.HopMessage.RESERVE, - status=proto.Status( - code=proto.Status.PERMISSION_DENIED, - message="Missing senderRecord", - ), + type=proto.HopMessage.Type.STATUS, + status=proto.Status.PERMISSION_DENIED, ) - await stream.write(response.SerializeToString()) + await write_circuit_v2_pb(stream, response.SerializeToString()) return + else: + logger.warning( + "Reservation request from %s is missing senderRecord", + peer_id, + ) + response = proto.HopMessage( + type=proto.HopMessage.Type.STATUS, + status=proto.Status.PERMISSION_DENIED, + ) + await write_circuit_v2_pb(stream, response.SerializeToString()) + return + + if ( + peer_id in reserved_clients + or len(reserved_clients) < max_reservations + ): + if peer_id not in reserved_clients: + reserved_clients.add(peer_id) + response = proto.HopMessage( + type=proto.HopMessage.Type.STATUS, + status=proto.Status.OK, + reservation=proto.Reservation( + expire=int(time.time()) + 3600, + voucher=b"test-voucher", + signature=b"", + ), + limit=proto.Limit( + duration=3600, + data=1024 * 1024 * 1024, + ), + ) + logger.info("Mock handler accepting reservation for %s", peer_id) + else: + response = proto.HopMessage( + type=proto.HopMessage.Type.STATUS, + status=proto.Status.RESOURCE_LIMIT_EXCEEDED, + ) + logger.info( + "Mock handler rejecting reservation for %s due to limit", + peer_id, + ) - # Check if we've reached reservation limit - if ( - peer_id in reserved_clients - or len(reserved_clients) < max_reservations - ): - # Accept the reservation - if peer_id not in reserved_clients: - reserved_clients.add(peer_id) - - # Create a success response - response = proto.HopMessage( - type=proto.HopMessage.RESERVE, - status=proto.Status( - code=proto.Status.OK, - message="Reservation accepted", - ), - reservation=proto.Reservation( - expire=int(time.time()) + 3600, # 1 hour from now - voucher=b"test-voucher", - signature=b"", - ), - limit=proto.Limit( - duration=3600, # 1 hour - data=1024 * 1024 * 1024, # 1GB - ), - ) - logger.info( - "Mock handler accepting reservation for %s", peer_id - ) - else: - # Reject the reservation due to limits - response = proto.HopMessage( - type=proto.HopMessage.RESERVE, - status=proto.Status( - code=proto.Status.RESOURCE_LIMIT_EXCEEDED, - message="Reservation limit exceeded", - ), - ) - logger.info( - "Mock handler rejecting reservation for %s due to limit", - peer_id, - ) - - # Send the response - logger.info("Mock handler sending response") - await stream.write(response.SerializeToString()) - logger.info("Mock handler sent response") - - # Keep stream open for client to read response - await trio.sleep(5) + logger.info("Mock handler sending response") + await write_circuit_v2_pb(stream, response.SerializeToString()) + logger.info("Mock handler sent response") + await trio.sleep(5) except Exception as e: logger.error("Error in mock handler: %s", str(e)) @@ -893,13 +870,15 @@ async def mock_reserve_handler(stream): client1_host_envelope, _ = env_to_send_in_RPC(client1_host) logger.info("Preparing reservation request for client1") request1 = proto.HopMessage( - type=proto.HopMessage.RESERVE, - peer=client1_host.get_id().to_bytes(), + type=proto.HopMessage.Type.RESERVE, senderRecord=client1_host_envelope, ) + _p1 = proto.Peer() + _p1.id = client1_host.get_id().to_bytes() + request1.peer.CopyFrom(_p1) logger.info("Sending reservation request for client1") - await stream1.write(request1.SerializeToString()) + await write_circuit_v2_pb(stream1, request1.SerializeToString()) logger.info("Sent reservation request for client1") # Wait to ensure the request is processed @@ -907,7 +886,7 @@ async def mock_reserve_handler(stream): # Read response directly logger.info("Reading response for client1") - response_bytes = await stream1.read(MAX_READ_LEN) + response_bytes = await read_circuit_v2_pb(stream1) assert response_bytes, "No response received for client1" # Parse response @@ -915,12 +894,12 @@ async def mock_reserve_handler(stream): response1.ParseFromString(response_bytes) # Verify response - assert response1.type == proto.HopMessage.RESERVE, ( + assert response1.type == proto.HopMessage.Type.STATUS, ( f"Wrong response type: {response1.type}" ) assert response1.HasField("status"), "No status field" - assert response1.status.code == proto.Status.OK, ( - f"Wrong status code: {response1.status.code}" + assert response1.status == proto.Status.OK, ( + f"Wrong status code: {response1.status}" ) # Verify reservation details @@ -952,13 +931,15 @@ async def mock_reserve_handler(stream): client2_host_envelope, _ = env_to_send_in_RPC(client2_host) logger.info("Preparing reservation request for client2") request2 = proto.HopMessage( - type=proto.HopMessage.RESERVE, - peer=client2_host.get_id().to_bytes(), + type=proto.HopMessage.Type.RESERVE, senderRecord=client2_host_envelope, ) + _p2 = proto.Peer() + _p2.id = client2_host.get_id().to_bytes() + request2.peer.CopyFrom(_p2) logger.info("Sending reservation request for client2") - await stream2.write(request2.SerializeToString()) + await write_circuit_v2_pb(stream2, request2.SerializeToString()) logger.info("Sent reservation request for client2") # Wait to ensure the request is processed @@ -966,7 +947,7 @@ async def mock_reserve_handler(stream): # Read response directly logger.info("Reading response for client2") - response_bytes = await stream2.read(MAX_READ_LEN) + response_bytes = await read_circuit_v2_pb(stream2) assert response_bytes, "No response received for client2" # Parse response @@ -974,12 +955,12 @@ async def mock_reserve_handler(stream): response2.ParseFromString(response_bytes) # Verify response - assert response2.type == proto.HopMessage.RESERVE, ( + assert response2.type == proto.HopMessage.Type.STATUS, ( f"Wrong response type: {response2.type}" ) assert response2.HasField("status"), "No status field" - assert response2.status.code == proto.Status.RESOURCE_LIMIT_EXCEEDED, ( - f"Wrong status code: {response2.status.code}, " + assert response2.status == proto.Status.RESOURCE_LIMIT_EXCEEDED, ( + f"Wrong status code: {response2.status}, " f"expected RESOURCE_LIMIT_EXCEEDED" ) logger.info("Verified client2 was correctly rejected") @@ -1014,40 +995,37 @@ async def test_circuit_v2_fails_with_invalid_SPR(): # Handler that checks SPR validity async def spr_validation_handler(stream): try: - request_data = await stream.read(MAX_READ_LEN) + request_data = await read_circuit_v2_pb(stream) request = proto.HopMessage() request.ParseFromString(request_data) - if request.type == proto.HopMessage.RESERVE: + if request.type == proto.HopMessage.Type.RESERVE: # Reject specific invalid SPR if ( request.HasField("senderRecord") and request.senderRecord == b"invalid-spr" ): status_code = proto.Status.MALFORMED_MESSAGE - message = "Invalid SPR rejected" else: status_code = proto.Status.OK - message = "Valid SPR accepted" response = proto.HopMessage( - type=proto.HopMessage.RESERVE, - status=proto.Status(code=status_code, message=message), + type=proto.HopMessage.Type.STATUS, + status=status_code, ) - await stream.write(response.SerializeToString()) + await write_circuit_v2_pb(stream, response.SerializeToString()) await trio.sleep(2) # Brief wait for client to read except Exception as e: logger.error("Handler error: %s", str(e)) try: error_response = proto.HopMessage( - type=proto.HopMessage.RESERVE, - status=proto.Status( - code=proto.Status.MALFORMED_MESSAGE, - message=f"Handler error: {str(e)}", - ), + type=proto.HopMessage.Type.STATUS, + status=proto.Status.MALFORMED_MESSAGE, + ) + await write_circuit_v2_pb( + stream, error_response.SerializeToString() ) - await stream.write(error_response.SerializeToString()) except Exception: pass @@ -1063,22 +1041,24 @@ async def spr_validation_handler(stream): relay_host.get_id(), [PROTOCOL_ID] ) request = proto.HopMessage( - type=proto.HopMessage.RESERVE, - peer=client_host.get_id().to_bytes(), + type=proto.HopMessage.Type.RESERVE, senderRecord=b"invalid-spr", # Invalid SPR ) - await stream.write(request.SerializeToString()) + _ip = proto.Peer() + _ip.id = client_host.get_id().to_bytes() + request.peer.CopyFrom(_ip) + await write_circuit_v2_pb(stream, request.SerializeToString()) await trio.sleep(SLEEP_TIME) - response_bytes = await stream.read(MAX_READ_LEN) + response_bytes = await read_circuit_v2_pb(stream) assert response_bytes, "No response received" response = proto.HopMessage() response.ParseFromString(response_bytes) assert response.HasField("status"), "No status field" - assert response.status.code == proto.Status.MALFORMED_MESSAGE, ( - f"Expected MALFORMED_MESSAGE, got {response.status.code}" + assert response.status == proto.Status.MALFORMED_MESSAGE, ( + f"Expected MALFORMED_MESSAGE, got {response.status}" ) logger.info("Successfully verified invalid SPR rejection") finally: @@ -1148,18 +1128,20 @@ async def test_reservation_fails_with_invalid_record_transfer(): ) request = proto.HopMessage( - type=proto.HopMessage.RESERVE, - peer=client_host.get_id().to_bytes(), + type=proto.HopMessage.Type.RESERVE, senderRecord=corrupted_env.marshal_envelope(), # Invalid SPR ) + _rp = proto.Peer() + _rp.id = client_host.get_id().to_bytes() + request.peer.CopyFrom(_rp) - await stream.write(request.SerializeToString()) + await write_circuit_v2_pb(stream, request.SerializeToString()) logger.info("Sent request with invalid SPR") await trio.sleep(SLEEP_TIME) # Try to read response, but expect the stream to be closed try: - response_bytes = await stream.read(MAX_READ_LEN) + response_bytes = await read_circuit_v2_pb(stream) if not response_bytes: # Empty response indicates stream was closed stream_closed_by_relay = True @@ -1171,11 +1153,9 @@ async def test_reservation_fails_with_invalid_record_transfer(): if ( response.HasField("status") - and response.status.code != proto.Status.OK + and response.status != proto.Status.OK ): - logger.info( - f"Invalid SPR rejected : {response.status.code}" - ) + logger.info(f"Invalid SPR rejected : {response.status}") else: logger.warning("Unexpected response to invalid SPR") except (StreamEOF, StreamError, StreamReset) as e: diff --git a/tests/core/relay/test_circuit_v2_transport.py b/tests/core/relay/test_circuit_v2_transport.py index 34ecabe8c..32ea0e682 100644 --- a/tests/core/relay/test_circuit_v2_transport.py +++ b/tests/core/relay/test_circuit_v2_transport.py @@ -35,7 +35,7 @@ CircuitV2Protocol, RelayLimits, ) -from libp2p.relay.circuit_v2.protocol_buffer import StatusCode, create_status +from libp2p.relay.circuit_v2.protocol_buffer import StatusCode from libp2p.relay.circuit_v2.transport import ( ID, PROTOCOL_ID, @@ -49,6 +49,7 @@ from libp2p.tools.utils import ( connect, ) +from libp2p.utils.varint import decode_varint_with_size from tests.utils.factories import ( HostFactory, ) @@ -343,7 +344,7 @@ async def test_circuit_v2_transport_message_routing_through_relay(): roles=RelayRole.STOP | RelayRole.CLIENT, limits=dest_limits ) dest_protocol = CircuitV2Protocol(target_host, dest_limits, allow_hop=False) - CircuitV2Transport(target_host, dest_protocol, dest_config) + dest_transport = CircuitV2Transport(target_host, dest_protocol, dest_config) target_host.set_stream_handler(PROTOCOL_ID, dest_protocol._handle_hop_stream) target_host.set_stream_handler( STOP_PROTOCOL_ID, dest_protocol._handle_stop_stream @@ -373,6 +374,17 @@ async def app_echo_handler(stream): await trio.sleep(SLEEP_TIME) + # Destination must reserve a slot on the relay before inbound circuits are + # allowed. + relay_id_for_dest = relay_host.get_id() + res_stream = await target_host.new_stream(relay_id_for_dest, [PROTOCOL_ID]) + try: + assert await dest_transport._make_reservation(res_stream, relay_id_for_dest) + finally: + await res_stream.close() + + await trio.sleep(SLEEP_TIME) + # Step 2: Source connects to Relay with trio.fail_after(CONNECT_TIMEOUT): await connect(client_host, relay_host) @@ -1629,12 +1641,14 @@ async def test_dial_peer_info_creates_and_stores_circuit(protocol): peerstore.addrs.return_value = [relay_addr] - status = create_status(code=StatusCode.OK, message="OK") - hop_resp = HopMessage(type=HopMessage.STATUS, status=status) - relay_stream.read.return_value = hop_resp.SerializeToString() + hop_resp = HopMessage(type=HopMessage.Type.STATUS, status=int(StatusCode.OK)) relay_stream.write = AsyncMock() - conn = await transport.dial_peer_info(peer_info) + with patch( + "libp2p.relay.circuit_v2.transport.read_circuit_v2_pb", + AsyncMock(return_value=hop_resp.SerializeToString()), + ): + conn = await transport.dial_peer_info(peer_info) peerstore.add_addrs.assert_called_once() assert isinstance(conn, TrackedRawConnection) @@ -1642,7 +1656,8 @@ async def test_dial_peer_info_creates_and_stores_circuit(protocol): @pytest.mark.trio -async def test_dial_peer_info_includes_reservation_proof(protocol): +async def test_dial_peer_info_connect_does_not_send_client_reservation(protocol): + """CONNECT must not embed this host's relay reservation (dest's voucher).""" mock_host = Mock() peerstore = Mock() mock_host.get_peerstore.return_value = peerstore @@ -1664,12 +1679,6 @@ async def test_dial_peer_info_includes_reservation_proof(protocol): mock_host.connect = AsyncMock(return_value=None) relay_stream = AsyncMock() relay_stream.write = AsyncMock() - relay_stream.read = AsyncMock( - return_value=HopMessage( - type=HopMessage.STATUS, - status=create_status(code=StatusCode.OK, message="connected"), - ).SerializeToString() - ) mock_host.new_stream = AsyncMock(return_value=relay_stream) transport = CircuitV2Transport( @@ -1679,27 +1688,36 @@ async def test_dial_peer_info_includes_reservation_proof(protocol): ) transport._select_relay = AsyncMock(return_value=relay_peer_id) - reservation_expiry = int(time.time()) + 120 transport._reservation_proofs[relay_peer_id] = Reservation( - expire=reservation_expiry, + expire=int(time.time()) + 120, voucher=b"voucher-bytes", signature=b"signature-bytes", ) - with patch( - "libp2p.relay.circuit_v2.transport.env_to_send_in_RPC", - return_value=(b"", None), + connect_resp = HopMessage( + type=HopMessage.Type.STATUS, + status=int(StatusCode.OK), + ) + with ( + patch( + "libp2p.relay.circuit_v2.transport.env_to_send_in_RPC", + return_value=(b"", None), + ), + patch( + "libp2p.relay.circuit_v2.transport.read_circuit_v2_pb", + AsyncMock(return_value=connect_resp.SerializeToString()), + ), ): await transport.dial_peer_info(dest_info) outbound_bytes = relay_stream.write.await_args_list[0].args[0] + plen, off = decode_varint_with_size(outbound_bytes) outbound_hop = HopMessage() - outbound_hop.ParseFromString(outbound_bytes) + outbound_hop.ParseFromString(outbound_bytes[off : off + plen]) - assert outbound_hop.type == HopMessage.CONNECT - assert outbound_hop.reservation.expire == reservation_expiry - assert outbound_hop.reservation.voucher == b"voucher-bytes" - assert outbound_hop.reservation.signature == b"signature-bytes" + assert outbound_hop.type == HopMessage.Type.CONNECT + assert outbound_hop.peer.id == dest_peer_id.to_bytes() + assert not outbound_hop.HasField("reservation") def test_valid_circuit_multiaddr(circuit_v2_transport):