diff --git a/libp2p/relay/circuit_v2/pb/circuit.proto b/libp2p/relay/circuit_v2/pb/circuit.proto index 4ffc9413d..1fd19d817 100644 --- a/libp2p/relay/circuit_v2/pb/circuit.proto +++ b/libp2p/relay/circuit_v2/pb/circuit.proto @@ -49,6 +49,7 @@ message Status { PERMISSION_DENIED = 102; CONNECTION_FAILED = 200; DIAL_REFUSED = 201; + NO_RESERVATION = 204; STOP_FAILED = 300; MALFORMED_MESSAGE = 400; } diff --git a/libp2p/relay/circuit_v2/pb/circuit_pb2.py b/libp2p/relay/circuit_v2/pb/circuit_pb2.py index ebd95e3ba..1df14ae86 100644 --- a/libp2p/relay/circuit_v2/pb/circuit_pb2.py +++ b/libp2p/relay/circuit_v2/pb/circuit_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(libp2p/relay/circuit_v2/pb/circuit.proto\x12\rcircuit.pb.v2\"\x9f\x02\n\nHopMessage\x12,\n\x04type\x18\x01 \x01(\x0e\x32\x1e.circuit.pb.v2.HopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12/\n\x0breservation\x18\x03 \x01(\x0b\x32\x1a.circuit.pb.v2.Reservation\x12#\n\x05limit\x18\x04 \x01(\x0b\x32\x14.circuit.pb.v2.Limit\x12%\n\x06status\x18\x05 \x01(\x0b\x32\x15.circuit.pb.v2.Status\x12\x19\n\x0csenderRecord\x18\x06 \x01(\x0cH\x00\x88\x01\x01\",\n\x04Type\x12\x0b\n\x07RESERVE\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\n\n\x06STATUS\x10\x02\x42\x0f\n\r_senderRecord\"\xbe\x01\n\x0bStopMessage\x12-\n\x04type\x18\x01 \x01(\x0e\x32\x1f.circuit.pb.v2.StopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12%\n\x06status\x18\x03 \x01(\x0b\x32\x15.circuit.pb.v2.Status\x12\x19\n\x0csenderRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\"\x1f\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\n\n\x06STATUS\x10\x01\x42\x0f\n\r_senderRecord\"A\n\x0bReservation\x12\x0f\n\x07voucher\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c\x12\x0e\n\x06\x65xpire\x18\x03 \x01(\x03\"\'\n\x05Limit\x12\x10\n\x08\x64uration\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x03\"\xf6\x01\n\x06Status\x12(\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1a.circuit.pb.v2.Status.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xb0\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\x17\n\x13RESERVATION_REFUSED\x10\x64\x12\x1b\n\x17RESOURCE_LIMIT_EXCEEDED\x10\x65\x12\x15\n\x11PERMISSION_DENIED\x10\x66\x12\x16\n\x11\x43ONNECTION_FAILED\x10\xc8\x01\x12\x11\n\x0c\x44IAL_REFUSED\x10\xc9\x01\x12\x10\n\x0bSTOP_FAILED\x10\xac\x02\x12\x16\n\x11MALFORMED_MESSAGE\x10\x90\x03\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(libp2p/relay/circuit_v2/pb/circuit.proto\x12\rcircuit.pb.v2\"\x9f\x02\n\nHopMessage\x12,\n\x04type\x18\x01 \x01(\x0e\x32\x1e.circuit.pb.v2.HopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12/\n\x0breservation\x18\x03 \x01(\x0b\x32\x1a.circuit.pb.v2.Reservation\x12#\n\x05limit\x18\x04 \x01(\x0b\x32\x14.circuit.pb.v2.Limit\x12%\n\x06status\x18\x05 \x01(\x0b\x32\x15.circuit.pb.v2.Status\x12\x19\n\x0csenderRecord\x18\x06 \x01(\x0cH\x00\x88\x01\x01\",\n\x04Type\x12\x0b\n\x07RESERVE\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\n\n\x06STATUS\x10\x02\x42\x0f\n\r_senderRecord\"\xbe\x01\n\x0bStopMessage\x12-\n\x04type\x18\x01 \x01(\x0e\x32\x1f.circuit.pb.v2.StopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12%\n\x06status\x18\x03 \x01(\x0b\x32\x15.circuit.pb.v2.Status\x12\x19\n\x0csenderRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\"\x1f\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\n\n\x06STATUS\x10\x01\x42\x0f\n\r_senderRecord\"A\n\x0bReservation\x12\x0f\n\x07voucher\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c\x12\x0e\n\x06\x65xpire\x18\x03 \x01(\x03\"\'\n\x05Limit\x12\x10\n\x08\x64uration\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x03\"\x8b\x02\n\x06Status\x12(\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1a.circuit.pb.v2.Status.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc5\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\x17\n\x13RESERVATION_REFUSED\x10\x64\x12\x1b\n\x17RESOURCE_LIMIT_EXCEEDED\x10\x65\x12\x15\n\x11PERMISSION_DENIED\x10\x66\x12\x16\n\x11\x43ONNECTION_FAILED\x10\xc8\x01\x12\x11\n\x0c\x44IAL_REFUSED\x10\xc9\x01\x12\x13\n\x0eNO_RESERVATION\x10\xcc\x01\x12\x10\n\x0bSTOP_FAILED\x10\xac\x02\x12\x16\n\x11MALFORMED_MESSAGE\x10\x90\x03\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -44,7 +44,7 @@ _globals['_LIMIT']._serialized_start=609 _globals['_LIMIT']._serialized_end=648 _globals['_STATUS']._serialized_start=651 - _globals['_STATUS']._serialized_end=897 + _globals['_STATUS']._serialized_end=918 _globals['_STATUS_CODE']._serialized_start=721 - _globals['_STATUS_CODE']._serialized_end=897 + _globals['_STATUS_CODE']._serialized_end=918 # @@protoc_insertion_point(module_scope) diff --git a/libp2p/relay/circuit_v2/pb/circuit_pb2.pyi b/libp2p/relay/circuit_v2/pb/circuit_pb2.pyi index abe0cd47b..936cd44cb 100644 --- a/libp2p/relay/circuit_v2/pb/circuit_pb2.pyi +++ b/libp2p/relay/circuit_v2/pb/circuit_pb2.pyi @@ -166,6 +166,7 @@ class Status(google.protobuf.message.Message): PERMISSION_DENIED: Status._Code.ValueType # 102 CONNECTION_FAILED: Status._Code.ValueType # 200 DIAL_REFUSED: Status._Code.ValueType # 201 + NO_RESERVATION: Status._Code.ValueType # 204 STOP_FAILED: Status._Code.ValueType # 300 MALFORMED_MESSAGE: Status._Code.ValueType # 400 @@ -176,6 +177,7 @@ class Status(google.protobuf.message.Message): PERMISSION_DENIED: Status.Code.ValueType # 102 CONNECTION_FAILED: Status.Code.ValueType # 200 DIAL_REFUSED: Status.Code.ValueType # 201 + NO_RESERVATION: Status.Code.ValueType # 204 STOP_FAILED: Status.Code.ValueType # 300 MALFORMED_MESSAGE: Status.Code.ValueType # 400 diff --git a/libp2p/relay/circuit_v2/protocol.py b/libp2p/relay/circuit_v2/protocol.py index ca7cb6da3..a7e9533b4 100644 --- a/libp2p/relay/circuit_v2/protocol.py +++ b/libp2p/relay/circuit_v2/protocol.py @@ -585,7 +585,7 @@ async def _handle_reserve(self, stream: INetStream, msg: HopMessage) -> None: status_msg_text = "Reservation accepted" # Get the reservation object to access its voucher and sign it - reservation_obj = self.resource_manager._reservations.get(peer_id) + reservation_obj = self.resource_manager.get_reservation(peer_id) if not reservation_obj: raise ValueError(f"Failed to create reservation for peer {peer_id}") @@ -688,14 +688,27 @@ async def _handle_connect(self, stream: INetStream, msg: HopMessage) -> None: await stream.reset() return - # Check resource limits - if not self.resource_manager.can_accept_connection(peer_id=source_addr): + if not self.resource_manager.can_accept_connection(peer_id=peer_id): + relay_envelope_bytes, _ = env_to_send_in_RPC(self.host) + relay_envelope = unmarshal_envelope(relay_envelope_bytes) + await self._send_status( + stream, + StatusCode.NO_RESERVATION, + "Destination peer has no active reservation on this relay", + relay_envelope, + ) + await stream.reset() + return + + # Separately enforce the source peer's per-reservation connection limit. + source_reservation = self.resource_manager.get_reservation(source_addr) + if source_reservation and not source_reservation.can_accept_connection(): relay_envelope_bytes, _ = env_to_send_in_RPC(self.host) relay_envelope = unmarshal_envelope(relay_envelope_bytes) await self._send_status( stream, StatusCode.RESOURCE_LIMIT_EXCEEDED, - "Connection limit exceeded", + "Source peer has exceeded its connection limit", relay_envelope, ) await stream.reset() @@ -772,7 +785,7 @@ async def _handle_connect(self, stream: INetStream, msg: HopMessage) -> None: logger.debug("Connection established for peer %s", peer_id) # Update reservation connection count - reservation = self.resource_manager._reservations.get(peer_id) + reservation = self.resource_manager.get_reservation(peer_id) if reservation: reservation.active_connections += 1 logger.debug( @@ -848,7 +861,7 @@ async def _relay_data( """ try: # Get the reservation for tracking data usage - reservation = self.resource_manager._reservations.get(peer_id) + reservation = self.resource_manager.get_reservation(peer_id) total_bytes = 0 while True: @@ -896,7 +909,7 @@ async def _relay_data( break # Update resource usage - reservation = self.resource_manager._reservations.get(peer_id) + reservation = self.resource_manager.get_reservation(peer_id) if reservation: reservation.data_used += len(data) if reservation.data_used >= reservation.limits.data: diff --git a/libp2p/relay/circuit_v2/protocol_buffer.py b/libp2p/relay/circuit_v2/protocol_buffer.py index 509cea1c6..2f5375623 100644 --- a/libp2p/relay/circuit_v2/protocol_buffer.py +++ b/libp2p/relay/circuit_v2/protocol_buffer.py @@ -23,6 +23,7 @@ class StatusCode(IntEnum): PERMISSION_DENIED = 102 CONNECTION_FAILED = 200 DIAL_REFUSED = 201 + NO_RESERVATION = 204 STOP_FAILED = 300 MALFORMED_MESSAGE = 400 diff --git a/libp2p/relay/circuit_v2/resources.py b/libp2p/relay/circuit_v2/resources.py index b509f2f49..b94aadec4 100644 --- a/libp2p/relay/circuit_v2/resources.py +++ b/libp2p/relay/circuit_v2/resources.py @@ -461,7 +461,7 @@ def can_accept_connection(self, peer_id: ID) -> bool: True if the connection can be accepted """ - reservation = self._reservations.get(peer_id) + reservation = self.get_reservation(peer_id) return reservation is not None and reservation.can_accept_connection() def track_data_transfer(self, peer_id: ID, bytes_transferred: int) -> bool: @@ -552,3 +552,23 @@ def refresh_reservation(self, peer_id: ID) -> int: return self.limits.duration return 0 + + def get_reservation(self, peer_id: ID) -> Reservation | None: + """ + Get an active reservation for a peer. + + Parameters + ---------- + peer_id : ID + The peer ID to get the reservation for + + Returns + ------- + Reservation | None + The reservation if it exists and is active, None otherwise + + """ + reservation = self._reservations.get(peer_id) + if reservation and not reservation.is_expired(): + return reservation + return None diff --git a/newsfragments/1342.bugfix.rst b/newsfragments/1342.bugfix.rst new file mode 100644 index 000000000..fc1c52c02 --- /dev/null +++ b/newsfragments/1342.bugfix.rst @@ -0,0 +1 @@ +Fixed Circuit Relay v2 issue where a source peer could connect to a destination peer without an active reservation on the relay. diff --git a/tests/core/relay/test_circuit_v2_protocol.py b/tests/core/relay/test_circuit_v2_protocol.py index 8d3878310..6271792e8 100644 --- a/tests/core/relay/test_circuit_v2_protocol.py +++ b/tests/core/relay/test_circuit_v2_protocol.py @@ -1218,3 +1218,165 @@ async def test_reservation_fails_with_invalid_record_transfer(): logger.info("Invalid SPR was correctly rejected") logger.info("Invalid SPR correctly rejected, peerstore protected") + + +@pytest.mark.trio +async def test_circuit_v2_connect_fails_without_reservation(): + """Test that relay rejects CONNECT requests for unreserved destinations.""" + async with HostFactory.create_batch_and_listen(3) as hosts: + relay_host, source_host, dest_host = hosts + logger.info( + "Created hosts for test_circuit_v2_connect_fails_without_reservation" + ) + + # Setup relay + limits = RelayLimits( + duration=DEFAULT_RELAY_LIMITS.duration, + data=DEFAULT_RELAY_LIMITS.data, + max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns, + max_reservations=DEFAULT_RELAY_LIMITS.max_reservations, + ) + relay_protocol = CircuitV2Protocol(relay_host, limits, allow_hop=True) + + async with background_trio_service(relay_protocol): + await relay_protocol.event_started.wait() + + # Connect source to relay + await connect(source_host, relay_host) + await trio.sleep(SLEEP_TIME) + + # Source sends CONNECT request for dest (who has NO reservation) + stream = None + try: + with trio.fail_after(STREAM_TIMEOUT): + stream = await source_host.new_stream( + relay_host.get_id(), [PROTOCOL_ID] + ) + + connect_msg = proto.HopMessage( + type=proto.HopMessage.CONNECT, + peer=dest_host.get_id().to_bytes(), + ) + + await stream.write(connect_msg.SerializeToString()) + logger.info( + "Sent CONNECT request for destination without reservation" + ) + + # Read response + response_bytes = await stream.read(MAX_READ_LEN) + assert response_bytes, "No response received" + + response = proto.HopMessage() + response.ParseFromString(response_bytes) + + # Verify response status is NO_RESERVATION (204) + assert response.type == proto.HopMessage.STATUS + assert response.status.code == proto.Status.NO_RESERVATION, ( + "Expected status NO_RESERVATION(204), " + f"got {response.status.code}" + ) + + # Verify stream is reset (or EOF) + try: + next_data = await stream.read(MAX_READ_LEN) + assert not next_data, ( + "Stream should be closed/reset after error status" + ) + except (StreamEOF, StreamReset, StreamError): + pass + + finally: + if stream: + await close_stream(stream) + + +@pytest.mark.trio +async def test_circuit_v2_connect_fails_when_source_limit_exceeded(): + """Test that relay rejects CONNECT when source exceeds connection limit.""" + async with HostFactory.create_batch_and_listen(3) as hosts: + relay_host, source_host, dest_host = hosts + logger.info( + "Created hosts for test_circuit_v2_connect_fails_when_source_limit_exceeded" + ) + + limits = RelayLimits( + duration=DEFAULT_RELAY_LIMITS.duration, + data=DEFAULT_RELAY_LIMITS.data, + max_circuit_conns=1, + max_reservations=DEFAULT_RELAY_LIMITS.max_reservations, + ) + relay_protocol = CircuitV2Protocol(relay_host, limits, allow_hop=True) + + async with background_trio_service(relay_protocol): + await relay_protocol.event_started.wait() + + async def send_reserve(host) -> None: + envelope_bytes, _ = env_to_send_in_RPC(host) + stream = await host.new_stream(relay_host.get_id(), [PROTOCOL_ID]) + try: + reserve_msg = proto.HopMessage( + type=proto.HopMessage.RESERVE, + peer=host.get_id().to_bytes(), + senderRecord=envelope_bytes, + ) + await stream.write(reserve_msg.SerializeToString()) + await stream.read(MAX_READ_LEN) + finally: + await close_stream(stream) + + await connect(dest_host, relay_host) + await connect(source_host, relay_host) + await trio.sleep(SLEEP_TIME) + + await send_reserve(dest_host) + await send_reserve(source_host) + await trio.sleep(SLEEP_TIME) + + source_reservation = relay_protocol.resource_manager.get_reservation( + source_host.get_id() + ) + assert source_reservation is not None + source_reservation.active_connections = 1 + + stream = None + try: + with trio.fail_after(STREAM_TIMEOUT): + stream = await source_host.new_stream( + relay_host.get_id(), [PROTOCOL_ID] + ) + + connect_msg = proto.HopMessage( + type=proto.HopMessage.CONNECT, + peer=dest_host.get_id().to_bytes(), + ) + await stream.write(connect_msg.SerializeToString()) + logger.info( + "Sent CONNECT request with source connection limit exceeded" + ) + + response_bytes = await stream.read(MAX_READ_LEN) + assert response_bytes, "No response received" + + response = proto.HopMessage() + response.ParseFromString(response_bytes) + + assert response.type == proto.HopMessage.STATUS + assert ( + response.status.code == proto.Status.RESOURCE_LIMIT_EXCEEDED + ), ( + "Expected status RESOURCE_LIMIT_EXCEEDED(101), " + f"got {response.status.code}" + ) + + try: + next_data = await stream.read(MAX_READ_LEN) + assert not next_data, ( + "Stream should be closed/reset after error status" + ) + except (StreamEOF, StreamReset, StreamError): + pass + + finally: + if stream: + await close_stream(stream) diff --git a/tests/core/relay/test_circuit_v2_transport.py b/tests/core/relay/test_circuit_v2_transport.py index 34ecabe8c..d08a9bc26 100644 --- a/tests/core/relay/test_circuit_v2_transport.py +++ b/tests/core/relay/test_circuit_v2_transport.py @@ -21,6 +21,7 @@ from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from libp2p.relay.circuit_v2.config import RelayConfig, RelayRole from libp2p.relay.circuit_v2.discovery import ( RelayDiscovery, @@ -373,7 +374,24 @@ async def app_echo_handler(stream): await trio.sleep(SLEEP_TIME) - # Step 2: Source connects to Relay + # Step 2: Destination makes a reservation on the relay. + with trio.fail_after(CONNECT_TIMEOUT): + dest_relay_stream = await target_host.new_stream( + relay_host.get_id(), [PROTOCOL_ID] + ) + envelope_bytes, _ = env_to_send_in_RPC(target_host) + reserve_msg = HopMessage( + type=HopMessage.RESERVE, + peer=target_host.get_id().to_bytes(), + senderRecord=envelope_bytes, + ) + await dest_relay_stream.write(reserve_msg.SerializeToString()) + # Read and discard the STATUS response from the relay + await dest_relay_stream.read(1024) + + await trio.sleep(SLEEP_TIME) + + # Step 3: Source connects to Relay with trio.fail_after(CONNECT_TIMEOUT): await connect(client_host, relay_host) assert relay_host.get_id() in client_host.get_network().connections @@ -383,7 +401,7 @@ async def app_echo_handler(stream): relay_id = relay_host.get_id() client_discovery.get_relay = lambda: relay_id - # Step 3: Source tries to dial the destination via p2p-circuit and opens stream + # Step 4: Source tries to dial the destination via p2p-circuit and opens stream relay_addr = relay_host.get_addrs()[0] dest_id = target_host.get_id() p2p_circuit_addr = Multiaddr(f"{relay_addr}/p2p-circuit/p2p/{dest_id}")