diff --git a/benchmarks/bench_noise_pq.py b/benchmarks/bench_noise_pq.py new file mode 100644 index 000000000..dff9fdf6d --- /dev/null +++ b/benchmarks/bench_noise_pq.py @@ -0,0 +1,509 @@ +""" +Noise PQ vs classical Noise benchmarks for py-libp2p. + +Measures: + 1. X-Wing KEM micro-benchmarks: keygen, encapsulate, decapsulate + 2. Classical Noise XX handshake latency (round-trip) + 3. Noise XXhfs (PQ) handshake latency (round-trip) + 4. Post-handshake transport throughput: 1 KB, 10 KB, 100 KB + +All latencies are median over N_HANDSHAKES iterations. +Throughput is measured as MB/s sustained over N_THROUGHPUT rounds. + +Run: + cd py-libp2p + python benchmarks/bench_noise_pq.py +""" + +import math +from pathlib import Path +import statistics +import time + +import trio + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +N_WARMUP = 3 # warm-up rounds discarded before measurement +N_HANDSHAKES = 50 # handshake latency iterations +N_THROUGHPUT = 200 # throughput iterations per message size +N_KEM = 200 # KEM micro-benchmark iterations + +# --------------------------------------------------------------------------- +# In-memory connection (same as test helpers) +# --------------------------------------------------------------------------- + + +class _MemoryConn: + def __init__(self, send_chan, recv_chan) -> None: + self._send = send_chan + self._recv = recv_chan + self._buf = bytearray() + + async def read(self, n: int | None = None) -> bytes: + while not self._buf: + try: + chunk = await self._recv.receive() + except trio.EndOfChannel: + return b"" + self._buf.extend(chunk) + if n is None: + data = bytes(self._buf) + self._buf.clear() + return data + data = bytes(self._buf[:n]) + del self._buf[:n] + return data + + async def write(self, data: bytes) -> None: + await self._send.send(bytes(data)) + + async def close(self) -> None: + await self._send.aclose() + + def get_remote_address(self) -> None: + return None + + def get_transport_addresses(self) -> list: + return [] + + def get_connection_type(self): + from libp2p.connection_types import ConnectionType + + return ConnectionType.UNKNOWN + + +def _make_conn_pair() -> tuple[_MemoryConn, _MemoryConn]: + a_to_b_s, a_to_b_r = trio.open_memory_channel(math.inf) + b_to_a_s, b_to_a_r = trio.open_memory_channel(math.inf) + return _MemoryConn(a_to_b_s, b_to_a_r), _MemoryConn(b_to_a_s, a_to_b_r) + + +# --------------------------------------------------------------------------- +# Key helpers +# --------------------------------------------------------------------------- + + +def _make_pq_pair(): + """Return (transport_local, peer_remote, transport_remote).""" + from libp2p.crypto.ed25519 import create_new_key_pair + from libp2p.crypto.keys import KeyPair + from libp2p.crypto.x25519 import X25519PrivateKey + from libp2p.peer.id import ID + from libp2p.security.noise.pq.transport_pq import TransportPQ + + kp_l = create_new_key_pair() + kp_r = create_new_key_pair() + t_l = TransportPQ( + KeyPair(kp_l.private_key, kp_l.public_key), X25519PrivateKey.new() + ) + t_r = TransportPQ( + KeyPair(kp_r.private_key, kp_r.public_key), X25519PrivateKey.new() + ) + peer_r = ID.from_pubkey(kp_r.public_key) + return t_l, peer_r, t_r + + +def _make_classical_pair(): + """Return (transport_local, peer_remote, transport_remote).""" + from libp2p.crypto.ed25519 import create_new_key_pair + from libp2p.crypto.keys import KeyPair + from libp2p.crypto.x25519 import X25519PrivateKey + from libp2p.peer.id import ID + from libp2p.security.noise.transport import Transport + + kp_l = create_new_key_pair() + kp_r = create_new_key_pair() + t_l = Transport(KeyPair(kp_l.private_key, kp_l.public_key), X25519PrivateKey.new()) + t_r = Transport(KeyPair(kp_r.private_key, kp_r.public_key), X25519PrivateKey.new()) + peer_r = ID.from_pubkey(kp_r.public_key) + return t_l, peer_r, t_r + + +# --------------------------------------------------------------------------- +# Benchmark helpers +# --------------------------------------------------------------------------- + + +def _fmt(ms: float, ops_s: float) -> str: + return f"{ms:7.2f} ms/op ({ops_s:8.1f} ops/s)" + + +def _stats(samples_s: list[float]) -> tuple[float, float]: + """Return (median_ms, ops_per_sec) from list of elapsed seconds.""" + med_ms = statistics.median(samples_s) * 1000 + ops_s = 1.0 / statistics.median(samples_s) + return med_ms, ops_s + + +# --------------------------------------------------------------------------- +# 1. KEM micro-benchmarks (synchronous, no trio needed) +# --------------------------------------------------------------------------- + + +def bench_kem() -> dict: + from libp2p.security.noise.pq.kem import XWingKem + + kem = XWingKem() + + # --- keygen --- + for _ in range(N_WARMUP): + kem.keygen() + samples: list[float] = [] + for _ in range(N_KEM): + t0 = time.perf_counter() + pk, sk = kem.keygen() + samples.append(time.perf_counter() - t0) + keygen_ms, keygen_ops = _stats(samples) + + # --- encapsulate --- + pk, sk = kem.keygen() + for _ in range(N_WARMUP): + kem.encapsulate(pk) + samples = [] + for _ in range(N_KEM): + t0 = time.perf_counter() + ct, ss = kem.encapsulate(pk) + samples.append(time.perf_counter() - t0) + encap_ms, encap_ops = _stats(samples) + + # --- decapsulate --- + ct, _ = kem.encapsulate(pk) + for _ in range(N_WARMUP): + kem.decapsulate(ct, sk) + samples = [] + for _ in range(N_KEM): + t0 = time.perf_counter() + kem.decapsulate(ct, sk) + samples.append(time.perf_counter() - t0) + decap_ms, decap_ops = _stats(samples) + + return { + "keygen_ms": keygen_ms, + "keygen_ops": keygen_ops, + "encap_ms": encap_ms, + "encap_ops": encap_ops, + "decap_ms": decap_ms, + "decap_ops": decap_ops, + } + + +# --------------------------------------------------------------------------- +# 2. Handshake latency benchmarks (async) +# --------------------------------------------------------------------------- + + +async def _one_pq_handshake(t_l, peer_r, t_r) -> float: + conn_l, conn_r = _make_conn_pair() + t0 = time.perf_counter() + async with trio.open_nursery() as n: + n.start_soon(t_l.secure_outbound, conn_l, peer_r) + n.start_soon(t_r.secure_inbound, conn_r) + return time.perf_counter() - t0 + + +async def _one_classical_handshake(t_l, peer_r, t_r) -> float: + conn_l, conn_r = _make_conn_pair() + t0 = time.perf_counter() + async with trio.open_nursery() as n: + n.start_soon(t_l.secure_outbound, conn_l, peer_r) + n.start_soon(t_r.secure_inbound, conn_r) + return time.perf_counter() - t0 + + +async def bench_handshakes() -> dict: + # --- classical XX --- + t_l, peer_r, t_r = _make_classical_pair() + for _ in range(N_WARMUP): + await _one_classical_handshake(t_l, peer_r, t_r) + samples_classical: list[float] = [] + for _ in range(N_HANDSHAKES): + t_l, peer_r, t_r = _make_classical_pair() # fresh keys each run + samples_classical.append(await _one_classical_handshake(t_l, peer_r, t_r)) + xx_ms, xx_ops = _stats(samples_classical) + + # --- XXhfs (PQ) --- + t_l, peer_r, t_r = _make_pq_pair() + for _ in range(N_WARMUP): + await _one_pq_handshake(t_l, peer_r, t_r) + samples_pq: list[float] = [] + for _ in range(N_HANDSHAKES): + t_l, peer_r, t_r = _make_pq_pair() # fresh keys each run + samples_pq.append(await _one_pq_handshake(t_l, peer_r, t_r)) + xxhfs_ms, xxhfs_ops = _stats(samples_pq) + + overhead = xxhfs_ms / xx_ms if xx_ms > 0 else float("inf") + + return { + "xx_ms": xx_ms, + "xx_ops": xx_ops, + "xxhfs_ms": xxhfs_ms, + "xxhfs_ops": xxhfs_ops, + "overhead_x": overhead, + } + + +# --------------------------------------------------------------------------- +# 3. Transport throughput (MB/s) after handshake completes +# --------------------------------------------------------------------------- + + +async def _throughput_one(session_out, session_in, payload: bytes) -> float: + t0 = time.perf_counter() + await session_out.write(payload) + await session_in.read(len(payload)) + return time.perf_counter() - t0 + + +async def _bench_throughput_one_size(make_pair_fn, size: int, n_rounds: int) -> float: + """Return throughput in MB/s for a single payload size.""" + payload = b"X" * size + + # One handshake to get the sessions + t_l, peer_r, t_r = make_pair_fn() + conn_l, conn_r = _make_conn_pair() + sessions: list = [None, None] + + async def do_out() -> None: + sessions[0] = await t_l.secure_outbound(conn_l, peer_r) + + async def do_in() -> None: + sessions[1] = await t_r.secure_inbound(conn_r) + + async with trio.open_nursery() as n: + n.start_soon(do_out) + n.start_soon(do_in) + + sess_out, sess_in = sessions + + # Warm-up + for _ in range(N_WARMUP): + await sess_out.write(payload) + await sess_in.read(size) + + # Timed rounds + elapsed: list[float] = [] + for _ in range(n_rounds): + t0 = time.perf_counter() + await sess_out.write(payload) + await sess_in.read(size) + elapsed.append(time.perf_counter() - t0) + + med_s = statistics.median(elapsed) + return (size / (1024 * 1024)) / med_s # MB/s + + +async def bench_throughput() -> dict: + # Noise spec caps single messages at 65535 bytes; 60 KB stays safely under + # the per-frame limit of both transports (both use 2-byte length prefixes). + sizes = [1024, 10 * 1024, 60 * 1024] + results: dict = {"classical": {}, "pq": {}} + + for size in sizes: + results["classical"][size] = await _bench_throughput_one_size( + _make_classical_pair, size, N_THROUGHPUT + ) + results["pq"][size] = await _bench_throughput_one_size( + _make_pq_pair, size, N_THROUGHPUT + ) + + return results + + +# --------------------------------------------------------------------------- +# Wire-size accounting (no runtime measurement needed — pure arithmetic) +# --------------------------------------------------------------------------- + + +def wire_sizes() -> dict: + from libp2p.security.noise.pq.kem import XWING_CT_SIZE, XWING_PK_SIZE + + x25519 = 32 + aead_tag = 16 + + # Classical XX (Noise spec, no libp2p payload for size accounting) + # Msg 1: e (32) + # Msg 2: e (32) + enc_s (48) + enc_payload (variable — use 0 here) + # Msg 3: enc_s (48) + enc_payload (variable) + classical_fixed = 32 + (32 + 48) + 48 # = 160 B fixed; payload adds ~32+ per side + + # XXhfs + # Msg A: e_pk (32) + e1_pk (1216) = 1248 + # Msg B: e (32) + enc_ct (1120+16=1136) + enc_s (48) + enc_payload + # Msg C: enc_s (48) + enc_payload + msg_a = x25519 + XWING_PK_SIZE # 32 + 1216 = 1248 + msg_b_fixed = x25519 + (XWING_CT_SIZE + aead_tag) + (x25519 + aead_tag) # 1216 + msg_c_fixed = x25519 + aead_tag # 48 + + return { + "classical_msg1": 32, + "classical_msg2_fixed": 32 + 48, + "classical_msg3_fixed": 48, + "xxhfs_msg_a": msg_a, + "xxhfs_msg_b_fixed": msg_b_fixed, + "xxhfs_msg_c_fixed": msg_c_fixed, + "xxhfs_total_fixed": msg_a + msg_b_fixed + msg_c_fixed, + "classical_total_fixed": classical_fixed, + } + + +# --------------------------------------------------------------------------- +# Main entry point + pretty-print results +# --------------------------------------------------------------------------- + + +def print_section(title: str) -> None: + print() + print("=" * 60) + print(f" {title}") + print("=" * 60) + + +async def run_all() -> dict: + print("Running benchmarks … (this may take ~30–60 seconds)") + print(f" KEM iterations: {N_KEM}") + print(f" Handshake iterations: {N_HANDSHAKES}") + print(f" Throughput rounds: {N_THROUGHPUT}") + + kem = bench_kem() + handshakes = await bench_handshakes() + throughput = await bench_throughput() + wires = wire_sizes() + + # ---- print ---- + + print_section("X-Wing KEM micro-benchmarks") + print(f" keygen : {_fmt(kem['keygen_ms'], kem['keygen_ops'])}") + print(f" encapsulate: {_fmt(kem['encap_ms'], kem['encap_ops'])}") + print(f" decapsulate: {_fmt(kem['decap_ms'], kem['decap_ops'])}") + kem_round_trip = kem["encap_ms"] + kem["decap_ms"] + print(f" round-trip (encap+decap): {kem_round_trip:.2f} ms") + + print_section("Handshake latency (in-memory, round-trip)") + print(f" Classical XX : {_fmt(handshakes['xx_ms'], handshakes['xx_ops'])}") + print(f" XXhfs (PQ) : {_fmt(handshakes['xxhfs_ms'], handshakes['xxhfs_ops'])}") + print(f" Overhead : {handshakes['overhead_x']:.1f}x") + + print_section("Transport throughput (after handshake)") + print(f" {'Size':>8} {'Classical':>12} {'XXhfs (PQ)':>12} {'Ratio':>8}") + for size in [1024, 10 * 1024, 60 * 1024]: + label = f"{size // 1024} KB" + c = throughput["classical"][size] + p = throughput["pq"][size] + ratio = p / c if c > 0 else float("inf") + print(f" {label:>8} {c:>10.1f} MB/s {p:>10.1f} MB/s {ratio:>7.2f}x") + + print_section("Wire sizes (handshake bytes, excluding libp2p payload)") + print(f" Classical XX total (fixed): {wires['classical_total_fixed']} B") + print(f" Msg 1: {wires['classical_msg1']} B") + print(f" Msg 2: {wires['classical_msg2_fixed']} B (fixed, + payload)") + print(f" Msg 3: {wires['classical_msg3_fixed']} B (fixed, + payload)") + print() + print(f" XXhfs total (fixed): {wires['xxhfs_total_fixed']} B") + print(f" Msg A: {wires['xxhfs_msg_a']} B (e + e1_pk)") + print(f" Msg B: {wires['xxhfs_msg_b_fixed']} B (e + enc_ct + enc_s)") + print(f" Msg C: {wires['xxhfs_msg_c_fixed']} B (enc_s, fixed)") + overhead_b = wires["xxhfs_total_fixed"] - wires["classical_total_fixed"] + overhead_x = overhead_b / wires["classical_total_fixed"] + print( + f" Wire overhead vs classical: +{overhead_b} B ({overhead_x:.0f}x)" + ) + + print() + return { + "kem": kem, + "handshakes": handshakes, + "throughput": throughput, + "wire_sizes": wires, + } + + +def save_results(results: dict) -> None: + """Save a markdown results file, mirroring js-libp2p-noise/benchmarks/results.md.""" + kem = results["kem"] + hs = results["handshakes"] + tp = results["throughput"] + ws = results["wire_sizes"] + + lines = [ + "# py-libp2p Noise PQ Benchmark Results", + "", + "> Generated by `benchmarks/bench_noise_pq.py` ", + f"> Iterations: KEM={N_KEM}, handshake={N_HANDSHAKES}," + f" throughput={N_THROUGHPUT}", + "", + "## X-Wing KEM Micro-benchmarks", + "", + "| Operation | Median (ms) | Throughput (ops/s) |", + "|-----------|-------------|--------------------|", + f"| keygen | {kem['keygen_ms']:.2f} | {kem['keygen_ops']:.0f} |", + f"| encapsulate | {kem['encap_ms']:.2f} | {kem['encap_ops']:.0f} |", + f"| decapsulate | {kem['decap_ms']:.2f} | {kem['decap_ops']:.0f} |", + f"| round-trip (encap+decap) | {kem['encap_ms'] + kem['decap_ms']:.2f} | — |", + "", + "## Handshake Latency (in-memory, round-trip)", + "", + "| Pattern | Median (ms) | Throughput (ops/s) |", + "|---------|-------------|--------------------|", + f"| Classical Noise XX | {hs['xx_ms']:.2f} | {hs['xx_ops']:.0f} |", + f"| Noise XXhfs (X-Wing) | {hs['xxhfs_ms']:.2f} | {hs['xxhfs_ops']:.0f} |", + f"| Overhead | {hs['overhead_x']:.1f}x | — |", + "", + "## Transport Throughput (post-handshake)", + "", + "| Payload | Classical (MB/s) | XXhfs (MB/s) | Ratio |", + "|---------|-----------------|--------------|-------|", + ] + for size in [1024, 10 * 1024, 60 * 1024]: + label = f"{size // 1024} KB" + c = tp["classical"][size] + p = tp["pq"][size] + ratio = p / c if c > 0 else float("inf") + lines.append(f"| {label} | {c:.1f} | {p:.1f} | {ratio:.2f}x |") + + overhead_b = ws["xxhfs_total_fixed"] - ws["classical_total_fixed"] + lines += [ + "", + "## Wire Sizes (fixed handshake bytes, excluding libp2p payload)", + "", + "| Pattern | Msg 1 | Msg 2 | Msg 3 | Total |", + "|---------|-------|-------|-------|-------|", + ( + f"| Classical XX | {ws['classical_msg1']} B" + f" | {ws['classical_msg2_fixed']} B + payload" + f" | {ws['classical_msg3_fixed']} B + payload" + f" | {ws['classical_total_fixed']} B |" + ), + ( + f"| XXhfs | {ws['xxhfs_msg_a']} B" + f" | {ws['xxhfs_msg_b_fixed']} B + payload" + f" | {ws['xxhfs_msg_c_fixed']} B + payload" + f" | {ws['xxhfs_total_fixed']} B |" + ), + "", + ( + f"KEM ciphertext overhead vs classical:" + f" +{overhead_b} B" + f" (+{overhead_b / ws['classical_total_fixed']:.0f}x fixed bytes)" + ), + "", + "## Comparison with js-libp2p-noise", + "", + "| Metric | js-libp2p (XXhfs) | py-libp2p (XXhfs) |", + "|--------|-------------------|-------------------|", + f"| Handshake latency | ~44 ms | {hs['xxhfs_ms']:.1f} ms |", + f"| vs classical overhead | ~5x | {hs['overhead_x']:.1f}x |", + f"| KEM round-trip | ~20 ms | {kem['encap_ms'] + kem['decap_ms']:.1f} ms |", + "", + ] + + out_path = Path(__file__).parent / "results.md" + out_path.write_text("\n".join(lines), encoding="utf-8") + print(f"\nResults saved to {out_path}") + + +if __name__ == "__main__": + results = trio.run(run_all) + save_results(results) diff --git a/benchmarks/results.md b/benchmarks/results.md new file mode 100644 index 000000000..4586d7c6b --- /dev/null +++ b/benchmarks/results.md @@ -0,0 +1,46 @@ +# py-libp2p Noise PQ Benchmark Results + +> Generated by `benchmarks/bench_noise_pq.py` +> Iterations: KEM=200, handshake=50, throughput=200 + +## X-Wing KEM Micro-benchmarks + +| Operation | Median (ms) | Throughput (ops/s) | +|-----------|-------------|--------------------| +| keygen | 10.54 | 95 | +| encapsulate | 12.27 | 82 | +| decapsulate | 15.46 | 65 | +| round-trip (encap+decap) | 27.73 | — | + +## Handshake Latency (in-memory, round-trip) + +| Pattern | Median (ms) | Throughput (ops/s) | +|---------|-------------|--------------------| +| Classical Noise XX | 4.00 | 250 | +| Noise XXhfs (X-Wing) | 40.74 | 25 | +| Overhead | 10.2x | — | + +## Transport Throughput (post-handshake) + +| Payload | Classical (MB/s) | XXhfs (MB/s) | Ratio | +|---------|-----------------|--------------|-------| +| 1 KB | 12.1 | 8.9 | 0.73x | +| 10 KB | 65.9 | 109.3 | 1.66x | +| 60 KB | 215.0 | 311.6 | 1.45x | + +## Wire Sizes (fixed handshake bytes, excluding libp2p payload) + +| Pattern | Msg 1 | Msg 2 | Msg 3 | Total | +|---------|-------|-------|-------|-------| +| Classical XX | 32 B | 80 B + payload | 48 B + payload | 160 B | +| XXhfs | 1248 B | 1216 B + payload | 48 B + payload | 2512 B | + +KEM ciphertext overhead vs classical: +2352 B (+15x fixed bytes) + +## Comparison with js-libp2p-noise + +| Metric | js-libp2p (XXhfs) | py-libp2p (XXhfs) | +|--------|-------------------|-------------------| +| Handshake latency | ~44 ms | 40.7 ms | +| vs classical overhead | ~5x | 10.2x | +| KEM round-trip | ~20 ms | 27.7 ms | diff --git a/libp2p/security/noise/pq/__init__.py b/libp2p/security/noise/pq/__init__.py new file mode 100644 index 000000000..dfdff0df4 --- /dev/null +++ b/libp2p/security/noise/pq/__init__.py @@ -0,0 +1,12 @@ +"""Post-quantum Noise security for py-libp2p. + +Public API:: + + from libp2p.security.noise.pq import TransportPQ, PROTOCOL_ID + + security_options = {PROTOCOL_ID: TransportPQ(libp2p_keypair, noise_privkey)} +""" + +from .transport_pq import PROTOCOL_ID, TransportPQ + +__all__ = ["PROTOCOL_ID", "TransportPQ"] diff --git a/libp2p/security/noise/pq/kem.py b/libp2p/security/noise/pq/kem.py new file mode 100644 index 000000000..69de2579f --- /dev/null +++ b/libp2p/security/noise/pq/kem.py @@ -0,0 +1,200 @@ +""" +X-Wing KEM for the Noise XXhfs handshake. + +X-Wing is a hybrid KEM combining ML-KEM-768 and X25519: + - Public key: ML-KEM-768 encapsulation key (1184 B) || X25519 public key (32 B) + - Secret key: ML-KEM-768 decapsulation key (2400 B) || X25519 private key (32 B) + - Ciphertext: ML-KEM-768 ciphertext (1088 B) || X25519 ephemeral public key (32 B) + - Shared secret: SHA3-256(ss_mlkem || ss_x25519 || ct_x25519 || pk_x25519 || label) + +Reference: draft-connolly-cfrg-xwing-kem +""" + +import hashlib +from typing import Protocol, runtime_checkable + +from kyber_py.ml_kem import ML_KEM_768 +from nacl.bindings import crypto_scalarmult, crypto_scalarmult_base +import nacl.utils + +# X-Wing domain separation label: ASCII bytes for "\.//^\" +_XWING_LABEL = bytes([0x5C, 0x2E, 0x2F, 0x2F, 0x5E, 0x5C]) + +# Key and ciphertext size constants +_ML_KEM_PK_SIZE = 1184 +_ML_KEM_SK_SIZE = 2400 +_ML_KEM_CT_SIZE = 1088 +_X25519_KEY_SIZE = 32 + +XWING_PK_SIZE = _ML_KEM_PK_SIZE + _X25519_KEY_SIZE # 1216 +XWING_SK_SIZE = _ML_KEM_SK_SIZE + _X25519_KEY_SIZE # 2432 +XWING_CT_SIZE = _ML_KEM_CT_SIZE + _X25519_KEY_SIZE # 1120 + + +@runtime_checkable +class IKem(Protocol): + """Backend-agnostic KEM interface for the XXhfs handshake.""" + + def keygen(self) -> tuple[bytes, bytes]: + """ + Generate a KEM key pair. + + Returns: + (public_key, secret_key) as raw bytes. + + """ + ... + + def encapsulate(self, pk: bytes) -> tuple[bytes, bytes]: + """ + Encapsulate a shared secret to a public key. + + Args: + pk: Recipient's public key. + + Returns: + (ciphertext, shared_secret) as raw bytes. + + """ + ... + + def decapsulate(self, ct: bytes, sk: bytes) -> bytes: + """ + Decapsulate a shared secret from a ciphertext. + + Args: + ct: Ciphertext from the encapsulator. + sk: Local secret key. + + Returns: + Shared secret as 32 raw bytes. + + """ + ... + + +def _xwing_combine( + ss_mlkem: bytes, + ss_x25519: bytes, + ct_x25519: bytes, + pk_x25519: bytes, +) -> bytes: + r""" + Combine ML-KEM and X25519 shared secrets per @noble/post-quantum 0.6.0. + + SHA3-256(ss_mlkem || ss_x25519 || ct_x25519 || pk_x25519 || label) + where label = b'\\.//' + b'^\\' (6 bytes, domain separation). + + Note: label is appended LAST to match @noble/post-quantum 0.6.0 combiner: + sha3_256(concatBytes(ss[0], ss[1], ct[1], pk[1], asciiToBytes('\\.//^\\'))) + """ + return hashlib.sha3_256( + ss_mlkem + ss_x25519 + ct_x25519 + pk_x25519 + _XWING_LABEL + ).digest() + + +class XWingKem: + """ + X-Wing hybrid KEM using ML-KEM-768 and X25519. + + Uses kyber-py as the ML-KEM-768 backend and PyNaCl for X25519. + Implements the IKem protocol. + """ + + def keygen(self) -> tuple[bytes, bytes]: + """ + Generate an X-Wing key pair. + + Returns: + (pk, sk) where: + pk = ml_kem_ek (1184 B) || x25519_pk (32 B) -- 1216 bytes total + sk = ml_kem_dk (2400 B) || x25519_sk (32 B) -- 2432 bytes total + + """ + ml_kem_pk, ml_kem_sk = ML_KEM_768.keygen() + + x25519_sk = nacl.utils.random(_X25519_KEY_SIZE) + x25519_pk = bytes(crypto_scalarmult_base(x25519_sk)) + + pk = ml_kem_pk + x25519_pk + sk = ml_kem_sk + x25519_sk + return pk, sk + + def encapsulate(self, pk: bytes) -> tuple[bytes, bytes]: + """ + Encapsulate a shared secret to an X-Wing public key. + + Generates a fresh X25519 ephemeral key pair each call. + + Args: + pk: X-Wing public key (1216 bytes). + + Returns: + (ct, ss) where: + ct = ml_kem_ct (1088 B) || x25519_eph_pk (32 B) -- 1120 bytes total + ss = 32-byte combined shared secret + + Raises: + ValueError: If pk is not 1216 bytes. + + """ + if len(pk) != XWING_PK_SIZE: + raise ValueError( + f"X-Wing public key must be {XWING_PK_SIZE} bytes, got {len(pk)}" + ) + + ml_kem_pk = pk[:_ML_KEM_PK_SIZE] + x25519_pk_r = pk[_ML_KEM_PK_SIZE:] + + # ML-KEM-768 encapsulation + ss_mlkem, ml_kem_ct = ML_KEM_768.encaps(ml_kem_pk) + + # X25519 ephemeral key exchange + x25519_eph_sk = nacl.utils.random(_X25519_KEY_SIZE) + x25519_eph_pk = bytes(crypto_scalarmult_base(x25519_eph_sk)) + ss_x25519 = bytes(crypto_scalarmult(x25519_eph_sk, x25519_pk_r)) + + ss = _xwing_combine(ss_mlkem, ss_x25519, x25519_eph_pk, x25519_pk_r) + ct = ml_kem_ct + x25519_eph_pk + return ct, ss + + def decapsulate(self, ct: bytes, sk: bytes) -> bytes: + """ + Decapsulate a shared secret from an X-Wing ciphertext. + + Args: + ct: X-Wing ciphertext (1120 bytes). + sk: X-Wing secret key (2432 bytes). + + Returns: + 32-byte combined shared secret. + + Raises: + ValueError: If ct or sk have unexpected lengths. + + """ + if len(ct) != XWING_CT_SIZE: + raise ValueError( + f"X-Wing ciphertext must be {XWING_CT_SIZE} bytes, got {len(ct)}" + ) + if len(sk) != XWING_SK_SIZE: + raise ValueError( + f"X-Wing secret key must be {XWING_SK_SIZE} bytes, got {len(sk)}" + ) + + ml_kem_sk = sk[:_ML_KEM_SK_SIZE] + x25519_sk_r = sk[_ML_KEM_SK_SIZE:] + + ml_kem_ct = ct[:_ML_KEM_CT_SIZE] + x25519_eph_pk = ct[_ML_KEM_CT_SIZE:] + + # ML-KEM-768 decapsulation + ss_mlkem = ML_KEM_768.decaps(ml_kem_sk, ml_kem_ct) + + # X25519 DH using our static private key and the ephemeral public key + ss_x25519 = bytes(crypto_scalarmult(x25519_sk_r, x25519_eph_pk)) + + # Reconstruct our X25519 public key for the combiner + x25519_pk_r = bytes(crypto_scalarmult_base(x25519_sk_r)) + + return _xwing_combine(ss_mlkem, ss_x25519, x25519_eph_pk, x25519_pk_r) diff --git a/libp2p/security/noise/pq/noise_state.py b/libp2p/security/noise/pq/noise_state.py new file mode 100644 index 000000000..4b707aaed --- /dev/null +++ b/libp2p/security/noise/pq/noise_state.py @@ -0,0 +1,152 @@ +""" +Noise symmetric state for the XXhfs handshake. + +Implements CipherState and SymmetricState as defined in the Noise protocol spec +(https://noiseprotocol.org/noise.html) with ChaCha20-Poly1305 and SHA-256, +extended for the HFS (Hybrid Forward Secrecy) pattern. + +Protocol name: Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256 +""" + +import hashlib +import hmac +import struct + +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + +# Full protocol name -- cryptographically bound to every derived key +PROTOCOL_NAME = b"Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256" + +_HASH_LEN = 32 # SHA-256 output length +_KEY_LEN = 32 # ChaCha20 key length +_TAG_LEN = 16 # Poly1305 tag length + + +def _hmac_sha256(key: bytes, data: bytes) -> bytes: + return hmac.new(key, data, hashlib.sha256).digest() + + +def _hkdf(chaining_key: bytes, input_key_material: bytes, n: int) -> tuple[bytes, ...]: + """ + Noise HKDF producing n outputs (2 or 3), each 32 bytes. + + temp_k = HMAC-SHA256(ck, ikm) + out_i = HMAC-SHA256(temp_k, out_{i-1} || byte(i)) + """ + temp_k = _hmac_sha256(chaining_key, input_key_material) + out1 = _hmac_sha256(temp_k, b"\x01") + out2 = _hmac_sha256(temp_k, out1 + b"\x02") + if n == 2: + return out1, out2 + out3 = _hmac_sha256(temp_k, out2 + b"\x03") + return out1, out2, out3 + + +def _nonce_bytes(n: int) -> bytes: + """Encode Noise nonce: 4 zero bytes + 8-byte little-endian counter = 12 bytes.""" + return b"\x00" * 4 + struct.pack(" None: + if len(key) != _KEY_LEN: + raise ValueError(f"Key must be {_KEY_LEN} bytes, got {len(key)}") + self._cipher = ChaCha20Poly1305(key) + self.n = 0 + + def encrypt_with_ad(self, ad: bytes, plaintext: bytes) -> bytes: + """Encrypt plaintext with associated data. Increments nonce counter.""" + ct = self._cipher.encrypt(_nonce_bytes(self.n), plaintext, ad) + self.n += 1 + return ct + + def decrypt_with_ad(self, ad: bytes, ciphertext: bytes) -> bytes: + """Decrypt ciphertext with associated data. Increments nonce counter.""" + plaintext = self._cipher.decrypt(_nonce_bytes(self.n), ciphertext, ad) + self.n += 1 + return plaintext + + +class SymmetricState: + """ + Noise SymmetricState for Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256. + + Maintains the chaining key (ck) and handshake hash (h) across all + message tokens. Both are initialised to SHA-256(PROTOCOL_NAME). + """ + + ck: bytes # chaining key + h: bytes # handshake hash (running transcript) + _cs: CipherState | None + + def __init__(self) -> None: + # Protocol name > 32 bytes so h = HASH(protocol_name) + digest = hashlib.sha256(PROTOCOL_NAME).digest() + self.ck = digest + self.h = digest + self._cs = None + + def mix_hash(self, data: bytes) -> None: + """h = SHA-256(h || data)""" + self.h = hashlib.sha256(self.h + data).digest() + + def mix_key(self, input_key_material: bytes) -> None: + """Update chaining key and cipher key via HKDF.""" + self.ck, temp_k = _hkdf(self.ck, input_key_material, 2) + self._cs = CipherState(temp_k) + + def mix_key_and_hash(self, input_key_material: bytes) -> None: + """ + 3-output HKDF for HFS tokens (used with KEM shared secret). + + ck, temp_h, temp_k = HKDF(ck, ss, 3) + MixHash(temp_h) + """ + self.ck, temp_h, temp_k = _hkdf(self.ck, input_key_material, 3) + self.mix_hash(temp_h) + self._cs = CipherState(temp_k) + + def encrypt_and_hash(self, plaintext: bytes) -> bytes: + """ + AEAD-encrypt plaintext, then mix the ciphertext into h. + + Returns the ciphertext (plaintext + 16-byte tag). + """ + if self._cs is None: + # No key yet -- send in the clear (used for early tokens) + self.mix_hash(plaintext) + return plaintext + ct = self._cs.encrypt_with_ad(self.h, plaintext) + self.mix_hash(ct) + return ct + + def decrypt_and_hash(self, ciphertext: bytes) -> bytes: + """ + AEAD-decrypt ciphertext, then mix the ciphertext into h. + + Returns the plaintext. + """ + if self._cs is None: + self.mix_hash(ciphertext) + return ciphertext + plaintext = self._cs.decrypt_with_ad(self.h, ciphertext) + self.mix_hash(ciphertext) + return plaintext + + def split(self) -> tuple[CipherState, CipherState]: + """ + Derive two transport CipherStates at the end of the handshake. + + Returns (initiator_cs, responder_cs). + """ + temp_k1, temp_k2 = _hkdf(self.ck, b"", 2) + return CipherState(temp_k1), CipherState(temp_k2) diff --git a/libp2p/security/noise/pq/patterns_pq.py b/libp2p/security/noise/pq/patterns_pq.py new file mode 100644 index 000000000..af9d7464f --- /dev/null +++ b/libp2p/security/noise/pq/patterns_pq.py @@ -0,0 +1,373 @@ +""" +XXhfs Noise handshake pattern for post-quantum security. + +Implements Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256: a three-message +handshake that adds X-Wing KEM tokens to the classical Noise XX pattern +for hybrid post-quantum forward secrecy. + +Message layout (wire bytes, inside 2-byte length-prefixed frames): + A (initiator -> responder): e_pk(32) || e1_pk(1216) = 1248 B + B (responder -> initiator): e_pk(32) || enc_ct(1136) + || enc_s(48) || enc_payload + C (initiator -> responder): enc_s(48) || enc_payload + +Token sequence: + A: e, e1 (no symmetric key yet, plain mix_hash) + B: e, ee, ekem1, s, es (ekem1 = encrypt(ct) then mix_key(ss_kem)) + C: s, se + +Transport keys after split(): + Initiator: encrypt=cs1, decrypt=cs2 + Responder: encrypt=cs2, decrypt=cs1 +""" + +import logging +from typing import cast + +import nacl.utils +from nacl.bindings import ( + crypto_scalarmult, + crypto_scalarmult_base, +) + +from libp2p.abc import ( + IRawConnection, + ISecureConn, +) +from libp2p.crypto.keys import PrivateKey +from libp2p.crypto.x25519 import X25519PublicKey +from libp2p.io.abc import ( + EncryptedMsgReadWriter, + ReadWriteCloser, +) +from libp2p.peer.id import ID +from libp2p.security.secure_session import SecureSession + +from ..exceptions import ( + InvalidSignature, + PeerIDMismatchesPubkey, +) +from ..io import NoisePacketReadWriter +from ..messages import ( + NoiseHandshakePayload, + make_handshake_payload_sig, + verify_handshake_payload_sig, +) +from .kem import ( + IKem, + XWING_CT_SIZE, + XWING_PK_SIZE, + XWingKem, +) +from .noise_state import ( + CipherState, + SymmetricState, +) + +logger = logging.getLogger(__name__) + +# Size constants +_X25519_SIZE = 32 +_AEAD_TAG = 16 +_KEM_CT_ENC_SIZE = XWING_CT_SIZE + _AEAD_TAG # 1120 + 16 = 1136 +_S_ENC_SIZE = _X25519_SIZE + _AEAD_TAG # 32 + 16 = 48 + + +class PQTransportReadWriter(EncryptedMsgReadWriter): + """Post-handshake transport that encrypts/decrypts with PQC CipherStates. + + Each direction uses its own CipherState with an independent nonce counter, + matching the Noise spec for transport-phase messages. + """ + + def __init__( + self, + conn: IRawConnection, + send_cs: CipherState, + recv_cs: CipherState, + ) -> None: + super().__init__(conn) # sets self.conn for address delegation + self._packet_rw = NoisePacketReadWriter(cast(ReadWriteCloser, conn)) + self._send_cs = send_cs + self._recv_cs = recv_cs + + def encrypt(self, data: bytes) -> bytes: + return self._send_cs.encrypt_with_ad(b"", data) + + def decrypt(self, data: bytes) -> bytes: + return self._recv_cs.decrypt_with_ad(b"", data) + + async def write_msg(self, msg: bytes) -> None: + await self._packet_rw.write_msg(self.encrypt(msg)) + + async def read_msg(self) -> bytes: + return self.decrypt(await self._packet_rw.read_msg()) + + async def close(self) -> None: + await self._packet_rw.close() + + +class PatternXXhfs: + """Noise XXhfs handshake pattern with X-Wing hybrid KEM. + + Provides mutual authentication (libp2p identity signatures) and + hybrid post-quantum forward secrecy via the X-Wing KEM (ML-KEM-768 + + X25519) alongside the classical X25519 DH exchange. + """ + + PROTOCOL_NAME = b"Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256" + + def __init__( + self, + local_peer: ID, + libp2p_privkey: PrivateKey, + noise_static_key: PrivateKey, + kem: IKem | None = None, + early_data: bytes | None = None, + ) -> None: + self.local_peer = local_peer + self.libp2p_privkey = libp2p_privkey + self.noise_static_key = noise_static_key + self.kem: IKem = kem if kem is not None else XWingKem() + self.early_data = early_data + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _static_pk_bytes(self) -> bytes: + """Return the raw 32-byte X25519 static public key.""" + return self.noise_static_key.get_public_key().to_bytes() + + def _static_sk_bytes(self) -> bytes: + """Return the raw 32-byte X25519 static private key.""" + return self.noise_static_key.to_bytes() + + def _make_payload(self) -> bytes: + """Serialize a libp2p NoiseHandshakePayload (id_pubkey + id_sig).""" + static_pubkey = self.noise_static_key.get_public_key() + sig = make_handshake_payload_sig(self.libp2p_privkey, static_pubkey) + return NoiseHandshakePayload( + id_pubkey=self.libp2p_privkey.get_public_key(), + id_sig=sig, + ).serialize() + + # ------------------------------------------------------------------ + # Initiator (outbound) + # ------------------------------------------------------------------ + + async def handshake_outbound( + self, conn: IRawConnection, remote_peer: ID | None + ) -> ISecureConn: + """Run the initiator side of the XXhfs handshake. + + Args: + conn: Raw underlying connection. + remote_peer: Expected peer ID of the responder (verified against + the responder's libp2p identity signature). + Pass ``None`` to accept any peer identity (useful + for interop tests where the remote peer ID is + not known in advance). + + Returns: + SecureSession ready for post-handshake transport. + + Raises: + PeerIDMismatchesPubkey: If the responder's peer ID does not match + ``remote_peer`` (only raised when non-None). + InvalidSignature: If the responder's identity signature + is invalid. + + """ + ss = SymmetricState() + ss.mix_hash(b"") # MixHash(prologue=empty) — required by Noise spec even when empty + pkt = NoisePacketReadWriter(cast(ReadWriteCloser, conn)) + + # ---- Message A: e, e1 ---------------------------------------- + # e: ephemeral X25519 + e_sk = nacl.utils.random(_X25519_SIZE) + e_pk = bytes(crypto_scalarmult_base(e_sk)) + ss.mix_hash(e_pk) + logger.debug("handshake_outbound: msg A – generated ephemeral X25519") + + # e1: ephemeral X-Wing KEM keypair + e1_pk, e1_sk = self.kem.keygen() + ss.mix_hash(e1_pk) + logger.debug("handshake_outbound: msg A – generated X-Wing KEM keypair") + + # Empty payload (no cipher key yet; encrypt_and_hash = mix_hash + identity) + enc_payload_a = ss.encrypt_and_hash(b"") + await pkt.write_msg(e_pk + e1_pk + enc_payload_a) + logger.debug("handshake_outbound: msg A sent (%d B)", len(e_pk) + len(e1_pk)) + + # ---- Message B: e, ee, ekem1, s, es -------------------------- + msg_b = await pkt.read_msg() + offset = 0 + logger.debug("handshake_outbound: msg B received (%d B)", len(msg_b)) + + # e: responder's ephemeral X25519 public key + resp_e_pk = msg_b[offset : offset + _X25519_SIZE] + offset += _X25519_SIZE + ss.mix_hash(resp_e_pk) + + # ee: DH(e_init, e_resp) + dh_ee = bytes(crypto_scalarmult(e_sk, resp_e_pk)) + ss.mix_key(dh_ee) + + # ekem1: decrypt KEM ciphertext, then mix KEM shared secret + enc_ct = msg_b[offset : offset + _KEM_CT_ENC_SIZE] + offset += _KEM_CT_ENC_SIZE + ct = ss.decrypt_and_hash(enc_ct) # decrypt with ee-derived key + ss_kem = self.kem.decapsulate(ct, e1_sk) # recover KEM shared secret + ss.mix_key(ss_kem) + + # s: decrypt responder's static public key + enc_s = msg_b[offset : offset + _S_ENC_SIZE] + offset += _S_ENC_SIZE + resp_s_pk_bytes = ss.decrypt_and_hash(enc_s) + resp_s_pk = X25519PublicKey.from_bytes(resp_s_pk_bytes) + + # es: DH(e_init, s_resp) + dh_es = bytes(crypto_scalarmult(e_sk, resp_s_pk_bytes)) + ss.mix_key(dh_es) + + # Decrypt responder's handshake payload + resp_payload_bytes = ss.decrypt_and_hash(msg_b[offset:]) + resp_payload = NoiseHandshakePayload.deserialize(resp_payload_bytes) + + # Verify responder's libp2p identity signature + if not verify_handshake_payload_sig(resp_payload, resp_s_pk): + raise InvalidSignature + resp_peer_id = ID.from_pubkey(resp_payload.id_pubkey) + if remote_peer is not None and resp_peer_id != remote_peer: + raise PeerIDMismatchesPubkey( + f"peer ID mismatch: expected {remote_peer}, got {resp_peer_id}" + ) + + # ---- Message C: s, se ---------------------------------------- + # s: encrypt our static public key + enc_s_c = ss.encrypt_and_hash(self._static_pk_bytes()) + + # se: DH(s_init, e_resp) + dh_se = bytes(crypto_scalarmult(self._static_sk_bytes(), resp_e_pk)) + ss.mix_key(dh_se) + + # Encrypt our handshake payload + enc_payload_c = ss.encrypt_and_hash(self._make_payload()) + await pkt.write_msg(enc_s_c + enc_payload_c) + logger.debug("handshake_outbound: msg C sent") + + # ---- Split and return ---------------------------------------- + cs1, cs2 = ss.split() + transport = PQTransportReadWriter(conn, send_cs=cs1, recv_cs=cs2) + return SecureSession( + local_peer=self.local_peer, + local_private_key=self.libp2p_privkey, + remote_peer=resp_peer_id, + remote_permanent_pubkey=resp_s_pk, + is_initiator=True, + conn=transport, + ) + + # ------------------------------------------------------------------ + # Responder (inbound) + # ------------------------------------------------------------------ + + async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: + """Run the responder side of the XXhfs handshake. + + Args: + conn: Raw underlying connection. + + Returns: + SecureSession ready for post-handshake transport. + + Raises: + InvalidSignature: If the initiator's identity signature is invalid. + + """ + ss = SymmetricState() + ss.mix_hash(b"") # MixHash(prologue=empty) — required by Noise spec even when empty + pkt = NoisePacketReadWriter(cast(ReadWriteCloser, conn)) + + # ---- Message A: receive e, e1 -------------------------------- + msg_a = await pkt.read_msg() + logger.debug("handshake_inbound: msg A received (%d B)", len(msg_a)) + offset = 0 + + # e: initiator's ephemeral X25519 public key + init_e_pk = msg_a[offset : offset + _X25519_SIZE] + offset += _X25519_SIZE + ss.mix_hash(init_e_pk) + + # e1: initiator's X-Wing KEM public key + init_e1_pk = msg_a[offset : offset + XWING_PK_SIZE] + offset += XWING_PK_SIZE + ss.mix_hash(init_e1_pk) + + # Empty payload (no cipher key yet) + ss.decrypt_and_hash(msg_a[offset:]) # mix_hash(b"") + + # ---- Message B: e, ee, ekem1, s, es -------------------------- + # e: generate our ephemeral X25519 + e_sk = nacl.utils.random(_X25519_SIZE) + e_pk = bytes(crypto_scalarmult_base(e_sk)) + ss.mix_hash(e_pk) + + # ee: DH(e_resp, e_init) + dh_ee = bytes(crypto_scalarmult(e_sk, init_e_pk)) + ss.mix_key(dh_ee) + + # ekem1: encapsulate to initiator's e1, encrypt ct, then mix ss_kem + ct, ss_kem_bytes = self.kem.encapsulate(init_e1_pk) + enc_ct = ss.encrypt_and_hash(ct) # encrypt with ee-derived key + ss.mix_key(ss_kem_bytes) # now ss_kem strengthens key material + + # s: encrypt our static public key + enc_s = ss.encrypt_and_hash(self._static_pk_bytes()) + + # es: DH(s_resp, e_init) + dh_es = bytes(crypto_scalarmult(self._static_sk_bytes(), init_e_pk)) + ss.mix_key(dh_es) + + # Encrypt our handshake payload + enc_payload_b = ss.encrypt_and_hash(self._make_payload()) + + await pkt.write_msg(e_pk + enc_ct + enc_s + enc_payload_b) + logger.debug("handshake_inbound: msg B sent") + + # ---- Message C: receive s, se -------------------------------- + msg_c = await pkt.read_msg() + logger.debug("handshake_inbound: msg C received (%d B)", len(msg_c)) + offset = 0 + + # s: decrypt initiator's static public key + enc_s_c = msg_c[offset : offset + _S_ENC_SIZE] + offset += _S_ENC_SIZE + init_s_pk_bytes = ss.decrypt_and_hash(enc_s_c) + + # se: DH(e_resp, s_init) + dh_se = bytes(crypto_scalarmult(e_sk, init_s_pk_bytes)) + ss.mix_key(dh_se) + + # Decrypt initiator's handshake payload + init_payload_bytes = ss.decrypt_and_hash(msg_c[offset:]) + init_payload = NoiseHandshakePayload.deserialize(init_payload_bytes) + + # Verify initiator's libp2p identity signature + init_s_pk = X25519PublicKey.from_bytes(init_s_pk_bytes) + if not verify_handshake_payload_sig(init_payload, init_s_pk): + raise InvalidSignature + init_peer_id = ID.from_pubkey(init_payload.id_pubkey) + + # ---- Split and return ---------------------------------------- + cs1, cs2 = ss.split() + transport = PQTransportReadWriter(conn, send_cs=cs2, recv_cs=cs1) + return SecureSession( + local_peer=self.local_peer, + local_private_key=self.libp2p_privkey, + remote_peer=init_peer_id, + remote_permanent_pubkey=init_s_pk, + is_initiator=False, + conn=transport, + ) diff --git a/libp2p/security/noise/pq/transport_pq.py b/libp2p/security/noise/pq/transport_pq.py new file mode 100644 index 000000000..473c10e53 --- /dev/null +++ b/libp2p/security/noise/pq/transport_pq.py @@ -0,0 +1,59 @@ +""" +Post-quantum Noise transport for py-libp2p. + +Wraps PatternXXhfs as an ISecureTransport so it integrates with the +standard py-libp2p security negotiation stack. + +Protocol ID: /noise-pq/1.0.0 +""" + +from libp2p.abc import ( + IRawConnection, + ISecureConn, + ISecureTransport, +) +from libp2p.crypto.keys import ( + KeyPair, + PrivateKey, +) +from libp2p.custom_types import TProtocol +from libp2p.peer.id import ID + +from .kem import XWingKem +from .patterns_pq import PatternXXhfs + +PROTOCOL_ID = TProtocol("/noise-pq/1.0.0") + + +class TransportPQ(ISecureTransport): + """ISecureTransport backed by the Noise XXhfs + X-Wing handshake. + + Drop-in replacement for the classical Noise ``Transport``; pass it + as a security option to ``BasicHost`` under the key ``PROTOCOL_ID``. + """ + + def __init__( + self, + libp2p_keypair: KeyPair, + noise_privkey: PrivateKey, + ) -> None: + self.libp2p_privkey = libp2p_keypair.private_key + self.noise_privkey = noise_privkey + self.local_peer = ID.from_pubkey(libp2p_keypair.public_key) + + def get_pattern(self) -> PatternXXhfs: + """Return a fresh PatternXXhfs for a single handshake.""" + return PatternXXhfs( + local_peer=self.local_peer, + libp2p_privkey=self.libp2p_privkey, + noise_static_key=self.noise_privkey, + kem=XWingKem(), + ) + + async def secure_inbound(self, conn: IRawConnection) -> ISecureConn: + """Upgrade an inbound raw connection to a PQC-secured session.""" + return await self.get_pattern().handshake_inbound(conn) + + async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn: + """Upgrade an outbound raw connection to a PQC-secured session.""" + return await self.get_pattern().handshake_outbound(conn, peer_id) diff --git a/scripts/interop_dial.py b/scripts/interop_dial.py new file mode 100644 index 000000000..8e396a844 --- /dev/null +++ b/scripts/interop_dial.py @@ -0,0 +1,137 @@ +""" +Phase 5 live interop: Python dialer connecting to the JS NoiseHFS listener. + +Usage: + # Terminal 1 — start the JS listener: + node js-libp2p-noise/scripts/node-listener.mjs + + # Terminal 2 — run this dialer: + cd py-libp2p + python scripts/interop_dial.py + +Verifies that Python (py-libp2p) and JavaScript (js-libp2p-noise) can +complete a Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256 handshake over a +real TCP connection and exchange encrypted messages. +""" + +import asyncio +import logging +import sys + +import anyio +import anyio.abc +from libp2p.crypto.ed25519 import create_new_key_pair as ed25519_keypair +from libp2p.crypto.x25519 import create_new_key_pair as x25519_keypair +from libp2p.peer.id import ID +from libp2p.security.noise.pq.patterns_pq import PatternXXhfs +from libp2p.security.noise.pq.kem import XWingKem + +HOST = "127.0.0.1" +PORT = 8000 + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s %(message)s", + stream=sys.stdout, +) +logger = logging.getLogger("interop_dial") + + +class RawTCPConn: + """Minimal IRawConnection wrapping an anyio SocketStream. + + Adapts anyio's ByteStream to the read/write/close interface that + NoisePacketReadWriter (and thus PatternXXhfs) expects. + """ + + is_initiator: bool = True + + def __init__(self, stream: anyio.abc.ByteStream) -> None: + self._stream = stream + + async def read(self, n: int | None = None) -> bytes: + if n is None: + return await self._stream.receive(65536) + return await self._stream.receive(n) + + async def write(self, data: bytes) -> None: + await self._stream.send(data) + + async def close(self) -> None: + await self._stream.aclose() + + def get_transport_addresses(self) -> list: + return [] + + +async def main() -> None: + # ── Key material ──────────────────────────────────────────────────────────── + # libp2p identity (Ed25519) — used in the Noise handshake payload signature + libp2p_kp = ed25519_keypair() + local_peer = ID.from_pubkey(libp2p_kp.public_key) + + # Noise static key (X25519) — used in the XX handshake s/se tokens + noise_kp = x25519_keypair() + noise_static = noise_kp.private_key + + logger.info("Local peer ID: %s", local_peer) + + # ── Connect ───────────────────────────────────────────────────────────────── + logger.info("Connecting to JS listener at %s:%d", HOST, PORT) + async with await anyio.connect_tcp(HOST, PORT) as stream: + conn = RawTCPConn(stream) + logger.info("TCP connection established") + + # ── Handshake ─────────────────────────────────────────────────────────── + # We don't know the JS peer ID in advance (it's freshly generated each + # run), so we pass a dummy peer ID and rely on signature verification. + # For a production scenario you'd pass the actual expected peer ID here. + pattern = PatternXXhfs( + local_peer=local_peer, + libp2p_privkey=libp2p_kp.private_key, + noise_static_key=noise_static, + kem=XWingKem(), + ) + + logger.info("Starting XXhfs handshake (Python = initiator)...") + + # Pass None for remote_peer — the JS listener is freshly keyed each run + # so its peer ID isn't known in advance. The signature is still fully + # verified; we just don't constrain which peer ID is acceptable. + try: + secure_conn = await pattern.handshake_outbound(conn, None) + except Exception as exc: + logger.error("Handshake failed: %s", exc) + raise + + logger.info( + "Handshake complete! Remote peer: %s", + secure_conn.get_remote_peer(), + ) + + # ── Exchange messages ──────────────────────────────────────────────────── + # Read greeting from JS — SecureSession.read() decrypts one Noise message + js_greeting_raw = await secure_conn.read() + js_greeting = js_greeting_raw.decode().strip() + logger.info('Received from JS: "%s"', js_greeting) + + if js_greeting != "hello from JS": + logger.error("Unexpected greeting from JS: %r", js_greeting) + sys.exit(1) + + # Send reply to JS — SecureSession.write() encrypts one Noise message + await secure_conn.write(b"hello from Python\n") + logger.info('Sent to JS: "hello from Python"') + + print() + print("=" * 60) + print("INTEROP SUCCESS") + print("Python <-> JavaScript NoiseHFS handshake complete.") + print(f"Protocol: Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256") + print(f"Local peer: {local_peer}") + print(f"Remote peer: {secure_conn.get_remote_peer()}") + print("=" * 60) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/tests/security/__init__.py b/tests/security/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/security/noise/__init__.py b/tests/security/noise/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/security/noise/pq/__init__.py b/tests/security/noise/pq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/security/noise/pq/test_kem.py b/tests/security/noise/pq/test_kem.py new file mode 100644 index 000000000..61361c82f --- /dev/null +++ b/tests/security/noise/pq/test_kem.py @@ -0,0 +1,96 @@ +"""Tests for the IKem interface and XWingKem implementation.""" + +import pytest + +from libp2p.security.noise.pq.kem import XWingKem + + +class TestXWingKemKeySizes: + """Verify X-Wing key and ciphertext sizes match the spec.""" + + def setup_method(self) -> None: + self.kem = XWingKem() + + def test_keygen_public_key_size(self) -> None: + pk, _ = self.kem.keygen() + # ML-KEM-768 ek (1184) + X25519 pk (32) = 1216 + assert len(pk) == 1216 + + def test_keygen_secret_key_size(self) -> None: + _, sk = self.kem.keygen() + # ML-KEM-768 dk (2400) + X25519 sk (32) = 2432 + assert len(sk) == 2432 + + def test_encapsulate_ciphertext_size(self) -> None: + pk, _ = self.kem.keygen() + ct, _ = self.kem.encapsulate(pk) + # ML-KEM-768 ct (1088) + X25519 ephemeral pk (32) = 1120 + assert len(ct) == 1120 + + def test_encapsulate_shared_secret_size(self) -> None: + pk, _ = self.kem.keygen() + _, ss = self.kem.encapsulate(pk) + assert len(ss) == 32 + + def test_decapsulate_shared_secret_size(self) -> None: + pk, sk = self.kem.keygen() + ct, _ = self.kem.encapsulate(pk) + ss = self.kem.decapsulate(ct, sk) + assert len(ss) == 32 + + +class TestXWingKemRoundTrip: + """Verify encapsulate and decapsulate produce the same shared secret.""" + + def setup_method(self) -> None: + self.kem = XWingKem() + + def test_round_trip(self) -> None: + pk, sk = self.kem.keygen() + ct, ss_enc = self.kem.encapsulate(pk) + ss_dec = self.kem.decapsulate(ct, sk) + assert ss_enc == ss_dec + + def test_round_trip_produces_32_byte_secret(self) -> None: + pk, sk = self.kem.keygen() + ct, ss_enc = self.kem.encapsulate(pk) + ss_dec = self.kem.decapsulate(ct, sk) + assert len(ss_enc) == 32 + assert len(ss_dec) == 32 + + def test_different_keys_produce_different_secrets(self) -> None: + pk1, sk1 = self.kem.keygen() + pk2, sk2 = self.kem.keygen() + ct1, ss1 = self.kem.encapsulate(pk1) + ct2, ss2 = self.kem.encapsulate(pk2) + assert ss1 != ss2 + + def test_wrong_secret_key_produces_different_secret(self) -> None: + pk, sk = self.kem.keygen() + _, wrong_sk = self.kem.keygen() + ct, ss_enc = self.kem.encapsulate(pk) + ss_wrong = self.kem.decapsulate(ct, wrong_sk) + assert ss_enc != ss_wrong + + +class TestXWingKemCombiner: + """Verify the X-Wing combiner produces deterministic output.""" + + def setup_method(self) -> None: + self.kem = XWingKem() + + def test_same_inputs_produce_same_output(self) -> None: + pk, sk = self.kem.keygen() + ct, ss1 = self.kem.encapsulate(pk) + # Encapsulate is non-deterministic (uses fresh ephemeral each time) + # but decapsulate must be deterministic + ss_dec1 = self.kem.decapsulate(ct, sk) + ss_dec2 = self.kem.decapsulate(ct, sk) + assert ss_dec1 == ss_dec2 + + def test_encapsulate_is_non_deterministic(self) -> None: + pk, _ = self.kem.keygen() + ct1, _ = self.kem.encapsulate(pk) + ct2, _ = self.kem.encapsulate(pk) + # Different ephemeral X25519 keys each time + assert ct1 != ct2 diff --git a/tests/security/noise/pq/test_noise_state.py b/tests/security/noise/pq/test_noise_state.py new file mode 100644 index 000000000..193c90d65 --- /dev/null +++ b/tests/security/noise/pq/test_noise_state.py @@ -0,0 +1,156 @@ +"""Tests for the Noise CipherState and SymmetricState.""" + +import hashlib +import hmac +import struct + +import pytest + +from libp2p.security.noise.pq.noise_state import ( + PROTOCOL_NAME, + CipherState, + SymmetricState, +) + + +class TestCipherState: + """Tests for CipherState (ChaCha20-Poly1305 with nonce counter).""" + + def test_encrypt_decrypt_round_trip(self) -> None: + key = bytes(range(32)) + cs = CipherState(key) + plaintext = b"hello noise" + ad = b"associated data" + ct = cs.encrypt_with_ad(ad, plaintext) + cs2 = CipherState(key) + result = cs2.decrypt_with_ad(ad, ct) + assert result == plaintext + + def test_nonce_increments_on_encrypt(self) -> None: + key = bytes(range(32)) + cs = CipherState(key) + assert cs.n == 0 + cs.encrypt_with_ad(b"", b"msg1") + assert cs.n == 1 + cs.encrypt_with_ad(b"", b"msg2") + assert cs.n == 2 + + def test_nonce_increments_on_decrypt(self) -> None: + key = bytes(range(32)) + cs_enc = CipherState(key) + cs_dec = CipherState(key) + ct1 = cs_enc.encrypt_with_ad(b"", b"msg1") + ct2 = cs_enc.encrypt_with_ad(b"", b"msg2") + cs_dec.decrypt_with_ad(b"", ct1) + assert cs_dec.n == 1 + cs_dec.decrypt_with_ad(b"", ct2) + assert cs_dec.n == 2 + + def test_wrong_ad_fails_decryption(self) -> None: + key = bytes(range(32)) + cs_enc = CipherState(key) + ct = cs_enc.encrypt_with_ad(b"correct ad", b"msg") + cs_dec = CipherState(key) + with pytest.raises(Exception): + cs_dec.decrypt_with_ad(b"wrong ad", ct) + + def test_empty_plaintext(self) -> None: + key = bytes(range(32)) + cs = CipherState(key) + ct = cs.encrypt_with_ad(b"ad", b"") + cs2 = CipherState(key) + result = cs2.decrypt_with_ad(b"ad", ct) + assert result == b"" + + def test_ciphertext_is_longer_than_plaintext(self) -> None: + """AEAD tag adds 16 bytes.""" + key = bytes(range(32)) + cs = CipherState(key) + plaintext = b"hello" + ct = cs.encrypt_with_ad(b"", plaintext) + assert len(ct) == len(plaintext) + 16 + + +class TestSymmetricState: + """Tests for SymmetricState (HKDF, MixKey, MixHash, EncryptAndHash).""" + + def test_initial_hash_is_protocol_name_hash(self) -> None: + ss = SymmetricState() + # h = SHA256(PROTOCOL_NAME) since len(PROTOCOL_NAME) > 32 + expected_h = hashlib.sha256(PROTOCOL_NAME).digest() + assert ss.h == expected_h + + def test_initial_chaining_key_equals_h(self) -> None: + ss = SymmetricState() + assert ss.ck == ss.h + + def test_mix_hash_updates_h(self) -> None: + ss = SymmetricState() + old_h = ss.h + ss.mix_hash(b"some data") + assert ss.h != old_h + # h = SHA256(h || data) + expected_h = hashlib.sha256(old_h + b"some data").digest() + assert ss.h == expected_h + + def test_mix_key_updates_ck(self) -> None: + ss = SymmetricState() + old_ck = ss.ck + ss.mix_key(b"shared secret" + bytes(19)) # 32 bytes + assert ss.ck != old_ck + + def test_encrypt_and_hash_round_trip(self) -> None: + ss_enc = SymmetricState() + ss_dec = SymmetricState() + + # Give both states a key via mix_key + ikm = bytes(32) + ss_enc.mix_key(ikm) + ss_dec.mix_key(ikm) + + plaintext = b"handshake payload" + ct = ss_enc.encrypt_and_hash(plaintext) + result = ss_dec.decrypt_and_hash(ct) + assert result == plaintext + + def test_encrypt_and_hash_updates_h(self) -> None: + ss = SymmetricState() + ss.mix_key(bytes(32)) + old_h = ss.h + ss.encrypt_and_hash(b"payload") + assert ss.h != old_h + + def test_split_returns_two_cipher_states(self) -> None: + ss = SymmetricState() + ss.mix_key(bytes(32)) + c1, c2 = ss.split() + # Both should be usable cipher states + ct = c1.encrypt_with_ad(b"", b"msg") + assert len(ct) > 0 + result = c2.encrypt_with_ad(b"", b"msg") + assert len(result) > 0 + + def test_split_produces_different_keys(self) -> None: + ss = SymmetricState() + ss.mix_key(bytes(32)) + c1, c2 = ss.split() + # Encrypting same data with different keys gives different results + ct1 = c1.encrypt_with_ad(b"", b"test") + ct2 = c2.encrypt_with_ad(b"", b"test") + assert ct1 != ct2 + + def test_identical_states_split_identically(self) -> None: + """Two symmetric states with the same transcript split to the same keys.""" + ikm = bytes(range(32)) + ss1 = SymmetricState() + ss2 = SymmetricState() + ss1.mix_hash(b"ephemeral key") + ss2.mix_hash(b"ephemeral key") + ss1.mix_key(ikm) + ss2.mix_key(ikm) + + c1_init, c1_resp = ss1.split() + c2_init, c2_resp = ss2.split() + + # Same input → same output + assert c1_init.encrypt_with_ad(b"", b"msg") == c2_init.encrypt_with_ad(b"", b"msg") diff --git a/tests/security/noise/pq/test_patterns_pq.py b/tests/security/noise/pq/test_patterns_pq.py new file mode 100644 index 000000000..07ea6010f --- /dev/null +++ b/tests/security/noise/pq/test_patterns_pq.py @@ -0,0 +1,393 @@ +""" +Tests for PatternXXhfs: the Noise XXhfs handshake with X-Wing KEM. + +Follows TDD: these tests are written before the implementation and initially fail. +""" + +import math + +import pytest +import trio + +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.crypto.x25519 import X25519PrivateKey +from libp2p.peer.id import ID +from libp2p.security.noise.exceptions import ( + PeerIDMismatchesPubkey, +) +from libp2p.security.noise.pq.patterns_pq import PatternXXhfs + +# --------------------------------------------------------------------------- +# In-memory connection helpers +# --------------------------------------------------------------------------- + + +class _MemoryConn: + """ + Async in-memory bidirectional stream backed by trio memory channels. + + Implements the ReadWriteCloser duck-type expected by NoisePacketReadWriter. + """ + + def __init__(self, send_chan, recv_chan) -> None: + self._send = send_chan + self._recv = recv_chan + self._buf = bytearray() + + async def read(self, n: int | None = None) -> bytes: + while not self._buf: + try: + chunk = await self._recv.receive() + except trio.EndOfChannel: + return b"" + self._buf.extend(chunk) + if n is None: + data = bytes(self._buf) + self._buf.clear() + return data + data = bytes(self._buf[:n]) + del self._buf[:n] + return data + + async def write(self, data: bytes) -> None: + await self._send.send(bytes(data)) + + async def close(self) -> None: + await self._send.aclose() + + def get_remote_address(self) -> None: + return None + + def get_transport_addresses(self) -> list: + return [] + + def get_connection_type(self): + from libp2p.connection_types import ConnectionType + + return ConnectionType.UNKNOWN + + +class _WriteCapture: + """Wraps a connection and records every call to write().""" + + def __init__(self, inner: _MemoryConn) -> None: + self._inner = inner + self.writes: list[bytes] = [] + + async def read(self, n: int | None = None) -> bytes: + return await self._inner.read(n) + + async def write(self, data: bytes) -> None: + self.writes.append(bytes(data)) + await self._inner.write(data) + + async def close(self) -> None: + await self._inner.close() + + def get_remote_address(self) -> None: + return None + + def get_transport_addresses(self) -> list: + return [] + + def get_connection_type(self): + from libp2p.connection_types import ConnectionType + + return ConnectionType.UNKNOWN + + +def _make_conn_pair() -> tuple[_MemoryConn, _MemoryConn]: + """Create a pair of in-memory connections wired together.""" + a_to_b_send, a_to_b_recv = trio.open_memory_channel(math.inf) + b_to_a_send, b_to_a_recv = trio.open_memory_channel(math.inf) + init_conn = _MemoryConn(a_to_b_send, b_to_a_recv) + resp_conn = _MemoryConn(b_to_a_send, a_to_b_recv) + return init_conn, resp_conn + + +def _make_pattern() -> tuple[PatternXXhfs, object, object, ID]: + """Create a fresh PatternXXhfs with newly-generated keys.""" + kp = create_new_key_pair() + noise_key = X25519PrivateKey.new() + peer = ID.from_pubkey(kp.public_key) + pattern = PatternXXhfs( + local_peer=peer, + libp2p_privkey=kp.private_key, + noise_static_key=noise_key, + ) + return pattern, kp, noise_key, peer + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestPatternXXhfsInit: + """Basic construction and attribute checks.""" + + def test_instantiation_stores_fields(self) -> None: + kp = create_new_key_pair() + noise_key = X25519PrivateKey.new() + peer = ID.from_pubkey(kp.public_key) + pattern = PatternXXhfs( + local_peer=peer, + libp2p_privkey=kp.private_key, + noise_static_key=noise_key, + ) + assert pattern.local_peer is peer + assert pattern.libp2p_privkey is kp.private_key + assert pattern.noise_static_key is noise_key + assert pattern.early_data is None + + def test_protocol_name(self) -> None: + pattern, _, _, _ = _make_pattern() + assert pattern.PROTOCOL_NAME == b"Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256" + + def test_default_kem_is_xwing(self) -> None: + from libp2p.security.noise.pq.kem import XWingKem + + pattern, _, _, _ = _make_pattern() + assert isinstance(pattern.kem, XWingKem) + + +class TestPatternXXhfsHandshake: + """Full-handshake integration tests.""" + + @pytest.mark.trio + async def test_handshake_completes(self) -> None: + """Both sides return a SecureSession after the handshake.""" + init_pat, _, _, _ = _make_pattern() + resp_pat, _, _, resp_peer = _make_pattern() + init_conn, resp_conn = _make_conn_pair() + + sessions: list = [None, None] + + async def run_init() -> None: + sessions[0] = await init_pat.handshake_outbound(init_conn, resp_peer) + + async def run_resp() -> None: + sessions[1] = await resp_pat.handshake_inbound(resp_conn) + + async with trio.open_nursery() as nursery: + nursery.start_soon(run_init) + nursery.start_soon(run_resp) + + assert sessions[0] is not None + assert sessions[1] is not None + + @pytest.mark.trio + async def test_bidirectional_data_exchange(self) -> None: + """Data written by each side is received correctly by the other.""" + init_pat, _, _, _ = _make_pattern() + resp_pat, _, _, resp_peer = _make_pattern() + init_conn, resp_conn = _make_conn_pair() + + sessions: list = [None, None] + + async def run_init() -> None: + sessions[0] = await init_pat.handshake_outbound(init_conn, resp_peer) + + async def run_resp() -> None: + sessions[1] = await resp_pat.handshake_inbound(resp_conn) + + async with trio.open_nursery() as nursery: + nursery.start_soon(run_init) + nursery.start_soon(run_resp) + + init_sess, resp_sess = sessions + + # Initiator → Responder + msg_i = b"hello from initiator" + await init_sess.write(msg_i) + assert await resp_sess.read(len(msg_i)) == msg_i + + # Responder → Initiator + msg_r = b"hello from responder" + await resp_sess.write(msg_r) + assert await init_sess.read(len(msg_r)) == msg_r + + @pytest.mark.trio + async def test_peer_ids_are_correct(self) -> None: + """Both sides see the correct remote peer ID after the handshake.""" + init_pat, _, _, init_peer = _make_pattern() + resp_pat, _, _, resp_peer = _make_pattern() + init_conn, resp_conn = _make_conn_pair() + + sessions: list = [None, None] + + async def run_init() -> None: + sessions[0] = await init_pat.handshake_outbound(init_conn, resp_peer) + + async def run_resp() -> None: + sessions[1] = await resp_pat.handshake_inbound(resp_conn) + + async with trio.open_nursery() as nursery: + nursery.start_soon(run_init) + nursery.start_soon(run_resp) + + init_sess, resp_sess = sessions + assert init_sess.remote_peer == resp_peer + assert resp_sess.remote_peer == init_peer + + @pytest.mark.trio + async def test_peer_id_mismatch_raises(self) -> None: + """Initiator raises PeerIDMismatchesPubkey when peer ID is wrong.""" + init_pat, _, _, _ = _make_pattern() + resp_pat, _, _, resp_peer = _make_pattern() + _, _, _, wrong_peer = _make_pattern() + init_conn, resp_conn = _make_conn_pair() + + init_error: Exception | None = None + + async def run_init() -> None: + nonlocal init_error + try: + await init_pat.handshake_outbound(init_conn, wrong_peer) + except PeerIDMismatchesPubkey as e: + init_error = e + except Exception: + pass + finally: + # Closing the send side unblocks the responder waiting for Msg C. + await init_conn.close() + + async def run_resp() -> None: + try: + await resp_pat.handshake_inbound(resp_conn) + except Exception: + pass + + async with trio.open_nursery() as nursery: + nursery.start_soon(run_init) + nursery.start_soon(run_resp) + + assert isinstance(init_error, PeerIDMismatchesPubkey) + + @pytest.mark.trio + async def test_large_payload_exchange(self) -> None: + """Transport handles payloads larger than a single cipher block.""" + init_pat, _, _, _ = _make_pattern() + resp_pat, _, _, resp_peer = _make_pattern() + init_conn, resp_conn = _make_conn_pair() + + sessions: list = [None, None] + + async def run_init() -> None: + sessions[0] = await init_pat.handshake_outbound(init_conn, resp_peer) + + async def run_resp() -> None: + sessions[1] = await resp_pat.handshake_inbound(resp_conn) + + async with trio.open_nursery() as nursery: + nursery.start_soon(run_init) + nursery.start_soon(run_resp) + + large_msg = b"Z" * 8192 + await sessions[0].write(large_msg) + received = await sessions[1].read(8192) + assert received == large_msg + + @pytest.mark.trio + async def test_independent_sessions_dont_interfere(self) -> None: + """Two simultaneous handshakes produce independent, non-interfering sessions.""" + ip1, _, _, _ = _make_pattern() + rp1, _, _, rp1_peer = _make_pattern() + ip2, _, _, _ = _make_pattern() + rp2, _, _, rp2_peer = _make_pattern() + + ic1, rc1 = _make_conn_pair() + ic2, rc2 = _make_conn_pair() + + sessions: list = [None] * 4 + + async def h1_init() -> None: + sessions[0] = await ip1.handshake_outbound(ic1, rp1_peer) + + async def h1_resp() -> None: + sessions[1] = await rp1.handshake_inbound(rc1) + + async def h2_init() -> None: + sessions[2] = await ip2.handshake_outbound(ic2, rp2_peer) + + async def h2_resp() -> None: + sessions[3] = await rp2.handshake_inbound(rc2) + + async with trio.open_nursery() as nursery: + nursery.start_soon(h1_init) + nursery.start_soon(h1_resp) + nursery.start_soon(h2_init) + nursery.start_soon(h2_resp) + + assert all(s is not None for s in sessions) + + # Both pairs exchange data independently + await sessions[0].write(b"pair1") + await sessions[2].write(b"pair2") + assert await sessions[1].read(5) == b"pair1" + assert await sessions[3].read(5) == b"pair2" + + +class TestPatternXXhfsWireFormat: + """Verify the on-wire message layout.""" + + @pytest.mark.trio + async def test_message_a_is_1248_bytes(self) -> None: + """Message A = e_pk(32) + e1_pk(1216) = 1248 bytes payload.""" + init_pat, _, _, _ = _make_pattern() + resp_pat, _, _, resp_peer = _make_pattern() + + inner_conn, resp_conn = _make_conn_pair() + spy = _WriteCapture(inner_conn) + + sessions: list = [None, None] + + async def run_init() -> None: + sessions[0] = await init_pat.handshake_outbound(spy, resp_peer) + + async def run_resp() -> None: + sessions[1] = await resp_pat.handshake_inbound(resp_conn) + + async with trio.open_nursery() as nursery: + nursery.start_soon(run_init) + nursery.start_soon(run_resp) + + # spy.writes[0] = 2-byte length prefix + message A bytes + assert len(spy.writes) >= 1 + frame = spy.writes[0] + msg_len = int.from_bytes(frame[:2], "big") + assert msg_len == 1248, f"Expected 1248, got {msg_len}" + + @pytest.mark.trio + async def test_message_b_overhead(self) -> None: + """Message B = e(32) + enc_ct(1136) + enc_s(48) + enc_payload(len+16).""" + init_pat, _, _, _ = _make_pattern() + resp_pat, _, _, resp_peer = _make_pattern() + + init_conn, inner_resp_conn = _make_conn_pair() + spy = _WriteCapture(inner_resp_conn) + + sessions: list = [None, None] + + async def run_init() -> None: + sessions[0] = await init_pat.handshake_outbound(init_conn, resp_peer) + + async def run_resp() -> None: + sessions[1] = await resp_pat.handshake_inbound(spy) + + async with trio.open_nursery() as nursery: + nursery.start_soon(run_init) + nursery.start_soon(run_resp) + + # spy.writes[0] = framed message B + assert len(spy.writes) >= 1 + frame = spy.writes[0] + msg_len = int.from_bytes(frame[:2], "big") + + # Fixed overhead: 32 (e) + 1136 (enc_ct) + 48 (enc_s) + 16 (AEAD tag) + # Payload size varies (protobuf-serialised NoiseHandshakePayload) but + # total fixed overhead is constant + fixed_overhead = 32 + 1136 + 48 + 16 + assert msg_len >= fixed_overhead, ( + f"Message B too short: {msg_len} < {fixed_overhead}" + ) diff --git a/tests/security/noise/pq/test_transport_pq.py b/tests/security/noise/pq/test_transport_pq.py new file mode 100644 index 000000000..aa14337d8 --- /dev/null +++ b/tests/security/noise/pq/test_transport_pq.py @@ -0,0 +1,203 @@ +"""Tests for TransportPQ: the ISecureTransport wrapper for XXhfs. + +Follows TDD: these tests are written before the implementation. +""" + +import math + +import pytest +import trio + +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.crypto.keys import KeyPair +from libp2p.crypto.x25519 import X25519PrivateKey +from libp2p.peer.id import ID +from libp2p.security.noise.pq.transport_pq import PROTOCOL_ID, TransportPQ + + +# --------------------------------------------------------------------------- +# Shared in-memory connection helpers (mirrors test_patterns_pq.py) +# --------------------------------------------------------------------------- + + +class _MemoryConn: + def __init__(self, send_chan, recv_chan) -> None: + self._send = send_chan + self._recv = recv_chan + self._buf = bytearray() + + async def read(self, n: int | None = None) -> bytes: + while not self._buf: + try: + chunk = await self._recv.receive() + except trio.EndOfChannel: + return b"" + self._buf.extend(chunk) + if n is None: + data = bytes(self._buf) + self._buf.clear() + return data + data = bytes(self._buf[:n]) + del self._buf[:n] + return data + + async def write(self, data: bytes) -> None: + await self._send.send(bytes(data)) + + async def close(self) -> None: + await self._send.aclose() + + def get_remote_address(self) -> None: + return None + + def get_transport_addresses(self) -> list: + return [] + + def get_connection_type(self): + from libp2p.connection_types import ConnectionType + + return ConnectionType.UNKNOWN + + +def _make_conn_pair() -> tuple[_MemoryConn, _MemoryConn]: + a_to_b_send, a_to_b_recv = trio.open_memory_channel(math.inf) + b_to_a_send, b_to_a_recv = trio.open_memory_channel(math.inf) + return ( + _MemoryConn(a_to_b_send, b_to_a_recv), + _MemoryConn(b_to_a_send, a_to_b_recv), + ) + + +def _make_transport() -> tuple[TransportPQ, ID]: + kp = create_new_key_pair() + noise_key = X25519PrivateKey.new() + peer = ID.from_pubkey(kp.public_key) + transport = TransportPQ( + libp2p_keypair=KeyPair(kp.private_key, kp.public_key), + noise_privkey=noise_key, + ) + return transport, peer + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestTransportPQInit: + def test_protocol_id(self) -> None: + assert PROTOCOL_ID == "/noise-pq/1.0.0" + + def test_instantiation(self) -> None: + transport, peer = _make_transport() + assert transport.local_peer == peer + + def test_get_pattern_returns_xxhfs(self) -> None: + from libp2p.security.noise.pq.patterns_pq import PatternXXhfs + + transport, _ = _make_transport() + pattern = transport.get_pattern() + assert isinstance(pattern, PatternXXhfs) + + def test_get_pattern_protocol_name(self) -> None: + transport, _ = _make_transport() + pattern = transport.get_pattern() + assert pattern.PROTOCOL_NAME == b"Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256" + + +class TestTransportPQHandshake: + @pytest.mark.trio + async def test_secure_inbound_and_outbound_complete(self) -> None: + """secure_outbound + secure_inbound both return a SecureSession.""" + local_transport, local_peer = _make_transport() + remote_transport, remote_peer = _make_transport() + local_conn, remote_conn = _make_conn_pair() + + sessions: list = [None, None] + + async def do_outbound() -> None: + sessions[0] = await local_transport.secure_outbound(local_conn, remote_peer) + + async def do_inbound() -> None: + sessions[1] = await remote_transport.secure_inbound(remote_conn) + + async with trio.open_nursery() as nursery: + nursery.start_soon(do_outbound) + nursery.start_soon(do_inbound) + + assert sessions[0] is not None + assert sessions[1] is not None + + @pytest.mark.trio + async def test_data_exchange_after_secure_transport(self) -> None: + """Data written via secure_outbound is readable via secure_inbound.""" + local_transport, _ = _make_transport() + remote_transport, remote_peer = _make_transport() + local_conn, remote_conn = _make_conn_pair() + + sessions: list = [None, None] + + async def do_outbound() -> None: + sessions[0] = await local_transport.secure_outbound(local_conn, remote_peer) + + async def do_inbound() -> None: + sessions[1] = await remote_transport.secure_inbound(remote_conn) + + async with trio.open_nursery() as nursery: + nursery.start_soon(do_outbound) + nursery.start_soon(do_inbound) + + outbound_sess, inbound_sess = sessions + + msg = b"post-quantum hello" + await outbound_sess.write(msg) + assert await inbound_sess.read(len(msg)) == msg + + reply = b"pq reply" + await inbound_sess.write(reply) + assert await outbound_sess.read(len(reply)) == reply + + @pytest.mark.trio + async def test_peer_ids_correct_after_transport(self) -> None: + """Both sides see the correct remote peer ID after the secure upgrade.""" + local_transport, local_peer = _make_transport() + remote_transport, remote_peer = _make_transport() + local_conn, remote_conn = _make_conn_pair() + + sessions: list = [None, None] + + async def do_outbound() -> None: + sessions[0] = await local_transport.secure_outbound(local_conn, remote_peer) + + async def do_inbound() -> None: + sessions[1] = await remote_transport.secure_inbound(remote_conn) + + async with trio.open_nursery() as nursery: + nursery.start_soon(do_outbound) + nursery.start_soon(do_inbound) + + outbound_sess, inbound_sess = sessions + assert outbound_sess.remote_peer == remote_peer + assert inbound_sess.remote_peer == local_peer + + @pytest.mark.trio + async def test_is_initiator_flag(self) -> None: + """secure_outbound returns is_initiator=True, secure_inbound returns False.""" + local_transport, _ = _make_transport() + remote_transport, remote_peer = _make_transport() + local_conn, remote_conn = _make_conn_pair() + + sessions: list = [None, None] + + async def do_outbound() -> None: + sessions[0] = await local_transport.secure_outbound(local_conn, remote_peer) + + async def do_inbound() -> None: + sessions[1] = await remote_transport.secure_inbound(remote_conn) + + async with trio.open_nursery() as nursery: + nursery.start_soon(do_outbound) + nursery.start_soon(do_inbound) + + assert sessions[0].is_initiator is True + assert sessions[1].is_initiator is False diff --git a/tests/security/noise/pq/test_vectors_pq.py b/tests/security/noise/pq/test_vectors_pq.py new file mode 100644 index 000000000..67d1f368c --- /dev/null +++ b/tests/security/noise/pq/test_vectors_pq.py @@ -0,0 +1,388 @@ +""" +Cross-implementation test vectors for Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256. + +Loads the 5 committed vectors from js-libp2p-noise and reproduces each +handshake with the exact same fixed seeds. Asserts byte-exact equality of: + - messages A, B, C + - final handshake hash (ss.h after Message C) + - transport cipher keys cs1.k and cs2.k + +All 5 vectors passing proves the Python and JavaScript implementations are +wire-compatible. + +Key seeding: + - Initiator ephemeral DH : ephemeral_dh_i_{public,private} + - Responder ephemeral DH : ephemeral_dh_r_{public,private} + - Initiator KEM keypair : ephemeral_kem_i_{public,secret} + - Responder encap seed : encap_seed_hex (64 bytes) + [0:32] = ML-KEM randomness m + [32:64] = X25519 ephemeral private key + +Prologue for all vectors: empty (ZEROLEN = b""), which the Noise spec requires +to be mixed even when empty: h = SHA256(SHA256(protocol_name)). +""" + +import hashlib +import json +from pathlib import Path + +import pytest +from nacl.bindings import crypto_scalarmult, crypto_scalarmult_base + +from kyber_py.ml_kem import ML_KEM_768 + +from libp2p.security.noise.pq.kem import ( + _ML_KEM_PK_SIZE, + _ML_KEM_CT_SIZE, + _X25519_KEY_SIZE, + _xwing_combine, +) +from libp2p.security.noise.pq.noise_state import SymmetricState, _hkdf + +# --------------------------------------------------------------------------- +# Vector file location +# --------------------------------------------------------------------------- + +_VECTORS_PATH = ( + Path(__file__).parents[4].parent # PQC-Research/ + / "js-libp2p-noise" + / "test" + / "fixtures" + / "pqc-test-vectors.json" +) + +# Size constants (mirror patterns_pq.py) +_AEAD_TAG = 16 +_KEM_CT_ENC_SIZE = _ML_KEM_CT_SIZE + _X25519_KEY_SIZE + _AEAD_TAG # 1136 +_S_ENC_SIZE = _X25519_KEY_SIZE + _AEAD_TAG # 48 + + +# --------------------------------------------------------------------------- +# Seeded X-Wing encapsulate (for responder with fixed seed) +# --------------------------------------------------------------------------- + + +def _xwing_encapsulate_seeded(pk: bytes, encap_seed: bytes) -> tuple[bytes, bytes]: + """ + Deterministic X-Wing encapsulation using a 64-byte seed. + + seed[0:32] = ML-KEM randomness m (passed to _encaps_internal) + seed[32:64] = X25519 ephemeral private key + + Matches XWing.encapsulate(pubkey, seed) from @noble/post-quantum. + + Returns: + (ciphertext, shared_secret) — same layout as XWingKem.encapsulate() + """ + assert len(pk) == _ML_KEM_PK_SIZE + _X25519_KEY_SIZE, f"bad pk len: {len(pk)}" + assert len(encap_seed) == 64, f"seed must be 64 bytes, got {len(encap_seed)}" + + ml_kem_pk = pk[:_ML_KEM_PK_SIZE] + x25519_pk_r = pk[_ML_KEM_PK_SIZE:] + + # ML-KEM with deterministic randomness + m = encap_seed[:32] + ss_mlkem, ml_kem_ct = ML_KEM_768._encaps_internal(ml_kem_pk, m) + + # X25519 with fixed ephemeral private key + x25519_eph_sk = encap_seed[32:] + x25519_eph_pk = bytes(crypto_scalarmult_base(x25519_eph_sk)) + ss_x25519 = bytes(crypto_scalarmult(x25519_eph_sk, x25519_pk_r)) + + ss = _xwing_combine(ss_mlkem, ss_x25519, x25519_eph_pk, x25519_pk_r) + ct = ml_kem_ct + x25519_eph_pk + return ct, ss + + +# --------------------------------------------------------------------------- +# X-Wing seed expansion (matches @noble/post-quantum combineKEMS + expandSeedXof) +# --------------------------------------------------------------------------- + + +def _xwing_sk_from_seed(seed: bytes) -> bytes: + """ + Expand a 32-byte X-Wing root seed into the full 2432-byte secret key. + + @noble/post-quantum stores secretKey as the 32-byte root seed and + re-expands on each decapsulate call using: + expanded = SHAKE-256(seed, 96 bytes) + expanded[0:32] = ML-KEM-768 d (randomness) + expanded[32:64] = ML-KEM-768 z (implicit rejection randomness) + expanded[64:96] = X25519 private key + + Returns: + 2432-byte X-Wing secret key (ml_kem_dk || x25519_sk) + """ + assert len(seed) == 32, f"seed must be 32 bytes, got {len(seed)}" + expanded = hashlib.shake_256(seed).digest(96) + d = expanded[0:32] + z = expanded[32:64] + x25519_sk = expanded[64:96] + _ml_kem_pk, ml_kem_sk = ML_KEM_768._keygen_internal(d, z) + return ml_kem_sk + x25519_sk + + +# --------------------------------------------------------------------------- +# Deterministic handshake reproducer +# --------------------------------------------------------------------------- + + +def _run_vector_handshake(v: dict) -> dict: + """ + Reproduce the XXhfs handshake with all keys fixed from a test vector. + + Returns: + dict with keys: msg_a, msg_b, msg_c, handshake_hash, cs1_k, cs2_k + """ + # Load fixed values + e_i_sk = bytes.fromhex(v["ephemeral_dh_i_private"]) + e_i_pk = bytes.fromhex(v["ephemeral_dh_i_public"]) + e_r_sk = bytes.fromhex(v["ephemeral_dh_r_private"]) + e_r_pk = bytes.fromhex(v["ephemeral_dh_r_public"]) + static_i_sk = bytes.fromhex(v["static_i_private"]) + static_i_pk = bytes.fromhex(v["static_i_public"]) + static_r_sk = bytes.fromhex(v["static_r_private"]) + static_r_pk = bytes.fromhex(v["static_r_public"]) + e1_pk = bytes.fromhex(v["ephemeral_kem_i_public"]) + # ephemeral_kem_i_secret is a 32-byte root seed (not the full 2432-byte key) + # @noble/post-quantum stores the root seed as secretKey and re-expands via SHAKE-256 + e1_sk = _xwing_sk_from_seed(bytes.fromhex(v["ephemeral_kem_i_secret"])) + encap_seed = bytes.fromhex(v["encap_seed_hex"]) + # Prologue: ZEROLEN = b"" + + # ---- Initiator SymmetricState ---------------------------------------- + ss_i = SymmetricState() + ss_i.mix_hash(b"") # MixHash(prologue=empty) + + # ---- Responder SymmetricState ---------------------------------------- + ss_r = SymmetricState() + ss_r.mix_hash(b"") # MixHash(prologue=empty) + + # ====================================================================== + # Message A: initiator sends e_pk || e1_pk (no payload, no AEAD yet) + # ====================================================================== + ss_i.mix_hash(e_i_pk) # writeE token + ss_i.mix_hash(e1_pk) # writeE1 token (encryptAndHash = mixHash when no key) + enc_payload_a = ss_i.encrypt_and_hash(b"") # empty payload → b"" + msg_a = e_i_pk + e1_pk + enc_payload_a # 32 + 1216 + 0 = 1248 B + + # Responder processes Message A + ss_r.mix_hash(e_i_pk) + ss_r.mix_hash(e1_pk) + ss_r.decrypt_and_hash(enc_payload_a) # mix_hash(b"") — keeps states in sync + + # ====================================================================== + # Message B: responder sends e_pk || enc_ct || enc_s || enc_payload + # Tokens: e, ee, ekem1, s, es + # ====================================================================== + ss_r.mix_hash(e_r_pk) # writeE + + dh_ee = bytes(crypto_scalarmult(e_r_sk, e_i_pk)) + ss_r.mix_key(dh_ee) # writeEE + + # writeEkem1: encapsulate, encrypt ct, then mix KEM ss + ct, ss_kem_r = _xwing_encapsulate_seeded(e1_pk, encap_seed) + enc_ct = ss_r.encrypt_and_hash(ct) # encrypted under ee-derived key + ss_r.mix_key(ss_kem_r) # then strengthen with KEM output + + # writeS: encrypt responder static pubkey + enc_s_r = ss_r.encrypt_and_hash(static_r_pk) + + # writeES (responder role): MixKey(DH(s_responder, e_initiator)) + dh_es_r = bytes(crypto_scalarmult(static_r_sk, e_i_pk)) + ss_r.mix_key(dh_es_r) + + # payload = ZEROLEN + enc_payload_b = ss_r.encrypt_and_hash(b"") + msg_b = e_r_pk + enc_ct + enc_s_r + enc_payload_b # 32+1136+48+16 = 1232 B + + # Initiator processes Message B + ss_i.mix_hash(e_r_pk) + + dh_ee_i = bytes(crypto_scalarmult(e_i_sk, e_r_pk)) + ss_i.mix_key(dh_ee_i) # ee token + + # readEkem1: decrypt ct, then decapsulate, then mix KEM ss + from libp2p.security.noise.pq.kem import XWingKem + ct_dec = ss_i.decrypt_and_hash(enc_ct) + ss_kem_i = XWingKem().decapsulate(ct_dec, e1_sk) + ss_i.mix_key(ss_kem_i) + + # readS: decrypt responder static pubkey + dec_s_r = ss_i.decrypt_and_hash(enc_s_r) + + # readES (initiator role): MixKey(DH(e_initiator, s_responder)) + dh_es_i = bytes(crypto_scalarmult(e_i_sk, dec_s_r)) + ss_i.mix_key(dh_es_i) + + ss_i.decrypt_and_hash(enc_payload_b) # empty payload + + # ====================================================================== + # Message C: initiator sends enc_s || enc_payload + # Tokens: s, se + # ====================================================================== + # writeS: encrypt initiator static pubkey + enc_s_i = ss_i.encrypt_and_hash(static_i_pk) + + # writeSE (initiator role): MixKey(DH(s_initiator, e_responder)) + dh_se_i = bytes(crypto_scalarmult(static_i_sk, e_r_pk)) + ss_i.mix_key(dh_se_i) + + enc_payload_c = ss_i.encrypt_and_hash(b"") + msg_c = enc_s_i + enc_payload_c # 48 + 16 = 64 B + + # Responder processes Message C + dec_s_i = ss_r.decrypt_and_hash(enc_s_i) + + # readSE (responder role): MixKey(DH(e_responder, s_initiator)) + dh_se_r = bytes(crypto_scalarmult(e_r_sk, dec_s_i)) + ss_r.mix_key(dh_se_r) + + ss_r.decrypt_and_hash(enc_payload_c) + + # ====================================================================== + # Split — both sides must derive the same cipher keys + # ====================================================================== + cs1_k, cs2_k = _hkdf(ss_i.ck, b"", 2) + cs1_k_r, cs2_k_r = _hkdf(ss_r.ck, b"", 2) + assert cs1_k == cs1_k_r, "cs1 key mismatch between initiator and responder" + assert cs2_k == cs2_k_r, "cs2 key mismatch between initiator and responder" + + return { + "msg_a": msg_a, + "msg_b": msg_b, + "msg_c": msg_c, + "handshake_hash": ss_i.h, + "cs1_k": cs1_k, + "cs2_k": cs2_k, + } + + +# --------------------------------------------------------------------------- +# Test fixture loading +# --------------------------------------------------------------------------- + + +def _load_vectors() -> list[dict]: + if not _VECTORS_PATH.exists(): + pytest.skip(f"JS test vectors not found at {_VECTORS_PATH}") + with open(_VECTORS_PATH) as f: + data = json.load(f) + assert data["protocol"] == "Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256" + return data["vectors"] + + +# --------------------------------------------------------------------------- +# Parameterised tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def vectors() -> list[dict]: + return _load_vectors() + + +class TestVectorsMeta: + def test_protocol_name(self) -> None: + if not _VECTORS_PATH.exists(): + pytest.skip("JS test vectors not found") + with open(_VECTORS_PATH) as f: + data = json.load(f) + assert data["protocol"] == "Noise_XXhfs_25519+XWing_ChaChaPoly_SHA256" + + def test_five_vectors_present(self, vectors: list[dict]) -> None: + assert len(vectors) == 5 + + +class TestVectorHandshake: + """One test class; vectors are parameterised inside each test method.""" + + @pytest.mark.parametrize("idx", range(5)) + def test_msg_a_bytes(self, vectors: list[dict], idx: int) -> None: + v = vectors[idx] + result = _run_vector_handshake(v) + assert len(result["msg_a"]) == v["msg_a_bytes"], ( + f"Vector {idx}: msg_a length {len(result['msg_a'])} != {v['msg_a_bytes']}" + ) + + @pytest.mark.parametrize("idx", range(5)) + def test_msg_a_content(self, vectors: list[dict], idx: int) -> None: + v = vectors[idx] + result = _run_vector_handshake(v) + got = result["msg_a"].hex() + assert got == v["msg_a"], ( + f"Vector {idx}: msg_a mismatch\n" + f" got: {got[:64]}...\n" + f" expected: {v['msg_a'][:64]}..." + ) + + @pytest.mark.parametrize("idx", range(5)) + def test_msg_b_bytes(self, vectors: list[dict], idx: int) -> None: + v = vectors[idx] + result = _run_vector_handshake(v) + assert len(result["msg_b"]) == v["msg_b_bytes"], ( + f"Vector {idx}: msg_b length {len(result['msg_b'])} != {v['msg_b_bytes']}" + ) + + @pytest.mark.parametrize("idx", range(5)) + def test_msg_b_content(self, vectors: list[dict], idx: int) -> None: + v = vectors[idx] + result = _run_vector_handshake(v) + got = result["msg_b"].hex() + assert got == v["msg_b"], ( + f"Vector {idx}: msg_b mismatch\n" + f" got: {got[:64]}...\n" + f" expected: {v['msg_b'][:64]}..." + ) + + @pytest.mark.parametrize("idx", range(5)) + def test_msg_c_bytes(self, vectors: list[dict], idx: int) -> None: + v = vectors[idx] + result = _run_vector_handshake(v) + assert len(result["msg_c"]) == v["msg_c_bytes"], ( + f"Vector {idx}: msg_c length {len(result['msg_c'])} != {v['msg_c_bytes']}" + ) + + @pytest.mark.parametrize("idx", range(5)) + def test_msg_c_content(self, vectors: list[dict], idx: int) -> None: + v = vectors[idx] + result = _run_vector_handshake(v) + got = result["msg_c"].hex() + assert got == v["msg_c"], ( + f"Vector {idx}: msg_c mismatch\n" + f" got: {got[:64]}...\n" + f" expected: {v['msg_c'][:64]}..." + ) + + @pytest.mark.parametrize("idx", range(5)) + def test_handshake_hash(self, vectors: list[dict], idx: int) -> None: + v = vectors[idx] + result = _run_vector_handshake(v) + got = result["handshake_hash"].hex() + assert got == v["handshake_hash"], ( + f"Vector {idx}: handshake_hash mismatch — transcript diverged\n" + f" got: {got}\n" + f" expected: {v['handshake_hash']}" + ) + + @pytest.mark.parametrize("idx", range(5)) + def test_cs1_k(self, vectors: list[dict], idx: int) -> None: + v = vectors[idx] + result = _run_vector_handshake(v) + got = result["cs1_k"].hex() + assert got == v["cs1_k"], ( + f"Vector {idx}: cs1_k mismatch\n" + f" got: {got}\n" + f" expected: {v['cs1_k']}" + ) + + @pytest.mark.parametrize("idx", range(5)) + def test_cs2_k(self, vectors: list[dict], idx: int) -> None: + v = vectors[idx] + result = _run_vector_handshake(v) + got = result["cs2_k"].hex() + assert got == v["cs2_k"], ( + f"Vector {idx}: cs2_k mismatch\n" + f" got: {got}\n" + f" expected: {v['cs2_k']}" + )