diff --git a/extra/multihash-spec b/extra/multihash-spec deleted file mode 160000 index b43ec1026..000000000 --- a/extra/multihash-spec +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b43ec1026a610fa87878e53b3daecf3a14b3ef6f diff --git a/extra/py-multihash b/extra/py-multihash deleted file mode 160000 index dfae0dd7a..000000000 --- a/extra/py-multihash +++ /dev/null @@ -1 +0,0 @@ -Subproject commit dfae0dd7a66e0f5a0346d0297e03582443297b9c diff --git a/extra/pymultihash b/extra/pymultihash deleted file mode 160000 index 215298fa2..000000000 --- a/extra/pymultihash +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 215298fa2faa55027384d1f22519229d0918cfb0 diff --git a/libp2p/security/noise/io.py b/libp2p/security/noise/io.py index 18fbbcd5c..d1460ed13 100644 --- a/libp2p/security/noise/io.py +++ b/libp2p/security/noise/io.py @@ -22,6 +22,9 @@ MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1 SIZE_NOISE_MESSAGE_BODY_LEN = 2 MAX_NOISE_MESSAGE_BODY_LEN = MAX_NOISE_MESSAGE_LEN - SIZE_NOISE_MESSAGE_BODY_LEN +# Max plaintext per Noise message: 65535 - 16 bytes Poly1305 MAC overhead. +# Matches go-libp2p's MaxPlaintextLength in p2p/security/noise/rw.go. +MAX_PLAINTEXT_LENGTH = MAX_NOISE_MESSAGE_LEN - 16 BYTE_ORDER = "big" # | Noise packet | @@ -53,14 +56,26 @@ def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None: self.noise_state = noise_state async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None: - logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes") - data_encrypted = self.encrypt(msg) - if prefix_encoded: - # Manually add the prefix if needed - data_encrypted = self.prefix + data_encrypted - logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes") - await self.read_writer.write_msg(data_encrypted) - logger.debug("Noise write_msg: write completed successfully") + # Chunk large messages to stay within the Noise 65535-byte transport + # message limit, matching go-libp2p's noise/rw.go Write() approach. + if len(msg) <= MAX_PLAINTEXT_LENGTH: + # Fast path: single message (covers handshake and small writes) + data_encrypted = self.encrypt(msg) + if prefix_encoded: + data_encrypted = self.prefix + data_encrypted + await self.read_writer.write_msg(data_encrypted) + else: + # Slow path: chunk into multiple Noise messages + total = len(msg) + written = 0 + while written < total: + end = min(written + MAX_PLAINTEXT_LENGTH, total) + chunk = msg[written:end] + data_encrypted = self.encrypt(chunk) + if prefix_encoded and written == 0: + data_encrypted = self.prefix + data_encrypted + await self.read_writer.write_msg(data_encrypted) + written = end async def read_msg(self, prefix_encoded: bool = False) -> bytes: logger.debug("Noise read_msg: reading encrypted message") diff --git a/libp2p/security/secure_session.py b/libp2p/security/secure_session.py index 29a970507..1147c9ce9 100644 --- a/libp2p/security/secure_session.py +++ b/libp2p/security/secure_session.py @@ -94,24 +94,42 @@ async def read(self, n: int | None = None) -> bytes: return b"" data_from_buffer = self._drain(n) - if len(data_from_buffer) > 0: + if n is None and len(data_from_buffer) > 0: return data_from_buffer - msg = await self.conn.read_msg() + if n is None: + msg = await self.conn.read_msg() - # If underlying connection returned empty bytes, treat as closed - # and raise to signal that reads after close are invalid. - if msg == b"": - raise Exception("Connection closed") + # If underlying connection returned empty bytes, treat as closed + # and raise to signal that reads after close are invalid. + if msg == b"": + raise Exception("Connection closed") - if n is None: return msg - if n < len(msg): - self._fill(msg) - return self._drain(n) - else: - return msg + if len(data_from_buffer) == n: + return data_from_buffer + + result = bytearray(data_from_buffer) + while len(result) < n: + msg = await self.conn.read_msg() + + # If the connection closes after a partial read, return the bytes + # we already assembled. This preserves the stream-read behavior + # expected by higher layers. + if msg == b"": + if result: + return bytes(result) + raise Exception("Connection closed") + + remaining = n - len(result) + if len(msg) <= remaining: + result.extend(msg) + else: + result.extend(msg[:remaining]) + self._fill(msg[remaining:]) + + return bytes(result) async def write(self, data: bytes) -> None: await self.conn.write_msg(data) diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index 8acc1a9ea..d565e7e72 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -70,6 +70,9 @@ # Network byte order: version (B), type (B), flags (H), stream_id (I), length (I) YAMUX_HEADER_FORMAT = "!BBHII" DEFAULT_WINDOW_SIZE = 256 * 1024 +MAX_WINDOW_SIZE = 16 * 1024 * 1024 # 16 MB max receive window (matches go-yamux) +MAX_MESSAGE_SIZE = 64 * 1024 # 64KB max frame payload, matches go-yamux default +RTT_MEASURE_INTERVAL = 30 # seconds between RTT measurements GO_AWAY_NORMAL = 0x0 GO_AWAY_PROTOCOL_ERROR = 0x1 @@ -77,6 +80,9 @@ class YamuxStream(IMuxedStream): + target_recv_window: int + epoch_start: float + def __init__(self, stream_id: int, conn: "Yamux", is_initiator: bool) -> None: self.stream_id = stream_id self.conn = conn @@ -89,6 +95,8 @@ def __init__(self, stream_id: int, conn: "Yamux", is_initiator: bool) -> None: self.send_window = DEFAULT_WINDOW_SIZE self.recv_window = DEFAULT_WINDOW_SIZE self.window_lock = trio.Lock() + self.target_recv_window = DEFAULT_WINDOW_SIZE # grows up to MAX_WINDOW_SIZE + self.epoch_start = 0.0 # trio.current_time() of last window update self.rw_lock = ReadWriteLock() self.close_lock = trio.Lock() @@ -143,8 +151,13 @@ async def write(self, data: bytes) -> None: if self.closed: raise MuxedStreamError("Stream is closed") - # Calculate how much we can send now - to_send = min(self.send_window, total_len - sent) + # Calculate how much we can send now (cap at MaxMessageSize + # minus header, matching go-yamux's per-frame limit) + to_send = min( + self.send_window, + MAX_MESSAGE_SIZE - HEADER_SIZE, + total_len - sent, + ) chunk = data[sent : sent + to_send] self.send_window -= to_send @@ -152,7 +165,7 @@ async def write(self, data: bytes) -> None: header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk) ) - await self.conn.secured_conn.write(header + chunk) + await self.conn._write_frame(header + chunk) sent += to_send async def send_window_update(self, increment: int, skip_lock: bool = False) -> None: @@ -192,7 +205,7 @@ async def _do_window_update() -> None: increment, ) try: - await self.conn.secured_conn.write(header) + await self.conn._write_frame(header) except ConnectionClosedError as e: # Typed exception from transports (e.g., WebSocket) that # properly signal connection closure — handle gracefully. @@ -234,6 +247,53 @@ async def _do_window_update() -> None: async with self.window_lock: await _do_window_update() + async def _auto_tune_and_send_window_update(self: "YamuxStream") -> None: + """ + Auto-tune receive window size based on RTT and send window update. + + Ports go-yamux's two-pass GrowTo + sendWindowUpdate logic: + - Pass 1: GrowTo(current_target) — restore window to current target + - Auto-tune: if within 4x RTT of last epoch, double the target + - Pass 2: GrowTo(new_target, force=True) — grow to new target + - Only the final delta is sent to the peer (matches go-yamux behavior) + """ + async with self.window_lock: + # Match go-yamux GrowTo: currentWindow = cap + len + buffered = len(self.conn.stream_buffers.get(self.stream_id, b"")) + current_window = self.recv_window + buffered + + # Pass 1: GrowTo(target_recv_window) — like go's first GrowTo call + delta = self.target_recv_window - current_window + if delta <= 0: + return + # Hysteresis: skip if delta < 50% of target (matches go-yamux GrowTo) + if delta < self.target_recv_window // 2: + return + # Apply first pass growth to recv_window (like go's cap += delta) + self.recv_window += delta + + # Auto-tune: if within 4x RTT of last epoch, double the target + now = trio.current_time() + rtt = self.conn.rtt() + if rtt > 0 and self.epoch_start > 0 and (now - self.epoch_start) < rtt * 4: + new_target = min(self.target_recv_window * 2, MAX_WINDOW_SIZE) + if new_target > self.target_recv_window: + self.target_recv_window = new_target + # Pass 2: GrowTo(new_target, force=True) — incremental + # Recompute current_window after pass 1 growth + new_current = self.recv_window + buffered + extra_delta = self.target_recv_window - new_current + if extra_delta > 0: + self.recv_window += extra_delta + delta += extra_delta # Send total delta (pass 1 + pass 2) + + self.epoch_start = now + logger.debug( + f"Stream {self.stream_id}: Auto-tune window update " + f"delta={delta}, target={self.target_recv_window}" + ) + await self.send_window_update(delta, skip_lock=True) + async def read(self, n: int | None = -1) -> bytes: """ Read data from the stream. @@ -288,11 +348,8 @@ async def read(self, n: int | None = -1) -> bytes: buffer.clear() data += chunk - # Send window update for the chunk we just read - async with self.window_lock: - self.recv_window += len(chunk) - logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}") - await self.send_window_update(len(chunk), skip_lock=True) + # Auto-tune and send window update for the chunk we just read + await self._auto_tune_and_send_window_update() # Check for reset if self.reset_received: @@ -337,13 +394,7 @@ async def read(self, n: int | None = -1) -> bytes: return b"" else: data = await self.conn.read_stream(self.stream_id, n) - async with self.window_lock: - self.recv_window += len(data) - logger.debug( - f"Stream {self.stream_id}: Sending window update after read, " - f"increment={len(data)}" - ) - await self.send_window_update(len(data), skip_lock=True) + await self._auto_tune_and_send_window_update() return data async def close(self) -> None: @@ -352,9 +403,14 @@ async def close(self) -> None: logger.debug(f"Half-closing stream {self.stream_id} (local end)") try: header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0 + YAMUX_HEADER_FORMAT, + 0, + TYPE_WINDOW_UPDATE, + FLAG_FIN, + self.stream_id, + 0, ) - await self.conn.secured_conn.write(header) + await self.conn._write_frame(header) except (RawConnError, ConnectionClosedError) as e: logger.debug(f"Error sending FIN, connection likely closed: {e}") finally: @@ -373,9 +429,14 @@ async def reset(self) -> None: logger.debug(f"Resetting stream {self.stream_id}") try: header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0 + YAMUX_HEADER_FORMAT, + 0, + TYPE_WINDOW_UPDATE, + FLAG_RST, + self.stream_id, + 0, ) - await self.conn.secured_conn.write(header) + await self.conn._write_frame(header) except (RawConnError, ConnectionClosedError) as e: logger.debug(f"Error sending RST, connection likely closed: {e}") finally: @@ -450,8 +511,45 @@ def __init__( self.event_started = trio.Event() self.stream_buffers: dict[int, bytearray] = {} self.stream_events: dict[int, trio.Event] = {} + self._write_lock = trio.Lock() self._nursery: Nursery | None = None self._established: bool = False + self._rtt: float = 0.0 # smoothed RTT in seconds + self._ping_id: int = 0 # incrementing ping nonce + self._ping_sent_time: float = 0.0 # trio.current_time() when ping sent + self._ping_event: trio.Event = trio.Event() + + def rtt(self) -> float: + """Return the current smoothed RTT estimate in seconds.""" + return self._rtt + + async def _measure_rtt_loop(self) -> None: + """Background task that periodically measures RTT via ping/pong.""" + # Initial delay to let the connection establish + await trio.sleep(0.5) + while not self.event_shutting_down.is_set(): + try: + self._ping_id += 1 + self._ping_event = trio.Event() + header = struct.pack( + YAMUX_HEADER_FORMAT, 0, TYPE_PING, FLAG_SYN, 0, self._ping_id + ) + await self._write_frame(header) + # Record time AFTER write completes, matching go-yamux which + # times after dispatch to avoid including write-lock wait time. + self._ping_sent_time = trio.current_time() + # Wait for pong with timeout + with trio.move_on_after(10.0): + await self._ping_event.wait() + except Exception: + # Connection likely closed, exit the loop + break + if self.event_shutting_down.is_set(): + break + # Sleep between measurements, checking shutdown periodically + with trio.move_on_after(RTT_MEASURE_INTERVAL): + while not self.event_shutting_down.is_set(): + await trio.sleep(1.0) @property def is_established(self) -> bool: @@ -480,10 +578,14 @@ async def start(self) -> None: logger.debug( f"Yamux.start() starting handle_incoming task for {self.peer_id}" ) + + nursery.start_soon(self._measure_rtt_loop) # Use nursery.start() to ensure handle_incoming has started # before we set event_started. This prevents race conditions # where streams are opened before the muxer is ready. + # When handle_incoming exits, the finally block cancels the nursery. await nursery.start(self._handle_incoming_with_ready_signal) + logger.debug(f"Yamux.start() setting event_started for {self.peer_id}") self._established = True self.event_started.set() @@ -512,7 +614,7 @@ async def close(self, error_code: int = GO_AWAY_NORMAL) -> None: header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_GO_AWAY, 0, 0, error_code ) - await self.secured_conn.write(header) + await self._write_frame(header) except Exception as e: logger.debug(f"Failed to send GO_AWAY: {e}") self.event_shutting_down.set() @@ -559,6 +661,32 @@ def get_connection_type(self) -> ConnectionType: """ return self.secured_conn.get_connection_type() + async def _write_frame(self, data: bytes) -> None: + """Write a frame to the connection, serializing all writes.""" + if len(data) >= HEADER_SIZE: + _, typ, flags, sid, length = struct.unpack( + YAMUX_HEADER_FORMAT, data[:HEADER_SIZE] + ) + flag_names = [] + if flags & FLAG_SYN: + flag_names.append("SYN") + if flags & FLAG_ACK: + flag_names.append("ACK") + if flags & FLAG_FIN: + flag_names.append("FIN") + if flags & FLAG_RST: + flag_names.append("RST") + type_names = {0: "DATA", 1: "WINDOW_UPDATE", 2: "PING", 3: "GO_AWAY"} + logger.debug( + f"YAMUX TX: type={type_names.get(typ, typ)} " + f"flags={'+'.join(flag_names) or '0'} " + f"stream={sid} length={length} " + f"is_initiator={self.is_initiator_value} " + f"payload_bytes={len(data) - HEADER_SIZE}" + ) + async with self._write_lock: + await self.secured_conn.write(data) + async def open_stream(self) -> YamuxStream: # Wait for backlog slot await self.stream_backlog_semaphore.acquire() @@ -576,10 +704,15 @@ async def open_stream(self) -> YamuxStream: # If stream is rejected or errors, release the semaphore try: header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0 + YAMUX_HEADER_FORMAT, + 0, + TYPE_WINDOW_UPDATE, + FLAG_SYN, + stream_id, + 0, ) logger.debug(f"Sending SYN header for stream {stream_id}") - await self.secured_conn.write(header) + await self._write_frame(header) return stream except Exception as e: self.stream_backlog_semaphore.release() @@ -721,13 +854,18 @@ async def _handle_incoming_with_ready_signal( This method uses trio's task_status to signal that the handle_incoming loop is ready to process frames. This prevents race conditions where streams are opened before the muxer is ready to handle them. + When handle_incoming exits, this cancels the nursery scope. """ logger.debug( f"Yamux _handle_incoming_with_ready_signal() starting for " f"peer {self.peer_id}" ) task_status.started() - await self.handle_incoming() + try: + await self.handle_incoming() + finally: + if self._nursery is not None: + self._nursery.cancel_scope.cancel() async def handle_incoming(self) -> None: logger.debug(f"Yamux handle_incoming() started for peer {self.peer_id}") @@ -795,10 +933,21 @@ async def handle_incoming(self) -> None: version, typ, flags, stream_id, length = struct.unpack( YAMUX_HEADER_FORMAT, header ) + type_names = {0: "DATA", 1: "WINDOW_UPDATE", 2: "PING", 3: "GO_AWAY"} + flag_names = [] + if flags & FLAG_SYN: + flag_names.append("SYN") + if flags & FLAG_ACK: + flag_names.append("ACK") + if flags & FLAG_FIN: + flag_names.append("FIN") + if flags & FLAG_RST: + flag_names.append("RST") logger.debug( - f"Received header for peer {self.peer_id}:" - f"type={typ}, flags={flags}, stream_id={stream_id}," - f"length={length}" + f"YAMUX RX: type={type_names.get(typ, typ)} " + f"flags={'+'.join(flag_names) or '0'} " + f"stream={stream_id} length={length} " + f"is_initiator={self.is_initiator_value}" ) if (typ == TYPE_DATA or typ == TYPE_WINDOW_UPDATE) and flags & FLAG_SYN: async with self.streams_lock: @@ -808,11 +957,28 @@ async def handle_incoming(self) -> None: self.stream_buffers[stream_id] = bytearray() self.stream_events[stream_id] = trio.Event() - # Read any data that came with the SYN frame - if length > 0: + if typ == TYPE_WINDOW_UPDATE and length > 0: + # Window update SYN: length is a delta + # to add to the initial send window + async with stream.window_lock: + stream.send_window += length + logger.debug( + f"SYN window update for stream " + f"{stream_id}: window={length}" + ) + elif typ == TYPE_DATA and length > 0: + # Data SYN: length is payload bytes try: data = await read_exactly(self.secured_conn, length) self.stream_buffers[stream_id].extend(data) + stream.recv_window -= len(data) + if stream.recv_window < 0: + logger.warning( + f"Stream {stream_id}: peer exceeded " + f"receive window by " + f"{-stream.recv_window} bytes" + ) + stream.recv_window = 0 self.stream_events[stream_id].set() logger.debug( f"Read {length} bytes with SYN " @@ -820,10 +986,9 @@ async def handle_incoming(self) -> None: ) except IncompleteReadError as e: logger.error( - "Incomplete read for SYN data on stream " - f"{stream_id}: {e}" + "Incomplete read for SYN data on " + f"stream {stream_id}: {e}" ) - # Mark stream as closed stream.recv_closed = True stream.closed = True if stream_id in self.stream_events: @@ -832,12 +997,12 @@ async def handle_incoming(self) -> None: ack_header = struct.pack( YAMUX_HEADER_FORMAT, 0, - TYPE_DATA, + TYPE_WINDOW_UPDATE, FLAG_ACK, stream_id, 0, ) - await self.secured_conn.write(ack_header) + await self._write_frame(ack_header) logger.debug( f"Sending stream {stream_id}" f"to channel for peer {self.peer_id}" @@ -847,40 +1012,62 @@ async def handle_incoming(self) -> None: rst_header = struct.pack( YAMUX_HEADER_FORMAT, 0, - TYPE_DATA, + TYPE_WINDOW_UPDATE, FLAG_RST, stream_id, 0, ) - await self.secured_conn.write(rst_header) - elif typ == TYPE_DATA and flags & FLAG_ACK: + await self._write_frame(rst_header) + elif ( + typ == TYPE_DATA or typ == TYPE_WINDOW_UPDATE + ) and flags & FLAG_ACK: async with self.streams_lock: if stream_id in self.streams: - # Read any data that came with the ACK - if length > 0: + stream = self.streams[stream_id] + if typ == TYPE_WINDOW_UPDATE: + # Window update ACK: length is a delta + # (matches go-yamux incrSendWindow). + if length > 0: + async with stream.window_lock: + stream.send_window += length + logger.debug( + f"Received WINDOW_UPDATE ACK for stream " + f"{stream_id}, send_window={length} " + f"for peer {self.peer_id}" + ) + elif typ == TYPE_DATA and length > 0: + # Data ACK: length is payload bytes try: data = await read_exactly(self.secured_conn, length) self.stream_buffers[stream_id].extend(data) + self.streams[stream_id].recv_window -= len(data) + if self.streams[stream_id].recv_window < 0: + logger.warning( + f"Stream {stream_id}: peer exceeded " + f"receive window by " + f"{-self.streams[stream_id].recv_window}" + f" bytes" + ) + self.streams[stream_id].recv_window = 0 self.stream_events[stream_id].set() logger.debug( - f"Received ACK with {length} bytes for stream " - f"{stream_id} for peer {self.peer_id}" + f"Received ACK with {length} bytes " + f"for stream {stream_id} " + f"for peer {self.peer_id}" ) except IncompleteReadError as e: logger.error( - "Incomplete read for ACK data on stream " - f"{stream_id}: {e}" + "Incomplete read for ACK data on " + f"stream {stream_id}: {e}" ) - # Mark stream as closed - stream = self.streams[stream_id] stream.recv_closed = True stream.closed = True if stream_id in self.stream_events: self.stream_events[stream_id].set() else: logger.debug( - f"Received ACK (no data) for stream {stream_id} " - f"for peer {self.peer_id}" + f"Received ACK (no data) for stream " + f"{stream_id} for peer {self.peer_id}" ) elif typ == TYPE_GO_AWAY: error_code = length @@ -914,11 +1101,19 @@ async def handle_incoming(self) -> None: ping_header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_PING, FLAG_ACK, 0, length ) - await self.secured_conn.write(ping_header) + await self._write_frame(ping_header) elif flags & FLAG_ACK: + # Compute RTT with exponential smoothing + now = trio.current_time() + new_rtt = now - self._ping_sent_time + if self._rtt == 0.0: + self._rtt = new_rtt + else: + self._rtt = (self._rtt + new_rtt) / 2 + self._ping_event.set() logger.debug( f"Received ping response with value" - f"{length} for peer {self.peer_id}" + f"{length} for peer {self.peer_id}, rtt={self._rtt:.4f}s" ) elif typ == TYPE_DATA: try: @@ -962,6 +1157,15 @@ async def handle_incoming(self) -> None: async with self.streams_lock: if stream_id in self.streams: self.stream_buffers[stream_id].extend(data) + self.streams[stream_id].recv_window -= len(data) + if self.streams[stream_id].recv_window < 0: + logger.warning( + f"Stream {stream_id}: peer exceeded " + f"receive window by " + f"{-self.streams[stream_id].recv_window}" + f" bytes" + ) + self.streams[stream_id].recv_window = 0 # Always set event, even if no data # in case FIN/RST is set self.stream_events[stream_id].set() diff --git a/newsfragments/1270.feature.rst b/newsfragments/1270.feature.rst new file mode 100644 index 000000000..7b1cd8137 --- /dev/null +++ b/newsfragments/1270.feature.rst @@ -0,0 +1 @@ +Added yamux receive window auto-tuning: the per-stream receive window starts at 256 KB and doubles each RTT epoch up to 16 MB, matching go-yamux behavior for improved throughput on high-bandwidth connections. diff --git a/newsfragments/1271.bugfix.rst b/newsfragments/1271.bugfix.rst new file mode 100644 index 000000000..b26f965db --- /dev/null +++ b/newsfragments/1271.bugfix.rst @@ -0,0 +1 @@ +Fixed yamux interoperability with go-yamux: SYN/ACK/FIN/RST frames are now sent as TYPE_WINDOW_UPDATE (not TYPE_DATA), writes are serialized with a lock to prevent frame interleaving, and SYN/ACK window values match go-yamux conventions so peers no longer get an inflated send window. diff --git a/tests/core/security/noise/test_large_payloads.py b/tests/core/security/noise/test_large_payloads.py index deb2985c1..107f418a1 100644 --- a/tests/core/security/noise/test_large_payloads.py +++ b/tests/core/security/noise/test_large_payloads.py @@ -16,6 +16,23 @@ class TestLargePayloads: """Test large payload handling in Noise transport.""" + @pytest.mark.trio + async def test_go_large_payload_roundtrip(self, nursery): + """Match go-libp2p's large-payload transport test.""" + async with noise_conn_factory(nursery) as conns: + local_conn, remote_conn = conns + + random.seed(1234) + size = 100000 + test_data = bytes(random.getrandbits(8) for _ in range(size)) + + await local_conn.write(test_data) + + received_data = await remote_conn.read(len(test_data)) + + assert len(received_data) == len(test_data) + assert received_data == test_data + @pytest.mark.trio async def test_large_payload_roundtrip(self, nursery): """Test large payload requiring multiple Noise messages.""" diff --git a/tests/core/stream_muxer/test_yamux.py b/tests/core/stream_muxer/test_yamux.py index 8e0befc89..fc1cc02af 100644 --- a/tests/core/stream_muxer/test_yamux.py +++ b/tests/core/stream_muxer/test_yamux.py @@ -323,9 +323,10 @@ async def test_yamux_flow_control(yamux_pair): # Send the data await client_stream.write(large_data) - # Check that window was reduced - assert client_stream.send_window < initial_window, ( - "Window should be reduced after sending" + # Window was reduced by the send; ACK may have already restored some, + # but it should differ from the initial value. + assert client_stream.send_window != initial_window, ( + "Window should have changed after sending data and receiving ACK" ) # Read the data on the server side diff --git a/tests/core/stream_muxer/yamux/test_yamux_window_update_error_handling.py b/tests/core/stream_muxer/yamux/test_yamux_window_update_error_handling.py index 92715e6b9..617a91f05 100644 --- a/tests/core/stream_muxer/yamux/test_yamux_window_update_error_handling.py +++ b/tests/core/stream_muxer/yamux/test_yamux_window_update_error_handling.py @@ -41,8 +41,7 @@ async def test_send_window_update_handles_connection_closed_error(): by type — no string matching required. """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock( + mock_conn._write_frame = AsyncMock( side_effect=ConnectionClosedError( "WebSocket connection closed by peer during write operation", close_code=1000, @@ -57,7 +56,7 @@ async def test_send_window_update_handles_connection_closed_error(): # Should not raise — ConnectionClosedError is handled gracefully await stream.send_window_update(32) - assert mock_conn.secured_conn.write.called + assert mock_conn._write_frame.called @pytest.mark.trio @@ -75,14 +74,13 @@ async def test_send_window_update_handles_connection_closed_error_any_message(): for msg in unusual_messages: mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock( + mock_conn._write_frame = AsyncMock( side_effect=ConnectionClosedError(msg, close_code=1000) ) stream = YamuxStream(1, mock_conn, is_initiator=True) await stream.send_window_update(32) # Should not raise - assert mock_conn.secured_conn.write.called + assert mock_conn._write_frame.called # --------------------------------------------------------------------------- @@ -97,10 +95,7 @@ async def test_send_window_update_handles_raw_conn_error(): gracefully (string-matching fallback for TCP transport). """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock( - side_effect=RawConnError("Connection closed") - ) + mock_conn._write_frame = AsyncMock(side_effect=RawConnError("Connection closed")) stream_id = 1 stream = YamuxStream(stream_id, mock_conn, is_initiator=True) @@ -108,7 +103,7 @@ async def test_send_window_update_handles_raw_conn_error(): # Should not raise — falls through to string-matching fallback await stream.send_window_update(32) - assert mock_conn.secured_conn.write.called + assert mock_conn._write_frame.called @pytest.mark.trio @@ -126,14 +121,13 @@ async def test_send_window_update_handles_various_closure_messages(): for error_msg in closure_messages: mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock(side_effect=IOException(error_msg)) + mock_conn._write_frame = AsyncMock(side_effect=IOException(error_msg)) stream = YamuxStream(1, mock_conn, is_initiator=True) # Should not raise for any of these messages await stream.send_window_update(32) - assert mock_conn.secured_conn.write.called + assert mock_conn._write_frame.called # --------------------------------------------------------------------------- @@ -147,8 +141,7 @@ async def test_send_window_update_raises_unexpected_errors(): Test that unexpected errors (not connection closure) are still raised. """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock(side_effect=ValueError("Unexpected error")) + mock_conn._write_frame = AsyncMock(side_effect=ValueError("Unexpected error")) stream_id = 1 stream = YamuxStream(stream_id, mock_conn, is_initiator=True) @@ -163,8 +156,7 @@ async def test_send_window_update_raises_non_closure_io_exception(): Test that plain IOException with non-closure message is still raised. """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock(side_effect=IOException("Disk full error")) + mock_conn._write_frame = AsyncMock(side_effect=IOException("Disk full error")) stream_id = 1 stream = YamuxStream(stream_id, mock_conn, is_initiator=True) @@ -184,16 +176,15 @@ async def test_send_window_update_succeeds_when_connection_open(): Test that send_window_update succeeds normally when connection is open. """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock() # No error + mock_conn._write_frame = AsyncMock() # No error stream_id = 1 stream = YamuxStream(stream_id, mock_conn, is_initiator=True) await stream.send_window_update(32) - assert mock_conn.secured_conn.write.called - call_args = mock_conn.secured_conn.write.call_args[0][0] + assert mock_conn._write_frame.called + call_args = mock_conn._write_frame.call_args[0][0] assert len(call_args) == 12 # Yamux header is 12 bytes assert call_args[1] == 0x1 # Window update type @@ -204,13 +195,13 @@ async def test_send_window_update_skips_zero_increment(): Test that send_window_update skips sending when increment is zero or negative. """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() + mock_conn._write_frame = AsyncMock() stream_id = 1 stream = YamuxStream(stream_id, mock_conn, is_initiator=True) await stream.send_window_update(0) - assert not mock_conn.secured_conn.write.called + assert not mock_conn._write_frame.called await stream.send_window_update(-1) - assert not mock_conn.secured_conn.write.called + assert not mock_conn._write_frame.called