diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4991a5968..1bf1e8196 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,5 +1,23 @@ name: build -on: [push, pull_request] +on: + push: + paths-ignore: + - '**/*.md' + - 'docs/**' + - '.gitignore' + - 'LICENSE-APACHE' + - 'LICENSE-MIT' + - 'funding.json' + - '.github/ISSUE_TEMPLATE/**' + pull_request: + paths-ignore: + - '**/*.md' + - 'docs/**' + - '.gitignore' + - 'LICENSE-APACHE' + - 'LICENSE-MIT' + - 'funding.json' + - '.github/ISSUE_TEMPLATE/**' jobs: gradle: strategy: @@ -12,12 +30,16 @@ jobs: with: distribution: temurin java-version: 11 + - name: Install and run ipfs + run: ./install-run-ipfs.sh - name: Setup Gradle uses: gradle/gradle-build-action@v2 - name: Setup Android SDK - uses: android-actions/setup-android@v2 + uses: android-actions/setup-android@v3 + with: + cmdline-tools-version: 8512546 - name: Execute Gradle build run: ./gradlew -s build dokkaJar \ No newline at end of file diff --git a/.github/workflows/generated-pr.yml b/.github/workflows/generated-pr.yml new file mode 100644 index 000000000..b8c5cc631 --- /dev/null +++ b/.github/workflows/generated-pr.yml @@ -0,0 +1,14 @@ +name: Close Generated PRs + +on: + schedule: + - cron: '0 0 * * *' + workflow_dispatch: + +permissions: + issues: write + pull-requests: write + +jobs: + stale: + uses: ipdxco/unified-github-workflows/.github/workflows/reusable-generated-pr.yml@v1 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 0eeea244e..bc94539fa 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -3,6 +3,14 @@ on: push: branches: - "develop" + paths-ignore: + - '**/*.md' + - 'docs/**' + - '.gitignore' + - 'LICENSE-APACHE' + - 'LICENSE-MIT' + - 'funding.json' + - '.github/ISSUE_TEMPLATE/**' jobs: publish: runs-on: ubuntu-latest diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 16d65d721..7c955c414 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -1,8 +1,9 @@ -name: Close and mark stale issue +name: Close Stale Issues on: schedule: - cron: '0 0 * * *' + workflow_dispatch: permissions: issues: write @@ -10,4 +11,4 @@ permissions: jobs: stale: - uses: pl-strflt/.github/.github/workflows/reusable-stale-issue.yml@v0.3 + uses: ipdxco/unified-github-workflows/.github/workflows/reusable-stale-issue.yml@v1 diff --git a/.gitignore b/.gitignore index a939ebb2f..73272b159 100644 --- a/.gitignore +++ b/.gitignore @@ -190,3 +190,11 @@ $RECYCLE.BIN/ node_modules package-lock.json /src/jmh/java/generated/ + +#Jenv +.java-version + +# Claude +CLAUDE.local.md +.claude/settings.local.json +.worktrees/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..79fd53bfd --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,314 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +jvm-libp2p is a JVM implementation of the [libp2p](https://libp2p.io/) networking stack, written in Kotlin. It provides peer-to-peer networking capabilities including transport protocols (TCP, QUIC, WebSocket), security channels (Noise, TLS), stream multiplexing (Yamux, Mplex), and pub/sub messaging (Gossipsub, Floodsub). + +Notable users: Teku (Ethereum Consensus Layer client), Nabu (minimal IPFS), Peergos (peer-to-peer encrypted filesystem). + +## Build Commands + +```bash +# Build the entire project +./gradlew build + +# Run all tests (excludes interop tests tagged with "interop") +./gradlew test + +# Run tests for a specific module +./gradlew :libp2p:test + +# Run a specific test class +./gradlew :libp2p:test --tests "io.libp2p.pubsub.gossip.GossipRpcPartsQueueTest" + +# Run a specific test method +./gradlew :libp2p:test --tests "io.libp2p.pubsub.gossip.GossipRpcPartsQueueTest.mergeMessageParts*" + +# Check code formatting +./gradlew spotlessCheck + +# Apply code formatting +./gradlew spotlessApply + +# Run static analysis (Detekt) +./gradlew detekt + +# Generate documentation +./gradlew dokkaHtml +# Output in build/dokka/ + +# Clean build artifacts +./gradlew clean +``` + +**Requirements:** JDK 11 or higher + +**Module Structure:** +- `:libp2p` - Main library module +- `:tools:simulator` - Gossip network simulator +- `:tools:schedulers` - Test scheduling utilities +- `:examples:chatter`, `:examples:cli-chatter`, `:examples:pinger` - Example applications +- `:interop-test-client` - Interoperability testing client + +## Architecture Overview + +### Core Abstraction Layers + +The library follows a layered architecture with protocol negotiation at each layer: + +``` +Application Layer + ↓ (Protocol negotiation via multistream-select) +Stream/Protocol Layer (PingProtocol, ChatProtocol, PubsubRouter) + ↓ (Stream creation) +Stream Multiplexing Layer (Yamux, Mplex) + ↓ (Multiplexer negotiation) +Security Layer (Noise, TLS) + ↓ (Security negotiation) +Transport Layer (TCP, QUIC, WebSocket) + ↓ +Raw Network +``` + +### Key Interfaces and Their Roles + +**`Host`** (`core/Host.kt`): +- Main entry point for all libp2p operations +- Manages identity (`PeerId`, `PrivKey`), network, and protocol handlers +- Created via DSL builder: `host { identity { ... }; transports { ... }; protocols { ... } }` + +**`Network`** (`core/Network.kt`): +- Manages transports and active connections +- Handles `listen()` and `dial()` operations +- Reuses connections to the same peer + +**`Connection`** and **`Stream`** (both extend `P2PChannel`): +- `Connection`: Secured, multiplexed connection between two peers +- `Stream`: Logical stream over a connection for a specific protocol + +**`Transport`** (`transport/Transport.kt`): +- Handles raw connection establishment (TCP, QUIC, WebSocket) +- Each transport parses specific multiaddr formats (e.g., `/ip4/127.0.0.1/tcp/30333`) + +**`SecureChannel`** (`security/SecureChannel.kt`): +- Protocol binding for security layer negotiation +- Returns `SecureChannel.Session` with `remoteId`, `remotePubKey` +- Implementations: `NoiseXXSecureChannel` (production), `TlsSecureChannel` (beta) + +**`StreamMuxer`** (`mux/StreamMuxer.kt`): +- Protocol binding for multiplexer negotiation +- Returns `StreamMuxer.Session` for creating/receiving streams +- Implementations: `MplexStreamMuxer` (production), `YamuxStreamMuxer` (beta) + +### The Connection Upgrade Pipeline + +When a raw transport connection is established, it goes through staged upgrades: + +``` +1. Raw Transport (TCP/QUIC/WS) + ↓ +2. ConnectionBuilder (transport/implementation/ConnectionBuilder.kt) + ↓ +3. Security Negotiation → SecureChannel.Session + ↓ +4. Multiplexer Negotiation → StreamMuxer.Session + ↓ +5. Full Connection Ready → ConnectionOverNetty +``` + +**Key Class:** `ConnectionUpgrader` (`transport/implementation/ConnectionUpgrader.kt`) +- Orchestrates security and muxer protocol negotiation +- Uses `MultistreamProtocol` for protocol selection +- Supports early muxer negotiation (TLS 1.3 feature) + +### Protocol Handler Pattern + +Custom protocols implement `ProtocolHandler`: + +```kotlin +// Define protocol binding +StrictProtocolBinding("/ipfs/ping/1.0.0", PingProtocol()) + +// Implement handler +class PingProtocol : ProtocolHandler { + override fun onStartInitiator(stream: Stream): CompletableFuture + override fun onStartResponder(stream: Stream): CompletableFuture +} +``` + +See `examples/chatter/ChatProtocol.kt` for a complete example. + +### Pub/Sub Architecture + +The pub/sub system is located in `pubsub/` and follows this structure: + +**`AbstractRouter`** (`pubsub/AbstractRouter.kt`): +- Base class providing common pubsub logic +- Manages peer subscriptions via `peersTopics` (multi-bimap) +- Implements message validation, deduplication (via `SeenCache`), and batching +- Uses single-threaded event loop (`P2PService`) for thread-safety + +**Message Batching via `RpcPartsQueue`**: +- Per-peer queue that accumulates message parts before transmission +- Pattern: accumulate parts → flush via `takeMerged()` → send merged RPC +- Default implementation merges all parts into single RPC +- Gossip implementation (`GossipRpcPartsQueue`) splits messages to respect per-category limits + +**Message Flow:** +``` +Outbound: publish() → validateAndBroadcast() → submitPublishMessage(peer) + → queue.addPublish() → flushPending() → queue.takeMerged() → send() + +Inbound: channelRead() → onInbound() → validate & deduplicate + → broadcastInbound() → queue.addPublish() → flushPending() +``` + +**Gossip-Specific:** +- **`GossipRouter`** extends `AbstractRouter` with mesh topology management +- Heartbeat mechanism for GRAFT/PRUNE/IHAVE/IWANT control messages +- Peer scoring for spam resistance +- Control messages batched via `GossipRpcPartsQueue` + +**Key Flush Triggers:** +- After processing inbound messages (sync validation complete) +- After async message validation completes +- On peer activation (sends initial subscriptions) +- During Gossip heartbeat (mesh management operations) +- After explicit publish/subscribe API calls + +### Multistream Protocol Negotiation + +**`MultistreamProtocol`** (`protocol/multistream/MultistreamProtocol.kt`): +- Used at three layers: security negotiation, muxer negotiation, protocol negotiation +- Contains list of `ProtocolBinding`s with protocol names +- Delegates to `Negotiator` (initiator/responder) +- Completes with `ProtocolSelect` containing selected protocol handler + +**Pattern:** Any negotiable component extends `ProtocolBinding`: +- Security channels, stream muxers, application protocols all use this pattern + +## Development Patterns + +### Netty Integration + +All protocol logic is implemented as Netty `ChannelHandler`s: +- **`P2PChannelOverNetty`**: Base wrapper for both `Connection` and `Stream` +- **`ConnectionOverNetty`**: Wraps connection-level channel with secure and muxer sessions +- **`StreamOverNetty`**: Wraps stream-level channel with protocol negotiation + +### Async Pattern + +Extensive use of `CompletableFuture` for async operations: +- Protocol negotiation with timeouts +- Connection establishment across multiple addresses +- Message publishing and validation + +### Event Thread Safety + +The pub/sub system (and other components) use single-threaded event loops via `P2PService`: +- All operations run on `executor: ScheduledExecutorService` +- Components like `RpcPartsQueue` are explicitly "NOT thread safe" but guaranteed single-threaded access +- Methods: `runOnEventThread {}`, `submitOnEventThread {}`, `submitAsyncOnEventThread {}` + +### Testing Patterns + +**JUnit 5** with: +- `@Test` for standard tests +- `@ParameterizedTest` with `@MethodSource` for data-driven tests +- AssertJ for fluent assertions (`assertThat(...)`) +- MockK for mocking + +**Test Infrastructure:** +- Test fixtures in `src/testFixtures/` for shared test utilities +- Host builder DSL used extensively in tests +- `TestChannel` and `TestLogAppender` utilities + +**Example Test Pattern (from GossipRpcPartsQueueTest):** +```kotlin +@ParameterizedTest +@MethodSource("testCases") +fun `test message merging`(params: GossipParams, queue: TestQueue) { + val monolith = queue.mergedSingle() // Ground truth + val split = queue.takeMerged() // Actual implementation + + // Verify limits respected + assertThat(split).allMatch { router.validateMessageListLimits(it) } + + // Verify semantic equivalence + assertThat(split.merge().disperse()).isEqualTo(monolith.disperse()) +} +``` + +### Code Style + +- Kotlin 1.6 with JVM target 11 +- ktlint formatting (run `./gradlew spotlessApply`) +- Detekt static analysis +- Wildcard imports allowed +- No trailing commas enforced +- All warnings as errors (`allWarningsAsErrors = true`) + +## Important Implementation Details + +### Protobuf Code Generation + +Protobuf definitions in `src/main/proto/` are compiled via `com.google.protobuf` Gradle plugin. +Generated code in `build/generated/source/proto/main/java/`. + +To regenerate: `./gradlew :libp2p:clean :libp2p:build` + +### Multiaddr Format + +Network addresses use multiaddr format: +- Example: `/ip4/127.0.0.1/tcp/30333/p2p/QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N` +- Parsed/managed in `core/multiformats/` +- Each transport validates specific multiaddr components + +### PeerId Generation + +`PeerId` is derived from peer's public key: +- Multihash of the public key bytes +- 32-50 bytes depending on key type +- Supports RSA, Ed25519, Secp256k1, ECDSA + +### Security Handshake Timeout + +Default timeout for security handshakes: **5 seconds** +- Applies to Noise and TLS handshakes +- Configurable in protocol implementations + +## Common Development Workflows + +### Adding a New Protocol + +1. Define protocol binding with multistream name (e.g., `/myapp/myprotocol/1.0.0`) +2. Implement `ProtocolHandler` with initiator/responder logic +3. Register with Host via `protocols { add(...) }` in builder +4. Implement controller interface for protocol operations + +See `examples/chatter/` for a complete example. + +### Adding a New Transport + +1. Extend `Transport` interface +2. Implement `listen()` and `dial()` for raw connection establishment +3. Delegate to `ConnectionUpgrader` for security/muxer negotiation +4. Add multiaddr parsing logic for transport-specific components +5. Register with Host via `transports { add(...) }` + +### Debugging Connection Issues + +- Use `ConnectionVisitor` and `StreamVisitor` for lifecycle observation +- Enable debug logging for `io.libp2p` package +- Check multiaddr format compatibility between peers +- Verify protocol versions match (especially for security/muxer) + +### Working with Pub/Sub + +- All pub/sub operations run on event thread (thread-safe by design) +- Message validation happens before broadcasting +- Seen cache prevents duplicate message processing +- Control messages automatically batched for efficiency +- Gossip mesh heartbeat runs every 1 second (default) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..4971e5a0e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +FROM eclipse-temurin:11-jdk AS build +COPY . /jvm-libp2p +WORKDIR /jvm-libp2p +RUN ./gradlew build -x test --no-daemon + +FROM eclipse-temurin:11-jdk +WORKDIR /jvm-libp2p +COPY --from=build /jvm-libp2p/interop-test-client/build/distributions/interop-test-client*.tar . +RUN tar -xf interop-test-client*.tar && rm interop-test-client*.tar + +ENTRYPOINT ["/jvm-libp2p/interop-test-client-develop/bin/interop-test-client"] +EXPOSE 4001 diff --git a/README.md b/README.md index b7ec0ae09..36d4a165b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![](https://img.shields.io/badge/project-libp2p-yellow.svg?style=flat-square)](https://libp2p.io/) [![Gitter](https://img.shields.io/gitter/room/libp2p/jvm-libp2p.svg)](https://gitter.im/jvm-libp2p/community) [![](https://img.shields.io/badge/freenode-%23libp2p-yellow.svg?style=flat-square)](http://webchat.freenode.net/?channels=%23libp2p) -![Build Status](https://github.com/libp2p/jvm-libp2p/actions/workflows/build.yml/badge.svg?branch=master) +[![Build Status](https://github.com/libp2p/jvm-libp2p/actions/workflows/build.yml/badge.svg?branch=master)](https://github.com/libp2p/jvm-libp2p/actions/workflows/build.yml) [![Discourse posts](https://img.shields.io/discourse/https/discuss.libp2p.io/posts.svg)](https://discuss.libp2p.io) [Libp2p](https://libp2p.io/) implementation for the JVM, written in Kotlin 🔥 @@ -15,7 +15,7 @@ List of components in the Libp2p spec and their JVM implementation status | | Component | Status | |--------------------------|-------------------------------------------------------------------------------------------------|:----------------:| | **Transport** | tcp | :green_apple: | -| | [quic](https://github.com/libp2p/specs/tree/master/quic) | :tomato: | +| | [quic](https://github.com/libp2p/specs/tree/master/quic) | :lemon: | | | websocket | :lemon: | | | [webtransport](https://github.com/libp2p/specs/tree/master/webtransport) | | | | [webrtc-browser-to-server](https://github.com/libp2p/specs/blob/master/webrtc/webrtc-direct.md) | | diff --git a/build.gradle.kts b/build.gradle.kts index 17ad137bc..76aff4b51 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,9 +18,9 @@ plugins { id("java") id("maven-publish") id("org.jetbrains.dokka").version("1.9.20") - id("com.diffplug.spotless").version("6.25.0") + id("com.diffplug.spotless").version("7.2.1") id("java-test-fixtures") - id("io.spring.dependency-management").version("1.1.6") + id("io.spring.dependency-management").version("1.1.7") id("org.jetbrains.kotlin.android") version kotlinVersion apply false id("com.android.application") version "7.4.2" apply false @@ -36,7 +36,7 @@ configure( } ) { group = "io.libp2p" - version = "1.2.2-RELEASE" + version = "1.3.0-RELEASE" apply(plugin = "kotlin") apply(plugin = "idea") @@ -63,6 +63,8 @@ configure( testImplementation("org.junit.jupiter:junit-jupiter") testImplementation("org.junit.jupiter:junit-jupiter-params") + testImplementation("org.junit.platform:junit-platform-launcher") + testRuntimeOnly("org.junit.platform:junit-platform-engine") testImplementation("io.mockk:mockk") testImplementation("org.assertj:assertj-core") testRuntimeOnly("org.apache.logging.log4j:log4j-slf4j2-impl") @@ -84,6 +86,9 @@ configure( tasks.withType { duplicatesStrategy = DuplicatesStrategy.INCLUDE } + tasks.withType().configureEach { + jvmTarget = "11" + } // Parallel build execution tasks.test { @@ -108,8 +113,10 @@ configure( } configure { + // https://github.com/pinterest/ktlint/releases + val ktlintVersion = "1.1.1" kotlin { - ktlint().editorConfigOverride( + ktlint(ktlintVersion).editorConfigOverride( mapOf( "ktlint_standard_no-wildcard-imports" to "disabled", "ktlint_standard_enum-entry-name-case" to "disabled", @@ -140,7 +147,7 @@ configure( jdkVersion.set(11) reportUndocumented.set(false) externalDocumentationLink { - url.set(URI.create("https://netty.io/4.1/api/").toURL()) + url.set(URI.create("https://netty.io/4.2/api/").toURL()) } } } @@ -187,7 +194,7 @@ configure( } detekt { - config = files("$rootDir/detekt/config.yml") + config.from("$rootDir/detekt/config.yml") buildUponDefaultConfig = true } } diff --git a/docs/partial-messages.md b/docs/partial-messages.md new file mode 100644 index 000000000..4c715a7b0 --- /dev/null +++ b/docs/partial-messages.md @@ -0,0 +1,450 @@ +# Gossipsub Partial Messages — Design Document + +Status: **Draft / MVP design** +Tracking issue: [libp2p/jvm-libp2p#435](https://github.com/libp2p/jvm-libp2p/issues/435) +Last updated: see `git log -- docs/partial-messages.md` + +This document is the source of truth for the jvm-libp2p implementation of the +gossipsub partial-messages extension. It captures the scope, the jvm-libp2p ↔ +client responsibility boundary, the public API, routing semantics, and the +implementation plan. It is a **living document** — append to the decision log +(§9) when we revise anything. + +--- + +## 1. Scope and non-goals + +### In scope (MVP) + +- Full wire-level support for the `PartialMessagesExtension` RPC: + - Per-topic negotiation via `SubOpts.requestsPartial` / `SubOpts.supportsSendingPartial`. + - Inbound and outbound handling of `RPC.partial`. + - Both metadata-only and payload-only variants in both directions. +- A Kotlin API that lets a client (Teku) plug in its own per-peer state, + metadata encoding, group-ID generation, part-level validation, and publish + decisions. +- Integration with the existing gossipsub routing rules: + - Suppress full-message send to peers that requested partial on that topic. + - Suppress IDONTWANT to peers we request partial from. + - Replace IHAVE with an `onEmitGossip` callback for partial-capable peers + in the lazy-push loop. +- Per-group lifecycle (TTL in heartbeats, DoS caps on peer-initiated groups). +- A side-channel `peerFeedback` API so the client can drive peer scoring + explicitly instead of via callback return values. + +### Non-goals (deferred, but documented for future) + +- `interop-test-client` partial-messages support. Deferred; see §7 for notes. +- Partial-specific peer-scoring rules beyond what the Extensions handshake + already enforces. Spec is silent; match go-libp2p (no scoring) for MVP. +- Topic-level "partial-only" mode. Spec explicitly defers this to a future + extension. +- Reassembling a full `Message` and re-entering the normal gossip flow. MVP + delivers parts upward to the application only; the application is free to + never republish a reconstructed full message (matches Ethereum PeerDAS). +- New wire messages. Spec and go-libp2p use the single + `PartialMessagesExtension` for both lazy-push and payload delivery — no + `partialIHAVE` / `partialIWANT`. + +--- + +## 2. Reference pins + +The partial-messages spec is **Lifecycle 1A (Working Draft)** and may change. +When revising this document, update these pins. + +| Source | Pin | Location | +|---|---|---| +| libp2p/specs | merge commit `6b6203ee` (PR #685, merged 2026-02-26) | `pubsub/gossipsub/partial-messages.md` | +| libp2p/go-libp2p-pubsub | `master` at time of MVP (note in decision log when pinned) | `extensions.go`, `partialmessages/partialmsgs.go`, `gossipsub.go`, `pubsub.go` | +| libp2p/test-plans gossipsub-interop | `master` | `gossipsub-interop/go-libp2p/experiment.go`, `main.go` | +| OffchainLabs/prysm | branch `prysm/partial-cells-current`, latest seen `e8480a86` (2026-03-31) | `beacon-chain/p2p/partialdatacolumnbroadcaster/`, `consensus-types/blocks/partialdatacolumn.go`, `proto/prysm/v1alpha1/partial_data_columns.proto` | + +### Related in-flight spec work (watch) + +- libp2p/specs#681 — Choke extension. +- libp2p/specs#699 — Topic table. +- libp2p/specs#706 — Gossipsub v1.4. +- libp2p/specs#654 — Message preamble. + +None directly modify partial-messages, but v1.4 and message-preamble overlap +in motivation. + +--- + +## 3. Responsibility boundary (jvm-libp2p ↔ client) + +The one-line model: + +> **jvm-libp2p is a transport + per-peer bookkeeper for opaque partial-message +> RPCs. The client (Teku) owns everything about what those bytes mean, when a +> group is "complete", and who gets what.** + +| Concern | jvm-libp2p | Client (Teku) | +|---|---|---| +| v1.3 Control Extensions handshake | ✅ (done on this branch) | — | +| `SubOpts.requestsPartial` / `supportsSendingPartial` wire handling | ✅ | — | +| Per-peer partial-capability state (node-level and topic-level) | ✅ | — | +| Per-`(topic, groupID)` state container, TTL GC, DoS caps | ✅ | — | +| Routing: suppress full-msg send to partial-requesting peers | ✅ | — | +| Routing: suppress IDONTWANT to peers we request partial from | ✅ | — | +| Routing: replace IHAVE with `onEmitGossip` for partial peers | ✅ | — | +| Wire framing of `PartialMessagesExtension` in/out | ✅ | — | +| Spec MUST: omit `partialMessage` if peer supports-but-didn't-request | ✅ | — | +| `partsMetadata` encoding (bitmap / Bloom / whatever) | ❌ opaque | ✅ | +| `groupID` generation | ❌ opaque | ✅ | +| Merging incoming `partsMetadata` into local per-peer view | ❌ | ✅ | +| Deciding which parts to send to which peer | ❌ | ✅ (`PublishActionsFn`) | +| Reassembling a full message | ❌ never | ✅ | +| Validating individual parts (e.g. KZG) | ❌ | ✅ (inside `onIncomingRpc`) | +| Detecting "group complete" and delivering upward | ❌ | ✅ | +| Per-part peer scoring (spammy parts, etc.) | ❌ MVP | Future, in coordination | + +Rationale for each line is grounded in go-libp2p's and Prysm's current +behaviour — see §9 and the research notes that produced this document. + +--- + +## 4. Public API (jvm-libp2p surface) + +### 4.1 Builder wiring + +```kotlin +GossipRouterBuilder().apply { + enabledGossipExtensions(GossipExtension.PARTIAL_MESSAGES) + partialMessagesHandler = MyTekuPartialMessagesHandler() // new +} +``` + +- The `GossipExtension.PARTIAL_MESSAGES` feature flag stays as the capability + switch (already wired). +- `partialMessagesHandler: PartialMessagesHandler<*>?` is a new optional + field on the builder. Null + flag enabled = build-time error. + +### 4.2 Client-supplied handler + +```kotlin +interface PartialMessagesHandler { + + /** + * Called on every inbound PartialMessagesExtension RPC on the pubsub + * event thread. MUST be fast and non-blocking: dispatch heavy work + * (decoding, validation) to your own executor. + * + * Any of rpc.partialMessage and rpc.partsMetadata may be absent; all + * four combinations are valid. + */ + fun onIncomingRpc( + from: PeerId, + peerStates: Map, + rpc: Rpc.PartialMessagesExtension + ) + + /** + * Called once per group during the gossipsub heartbeat, for gossip + * targets that are partial-capable. The client typically responds by + * calling publishPartial(...) for the same (topic, groupId). + */ + fun onEmitGossip( + topic: Topic, + groupId: ByteArray, + gossipPeers: Collection, + peerStates: Map + ) +} +``` + +Notes: +- `PeerState` is fully generic. The library stores it per + `(topic, groupId, peerId)` and never interprets it. +- Both callbacks run on the pubsub event thread. Document prominently. + +### 4.3 Publishing + +```kotlin +fun interface PublishActionsFn { + fun decide( + peerStates: Map, + peerRequestsPartial: (PeerId) -> Boolean + ): Sequence>> +} + +data class PublishAction( + val partialMessage: ByteArray? = null, + val partsMetadata: ByteArray? = null, + val nextPeerState: PeerState? = null, // library applies atomically + val error: Throwable? = null +) + +// Entry point on the Gossip facade +fun Gossip.publishPartial( + topic: Topic, + groupId: ByteArray, + actions: PublishActionsFn<*> +): CompletableFuture +``` + +Key API differences vs. go-libp2p (deliberate): + +1. **No in-place map mutation.** `PublishAction.nextPeerState` is applied + atomically by the library per peer, instead of asking the client to + mutate `Map` inside the iterator. Prysm has fixed race + bugs in the in-place pattern (see commits on `prysm/partial-cells-current`, + Mar 31 2026); Kotlin's single-threaded event loop makes the atomic-return + shape natural. +2. **`Unit`-returning callbacks.** Errors do not drive scoring; see §4.4. + +### 4.4 Peer feedback (scoring side-channel) + +```kotlin +interface PartialMessagesPeerFeedback { + fun reportFeedback(topic: Topic, peer: PeerId, kind: FeedbackKind) +} + +enum class FeedbackKind { USEFUL, INVALID, IGNORED } +``` + +The handler receives a `PartialMessagesPeerFeedback` instance (via +constructor or context object — TBD during implementation) and uses it to +drive peer score adjustments. This mirrors Prysm's `peerFeedback` pattern. +`INVALID` hooks into the existing `notifyRouterMisbehavior` path. + +### 4.5 Topic options + +Subscribing to a topic with partial-message flags: + +```kotlin +gossip.subscribe(topic, handler, + requestsPartial = true, + supportsSendingPartial = true) // implied if requestsPartial = true +``` + +Go-libp2p exposes `RequestPartialMessages()` and `SupportsPartialMessages()` +as separate topic options. In Prysm's real integration, only +`RequestPartialMessages()` is ever used; the "supports-but-doesn't-request" +half is currently unexercised. MVP supports both flags in the API but only +the `requests` path needs end-to-end testing. + +--- + +## 5. Routing rules (inside `GossipRouter`) + +Three modifications to the existing routing, all behind +`partialMessagesEnabled()` and the per-peer handshake state. + +### 5.1 Full-message suppression + +When broadcasting a `Message` for topic `T` to peer `P`: +- If `gossipExtensionsState.peerSupportsPartialMessages(P)` **and** + `partialTopicState.peerRequestsPartial(P, T)` → **do not** send the full + message to `P`. The client is responsible for pushing parts via + `publishPartial(...)`. +- This filter applies in `broadcastInbound` and `broadcastOutbound`, before + messages are queued into `GossipRpcPartsQueue`. +- Spec MUST (§Wire rules): if peer supports sending partial but did *not* + request, we still send the full message, but when we send a + `PartialMessagesExtension` to that peer we MUST omit `partialMessage`. + +### 5.2 IDONTWANT suppression + +When emitting IDONTWANT for a message on topic `T`: +- If, for peer `P`, we `iRequestPartial(T)` **and** + `peerSupportsSendingPartial(P, T)` → skip IDONTWANT to `P`. +- go-libp2p: `gossipsub.go:892-904`. + +### 5.3 IHAVE replacement with `onEmitGossip` + +During gossipsub heartbeat lazy-push: +- Partition the selected IHAVE targets into `fullPeers` and + `partialPeers = { p | iSupportSendingPartial(T) ∧ peerRequestsPartial(p, T) }`. +- Do not enqueue IHAVE for `partialPeers`. +- After the normal loop, for every locally-initiated group under `T`, call + `handler.onEmitGossip(T, groupId, partialPeers, peerStatesForGroup)` once. +- go-libp2p: `gossipsub.go:2018-2074`. + +--- + +## 6. State and lifecycle + +### 6.1 Per-topic-per-peer partial-capability state + +Per-peer flags per topic, updated from every inbound `SubOpts` (where +`subscribe = true`): + +- `requestsPartial: Boolean` +- `supportsSendingPartial: Boolean` + +Spec + go-libp2p coercion: on receive, store +`supportsSendingPartial := requestsPartial || supportsSendingPartial`. + +MUST ignore both flags on `SubOpts` with `subscribe = false`. + +### 6.2 Per-`(topic, groupID)` group state + +``` +GroupState { + ttlInHeartbeats: Int // counts down each heartbeat, GC at 0 + peerInitiated: Boolean // true if first seen from a peer, not us + peerStates: Map // app-opaque +} +``` + +- Stored in a plain `HashMap` — not thread-safe; access is serialised on the + pubsub event loop (per the project-wide invariant; do **not** use + `ConcurrentHashMap`). +- TTL reset whenever `publishPartial(topic, groupId, …)` is called for the + group. +- GC on `ttl == 0` **or** `peerStates` empty. + +### 6.3 DoS caps (match go-libp2p defaults) + +Applies only to **peer-initiated** groups (first touched from an inbound +RPC, not via `publishPartial`). + +| Cap | Default | Where | +|---|---|---| +| `peerInitiatedGroupLimitPerTopic` | 255 | Across all peers, per topic | +| `peerInitiatedGroupLimitPerTopicPerPeer` | 8 | Per (topic, peer) | + +Over-cap: log and drop the RPC. No disconnect. No score penalty (match go; +revise if spec adds guidance). + +### 6.4 Cleanup hooks + +- Peer disconnect → remove all `peerStates[peer]` entries across groups. +- Unsubscribe (we leave a topic) → drop all group state for that topic. +- Heartbeat → decrement TTLs, GC expired groups. + +--- + +## 7. Known gaps vs. full spec + +Explicitly deferred in MVP; listed here so future work can pick them up. + +1. **Validator pipeline for partial RPCs** — bypassed entirely (matches + go-libp2p). Client validates inside `onIncomingRpc`. +2. **Scoring rules for partial misbehaviour** — spec silent, go silent. MVP + only scores via the existing `notifyRouterMisbehavior` path plus the + client's `peerFeedback` calls. +3. **Message-ID of reassembled full messages** — spec silent. MVP does not + reassemble at all; the reconstructed message never re-enters gossip. +4. **Topic-level "partial-only" mode** — spec explicitly defers; no + implementation. +5. **`SupportsPartialMessages()`-only (support without request) path** — + supported by the API, but Prysm doesn't exercise it and we don't have an + end-to-end test for it. Flag if we ship without coverage. +6. **Fanout peers in publish** — MVP does mesh peers (+ fanout fallback if + mesh empty), mirroring go-libp2p's `MeshPeers`. Fanout specifically for + partial is not independently exercised. +7. **`interop-test-client`** — deferred. Future work should: + - Implement `PartialMessagesHandler` with SSZ-like + bitlists for `partsMetadata`. + - Test the 4-combo matrix (payload+meta / meta-only / payload-only / + neither) on both send and receive. + - Test mixed-peer topic: one partial-enabled node, one full-only; verify + full-only path still works end-to-end. + - Test `ControlExtensions` handshake ordering: extension RPCs arriving + before the handshake completes must be ignored. + +--- + +## 8. Implementation plan + +Order chosen so an end-to-end partial round-trip works before any of the +fragile routing rules are touched. Each step is independently testable and +mergeable. + +Mirror this checklist in issue #435. + +- [ ] **Step 1** — Per-topic `SubOpts` flag plumbing. Outbound: flags added + to subscribe announce RPCs. Inbound: parse flags into a + `PartialTopicState` (`Map>`). + Coercion rule applied on receive. Flags ignored on `subscribe=false`. +- [ ] **Step 2** — `PartialMessagesHandler` interface, + `PublishAction` (with `nextPeerState`), + `PublishActionsFn`, `PartialMessagesPeerFeedback`, and + `GroupState` container with TTL + DoS caps. No routing yet. +- [ ] **Step 3** — Inbound `RPC.partial` dispatch: replace the stub at + `GossipRouter.kt:476` with the full flow (validate caps, create/update + group state, call `onIncomingRpc`). +- [ ] **Step 4** — Outbound `publishPartial(...)` on the `Gossip` facade; + route through `GossipRpcPartsQueue` (do **not** bypass — PR #433 got + this wrong). Enforce the "omit `partialMessage` when peer supports but + didn't request" MUST. +- [ ] **Step 5** — End-to-end integration test with a trivial bitmap-based + handler. Exercises Steps 1-4 before any routing changes. +- [ ] **Step 6** — Routing: full-message suppression (§5.1). +- [ ] **Step 7** — Routing: IDONTWANT suppression (§5.2). +- [ ] **Step 8** — Heartbeat tick + TTL GC + cleanup hooks (§6.4). +- [ ] **Step 9** — Routing: IHAVE replacement with `onEmitGossip` (§5.3). +- [ ] **Step 10** — Simulator scenario + mixed-peer interop test (partial + + non-partial nodes on the same topic). + +--- + +## 9. Decision log + +Append entries here when design choices change. Keep most-recent on top. + +### 2026-04-20 — Initial design + +- Scope, boundary, and API agreed per research summarised in this document. +- `PublishAction` returns `nextPeerState` rather than asking the client to + mutate a shared map in place. Motivation: cleaner Kotlin ergonomics, + avoids the category of race that Prysm's + `prysm/partial-cells-current` fixed on 2026-03-31. +- Peer scoring feedback lives on a side-channel + `PartialMessagesPeerFeedback`, not on callback return values. Matches + Prysm's `peerFeedback` pattern. +- MVP does not ship `interop-test-client` support; see §7.7 for the future + checklist. +- DoS caps pinned to go-libp2p defaults (255 / 8). +- Spec pinned to libp2p/specs#685 merge `6b6203ee`. Spec is lifecycle 1A; + revise this document when spec revisions land. + +### Open questions to resolve during implementation + +- Exact wiring of `PartialMessagesPeerFeedback` — constructor arg on the + handler, or a context object passed to each callback? Decide during + Step 2. +- Whether `publishPartial` on the `Gossip` facade takes a single + `(topic, groupId)` or supports batched `Seq<(topic, groupId)>`. Prysm + calls per-topic and iterates; MVP will match. +- Exact return type of `publishPartial` — `CompletableFuture` follows + jvm-libp2p convention; finalise during Step 4. + +--- + +## 10. References + +### Spec + +- [libp2p/specs — Gossipsub Partial Messages spec (PR #685)](https://github.com/libp2p/specs/pull/685) +- [libp2p/specs — partial-messages.md @ 6b6203ee](https://github.com/libp2p/specs/blob/6b6203ee16ef2e01e6b86fc8f6c3fae0d1c6490e/pubsub/gossipsub/partial-messages.md) + +### Related in-flight spec work + +- [libp2p/specs#681 — Choke extension](https://github.com/libp2p/specs/pull/681) +- [libp2p/specs#699 — Topic table](https://github.com/libp2p/specs/pull/699) +- [libp2p/specs#706 — Gossipsub v1.4](https://github.com/libp2p/specs/pull/706) +- [libp2p/specs#654 — Message preamble](https://github.com/libp2p/specs/pull/654) + +### Implementations + +- [go-libp2p-pubsub — extensions.go](https://github.com/libp2p/go-libp2p-pubsub/blob/master/extensions.go) +- [go-libp2p-pubsub — partialmessages/partialmsgs.go](https://github.com/libp2p/go-libp2p-pubsub/blob/master/partialmessages/partialmsgs.go) +- [go-libp2p-pubsub — gossipsub.go](https://github.com/libp2p/go-libp2p-pubsub/blob/master/gossipsub.go) +- [go-libp2p-pubsub — pubsub.go](https://github.com/libp2p/go-libp2p-pubsub/blob/master/pubsub.go) +- [OffchainLabs/prysm — branch `prysm/partial-cells-current`](https://github.com/OffchainLabs/prysm/tree/prysm/partial-cells-current) + - [`beacon-chain/p2p/partialdatacolumnbroadcaster/`](https://github.com/OffchainLabs/prysm/tree/prysm/partial-cells-current/beacon-chain/p2p/partialdatacolumnbroadcaster) + - [`consensus-types/blocks/partialdatacolumn.go`](https://github.com/OffchainLabs/prysm/blob/prysm/partial-cells-current/consensus-types/blocks/partialdatacolumn.go) + - [`proto/prysm/v1alpha1/partial_data_columns.proto`](https://github.com/OffchainLabs/prysm/blob/prysm/partial-cells-current/proto/prysm/v1alpha1/partial_data_columns.proto) + +### Interop testing + +- [libp2p/test-plans — gossipsub-interop experiment.go](https://github.com/libp2p/test-plans/blob/master/gossipsub-interop/go-libp2p/experiment.go) +- [libp2p/test-plans — gossipsub-interop main.go](https://github.com/libp2p/test-plans/blob/master/gossipsub-interop/go-libp2p/main.go) + +### Tracking + +- [libp2p/jvm-libp2p#435 — Partial messages tracking issue](https://github.com/libp2p/jvm-libp2p/issues/435) diff --git a/examples/android-chatter/build.gradle b/examples/android-chatter/build.gradle index 41dd1bfcd..fa37a35bd 100644 --- a/examples/android-chatter/build.gradle +++ b/examples/android-chatter/build.gradle @@ -22,9 +22,17 @@ android { } } packagingOptions { - exclude 'META-INF/io.netty.versions.properties' - exclude 'META-INF/INDEX.LIST' - exclude 'META-INF/versions/9/OSGI-INF/MANIFEST.MF' + resources { + excludes.add("META-INF/io.netty.versions.properties") + excludes.add("META-INF/INDEX.LIST") + excludes.add("META-INF/versions/9/OSGI-INF/MANIFEST.MF") + excludes.add("META-INF/native-image/io.netty/netty-codec-native-quic/jni-config.json") + excludes.add("META-INF/native-image/io.netty/netty-codec-native-quic/reflect-config.json") + excludes.add("META-INF/native-image/io.netty/netty-codec-native-quic/resource-config.json") + excludes.add("META-INF/native-image/io.netty/netty-codec-native-quic/native-image.properties") + excludes.add("META-INF/license/*") + } + } kotlinOptions { jvmTarget = "11" diff --git a/examples/pinger/build.gradle b/examples/pinger/build.gradle index b4ce18580..f6dd794d5 100644 --- a/examples/pinger/build.gradle +++ b/examples/pinger/build.gradle @@ -9,5 +9,5 @@ dependencies { } application { - mainClassName = 'io.libp2p.example.Pinger' + mainClassName = 'io.libp2p.example.ping.Pinger' } \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index ccebba771..1b33c55ba 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 6a93cb7a1..78cb6e16a 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,8 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=31c55713e40233a8303827ceb42ca48a47267a0ad4bab9177123121e71524c26 -distributionUrl=https\://services.gradle.org/distributions/gradle-8.10.2-bin.zip +distributionSha256Sum=bd71102213493060956ec229d946beee57158dbd89d0e62b91bca0fa2c5f3531 +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 79a61d421..23d15a936 100755 --- a/gradlew +++ b/gradlew @@ -15,6 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# SPDX-License-Identifier: Apache-2.0 +# ############################################################################## # @@ -55,7 +57,7 @@ # Darwin, MinGW, and NonStop. # # (3) This script is generated from the Groovy template -# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt # within the Gradle project. # # You can find Gradle at https://github.com/gradle/gradle/. @@ -83,10 +85,8 @@ done # This is normally unused # shellcheck disable=SC2034 APP_BASE_NAME=${0##*/} -APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s\n' "$PWD" ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum @@ -114,7 +114,7 @@ case "$( uname )" in #( NONSTOP* ) nonstop=true ;; esac -CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar +CLASSPATH="\\\"\\\"" # Determine the Java command to use to start the JVM. @@ -133,10 +133,13 @@ location of your Java installation." fi else JAVACMD=java - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." + fi fi # Increase the maximum file descriptors if we can. @@ -144,7 +147,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then case $MAX_FD in #( max*) # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. - # shellcheck disable=SC3045 + # shellcheck disable=SC2039,SC3045 MAX_FD=$( ulimit -H -n ) || warn "Could not query maximum file descriptor limit" esac @@ -152,7 +155,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then '' | soft) :;; #( *) # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. - # shellcheck disable=SC3045 + # shellcheck disable=SC2039,SC3045 ulimit -n "$MAX_FD" || warn "Could not set maximum file descriptor limit to $MAX_FD" esac @@ -197,16 +200,20 @@ if "$cygwin" || "$msys" ; then done fi -# Collect all arguments for the java command; -# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of -# shell script including quotes and variable substitutions, so put them in -# double quotes to make sure that they get re-expanded; and -# * put everything else in single quotes, so that it's not re-expanded. + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. set -- \ "-Dorg.gradle.appname=$APP_BASE_NAME" \ -classpath "$CLASSPATH" \ - org.gradle.wrapper.GradleWrapperMain \ + -jar "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" \ "$@" # Stop when "xargs" is not available. diff --git a/gradlew.bat b/gradlew.bat index 93e3f59f1..db3a6ac20 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -13,6 +13,8 @@ @rem See the License for the specific language governing permissions and @rem limitations under the License. @rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem @if "%DEBUG%"=="" @echo off @rem ########################################################################## @@ -43,11 +45,11 @@ set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 if %ERRORLEVEL% equ 0 goto execute -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -57,22 +59,22 @@ set JAVA_EXE=%JAVA_HOME%/bin/java.exe if exist "%JAVA_EXE%" goto execute -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail :execute @rem Setup the command line -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar +set CLASSPATH= @rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* :end @rem End local scope for the variables with windows NT shell diff --git a/install-run-ipfs.sh b/install-run-ipfs.sh new file mode 100755 index 000000000..ef35062e7 --- /dev/null +++ b/install-run-ipfs.sh @@ -0,0 +1,13 @@ +#! /bin/sh +wget https://dist.ipfs.io/kubo/v0.34.1/kubo_v0.34.1_linux-amd64.tar.gz -O /tmp/kubo_linux-amd64.tar.gz +hash="$(sha256sum /tmp/kubo_linux-amd64.tar.gz)" +expected=42045802fe60c64fb01350bc071190c534d600fe269759c06e27e22b2012fd3e +if [[ "$hash" != "$expected" ]] +then + echo "incorrect ipfs hash!" 1>&2 + exit 64 +fi +tar -xvf /tmp/kubo_linux-amd64.tar.gz +export PATH=$PATH:$PWD/kubo/ +ipfs init +ipfs daemon --routing=dhtserver & diff --git a/interop-test-client/README.md b/interop-test-client/README.md new file mode 100644 index 000000000..0cb7bdf97 --- /dev/null +++ b/interop-test-client/README.md @@ -0,0 +1,48 @@ +# Interop Tests + +For more info: https://github.com/libp2p/test-plans/tree/master/transport-interop#readme + +## Requirements + +To run the interop test framework locally, you need: + +- Docker +- node, nvm and ts-node + +## Running it locally + +The first thing to be able to run the test locally is build the images of each livp2p implementation +being tested. You need to run the following steps for each one of the implementations that you are +planning to run: + +1. Checkout the project https://github.com/libp2p/test-plans. +2. Navigate to transport-interop/impl//; where is the implementation that you want to + build (e.g. `jvm`) and is what version you want (e.g. `v1.2`). +3. Once in the specific version folder run `make` to build the image. This will create a + `image.json` file with the hash of the Docker image built. + +Once you have the images that you want, navigate back to the `transport-interop` folder and run: + +``` +npm test --name-filter=jvm-v1.2 +``` + +The parameter `--name-filter` can be used to limit the pairs that are going to be executed. +In the previous example, only pairs with `jvm-1.2` are going to run. +Similarly, `--name-ignore` can be used to remove pairs. + +Here is the output of a sample run: + +``` +npm test --name-filter=jvm-v1.2 --name-ignore= --verbose=true + +> @libp2p/transport-interop@0.0.1 test +> ts-node src/compose-stdout-helper.ts && ts-node testplans.ts + +Checking jvm-v1.2 x jvm-v1.2 (tcp, tls, mplex)...ACCEPTED (filter match: '*') +Running 1 tests +Running test spec: jvm-v1.2 x jvm-v1.2 (tcp, tls, mplex) +Finished: jvm-v1.2 x jvm-v1.2 (tcp, tls, mplex) { handshakePlusOneRTTMillis: 380, pingRTTMilllis: 2 } +0 failures [] +Run complete +``` \ No newline at end of file diff --git a/interop-test-client/build.gradle.kts b/interop-test-client/build.gradle.kts new file mode 100644 index 000000000..3a30ae115 --- /dev/null +++ b/interop-test-client/build.gradle.kts @@ -0,0 +1,16 @@ +import org.jetbrains.kotlin.cli.jvm.compiler.findMainClass + +plugins { + id("application") + id("kotlin") +} + +application { + mainClass = "io.libp2p.interop.InteropTestAgentKt" +} + +dependencies { + implementation(project(":libp2p")) + implementation("redis.clients:jedis:6.1.0") + runtimeOnly("org.apache.logging.log4j:log4j-slf4j2-impl") +} diff --git a/interop-test-client/src/main/kotlin/io/libp2p/interop/InteropTestAgent.kt b/interop-test-client/src/main/kotlin/io/libp2p/interop/InteropTestAgent.kt new file mode 100644 index 000000000..6070ffda8 --- /dev/null +++ b/interop-test-client/src/main/kotlin/io/libp2p/interop/InteropTestAgent.kt @@ -0,0 +1,251 @@ +package io.libp2p.interop + +import identify.pb.IdentifyOuterClass +import io.libp2p.core.Connection +import io.libp2p.core.ConnectionHandler +import io.libp2p.core.Host +import io.libp2p.core.PeerId.Companion.fromPubKey +import io.libp2p.core.crypto.PrivKey +import io.libp2p.core.dsl.Builder +import io.libp2p.core.dsl.hostJ +import io.libp2p.core.multiformats.Multiaddr +import io.libp2p.core.multistream.ProtocolBinding +import io.libp2p.core.mux.StreamMuxerProtocol +import io.libp2p.core.mux.StreamMuxerProtocol.Companion.Mplex +import io.libp2p.core.mux.StreamMuxerProtocol.Companion.getYamux +import io.libp2p.crypto.keys.generateEd25519KeyPair +import io.libp2p.etc.types.toProtobuf +import io.libp2p.protocol.Identify +import io.libp2p.protocol.Ping +import io.libp2p.security.noise.NoiseXXSecureChannel +import io.libp2p.security.tls.TlsSecureChannel.Companion.ECDSA +import io.libp2p.transport.quic.QuicTransport +import io.libp2p.transport.tcp.TcpTransport +import io.libp2p.transport.ws.WsTransport +import redis.clients.jedis.Jedis +import java.util.concurrent.CompletableFuture +import java.util.concurrent.TimeUnit +import java.util.stream.Collectors +import kotlin.random.Random +import kotlin.system.exitProcess + +const val TCP = "tcp" +const val WS = "ws" +const val QUIC_V1 = "quic-v1" +private const val REDIS_KEY_LISTENER_ADDRESS = "listenerAddr" + +class InteropTestAgent(val params: InteropTestParams) { + + private val advertisedAddress: Multiaddr + private val node: Host + + init { + val port = 10000 + Random.nextInt(50000) + val transport = params.transport + val protocol = when (transport) { + TCP -> TCP + WS -> TCP + else -> "udp" + } + val maybeSuffix = when (transport) { + TCP -> "" + WS -> "/ws" + else -> "/quic-v1" + } + val address = + Multiaddr.fromString("/ip4/${params.ip}/$protocol/${port}$maybeSuffix") + + val privateKey = generateEd25519KeyPair().first + val peerID = fromPubKey(privateKey.publicKey()) + advertisedAddress = address.withP2P(peerID) + + val listenAddresses = ArrayList() + listenAddresses.add(address.toString()) + val protocols = createProtocols(privateKey, listenAddresses) + node = createHost(privateKey, protocols, listenAddresses) + } + + fun run(): CompletableFuture { + return node.start() + .thenCompose { startJedisConnection() } + .thenCompose { jedis -> + if (params.isDialer) { + startDialer(jedis, node, advertisedAddress) + } else { + startListener(jedis, advertisedAddress) + } + }.whenComplete { _, _ -> node.stop() } + } + + private fun createHost( + privateKey: PrivKey, + protocols: ArrayList>, + listenAddresses: ArrayList + ): Host = hostJ(Builder.Defaults.None, fn = { + it.identity.factory = { privateKey } + when (params.transport) { + QUIC_V1 -> it.secureTransports.add(QuicTransport::ECDSA) + WS -> it.transports.add(::WsTransport) + else -> it.transports.add(::TcpTransport) + } + + if ("noise" == params.security) { + it.secureChannels.add(::NoiseXXSecureChannel) + } else if ("tls" == params.security) { + it.secureChannels.add(::ECDSA) + } + + val muxers = ArrayList() + if ("mplex" == params.muxer) { + muxers.add(Mplex) + } else if ("yamux" == params.muxer) { + muxers.add(getYamux()) + } + it.muxers.addAll(muxers) + + for (protocol in protocols) { + it.protocols.add(protocol) + } + + for (listenAddr in listenAddresses) { + it.network.listen(listenAddr) + } + + it.connectionHandlers.add { + ConnectionHandler { conn: Connection -> + printDiagnosticsLog( + ( + conn.localAddress() + .toString() + " received connection from " + + conn.remoteAddress() + + " on transport " + + conn.transport() + ) + ) + } + } + }) + + private fun startJedisConnection(): CompletableFuture { + return CompletableFuture.supplyAsync { + val jedis = Jedis("http://${params.redisAddress}") + var isReady = false + while (!isReady) { + if ("PONG" == jedis.ping()) { + isReady = true + } else { + printDiagnosticsLog("waiting for redis to start...") + Thread.sleep(1000) + } + } + printDiagnosticsLog("Connection established to Redis ($jedis)") + jedis + } + } + + /* + Start dialer and try to connect with a listener + */ + private fun startDialer( + jedis: Jedis, + node: Host, + advertisedAddress: Multiaddr + ): CompletableFuture { + return CompletableFuture.supplyAsync { + printDiagnosticsLog("Starting dialer with advertisedAddress: $advertisedAddress") + + val listenerAddresses = + jedis.blpop(params.testTimeoutInSeconds, REDIS_KEY_LISTENER_ADDRESS) + if (listenerAddresses == null || listenerAddresses.isEmpty()) { + throw IllegalStateException("listenerAddr not set") + } + + val listenerAddr = + Multiaddr.fromString(listenerAddresses.first { s -> s.startsWith("/") }) + + printDiagnosticsLog("Sending ping messages to $listenerAddr") + + val handshakeStart = System.currentTimeMillis() + + val pingController = Ping().dial(node, listenerAddr).controller.join() + val pingRTTMillis = pingController.ping().join() + val handshakeEnd = System.currentTimeMillis() + val handshakePlusOneRTT = handshakeEnd - handshakeStart + + printDiagnosticsLog("Ping latency $pingRTTMillis ms") + + val jsonResult = + "{\"handshakePlusOneRTTMillis\":${handshakePlusOneRTT.toDouble()}, \"pingRTTMilllis\": ${pingRTTMillis.toDouble()}}" + + emitResult(jsonResult) + null + } + } + + /* + Start listener and wait up to testTimeoutInSeconds for a message from dialer + */ + private fun startListener( + jedis: Jedis, + advertisedAddress: Multiaddr + ): CompletableFuture { + return CompletableFuture.supplyAsync { + printDiagnosticsLog("Starting listener with advertisedAddress: $advertisedAddress") + + jedis.rpush(REDIS_KEY_LISTENER_ADDRESS, advertisedAddress.toString()) + + // Wait for dialer + Thread.sleep(params.testTimeoutInSeconds.toLong() * 1000L) + null + } + } + + private fun createProtocols( + privateKey: PrivKey, + listenAddresses: ArrayList + ): ArrayList> { + var identifyBuilder = + IdentifyOuterClass.Identify.newBuilder() + .setProtocolVersion("ipfs/0.1.0") + .setAgentVersion("jvm-libp2p/v1.0.0") + .setPublicKey(privateKey.publicKey().bytes().toProtobuf()) + .addAllListenAddrs( + listenAddresses.stream() + .map(Multiaddr::fromString) + .map(Multiaddr::serialize) + .map(ByteArray::toProtobuf) + .collect(Collectors.toList()) + ) + + val protocols = ArrayList>() + protocols.add(Ping()) + for (protocol in protocols) { + identifyBuilder = + identifyBuilder.addAllProtocols(protocol.protocolDescriptor.announceProtocols) + } + protocols.add(Identify(identifyBuilder.build())) + + return protocols + } +} + +private fun emitResult(json: String) { + println(json) +} + +private fun printDiagnosticsLog(msg: String) { + System.err.println(msg) +} + +fun main() { + try { + val params = InteropTestParams.Builder().fromEnvironmentVariables().build() + + InteropTestAgent(params).run() + .orTimeout(params.testTimeoutInSeconds.toLong(), TimeUnit.SECONDS) + .join() + } catch (e: Exception) { + printDiagnosticsLog("Unexpected exit: $e") + exitProcess(-1) + } +} diff --git a/interop-test-client/src/main/kotlin/io/libp2p/interop/InteropTestParams.kt b/interop-test-client/src/main/kotlin/io/libp2p/interop/InteropTestParams.kt new file mode 100644 index 000000000..b6ebc8117 --- /dev/null +++ b/interop-test-client/src/main/kotlin/io/libp2p/interop/InteropTestParams.kt @@ -0,0 +1,98 @@ +package io.libp2p.interop + +import java.net.Inet6Address +import java.net.NetworkInterface +import java.util.stream.Collectors + +class InteropTestParams( + val transport: String?, + val muxer: String?, + val security: String?, + val isDialer: Boolean, + val ip: String?, + val redisAddress: String?, + val testTimeoutInSeconds: Int +) { + + data class Builder( + var transport: String? = "", + var muxer: String? = "", + var security: String? = "", + var isDialer: Boolean = false, + var ip: String? = "", + var redisAddress: String? = "", + var testTimeoutInSeconds: Int = 180 + ) { + fun transport(transport: String) = apply { this.transport = transport } + fun muxer(muxer: String) = apply { this.muxer = muxer } + fun security(security: String) = apply { this.security = security } + fun isDialer(isDialer: Boolean) = apply { this.isDialer = isDialer } + fun ip(ip: String) = apply { this.ip = ip } + fun redisAddress(redisAddress: String) = apply { this.redisAddress = redisAddress } + fun testTimeoutInSeconds(testTimeoutInSeconds: Int) = + apply { this.testTimeoutInSeconds = testTimeoutInSeconds } + + fun build(): InteropTestParams { + checkNonEmptyParam("transport", transport) + if (transport != QUIC_V1) { + checkNonEmptyParam("security", security) + checkNonEmptyParam("muxer", muxer) + } + + if (redisAddress == null || redisAddress!!.isBlank()) { + redisAddress = "redis:6379" + } + + if (ip == null || ip!!.isBlank()) { + ip = "0.0.0.0" + } + if (!isDialer && ip.equals("0.0.0.0")) { + ip = getLocalIPAddress() + } + + return InteropTestParams( + transport, + muxer, + security, + isDialer, + ip, + redisAddress, + testTimeoutInSeconds + ) + } + + private fun checkNonEmptyParam(paramName: String, paramValue: String?) { + if (paramValue == null) { + throw IllegalArgumentException("Parameter '$paramName' must be non-empty") + } + } + + fun fromEnvironmentVariables(): Builder { + return Builder( + transport = System.getenv("transport"), + muxer = System.getenv("muxer"), + security = System.getenv("security"), + isDialer = System.getenv("is_dialer")?.toBooleanStrictOrNull() ?: false, + ip = System.getenv("ip"), + redisAddress = System.getenv("redis_addr"), + testTimeoutInSeconds = System.getenv("test_timeout_seconds")?.toInt() ?: 180 + ) + } + + private fun getLocalIPAddress(): String { + val interfaces = + NetworkInterface.networkInterfaces().collect(Collectors.toList()) + for (inter in interfaces) { + for (addr in inter.interfaceAddresses) { + val address = addr.address + if (!address.isLoopbackAddress && address !is Inet6Address) return address.hostAddress + } + } + throw IllegalStateException("Unable to determine local IPAddress") + } + } + + override fun toString(): String { + return "InteropTestParams(transport=$transport, muxer=$muxer, security=$security, isDialer=$isDialer, ip=$ip, redisAddress=$redisAddress, testTimeoutInSeconds=$testTimeoutInSeconds)" + } +} diff --git a/interop-test-client/src/test/resources/compose.yaml b/interop-test-client/src/test/resources/compose.yaml new file mode 100644 index 000000000..6375c37e9 --- /dev/null +++ b/interop-test-client/src/test/resources/compose.yaml @@ -0,0 +1,26 @@ +services: + dialer: + build: . + depends_on: + - redis + environment: + transport: tcp + is_dialer: true + ip: 0.0.0.0 + muxer: mplex + security: tls + listener: + init: true + build: . + depends_on: + - redis + environment: + transport: tcp + is_dialer: false + ip: 0.0.0.0 + muxer: mplex + security: tls + redis: + image: redis:7-alpine + environment: + REDIS_ARGS: --loglevel warning \ No newline at end of file diff --git a/libp2p/build.gradle.kts b/libp2p/build.gradle.kts index 4810f6745..d9037e61c 100644 --- a/libp2p/build.gradle.kts +++ b/libp2p/build.gradle.kts @@ -1,6 +1,6 @@ plugins { id("com.google.protobuf").version("0.9.4") - id("me.champeau.jmh").version("0.7.2") + id("me.champeau.jmh").version("0.7.3") } // https://docs.gradle.org/current/userguide/java_testing.html#ex-disable-publishing-of-test-fixtures-variants @@ -14,6 +14,20 @@ dependencies { api("io.netty:netty-transport") implementation("io.netty:netty-handler") implementation("io.netty:netty-codec-http") + implementation("io.netty:netty-codec-protobuf") + implementation("io.netty:netty-transport-classes-epoll") + implementation("io.netty:netty-codec-native-quic") + // OS-specific bindings + implementation("io.netty:netty-codec-native-quic::linux-x86_64") + implementation("io.netty:netty-codec-native-quic::linux-aarch_64") + implementation("io.netty:netty-codec-native-quic::osx-x86_64") + implementation("io.netty:netty-codec-native-quic::osx-aarch_64") + implementation("io.netty:netty-codec-native-quic::windows-x86_64") + implementation("io.netty:netty-tcnative-boringssl-static::linux-x86_64") + implementation("io.netty:netty-tcnative-boringssl-static::linux-aarch_64") + implementation("io.netty:netty-tcnative-boringssl-static::osx-x86_64") + implementation("io.netty:netty-tcnative-boringssl-static::osx-aarch_64") + implementation("io.netty:netty-tcnative-boringssl-static::windows-x86_64") api("com.google.protobuf:protobuf-java") @@ -22,7 +36,6 @@ dependencies { implementation("org.bouncycastle:bcprov-jdk18on") implementation("org.bouncycastle:bcpkix-jdk18on") - implementation("org.bouncycastle:bctls-jdk18on") testImplementation(project(":tools:schedulers")) @@ -33,8 +46,6 @@ dependencies { testFixturesImplementation("org.junit.jupiter:junit-jupiter-api") jmhImplementation(project(":tools:schedulers")) - jmhImplementation("org.openjdk.jmh:jmh-core") - jmhAnnotationProcessor("org.openjdk.jmh:jmh-generator-annprocess") } protobuf { diff --git a/libp2p/src/jmh/java/io/libp2p/pubsub/gossip/GossipScoreBenchmark.java b/libp2p/src/jmh/java/io/libp2p/pubsub/gossip/GossipScoreBenchmark.java index ce87b32d0..1128a2512 100644 --- a/libp2p/src/jmh/java/io/libp2p/pubsub/gossip/GossipScoreBenchmark.java +++ b/libp2p/src/jmh/java/io/libp2p/pubsub/gossip/GossipScoreBenchmark.java @@ -107,12 +107,4 @@ public void scoresDelay10000(Blackhole bh) { bh.consume(s); } } - - /** Uncomment for debugging */ - // public static void main(String[] args) { - // GossipScoreBenchmark benchmark = new GossipScoreBenchmark(); - // Blackhole blackhole = new Blackhole("Today's password is swordfish. I understand - // instantiating Blackholes directly is dangerous."); - // benchmark.scoresDelay0(blackhole); - // } } diff --git a/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java b/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java index 6eba6a226..a9223f737 100644 --- a/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java +++ b/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java @@ -1,7 +1,7 @@ package io.libp2p.core.dsl; import io.libp2p.core.Host; -import io.libp2p.core.crypto.PrivKey; +import io.libp2p.core.crypto.*; import io.libp2p.core.multistream.ProtocolBinding; import io.libp2p.core.mux.*; import io.libp2p.core.security.SecureChannel; @@ -57,11 +57,23 @@ public final HostBuilder protocol(ProtocolBinding... protocols) { return this; } + @SafeVarargs + public final HostBuilder secureTransport( + BiFunction>, Transport>... transports) { + secureTransports_.addAll(Arrays.asList(transports)); + return this; + } + public final HostBuilder listen(String... addresses) { listenAddresses_.addAll(Arrays.asList(addresses)); return this; } + public HostBuilder keyType(KeyType keyType) { + this.keyType = keyType; + return this; + } + public final HostBuilder builderModifier(Consumer builderModifier) { this.builderModifier = builderModifier; return this; @@ -72,8 +84,9 @@ public Host build() { return BuilderJKt.hostJ( defaultMode_.asBuilderDefault(), b -> { - b.getIdentity().random(); + b.getIdentity().random(keyType); + secureTransports_.forEach(st -> b.getSecureTransports().add(st::apply)); transports_.forEach(t -> b.getTransports().add(t::apply)); secureChannels_.forEach( sc -> b.getSecureChannels().add((k, m) -> sc.apply(k, (List) m))); @@ -85,6 +98,10 @@ public Host build() { } // build private DefaultMode defaultMode_; + private KeyType keyType = KeyType.ECDSA; + private List>, Transport>> + secureTransports_ = new ArrayList<>(); + private List> transports_ = new ArrayList<>(); private List, SecureChannel>> secureChannels_ = new ArrayList<>(); diff --git a/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java b/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java index be2be179d..687880c15 100644 --- a/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java +++ b/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java @@ -315,14 +315,22 @@ public void onMessage(@NotNull Stream stream, Circuit.HopMessage msg) { new CircuitStopProtocol.StopRemover()); // connect these streams with time + bytes enforcement - fromRequestor.pushHandler(new InboundTrafficLimitHandler(resv.maxBytes)); - fromRequestor.pushHandler( - new TotalTimeoutHandler( - Duration.of(resv.durationSeconds, ChronoUnit.SECONDS))); - toTarget.pushHandler(new InboundTrafficLimitHandler(resv.maxBytes)); - toTarget.pushHandler( - new TotalTimeoutHandler( - Duration.of(resv.durationSeconds, ChronoUnit.SECONDS))); + if (resv.maxBytes > 0) { + fromRequestor.pushHandler(new InboundTrafficLimitHandler(resv.maxBytes)); + } + if (resv.durationSeconds > 0) { + fromRequestor.pushHandler( + new TotalTimeoutHandler( + Duration.of(resv.durationSeconds, ChronoUnit.SECONDS))); + } + if (resv.maxBytes > 0) { + toTarget.pushHandler(new InboundTrafficLimitHandler(resv.maxBytes)); + } + if (resv.durationSeconds > 0) { + toTarget.pushHandler( + new TotalTimeoutHandler( + Duration.of(resv.durationSeconds, ChronoUnit.SECONDS))); + } fromRequestor.pushHandler(new ProxyHandler(toTarget)); toTarget.pushHandler(new ProxyHandler(fromRequestor)); } else { diff --git a/libp2p/src/main/kotlin/io/libp2p/core/dsl/BuilderJ.kt b/libp2p/src/main/kotlin/io/libp2p/core/dsl/BuilderJ.kt index 013bb7201..108f7ee3d 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/dsl/BuilderJ.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/dsl/BuilderJ.kt @@ -19,6 +19,7 @@ class BuilderJ : Builder() { public override val identity = super.identity public override val secureChannels = super.secureChannels public override val muxers = super.muxers + public override val secureTransports = super.secureTransports public override val transports = super.transports public override val addressBook = super.addressBook public override val protocols = super.protocols diff --git a/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt b/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt index ce1416dfd..7f6e6d9c5 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt @@ -37,6 +37,7 @@ import io.netty.handler.logging.LoggingHandler import java.util.concurrent.CopyOnWriteArrayList typealias TransportCtor = (ConnectionUpgrader) -> Transport +typealias SecureTransportCtor = (PrivKey, List>) -> Transport typealias SecureChannelCtor = (PrivKey, List) -> SecureChannel typealias IdentityFactory = () -> PrivKey @@ -58,6 +59,7 @@ open class Builder { protected open val secureChannels = SecureChannelsBuilder() protected open val muxers = MuxersBuilder() protected open val transports = TransportsBuilder() + protected open val secureTransports = SecureTransportsBuilder() protected open val addressBook = AddressBookBuilder() protected open val protocols = ProtocolsBuilder() protected open val connectionHandlers = ConnectionHandlerBuilder() @@ -88,6 +90,11 @@ open class Builder { */ open fun transports(fn: TransportsBuilder.() -> Unit): Builder = apply { fn(transports) } + /** + * Manipulates the secure transports for this host. + */ + open fun secureTransports(fn: SecureTransportsBuilder.() -> Unit): Builder = apply { fn(secureTransports) } + /** * [AddressBook] implementation */ @@ -126,9 +133,9 @@ open class Builder { if (def == Defaults.None) { if (identity.factory == null) throw IllegalStateException("No identity builder") - if (transports.values.isEmpty()) throw HostConfigurationException("at least one transport is required") - if (secureChannels.values.isEmpty()) throw HostConfigurationException("at least one secure channel is required") - if (muxers.values.isEmpty()) throw HostConfigurationException("at least one muxer is required") + if (secureTransports.isEmpty() && transports.values.isEmpty()) throw HostConfigurationException("at least one transport is required") + if (secureTransports.isEmpty() && secureChannels.values.isEmpty()) throw HostConfigurationException("at least one secure channel or secure transport is required") + if (secureTransports.isEmpty() && muxers.values.isEmpty()) throw HostConfigurationException("at least one muxer or secure transport is required") } if (def == Defaults.Standard) { if (identity.factory == null) identity.random() @@ -189,7 +196,12 @@ open class Builder { val upgrader = ConnectionUpgrader(secureMultistreamProtocol, secureChannels, muxerMultistreamProtocol, muxers) - val transports = transports.values.map { it(upgrader) } + val allTransports = + listOf( + transports.values.map { it(upgrader) }, + secureTransports.values.map { it(privKey, updatableProtocols) } + ).flatten() + val addressBook = addressBook.impl val connHandlerProtocols = protocols.values.mapNotNull { it as? ConnectionHandler } @@ -197,7 +209,7 @@ open class Builder { connHandlerProtocols + connectionHandlers.values ) - val networkImpl = NetworkImpl(transports, broadcastConnHandler) + val networkImpl = NetworkImpl(allTransports, broadcastConnHandler) return HostImpl( privKey, @@ -230,6 +242,7 @@ class AddressBookBuilder { fun memory(): AddressBookBuilder = apply { impl = MemoryAddressBook() } } +class SecureTransportsBuilder : Enumeration() class TransportsBuilder : Enumeration() class SecureChannelsBuilder : Enumeration() class MuxersBuilder : Enumeration() diff --git a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Protocol.kt b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Protocol.kt index 5d171b811..dcba6b811 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Protocol.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Protocol.kt @@ -42,6 +42,8 @@ enum class Protocol( DNS6(55, LENGTH_PREFIXED_VAR_SIZE, "dns6", UTF8_PARSER, UTF8_STRINGIFIER, UTF8_VALIDATOR), DNSADDR(56, LENGTH_PREFIXED_VAR_SIZE, "dnsaddr", UTF8_PARSER, UTF8_STRINGIFIER, UTF8_VALIDATOR), SCTP(132, 16, "sctp", UINT16_PARSER, UINT16_STRINGIFIER), + WEBRTC_DIRECT(280, 0, "webrtc-direct"), + WEBRTC(28, 0, "webrtc"), UTP(301, 0, "utp"), UDT(302, 0, "udt"), UNIX(400, LENGTH_PREFIXED_VAR_SIZE, "unix", UNIX_PATH_PARSER, UTF8_STRINGIFIER, UTF8_VALIDATOR, isPath = true), diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/types/NettyExt.kt b/libp2p/src/main/kotlin/io/libp2p/etc/types/NettyExt.kt index e4c9c1d49..e22d68efc 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/types/NettyExt.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/types/NettyExt.kt @@ -4,6 +4,7 @@ import io.netty.channel.Channel import io.netty.channel.ChannelFuture import io.netty.channel.ChannelHandler import io.netty.channel.ChannelPipeline +import io.netty.util.concurrent.Future import java.util.concurrent.CompletableFuture fun ChannelFuture.toVoidCompletableFuture(): CompletableFuture = toCompletableFuture().thenApply { } @@ -20,6 +21,21 @@ fun ChannelFuture.toCompletableFuture(): CompletableFuture { return ret } +fun Future<*>.toVoidCompletableFuture(): CompletableFuture = toCompletableFuture().thenApply { } + +fun Future.toCompletableFuture(): CompletableFuture { + val ret = CompletableFuture() + this.addListener { + if (it.isSuccess) { + @Suppress("UNCHECKED_CAST") + ret.complete(it.get() as T) + } else { + ret.completeExceptionally(it.cause()) + } + } + return ret +} + fun ChannelPipeline.replace(oldHandler: ChannelHandler, newHandlers: List>) { replace(oldHandler, newHandlers[0].first, newHandlers[0].second) for (i in 1 until newHandlers.size) { @@ -32,5 +48,5 @@ fun ChannelPipeline.getHandlerName(handler: ChannelHandler) = ( ?: throw IllegalArgumentException("Handler $handler not found in pipeline $this") ) -fun ChannelPipeline.addAfter(handler: ChannelHandler, newHandlerName: String, newHandler: ChannelHandler) = +fun ChannelPipeline.addAfter(handler: ChannelHandler, newHandlerName: String, newHandler: ChannelHandler): ChannelPipeline = addAfter(getHandlerName(handler), newHandlerName, newHandler) diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/MultiaddrUtils.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/MultiaddrUtils.kt new file mode 100644 index 000000000..8b8716307 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/MultiaddrUtils.kt @@ -0,0 +1,30 @@ +package io.libp2p.etc.util + +import io.libp2p.core.InternalErrorException +import io.libp2p.core.multiformats.Multiaddr +import io.libp2p.core.multiformats.Protocol +import java.net.* + +class MultiaddrUtils { + + companion object { + + fun inetAddressToIpMultiaddr(addr: InetAddress): Multiaddr { + val proto = when (addr) { + is Inet4Address -> Protocol.IP4 + is Inet6Address -> Protocol.IP6 + else -> throw InternalErrorException("Unknown address type $addr") + } + return Multiaddr.empty() + .withComponent(proto, addr.hostAddress) + } + + fun inetSocketAddressToTcpMultiaddr(addr: InetSocketAddress): Multiaddr = + inetAddressToIpMultiaddr(addr.address) + .withComponent(Protocol.TCP, addr.port.toString()) + + fun inetSocketAddressToUdpMultiaddr(addr: InetSocketAddress): Multiaddr = + inetAddressToIpMultiaddr(addr.address) + .withComponent(Protocol.UDP, addr.port.toString()) + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/NettyUtil.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/NettyUtil.kt index 4743c9e6e..05502daa5 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/NettyUtil.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/NettyUtil.kt @@ -1,13 +1,12 @@ package io.libp2p.etc.util.netty import io.libp2p.etc.types.addAfter -import io.libp2p.etc.types.fromHex import io.netty.channel.Channel import io.netty.channel.ChannelHandler import io.netty.channel.ChannelInitializer import io.netty.util.internal.StringUtil -class NettyInit(val channel: Channel, val thisHandler: ChannelHandler) { +class NettyInit(val channel: Channel, thisHandler: ChannelHandler) { private var lastLocalHandler = thisHandler fun addLastLocal(handler: ChannelHandler) { channel.pipeline().addAfter(lastLocalHandler, generateName(channel, handler), handler) @@ -23,13 +22,6 @@ fun nettyInitializer(initer: (NettyInit) -> Unit): ChannelInitializer { } } -private val regex = Regex("\\|[0-9a-fA-F]{8}\\| ") -fun String.fromLogHandler() = lines() - .filter { it.contains(regex) } - .map { it.substring(11, 59).replace(" ", "") } - .flatMap { it.fromHex().asList() } - .toByteArray() - private fun generateName(ch: Channel, handler: ChannelHandler): String { val className = StringUtil.simpleClassName(handler.javaClass) val names = ch.pipeline().names().toSet() diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt index 855046c5a..2c4059c83 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt @@ -32,13 +32,20 @@ class MuxChannel( initializer(this) } + @Suppress("SwallowedException") override fun doWrite(buf: ChannelOutboundBuffer) { while (true) { val msg = buf.current() ?: break + if (localDisconnected) { + // Must not throw from doWrite — exceptions escape uncaught to the Netty event loop. + // Wrap buf.remove() defensively: in some Netty versions promise listeners triggered + // by buf.remove() can propagate back through it. + try { + buf.remove(ConnectionClosedException("The stream was closed for writing locally: $id")) + } catch (e: Throwable) { } + continue + } try { - if (localDisconnected) { - throw ConnectionClosedException("The stream was closed for writing locally: $id") - } // the msg is released by both onChildWrite and buf.remove() so we need to retain // however it is still to be confirmed that no buf leaks happen here TODO ReferenceCountUtil.retain(msg) diff --git a/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandler.kt b/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandler.kt index f59b8d84c..6fd3e2900 100644 --- a/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandler.kt @@ -6,10 +6,6 @@ interface ProtocolMessageHandler { fun onActivated(stream: Stream) = Unit fun onMessage(stream: Stream, msg: TMessage) = Unit fun onClosed(stream: Stream) = Unit + fun onReadClosed(stream: Stream) = Unit fun onException(cause: Throwable?) = Unit - - fun fireMessage(stream: Stream, msg: Any) { - @Suppress("UNCHECKED_CAST") - onMessage(stream, msg as TMessage) - } } diff --git a/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandlerAdapter.kt b/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandlerAdapter.kt index a86dd2cce..4cf9218ec 100644 --- a/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandlerAdapter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandlerAdapter.kt @@ -1,6 +1,7 @@ package io.libp2p.protocol import io.libp2p.core.Stream +import io.libp2p.etc.util.netty.mux.RemoteWriteClosed import io.netty.channel.ChannelHandlerContext import io.netty.channel.SimpleChannelInboundHandler import io.netty.util.ReferenceCounted @@ -33,7 +34,8 @@ class ProtocolMessageHandlerAdapter( } override fun channelRead0(ctx: ChannelHandlerContext?, msg: Any) { - pmh.fireMessage(stream, msg) + @Suppress("UNCHECKED_CAST") + pmh.onMessage(stream, msg as TMessage) } override fun channelUnregistered(ctx: ChannelHandlerContext?) { @@ -44,6 +46,13 @@ class ProtocolMessageHandlerAdapter( pmh.onException(cause) } + override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { + if (evt == RemoteWriteClosed) { + pmh.onReadClosed(stream) + } + super.userEventTriggered(ctx, evt) + } + // /////////////////////// private fun refCount(obj: Any): Int { return if (obj is ReferenceCounted) { diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt index d5e16401c..999bc2532 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt @@ -83,6 +83,15 @@ abstract class AbstractRouter( return true } + /** + * Per-router caps on repeated-field counts inside inbound RPCs. Enforced before + * protobuf materialisation by an [RpcCountFrameDecoder] inserted into the stream + * pipeline. Defaults to [PubsubRpcLimits.NONE] (no pre-decode cap). Subclasses + * with configured limits (e.g. [io.libp2p.pubsub.gossip.GossipRouter]) override. + */ + protected open val rpcLimits: PubsubRpcLimits + get() = PubsubRpcLimits.NONE + /** * Flushes all pending message parts for all peers */ @@ -113,6 +122,7 @@ abstract class AbstractRouter( with(streamHandler.stream) { pushHandler(LimitedProtobufVarint32FrameDecoder(maxMsgSize)) pushHandler(ProtobufVarint32LengthFieldPrepender()) + pushHandler(RpcCountFrameDecoder(rpcLimits)) pushHandler(ProtobufDecoder(Rpc.RPC.getDefaultInstance())) pushHandler(ProtobufEncoder()) handler?.also { pushHandler(it) } @@ -139,6 +149,11 @@ abstract class AbstractRouter( */ protected abstract fun processControl(ctrl: Rpc.ControlMessage, receivedFrom: PeerHandler) + /** + * Processes Gossipsub extensions messages + */ + protected abstract fun processExtensions(msg: Rpc.RPC, receivedFrom: PeerHandler) + override fun onPeerActive(peer: PeerHandler) { val partsQueue = pendingRpcParts.getQueue(peer) subscribedTopics.forEach { @@ -180,6 +195,10 @@ abstract class AbstractRouter( processControl(msg.control, peer) } + if (protocol.supportsExtensions()) { + processExtensions(msg, peer) + } + val (msgSubscribed, nonSubscribed) = msg.publishList .partition { rpcMsg -> rpcMsg.topicIDsList.any { it in subscribedTopics } } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubProtocol.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubProtocol.kt index 49cf95239..1b6186b02 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubProtocol.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubProtocol.kt @@ -7,6 +7,7 @@ enum class PubsubProtocol(val announceStr: ProtocolId) { Gossip_V_1_0("/meshsub/1.0.0"), Gossip_V_1_1("/meshsub/1.1.0"), Gossip_V_1_2("/meshsub/1.2.0"), + Gossip_V_1_3("/meshsub/1.3.0"), Floodsub("/floodsub/1.0.0"); companion object { @@ -18,13 +19,20 @@ enum class PubsubProtocol(val announceStr: ProtocolId) { * https://github.com/libp2p/specs/blob/master/pubsub/gossipsub/gossipsub-v1.1.md#prune-backoff-and-peer-exchange */ fun supportsBackoffAndPX(): Boolean { - return this == Gossip_V_1_1 || this == Gossip_V_1_2 + return this == Gossip_V_1_1 || this == Gossip_V_1_2 || this == Gossip_V_1_3 } /** * https://github.com/libp2p/specs/blob/master/pubsub/gossipsub/gossipsub-v1.2.md#idontwant-message */ fun supportsIDontWant(): Boolean { - return this == Gossip_V_1_2 + return this == Gossip_V_1_2 || this == Gossip_V_1_3 + } + + /** + * https://github.com/libp2p/specs/blob/master/pubsub/gossipsub/gossipsub-v1.3.md#the-extensions-control-message + */ + fun supportsExtensions(): Boolean { + return this == Gossip_V_1_3 } } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRpcLimits.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRpcLimits.kt new file mode 100644 index 000000000..2c9e7f848 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRpcLimits.kt @@ -0,0 +1,60 @@ +package io.libp2p.pubsub + +/** + * Per-router limits on repeated-field counts inside an inbound pubsub RPC. Enforced + * at decode time by [RpcMessageCountValidator] to prevent allocation amplification + * before [pubsub.pb.Rpc.RPC] is materialised. + * + * A null field means "no limit" — same semantics as the corresponding nullable + * fields on `GossipParams`. + */ +data class PubsubRpcLimits( + val maxPublishedMessages: Int?, + val maxTopicsPerPublishedMessage: Int?, + val maxSubscriptions: Int?, + val maxIHaveMessageIds: Int?, + val maxIWantMessageIds: Int?, + val maxGraftMessages: Int?, + val maxPruneMessages: Int?, + val maxPeersPerPruneMessage: Int?, + val maxIDontWantMessages: Int? = null, + val maxIDontWantMessageIds: Int? = null, + val rejectEmptyPublishEntries: Boolean = true, + val rejectEmptyIDontWantEntries: Boolean = true, +) { + /** + * True when no configured limit or reject-flag can fire. Lets + * [RpcCountFrameDecoder] skip the validator walk entirely on the toggle-off + * path. Any new field added to this data class must be considered here. + */ + val isNoop: Boolean = + maxPublishedMessages == null && + maxTopicsPerPublishedMessage == null && + maxSubscriptions == null && + maxIHaveMessageIds == null && + maxIWantMessageIds == null && + maxGraftMessages == null && + maxPruneMessages == null && + maxPeersPerPruneMessage == null && + maxIDontWantMessages == null && + maxIDontWantMessageIds == null && + !rejectEmptyPublishEntries && + !rejectEmptyIDontWantEntries + + companion object { + val NONE = PubsubRpcLimits( + maxPublishedMessages = null, + maxTopicsPerPublishedMessage = null, + maxSubscriptions = null, + maxIHaveMessageIds = null, + maxIWantMessageIds = null, + maxGraftMessages = null, + maxPruneMessages = null, + maxPeersPerPruneMessage = null, + maxIDontWantMessages = null, + maxIDontWantMessageIds = null, + rejectEmptyPublishEntries = false, + rejectEmptyIDontWantEntries = false, + ) + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/RpcCountFrameDecoder.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/RpcCountFrameDecoder.kt new file mode 100644 index 000000000..45b108f23 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/RpcCountFrameDecoder.kt @@ -0,0 +1,59 @@ +package io.libp2p.pubsub + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.CorruptedFrameException +import io.netty.handler.codec.MessageToMessageDecoder +import org.slf4j.LoggerFactory + +/** + * Pre-decode count cap for inbound pubsub RPC frames. Sits between + * [io.libp2p.etc.util.netty.protobuf.LimitedProtobufVarint32FrameDecoder] (byte-size + * cap) and [io.netty.handler.codec.protobuf.ProtobufDecoder] (materialisation). + * + * For each frame, delegates to [RpcMessageCountValidator]. Accepted frames are + * forwarded unchanged as a `ByteBuf` to the next handler. Frames rejected because + * a configured count limit was exceeded are dropped with a debug log; no + * `Rpc$Message` is allocated for them. Frames rejected because the protobuf bytes + * themselves are malformed propagate a [CorruptedFrameException] so that + * downstream handlers (e.g. [io.libp2p.pubsub.AbstractRouter.onPeerWireException]) + * can apply the same behaviour penalty they would have on a [ProtobufDecoder] + * failure. + * + * When [limits] is a no-op (see [PubsubRpcLimits.isNoop], e.g. [PubsubRpcLimits.NONE]) + * the validator is skipped entirely and the buffer is forwarded as-is. Malformed + * bytes still surface downstream from [ProtobufDecoder], which already triggers + * the same wire-exception path the validator would have used. + */ +class RpcCountFrameDecoder(private val limits: PubsubRpcLimits) : MessageToMessageDecoder() { + + override fun decode(ctx: ChannelHandlerContext, msg: ByteBuf, out: MutableList) { + if (limits.isNoop) { + out.add(msg.retain()) + return + } + + val result = try { + RpcMessageCountValidator.validate(msg, limits) + } catch (e: Exception) { + logger.debug("Dropping pubsub RPC frame due to unexpected validator error", e) + return + } + + when (result) { + RpcMessageCountValidator.Result.Accepted -> { + out.add(msg.retain()) + } + is RpcMessageCountValidator.Result.Malformed -> { + throw CorruptedFrameException(result.reason) + } + is RpcMessageCountValidator.Result.Rejected -> { + logger.debug("Dropping pubsub RPC frame: {}", result.reason) + } + } + } + + companion object { + private val logger = LoggerFactory.getLogger(RpcCountFrameDecoder::class.java) + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/RpcMessageCountValidator.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/RpcMessageCountValidator.kt new file mode 100644 index 000000000..179abbeb8 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/RpcMessageCountValidator.kt @@ -0,0 +1,266 @@ +package io.libp2p.pubsub + +import com.google.protobuf.CodedInputStream +import com.google.protobuf.Descriptors +import com.google.protobuf.WireFormat +import io.netty.buffer.ByteBuf +import pubsub.pb.Rpc +import java.io.IOException + +/** + * Walks an inbound pubsub RPC [ByteBuf] without materialising any `pubsub.pb.Rpc` + * message and rejects it if its repeated-field counts violate [PubsubRpcLimits]. + * + * Field numbers are taken from the protobuf-generated `*_FIELD_NUMBER` constants, + * so renames in `libp2p/src/main/proto/rpc.proto` break compilation. New repeated + * fields are caught by `RpcMessageCountValidatorProtoCoverageTest`, which + * recursively walks the descriptors reachable from [Rpc.RPC] and asserts each one + * appears in [ACKNOWLEDGED_REPEATED_FIELDS]. + * + * The walker uses [CodedInputStream] to read tags / lengths and to skip bodies, + * so no `Rpc$Message` / builder is allocated for rejected frames. + */ +object RpcMessageCountValidator { + + sealed interface Result { + object Accepted : Result + data class Rejected(val reason: String) : Result + data class Malformed(val reason: String) : Result + } + + // pubsub.RPC field numbers + private const val RPC_SUBSCRIPTIONS = Rpc.RPC.SUBSCRIPTIONS_FIELD_NUMBER + private const val RPC_PUBLISH = Rpc.RPC.PUBLISH_FIELD_NUMBER + private const val RPC_CONTROL = Rpc.RPC.CONTROL_FIELD_NUMBER + + // pubsub.Message field numbers + private const val MESSAGE_TOPIC_IDS = Rpc.Message.TOPICIDS_FIELD_NUMBER + + // pubsub.ControlMessage field numbers + private const val CTRL_IHAVE = Rpc.ControlMessage.IHAVE_FIELD_NUMBER + private const val CTRL_IWANT = Rpc.ControlMessage.IWANT_FIELD_NUMBER + private const val CTRL_GRAFT = Rpc.ControlMessage.GRAFT_FIELD_NUMBER + private const val CTRL_PRUNE = Rpc.ControlMessage.PRUNE_FIELD_NUMBER + private const val CTRL_IDONTWANT = Rpc.ControlMessage.IDONTWANT_FIELD_NUMBER + + // pubsub.ControlIHave / ControlIWant / ControlIDontWant repeated bytes field numbers + private const val IHAVE_MESSAGE_IDS = Rpc.ControlIHave.MESSAGEIDS_FIELD_NUMBER + private const val IWANT_MESSAGE_IDS = Rpc.ControlIWant.MESSAGEIDS_FIELD_NUMBER + private const val IDONTWANT_MESSAGE_IDS = Rpc.ControlIDontWant.MESSAGEIDS_FIELD_NUMBER + + // pubsub.ControlPrune.peers + private const val PRUNE_PEERS = Rpc.ControlPrune.PEERS_FIELD_NUMBER + + /** + * Single source of truth for every repeated proto field the validator inspects. + * The proto-coverage test asserts this map equals the set of repeated fields + * actually present in the proto, recursively from [Rpc.RPC]. Any new repeated + * field that lands in `rpc.proto` without being added here will fail the test. + */ + internal val ACKNOWLEDGED_REPEATED_FIELDS: Map> = mapOf( + Rpc.RPC.getDescriptor() to setOf(RPC_SUBSCRIPTIONS, RPC_PUBLISH), + Rpc.Message.getDescriptor() to setOf(MESSAGE_TOPIC_IDS), + Rpc.ControlMessage.getDescriptor() to setOf( + CTRL_IHAVE, + CTRL_IWANT, + CTRL_GRAFT, + CTRL_PRUNE, + CTRL_IDONTWANT + ), + Rpc.ControlIHave.getDescriptor() to setOf(IHAVE_MESSAGE_IDS), + Rpc.ControlIWant.getDescriptor() to setOf(IWANT_MESSAGE_IDS), + Rpc.ControlIDontWant.getDescriptor() to setOf(IDONTWANT_MESSAGE_IDS), + Rpc.ControlPrune.getDescriptor() to setOf(PRUNE_PEERS), + ) + + fun validate(buf: ByteBuf, limits: PubsubRpcLimits): Result { + val input = CodedInputStream.newInstance(buf.nioBuffer()) + return try { + validateRpc(input, limits) + } catch (e: IOException) { + Result.Malformed("malformed: ${e.message}") + } catch (e: IndexOutOfBoundsException) { + Result.Malformed("malformed: truncated (${e.message})") + } + } + + private class ControlCounters { + var ihaveMsgIds = 0 + var iwantMsgIds = 0 + var graftCount = 0 + var pruneCount = 0 + var idontwantCount = 0 + var idontwantMsgIds = 0 + } + + private fun validateRpc(input: CodedInputStream, limits: PubsubRpcLimits): Result { + var publishCount = 0 + var subscriptionCount = 0 + val ctrl = ControlCounters() + + while (!input.isAtEnd) { + val tag = input.readTag() + val fieldNumber = WireFormat.getTagFieldNumber(tag) + val wireType = WireFormat.getTagWireType(tag) + when { + fieldNumber == RPC_SUBSCRIPTIONS && + wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED -> { + subscriptionCount++ + limits.maxSubscriptions?.let { + if (subscriptionCount > it) return Result.Rejected("subscriptions count > $it") + } + input.skipField(tag) + } + fieldNumber == RPC_PUBLISH && + wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED -> { + val length = input.readRawVarint32() + if (length == 0 && limits.rejectEmptyPublishEntries) { + return Result.Rejected("empty publish entry") + } + publishCount++ + limits.maxPublishedMessages?.let { + if (publishCount > it) return Result.Rejected("publish count > $it") + } + val oldLimit = input.pushLimit(length) + val maxTopics = limits.maxTopicsPerPublishedMessage + if (maxTopics != null) { + val res = validatePublish(input, maxTopics) + if (res is Result.Rejected) return res + } else { + input.skipMessage() + } + input.popLimit(oldLimit) + } + fieldNumber == RPC_CONTROL && + wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED -> { + val length = input.readRawVarint32() + val oldLimit = input.pushLimit(length) + val res = validateControl(input, limits, ctrl) + if (res is Result.Rejected) return res + input.popLimit(oldLimit) + } + else -> input.skipField(tag) + } + } + return Result.Accepted + } + + private fun validatePublish(input: CodedInputStream, maxTopics: Int): Result { + var topicCount = 0 + while (!input.isAtEnd) { + val tag = input.readTag() + if (WireFormat.getTagFieldNumber(tag) == MESSAGE_TOPIC_IDS && + WireFormat.getTagWireType(tag) == WireFormat.WIRETYPE_LENGTH_DELIMITED + ) { + topicCount++ + if (topicCount > maxTopics) return Result.Rejected("topicIDs per publish > $maxTopics") + input.skipField(tag) + } else { + input.skipField(tag) + } + } + return Result.Accepted + } + + private fun validateControl( + input: CodedInputStream, + limits: PubsubRpcLimits, + c: ControlCounters, + ): Result { + while (!input.isAtEnd) { + val tag = input.readTag() + val fieldNumber = WireFormat.getTagFieldNumber(tag) + val wireType = WireFormat.getTagWireType(tag) + if (wireType != WireFormat.WIRETYPE_LENGTH_DELIMITED) { + input.skipField(tag) + continue + } + when (fieldNumber) { + CTRL_IHAVE -> { + val length = input.readRawVarint32() + val oldLimit = input.pushLimit(length) + val count = countRepeatedBytes(input, IHAVE_MESSAGE_IDS) + c.ihaveMsgIds += count + limits.maxIHaveMessageIds?.let { + if (c.ihaveMsgIds > it) return Result.Rejected("ihave messageIDs > $it") + } + input.popLimit(oldLimit) + } + CTRL_IWANT -> { + val length = input.readRawVarint32() + val oldLimit = input.pushLimit(length) + val count = countRepeatedBytes(input, IWANT_MESSAGE_IDS) + c.iwantMsgIds += count + limits.maxIWantMessageIds?.let { + if (c.iwantMsgIds > it) return Result.Rejected("iwant messageIDs > $it") + } + input.popLimit(oldLimit) + } + CTRL_GRAFT -> { + c.graftCount++ + limits.maxGraftMessages?.let { + if (c.graftCount > it) return Result.Rejected("graft count > $it") + } + input.skipField(tag) + } + CTRL_PRUNE -> { + c.pruneCount++ + limits.maxPruneMessages?.let { + if (c.pruneCount > it) return Result.Rejected("prune count > $it") + } + val length = input.readRawVarint32() + val oldLimit = input.pushLimit(length) + val maxPeers = limits.maxPeersPerPruneMessage + if (maxPeers != null) { + val peerCount = countRepeatedMessages(input, PRUNE_PEERS) + if (peerCount > maxPeers) return Result.Rejected("peers per prune > $maxPeers") + } else { + input.skipMessage() + } + input.popLimit(oldLimit) + } + CTRL_IDONTWANT -> { + c.idontwantCount++ + limits.maxIDontWantMessages?.let { + if (c.idontwantCount > it) return Result.Rejected("idontwant count > $it") + } + val length = input.readRawVarint32() + if (length == 0 && limits.rejectEmptyIDontWantEntries) { + return Result.Rejected("empty idontwant entry") + } + val oldLimit = input.pushLimit(length) + val count = countRepeatedBytes(input, IDONTWANT_MESSAGE_IDS) + c.idontwantMsgIds += count + limits.maxIDontWantMessageIds?.let { + if (c.idontwantMsgIds > it) return Result.Rejected("idontwant messageIDs > $it") + } + input.popLimit(oldLimit) + } + else -> input.skipField(tag) + } + } + return Result.Accepted + } + + /** + * Counts occurrences of a length-delimited repeated field inside a sub-message + * region. The [CodedInputStream] must already be bounded by a `pushLimit` on the + * caller side; this method walks until `isAtEnd` and skips every body. + */ + private fun countRepeatedBytes(input: CodedInputStream, fieldNumber: Int): Int { + var count = 0 + while (!input.isAtEnd) { + val tag = input.readTag() + if (WireFormat.getTagFieldNumber(tag) == fieldNumber && + WireFormat.getTagWireType(tag) == WireFormat.WIRETYPE_LENGTH_DELIMITED + ) { + count++ + } + input.skipField(tag) + } + return count + } + + private fun countRepeatedMessages(input: CodedInputStream, fieldNumber: Int): Int = + countRepeatedBytes(input, fieldNumber) +} diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/flood/FloodRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/flood/FloodRouter.kt index 9bed00ddd..acb4e912e 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/flood/FloodRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/flood/FloodRouter.kt @@ -36,6 +36,10 @@ class FloodRouter(executor: ScheduledExecutorService = Executors.newSingleThread // NOP } + override fun processExtensions(msg: Rpc.RPC, receivedFrom: PeerHandler) { + // NOP + } + private fun broadcast(msg: PubsubMessage, receivedFrom: PeerHandler?): CompletableFuture { val peers = msg.topics .map { getTopicPeers(it) } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt index ae5f3c5e2..39100f10c 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt @@ -12,6 +12,7 @@ import io.libp2p.pubsub.PubsubApiImpl import io.libp2p.pubsub.PubsubProtocol import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder import io.netty.channel.ChannelHandler +import org.slf4j.LoggerFactory import java.util.concurrent.CompletableFuture class Gossip @JvmOverloads constructor( @@ -21,6 +22,8 @@ class Gossip @JvmOverloads constructor( ) : ProtocolBinding, ConnectionHandler, PubsubApi by api { + private val logger = LoggerFactory.getLogger(Gossip::class.java) + fun updateTopicScoreParams(scoreParams: Map) { router.score.updateTopicParams(scoreParams) } @@ -31,6 +34,14 @@ class Gossip @JvmOverloads constructor( override val protocolDescriptor = when (router.protocol) { + PubsubProtocol.Gossip_V_1_3 -> { + ProtocolDescriptor( + PubsubProtocol.Gossip_V_1_3.announceStr, + PubsubProtocol.Gossip_V_1_2.announceStr, + PubsubProtocol.Gossip_V_1_1.announceStr, + PubsubProtocol.Gossip_V_1_0.announceStr + ) + } PubsubProtocol.Gossip_V_1_2 -> { ProtocolDescriptor( PubsubProtocol.Gossip_V_1_2.announceStr, @@ -54,6 +65,7 @@ class Gossip @JvmOverloads constructor( } override fun initChannel(ch: P2PChannel, selectedProtocol: String): CompletableFuture { + logger.trace("Gossip initChannel - selected protocol: {}", selectedProtocol) router.addPeerWithDebugHandler(ch as Stream, debugGossipHandler) return CompletableFuture.completedFuture(Unit) } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsState.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsState.kt new file mode 100644 index 000000000..970f61d86 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsState.kt @@ -0,0 +1,64 @@ +package io.libp2p.pubsub.gossip + +import io.libp2p.core.PeerId +import pubsub.pb.Rpc + +enum class GossipExtension { + // Canonical extensions + PARTIAL_MESSAGES, + + // Non-canonical extensions + TEST_EXTENSION +} + +data class GossipExtensionsConfig( + val partialMessagesEnabled: Boolean = false, + val testExtensionEnabled: Boolean = false +) + +class GossipExtensionsState(gossipExtensionsConfig: GossipExtensionsConfig? = null) { + + val localExtensionSupport: Rpc.ControlExtensions = Rpc.ControlExtensions.newBuilder() + .setTestExtension(gossipExtensionsConfig?.testExtensionEnabled ?: false) + .setPartialMessages(gossipExtensionsConfig?.partialMessagesEnabled ?: false) + .build() + + /* + Tracks the peers that we have already sent a control extensions message + */ + private val outgoingControlExtensionsMsgPeers: MutableSet = mutableSetOf() + + /* + Tracks peers that already sent us a control extensions message + */ + private val peerExtensionSupportMap: MutableMap = mutableMapOf() + + fun onPeerDisconnected(peer: PeerId) { + outgoingControlExtensionsMsgPeers.remove(peer) + peerExtensionSupportMap.remove(peer) + } + + fun onControlExtensionsMessage(ctrlExtensions: Rpc.ControlExtensions, receivedFrom: PeerId) { + peerExtensionSupportMap[receivedFrom] = ctrlExtensions + } + + fun registerControlExtensionMessageSentToPeers(peerId: PeerId) { + outgoingControlExtensionsMsgPeers.add(peerId) + } + + fun peerSupportedExtensions(peerId: PeerId) = peerExtensionSupportMap[peerId] + + fun hasReceivedControlExtensionsFrom(peer: PeerId) = + peerExtensionSupportMap.contains(peer) + + fun hasSentControlExtensionsTo(peer: PeerId) = + outgoingControlExtensionsMsgPeers.contains(peer) + + fun testExtensionsEnabled() = localExtensionSupport.testExtension + fun peerSupportsTestExtensions(peerId: PeerId) = + peerExtensionSupportMap[peerId]?.testExtension == true + + fun partialMessagesEnabled() = localExtensionSupport.partialMessages + fun peerSupportsPartialMessages(peerId: PeerId) = + peerExtensionSupportMap[peerId]?.partialMessages == true +} diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt index b385aaa3b..6b4551f97 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt @@ -7,6 +7,7 @@ import io.libp2p.core.pubsub.ValidationResult import io.libp2p.etc.types.* import io.libp2p.etc.util.P2PService import io.libp2p.pubsub.* +import io.libp2p.pubsub.PubsubRpcLimits import org.slf4j.LoggerFactory import pubsub.pb.Rpc import java.time.Duration @@ -82,6 +83,7 @@ open class GossipRouter( val name: String, val mCache: MCache, val score: GossipScore, + val gossipExtensionsConfig: GossipExtensionsConfig = GossipExtensionsConfig(), subscriptionTopicSubscriptionFilter: TopicSubscriptionFilter, protocol: PubsubProtocol, @@ -132,6 +134,8 @@ open class GossipRouter( private val acceptRequestsWhitelist = mutableMapOf() override val pendingRpcParts = PendingRpcPartsMap { DefaultGossipRpcPartsQueue(params) } + val gossipExtensionsState = GossipExtensionsState(gossipExtensionsConfig) + private fun setBackOff(peer: PeerHandler, topic: Topic) = setBackOff(peer, topic, params.pruneBackoff.toMillis()) private fun setBackOff(peer: PeerHandler, topic: Topic, delay: Long) { backoffExpireTimes[peer.peerId to topic] = currentTimeSupplier() + delay @@ -157,6 +161,7 @@ open class GossipRouter( fanout.values.forEach { it.remove(peer) } acceptRequestsWhitelist -= peer pendingRpcParts.popQueue(peer) // discard them + gossipExtensionsState.onPeerDisconnected(peer.peerId) super.onPeerDisconnected(peer) } @@ -164,6 +169,7 @@ open class GossipRouter( super.onPeerActive(peer) eventBroadcaster.notifyConnected(peer.peerId, peer.getRemoteAddress()) heartbeatTask.hashCode() // force lazy initialization + sendControlExtensions(peer) } override fun notifyUnseenMessage(peer: PeerHandler, msg: PubsubMessage) { @@ -251,6 +257,22 @@ open class GossipRouter( return peerScore >= scoreParams.graylistThreshold } + override val rpcLimits: PubsubRpcLimits by lazy { + PubsubRpcLimits( + maxPublishedMessages = params.maxPublishedMessages, + maxTopicsPerPublishedMessage = params.maxTopicsPerPublishedMessage, + maxSubscriptions = params.maxSubscriptions, + maxIHaveMessageIds = params.maxIHaveLength, + maxIWantMessageIds = params.maxIWantMessageIds, + maxGraftMessages = params.maxGraftMessages, + maxPruneMessages = params.maxPruneMessages, + maxPeersPerPruneMessage = params.maxPeersAcceptedInPruneMsg, + maxIDontWantMessageIds = params.maxIDontWantMessageIds, + rejectEmptyPublishEntries = true, + rejectEmptyIDontWantEntries = true, + ) + } + override fun validateMessageListLimits(msg: Rpc.RPCOrBuilder): Boolean { val iWantMessageIdCount = msg.control?.iwantList?.sumOf { w -> w.messageIDsCount } ?: 0 val iHaveMessageIdCount = msg.control?.ihaveList?.sumOf { w -> w.messageIDsCount } ?: 0 @@ -384,6 +406,100 @@ open class GossipRouter( ctrl.run { (graftList + pruneList + ihaveList + iwantList + idontwantList) }.forEach { processControlMessage(it, receivedFrom) } + + if (protocol.supportsExtensions() && ctrl.hasExtensions()) { + processControlExtensions(ctrl.extensions, receivedFrom) + } + } + + private fun processControlExtensions( + ctrlExtensions: Rpc.ControlExtensions, + receivedFrom: PeerHandler + ) { + logger.trace("Received control extension {}", ctrlExtensions.toString()) + + if (gossipExtensionsState.hasReceivedControlExtensionsFrom(receivedFrom.peerId)) { + logger.trace( + "Received another control extension message from peer {}", + receivedFrom.peerId + ) + notifyRouterMisbehavior(receivedFrom, 10) + return + } else { + gossipExtensionsState.onControlExtensionsMessage(ctrlExtensions, receivedFrom.peerId) + } + } + + override fun processExtensions(msg: Rpc.RPC, receivedFrom: PeerHandler) { + when { + msg.hasTestExtension() -> { + if (!gossipExtensionsState.testExtensionsEnabled()) { + logger.trace( + "Ignoring test extension message from peer {} - test extension disabled", + msg + ) + return + } + + if (!gossipExtensionsState.peerSupportsTestExtensions(receivedFrom.peerId)) { + logger.trace( + "Ignoring test extension message from peer {} - did peer send ControlExtensions prior?", + msg + ) + return + } + + processTestExtensionMessage(msg.testExtension, receivedFrom) + } + + msg.hasPartial() -> { + if (!gossipExtensionsState.partialMessagesEnabled()) { + logger.trace( + "Ignoring partial messages message from peer {} - partial messages extension disabled", + msg + ) + return + } + + if (!gossipExtensionsState.peerSupportsPartialMessages(receivedFrom.peerId)) { + logger.trace( + "Ignoring partial messages message from peer {} - did peer send ControlExtensions prior?", + msg + ) + return + } + + processPartialMessageExtension(msg.partial, receivedFrom) + } + } + } + + private fun processTestExtensionMessage( + testExtensionMessage: Rpc.TestExtension, + receivedFrom: PeerHandler + ) { + logger.trace( + "Processing test extension message {} from {}", + testExtensionMessage.toByteArray(), + receivedFrom.peerId + ) + + val response = + Rpc.RPC.newBuilder().setTestExtension(Rpc.TestExtension.newBuilder().build()).build() + + send(receivedFrom, response) + } + + private fun processPartialMessageExtension( + partialMessagesExtension: Rpc.PartialMessagesExtension, + receivedFrom: PeerHandler + ) { + logger.trace( + "Processing partial message extension message {} from {}", + partialMessagesExtension.toString(), + receivedFrom.peerId + ) + // TODO: implement partial message handling (https://github.com/libp2p/jvm-libp2p/issues/435) } override fun broadcastInbound(msgs: List, receivedFrom: PeerHandler) { @@ -490,11 +606,16 @@ open class GossipRouter( override fun subscribe(topic: Topic) { super.subscribe(topic) + // Peers that are still within their PRUNE backoff window must be excluded when + // seeding the mesh on (re-)subscribe; grafting them during backoff is a P7 + // behaviour-penalty violation in go-libp2p-pubsub scorers and matches the JOIN + // path of the reference implementation. Heartbeat-driven mesh maintenance has + // always filtered by isBackOff; this path historically did not. val fanoutPeers = (fanout[topic] ?: mutableSetOf()) - .filter { score.score(it.peerId) >= 0 && !isDirect(it) } + .filter { score.score(it.peerId) >= 0 && !isDirect(it) && !isBackOff(it, topic) } val meshPeers = mesh.getOrPut(topic) { mutableSetOf() } val otherPeers = (getTopicPeers(topic) - meshPeers - fanoutPeers) - .filter { score.score(it.peerId) >= 0 && !isDirect(it) } + .filter { score.score(it.peerId) >= 0 && !isDirect(it) && !isBackOff(it, topic) } if (meshPeers.size < params.D) { val addFromFanout = fanoutPeers.shuffled(random) @@ -508,6 +629,8 @@ open class GossipRouter( fanout -= topic lastPublished -= topic } + + activePeers.forEach { sendControlExtensions(it) } } override fun unsubscribe(topic: Topic) { @@ -708,6 +831,29 @@ open class GossipRouter( send(peer, iDontWant) } + private fun sendControlExtensions(peer: PeerHandler) { + if (!this.protocol.supportsExtensions()) { + logger.trace( + "Protocol does not support extensions. Won't send control extensions message." + ) + return + } + + if (gossipExtensionsState.hasSentControlExtensionsTo(peer.peerId)) { + logger.trace( + "Already sent control extensions msg to peer {}. Won't send another one.", + peer.peerId + ) + return + } + + logger.trace("Sending control extensions message to peer {}", peer.peerId) + + pendingRpcParts.getQueue(peer) + .addControlExtensions(gossipExtensionsState.localExtensionSupport) + gossipExtensionsState.registerControlExtensionMessageSentToPeers(peer.peerId) + } + data class AcceptRequestsWhitelistEntry(val whitelistedTill: Long, val messagesAccepted: Int = 0) { fun incrementMessageCount() = AcceptRequestsWhitelistEntry(whitelistedTill, messagesAccepted + 1) } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt index e90332589..32e5c908a 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt @@ -26,6 +26,9 @@ interface GossipRpcPartsQueue : RpcPartsQueue { * Gossip 1.1 variant */ fun addPrune(topic: Topic, backoffSeconds: Long, backoffPeers: List) + + // TODO Need to check if we should handle when control extension and extension messages could be separated by split (https://github.com/libp2p/jvm-libp2p/issues/440) + fun addControlExtensions(ctrlMessage: Rpc.ControlExtensions) } /** @@ -81,6 +84,12 @@ open class DefaultGossipRpcPartsQueue( } } + protected data class ControlExtensionPart(val ctrlExtension: Rpc.ControlExtensions) : AbstractPart { + override fun appendToBuilder(builder: Rpc.RPC.Builder) { + builder.controlBuilder.setExtensions(ctrlExtension) + } + } + override fun addIHave(messageId: MessageId, topic: Topic) { addPart(IHavePart(messageId, topic)) } @@ -101,6 +110,10 @@ open class DefaultGossipRpcPartsQueue( addPart(PrunePart(topic, backoffSeconds, backoffPeers)) } + override fun addControlExtensions(ctrlMessage: Rpc.ControlExtensions) { + addPart(ControlExtensionPart(ctrlMessage)) + } + override fun takeMerged(): List { val ret = mutableListOf() var partIdx = 0 diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt index 5c783ce5f..214d4b06d 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt @@ -24,7 +24,7 @@ open class GossipRouterBuilder( var scheduledAsyncExecutor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor( ThreadFactoryBuilder().setDaemon(true).setNameFormat("GossipRouter-event-thread-%d").build() ), - var currentTimeSuppluer: CurrentTimeSupplier = { System.currentTimeMillis() }, + var currentTimeSupplier: CurrentTimeSupplier = { System.currentTimeMillis() }, var random: Random = Random(), var messageFactory: PubsubMessageFactory = { DefaultPubsubMessage(it) }, @@ -33,26 +33,32 @@ open class GossipRouterBuilder( var subscriptionTopicSubscriptionFilter: TopicSubscriptionFilter = TopicSubscriptionFilter.AllowAllTopicSubscriptionFilter(), var scoreFactory: GossipScoreFactory = - { scoreParams1, scheduledAsyncRxecutor, currentTimeSuppluer1, eventsSubscriber -> - val gossipScore = DefaultGossipScore(scoreParams1, scheduledAsyncRxecutor, currentTimeSuppluer1) + { scoreParams1, scheduledAsyncRxecutor, currentTimeSupplier1, eventsSubscriber -> + val gossipScore = DefaultGossipScore(scoreParams1, scheduledAsyncRxecutor, currentTimeSupplier1) eventsSubscriber(gossipScore) gossipScore }, - val gossipRouterEventListeners: MutableList = mutableListOf() + val gossipRouterEventListeners: MutableList = mutableListOf(), + val enabledGossipExtensions: List = mutableListOf(), ) { - var seenCache: SeenCache> by lazyVar { TTLSeenCache(SimpleSeenCache(), params.seenTTL, currentTimeSuppluer) } + var seenCache: SeenCache> by lazyVar { TTLSeenCache(SimpleSeenCache(), params.seenTTL, currentTimeSupplier) } var mCache: MCache by lazyVar { MCache(params.gossipSize, params.gossipHistoryLength) } private var disposed = false + fun enabledGossipExtensions(vararg gossipExtensions: GossipExtension): GossipRouterBuilder { + (enabledGossipExtensions as MutableList).addAll(gossipExtensions) + return this + } + protected open fun createGossipRouter(): GossipRouter { - val gossipScore = scoreFactory(scoreParams, scheduledAsyncExecutor, currentTimeSuppluer, { gossipRouterEventListeners += it }) + val gossipScore = scoreFactory(scoreParams, scheduledAsyncExecutor, currentTimeSupplier, { gossipRouterEventListeners += it }) val router = GossipRouter( params = params, scoreParams = scoreParams, - currentTimeSupplier = currentTimeSuppluer, + currentTimeSupplier = currentTimeSupplier, random = random, name = name, mCache = mCache, @@ -62,7 +68,8 @@ open class GossipRouterBuilder( executor = scheduledAsyncExecutor, messageFactory = messageFactory, seenMessages = seenCache, - messageValidator = messageValidator + messageValidator = messageValidator, + gossipExtensionsConfig = buildGossipExtensionsConfig(), ) router.eventBroadcaster.listeners += gossipRouterEventListeners @@ -74,4 +81,11 @@ open class GossipRouterBuilder( disposed = true return createGossipRouter() } + + private fun buildGossipExtensionsConfig(): GossipExtensionsConfig { + return GossipExtensionsConfig( + partialMessagesEnabled = enabledGossipExtensions.contains(GossipExtension.PARTIAL_MESSAGES), + testExtensionEnabled = enabledGossipExtensions.contains(GossipExtension.TEST_EXTENSION) + ) + } } diff --git a/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt b/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt index f555418c7..fb3c33ca0 100644 --- a/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt +++ b/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt @@ -2,6 +2,7 @@ package io.libp2p.security.tls import crypto.pb.Crypto import io.libp2p.core.* +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.PrivKey import io.libp2p.core.crypto.PubKey import io.libp2p.core.crypto.unmarshalPublicKey @@ -25,6 +26,7 @@ import io.netty.handler.ssl.ApplicationProtocolConfig import io.netty.handler.ssl.ClientAuth import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.ssl.SslHandler +import io.netty.handler.ssl.SslProvider import org.bouncycastle.asn1.* import org.bouncycastle.asn1.edec.EdECObjectIdentifiers import org.bouncycastle.asn1.pkcs.PrivateKeyInfo @@ -36,13 +38,15 @@ import org.bouncycastle.cert.X509v3CertificateBuilder import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter import org.bouncycastle.crypto.params.Ed25519PublicKeyParameters import org.bouncycastle.jcajce.interfaces.EdDSAPublicKey -import org.bouncycastle.jsse.provider.BouncyCastleJsseProvider +import org.bouncycastle.operator.ContentVerifierProvider +import org.bouncycastle.operator.DefaultDigestAlgorithmIdentifierFinder +import org.bouncycastle.operator.bc.BcECContentVerifierProviderBuilder +import org.bouncycastle.operator.bc.BcEdDSAContentVerifierProviderBuilder import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder import java.math.BigInteger import java.security.KeyFactory import java.security.PrivateKey -import java.security.PublicKey -import java.security.Security +import java.security.SecureRandom import java.security.cert.Certificate import java.security.cert.CertificateException import java.security.cert.X509Certificate @@ -64,16 +68,10 @@ val certificatePrefix = "libp2p-tls-handshake:".encodeToByteArray() class TlsSecureChannel(private val localKey: PrivKey, private val muxers: List, private val certAlgorithm: String) : SecureChannel { - constructor(localKey: PrivKey, muxerIds: List) : this(localKey, muxerIds, "Ed25519") {} + constructor(localKey: PrivKey, muxerIds: List) : this(localKey, muxerIds, "ECDSA") {} companion object { const val announce = "/tls/1.0.0" - init { - Security.insertProviderAt(Libp2pCrypto.provider, 1) - Security.insertProviderAt(BouncyCastleJsseProvider(), 2) - Security.setProperty("ssl.KeyManagerFactory.algorithm", "PKIX") - Security.setProperty("ssl.TrustManagerFactory.algorithm", "PKIX") - } @JvmStatic fun ECDSA(localKey: PrivKey, muxerIds: List): TlsSecureChannel { @@ -102,35 +100,37 @@ fun buildTlsHandler( expectedRemotePeer: Optional, muxers: List, certAlgorithm: String, - ch: P2PChannel, + isInitiator: Boolean, handshakeComplete: CompletableFuture, ctx: ChannelHandlerContext ): SslHandler { - val connectionKeys = if (certAlgorithm.equals("ECDSA")) generateEcdsaKeyPair() else generateEd25519KeyPair() + val connectionKeys = if (certAlgorithm == "ECDSA") generateEcdsaKeyPair() else generateEd25519KeyPair() val javaPrivateKey = getJavaKey(connectionKeys.first) val sslContext = ( - if (ch.isInitiator) { - SslContextBuilder.forClient().keyManager(javaPrivateKey, listOf(buildCert(localKey, connectionKeys.first))) + if (isInitiator) { + SslContextBuilder.forClient() + .keyManager(javaPrivateKey, listOf(buildCert(localKey, connectionKeys.first))) } else { SslContextBuilder.forServer(javaPrivateKey, listOf(buildCert(localKey, connectionKeys.first))) + .keyManager(javaPrivateKey, listOf(buildCert(localKey, connectionKeys.first))) } ) .protocols(listOf("TLSv1.3")) .ciphers(listOf("TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384", "TLS_CHACHA20_POLY1305_SHA256")) .clientAuth(ClientAuth.REQUIRE) .trustManager(Libp2pTrustManager(expectedRemotePeer)) - .sslContextProvider(BouncyCastleJsseProvider()) + .sslProvider(SslProvider.OPENSSL) + .secureRandom(SecureRandom()) .applicationProtocolConfig( ApplicationProtocolConfig( ApplicationProtocolConfig.Protocol.ALPN, - ApplicationProtocolConfig.SelectorFailureBehavior.FATAL_ALERT, - ApplicationProtocolConfig.SelectedListenerFailureBehavior.FATAL_ALERT, + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, muxers.allProtocols + NoEarlyMuxerNegotiationEntry // early muxer negotiation ) ) .build() val handler = sslContext.newHandler(ctx.alloc()) - handler.sslCloseFuture().addListener { _ -> ctx.close() } val handshake = handler.handshakeFuture() val engine = handler.engine() handshake.addListener { fut -> @@ -180,17 +180,17 @@ private class ChannelSetup( if (!activated) { activated = true val expectedRemotePeerId = ctx.channel().attr(REMOTE_PEER_ID).get() - ctx.channel().pipeline().addLast( - buildTlsHandler( - localKey, - Optional.ofNullable(expectedRemotePeerId), - muxers, - certAlgorithm, - ch, - handshakeComplete, - ctx - ) + val handler = buildTlsHandler( + localKey, + Optional.ofNullable(expectedRemotePeerId), + muxers, + certAlgorithm, + ch.isInitiator, + handshakeComplete, + ctx ) + ctx.channel().pipeline().addLast(handler) + handler.sslCloseFuture().addListener { _ -> ctx.close() } ctx.channel().pipeline().remove(SetupHandlerName) } } @@ -215,18 +215,25 @@ private class ChannelSetup( } class Libp2pTrustManager(private val expectedRemotePeer: Optional) : X509TrustManager { + var remoteCert: Certificate? + + init { + remoteCert = null + } override fun checkClientTrusted(certs: Array?, authType: String?) { if (certs?.size != 1) { throw CertificateException() } - val claimedPeerId = verifyAndExtractPeerId(arrayOf(certs.get(0))) + val cert = certs.get(0) + remoteCert = cert + val claimedPeerId = verifyAndExtractPeerId(arrayOf(cert)) if (expectedRemotePeer.map { ex -> !ex.equals(claimedPeerId) }.orElse(false)) { throw InvalidRemotePubKey() } } override fun checkServerTrusted(certs: Array?, authType: String?) { - return checkClientTrusted(certs, authType) + checkClientTrusted(certs, authType) } override fun getAcceptedIssuers(): Array { @@ -264,18 +271,11 @@ fun getAsn1EncodedPublicKey(pub: PubKey): ByteArray { throw IllegalArgumentException("Unsupported TLS key type:" + pub.keyType) } -fun getPubKey(pub: PublicKey): PubKey { - if (pub.algorithm.equals("EdDSA") || pub.algorithm.equals("Ed25519")) { - val raw = (pub as EdDSAPublicKey).pointEncoding - return Ed25519PublicKey(Ed25519PublicKeyParameters(raw)) - } - if (pub.algorithm.equals("EC")) { - return EcdsaPublicKey(pub as ECPublicKey) +fun getContentVerifier(bcX509Cert: X509CertificateHolder): ContentVerifierProvider { + if (bcX509Cert.signatureAlgorithm.equals(AlgorithmIdentifier(ASN1ObjectIdentifier("1.3.101.112")))) { + return BcEdDSAContentVerifierProviderBuilder().build(bcX509Cert) } - if (pub.algorithm.equals("RSA")) { - throw IllegalStateException("Unimplemented RSA public key support for TLS") - } - throw IllegalStateException("Unsupported key type: " + pub.algorithm) + return BcECContentVerifierProviderBuilder(DefaultDigestAlgorithmIdentifierFinder()).build(bcX509Cert) } fun verifyAndExtractPeerId(chain: Array): PeerId { @@ -298,11 +298,15 @@ fun verifyAndExtractPeerId(chain: Array): PeerId { val pubKeyProto = (seq.getObjectAt(0) as DEROctetString).octets val signature = (seq.getObjectAt(1) as DEROctetString).octets val pubKey = unmarshalPublicKey(pubKeyProto) - if (!pubKey.verify(certificatePrefix.plus(cert.publicKey.encoded), signature)) { + + val pubKeyAsn1 = bcCert.subjectPublicKeyInfo.encoded + if (!pubKey.verify(certificatePrefix.plus(pubKeyAsn1), signature)) { throw IllegalStateException("Invalid signature on TLS certificate extension!") } - cert.verify(cert.publicKey) + if (!bcX509Cert.isSignatureValid(getContentVerifier(bcX509Cert))) { + throw IllegalStateException("TLS certificate has invalid signature!") + } val now = Date() if (bcCert.endDate.date.before(now)) { throw IllegalStateException("TLS certificate has expired!") @@ -313,12 +317,45 @@ fun verifyAndExtractPeerId(chain: Array): PeerId { return PeerId.fromPubKey(pubKey) } +fun getAlgorithmName(oid: String): String { + if ("1.2.840.113549.1.1.1".equals(oid)) { + return "RSA" + } + if ("1.2.840.10045.2.1".equals(oid)) { + return "EC" + } + if ("1.2.840.10040.4.1".equals(oid)) { + return "DSA" + } + return oid +} + +fun getLibp2pKeyFromCert(publicKeyInfo: SubjectPublicKeyInfo): PubKey { + val spec = X509EncodedKeySpec(publicKeyInfo.encoded) + val algorithmName = getAlgorithmName(publicKeyInfo.getAlgorithm().getAlgorithm().getId()) + val pub = KeyFactory.getInstance(algorithmName, Libp2pCrypto.provider).generatePublic(spec) + if (pub.algorithm.equals("EdDSA") || pub.algorithm.equals("Ed25519")) { + val raw = (pub as EdDSAPublicKey).pointEncoding + return Ed25519PublicKey(Ed25519PublicKeyParameters(raw)) + } + if (pub.algorithm.equals("EC")) { + return EcdsaPublicKey(pub as ECPublicKey) + } + if (pub.algorithm.equals("RSA")) { + throw IllegalStateException("Unimplemented RSA public key support for TLS") + } + throw IllegalStateException("Unsupported key type: " + pub.algorithm) +} + fun getPublicKeyFromCert(chain: Array): PubKey { if (chain.size != 1) { throw java.lang.IllegalStateException("Cert chain must have exactly 1 element!") } val cert = chain.get(0) - return getPubKey(cert.publicKey) + val bcCert = org.bouncycastle.asn1.x509.Certificate + .getInstance(ASN1Primitive.fromByteArray(cert.getEncoded())) + + return getLibp2pKeyFromCert(bcCert.subjectPublicKeyInfo) } /** Build a self signed cert, with an extension containing the host key + sig(cert public key) diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionBuilder.kt b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionBuilder.kt index 960c94f33..d4a1834d4 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionBuilder.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionBuilder.kt @@ -5,7 +5,6 @@ import io.libp2p.core.Connection import io.libp2p.core.ConnectionHandler import io.libp2p.core.P2PChannel import io.libp2p.core.PeerId -import io.libp2p.core.transport.Transport import io.libp2p.etc.REMOTE_PEER_ID import io.libp2p.etc.types.forward import io.libp2p.transport.ConnectionUpgrader @@ -14,7 +13,7 @@ import io.netty.channel.ChannelInitializer import java.util.concurrent.CompletableFuture class ConnectionBuilder( - private val transport: Transport, + private val transport: NettyTransport, private val upgrader: ConnectionUpgrader, private val connHandler: ConnectionHandler, private val initiator: Boolean, diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionOverNetty.kt b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionOverNetty.kt index 90c1d824f..9572c35a6 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionOverNetty.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionOverNetty.kt @@ -1,17 +1,11 @@ package io.libp2p.transport.implementation import io.libp2p.core.Connection -import io.libp2p.core.InternalErrorException import io.libp2p.core.multiformats.Multiaddr -import io.libp2p.core.multiformats.Protocol import io.libp2p.core.mux.StreamMuxer import io.libp2p.core.security.SecureChannel -import io.libp2p.core.transport.Transport import io.libp2p.etc.CONNECTION import io.netty.channel.Channel -import java.net.Inet4Address -import java.net.Inet6Address -import java.net.InetSocketAddress /** * A Connection is a high-level wrapper around a Netty Channel representing the conduit to a peer. @@ -20,7 +14,7 @@ import java.net.InetSocketAddress */ open class ConnectionOverNetty( ch: Channel, - private val transport: Transport, + private val nettyTransport: NettyTransport, initiator: Boolean ) : Connection, P2PChannelOverNetty(ch, initiator) { private lateinit var muxerSession: StreamMuxer.Session @@ -39,29 +33,8 @@ open class ConnectionOverNetty( override fun muxerSession() = muxerSession override fun secureSession() = secureSession - override fun transport() = transport + override fun transport() = nettyTransport - override fun localAddress(): Multiaddr = - toMultiaddr(nettyChannel.localAddress() as InetSocketAddress) - override fun remoteAddress(): Multiaddr = - toMultiaddr(nettyChannel.remoteAddress() as InetSocketAddress) - - private fun toMultiaddr(addr: InetSocketAddress): Multiaddr { - if (transport is NettyTransport) { - return transport.toMultiaddr(addr) - } else { - return toMultiaddrDefault(addr) - } - } - - fun toMultiaddrDefault(addr: InetSocketAddress): Multiaddr { - val proto = when (addr.address) { - is Inet4Address -> Protocol.IP4 - is Inet6Address -> Protocol.IP6 - else -> throw InternalErrorException("Unknown address type $addr") - } - return Multiaddr.empty() - .withComponent(proto, addr.address.hostAddress) - .withComponent(Protocol.TCP, addr.port.toString()) - } // toMultiaddr + override fun localAddress(): Multiaddr = nettyTransport.localAddress(nettyChannel) + override fun remoteAddress(): Multiaddr = nettyTransport.remoteAddress(nettyChannel) } diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/NettyTransport.kt b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/NettyTransport.kt index f29c2bfa6..2e5d9d305 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/NettyTransport.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/NettyTransport.kt @@ -1,208 +1,15 @@ package io.libp2p.transport.implementation -import io.libp2p.core.ChannelVisitor -import io.libp2p.core.Connection -import io.libp2p.core.ConnectionHandler -import io.libp2p.core.Libp2pException -import io.libp2p.core.P2PChannel -import io.libp2p.core.PeerId import io.libp2p.core.multiformats.Multiaddr -import io.libp2p.core.multiformats.MultiaddrDns -import io.libp2p.core.multiformats.Protocol import io.libp2p.core.transport.Transport -import io.libp2p.etc.types.lazyVar -import io.libp2p.etc.types.toCompletableFuture -import io.libp2p.etc.types.toVoidCompletableFuture -import io.libp2p.etc.util.netty.nettyInitializer -import io.libp2p.transport.ConnectionUpgrader -import io.netty.bootstrap.Bootstrap -import io.netty.bootstrap.ServerBootstrap import io.netty.channel.Channel -import io.netty.channel.ChannelHandler -import io.netty.channel.ChannelOption -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.socket.nio.NioSocketChannel -import java.net.InetSocketAddress -import java.time.Duration -import java.util.concurrent.CompletableFuture -abstract class NettyTransport( - private val upgrader: ConnectionUpgrader -) : Transport { - private var closed = false - var connectTimeout = Duration.ofSeconds(15) +/** + * A `Transport` which relies on a Netty `Channel` + */ +interface NettyTransport : Transport { - private val listeners = mutableMapOf() - private val channels = mutableListOf() + fun localAddress(nettyChannel: Channel): Multiaddr - private var workerGroup by lazyVar { NioEventLoopGroup() } - private var bossGroup by lazyVar { NioEventLoopGroup(1) } - - private var client by lazyVar { - Bootstrap().apply { - group(workerGroup) - channel(NioSocketChannel::class.java) - option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout.toMillis().toInt()) - } - } - - private var server by lazyVar { - ServerBootstrap().apply { - group(bossGroup, workerGroup) - channel(NioServerSocketChannel::class.java) - } - } - - override val activeListeners: Int - get() = listeners.size - override val activeConnections: Int - get() = channels.size - - override fun listenAddresses(): List { - return listeners.values.map { - toMultiaddr(it.localAddress() as InetSocketAddress) - } - } - - override fun initialize() { - } - - override fun close(): CompletableFuture { - closed = true - - val unbindsCompleted = listeners - .map { (_, ch) -> ch } - .map { it.close().toVoidCompletableFuture() } - - val channelsClosed = channels - .toMutableList() // need a copy to avoid potential co-modification problems - .map { it.close().toVoidCompletableFuture() } - - val everythingThatNeedsToClose = unbindsCompleted.union(channelsClosed) - val allClosed = CompletableFuture.allOf(*everythingThatNeedsToClose.toTypedArray()) - - return allClosed.thenApply { - workerGroup.shutdownGracefully() - bossGroup.shutdownGracefully() - Unit - } - } // close - - override fun listen(addr: Multiaddr, connHandler: ConnectionHandler, preHandler: ChannelVisitor?): CompletableFuture { - if (closed) throw Libp2pException("Transport is closed") - - val connectionBuilder = makeConnectionBuilder(connHandler, false, preHandler = preHandler) - val channelHandler = serverTransportBuilder(connectionBuilder, addr) ?: connectionBuilder - - val listener = server.clone() - .childHandler( - nettyInitializer { init -> - registerChannel(init.channel) - init.addLastLocal(channelHandler) - } - ) - - val bindComplete = listener.bind(fromMultiaddr(addr)) - - bindComplete.also { - synchronized(this@NettyTransport) { - listeners += addr to it.channel() - it.channel().closeFuture().addListener { - synchronized(this@NettyTransport) { - listeners -= addr - } - } - } - } - - return bindComplete.toVoidCompletableFuture() - } // listener - - protected abstract fun serverTransportBuilder( - connectionBuilder: ConnectionBuilder, - addr: Multiaddr - ): ChannelHandler? - - override fun unlisten(addr: Multiaddr): CompletableFuture { - return listeners[addr]?.close()?.toVoidCompletableFuture() - ?: throw Libp2pException("No listeners on address $addr") - } // unlisten - - override fun dial(addr: Multiaddr, connHandler: ConnectionHandler, preHandler: ChannelVisitor?): CompletableFuture { - if (closed) throw Libp2pException("Transport is closed") - - val remotePeerId = addr.getPeerId() - val connectionBuilder = makeConnectionBuilder(connHandler, true, remotePeerId, preHandler) - val channelHandler = clientTransportBuilder(connectionBuilder, addr) ?: connectionBuilder - - val chanFuture = client.clone() - .handler(channelHandler) - .connect(fromMultiaddr(addr)) - .also { registerChannel(it.channel()) } - - return chanFuture.toCompletableFuture() - .thenCompose { connectionBuilder.connectionEstablished } - } // dial - - protected abstract fun clientTransportBuilder( - connectionBuilder: ConnectionBuilder, - addr: Multiaddr - ): ChannelHandler? - - private fun registerChannel(ch: Channel) { - if (closed) { - ch.close() - return - } - - synchronized(this@NettyTransport) { - channels += ch - ch.closeFuture().addListener { - synchronized(this@NettyTransport) { - channels -= ch - } - } - } - } // registerChannel - - private fun makeConnectionBuilder( - connHandler: ConnectionHandler, - initiator: Boolean, - remotePeerId: PeerId? = null, - preHandler: ChannelVisitor? - ) = ConnectionBuilder( - this, - upgrader, - connHandler, - initiator, - remotePeerId, - preHandler - ) - - protected fun handlesHost(addr: Multiaddr) = - addr.hasAny(Protocol.IP4, Protocol.IP6, Protocol.DNS4, Protocol.DNS6, Protocol.DNSADDR) - - protected fun hostFromMultiaddr(addr: Multiaddr): String { - val resolvedAddresses = MultiaddrDns.resolve(addr) - if (resolvedAddresses.isEmpty()) { - throw Libp2pException("Could not resolve $addr to an IP address") - } - - return resolvedAddresses[0].components.find { - it.protocol in arrayOf(Protocol.IP4, Protocol.IP6) - }?.stringValue ?: throw Libp2pException("Missing IP4/IP6 in multiaddress $addr") - } - - protected fun portFromMultiaddr(addr: Multiaddr) = - addr.components.find { p -> p.protocol == Protocol.TCP } - ?.stringValue?.toInt() ?: throw Libp2pException("Missing TCP in multiaddress $addr") - - private fun fromMultiaddr(addr: Multiaddr): InetSocketAddress { - val host = hostFromMultiaddr(addr) - val port = portFromMultiaddr(addr) - return InetSocketAddress(host, port) - } // fromMultiaddr - - abstract fun toMultiaddr(addr: InetSocketAddress): Multiaddr -} // class NettyTransportBase + fun remoteAddress(nettyChannel: Channel): Multiaddr +} diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/PlainNettyTransport.kt b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/PlainNettyTransport.kt new file mode 100644 index 000000000..8d135fa9e --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/PlainNettyTransport.kt @@ -0,0 +1,228 @@ +package io.libp2p.transport.implementation + +import io.libp2p.core.ChannelVisitor +import io.libp2p.core.Connection +import io.libp2p.core.ConnectionHandler +import io.libp2p.core.Libp2pException +import io.libp2p.core.P2PChannel +import io.libp2p.core.PeerId +import io.libp2p.core.multiformats.Multiaddr +import io.libp2p.core.multiformats.MultiaddrDns +import io.libp2p.core.multiformats.Protocol +import io.libp2p.etc.types.lazyVar +import io.libp2p.etc.types.toCompletableFuture +import io.libp2p.etc.types.toVoidCompletableFuture +import io.libp2p.etc.util.netty.nettyInitializer +import io.libp2p.transport.ConnectionUpgrader +import io.netty.bootstrap.Bootstrap +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.Channel +import io.netty.channel.ChannelHandler +import io.netty.channel.ChannelOption +import io.netty.channel.MultiThreadIoEventLoopGroup +import io.netty.channel.nio.NioIoHandler +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.channel.socket.nio.NioSocketChannel +import java.net.InetSocketAddress +import java.net.SocketAddress +import java.time.Duration +import java.util.concurrent.CompletableFuture + +/** + * A plain `NettyTransport` without embedded security and muxer + */ +abstract class PlainNettyTransport( + private val upgrader: ConnectionUpgrader +) : NettyTransport { // class NettyTransportBase + private var closed = false + var connectTimeout = Duration.ofSeconds(15) + + private val listeners = mutableMapOf() + private val channels = mutableListOf() + + private var workerGroup by lazyVar { + MultiThreadIoEventLoopGroup(NioIoHandler.newFactory()) + } + private var bossGroup by lazyVar { + MultiThreadIoEventLoopGroup(1, NioIoHandler.newFactory()) + } + + private var client by lazyVar { + Bootstrap().apply { + group(workerGroup) + channel(NioSocketChannel::class.java) + option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout.toMillis().toInt()) + } + } + + private var server by lazyVar { + ServerBootstrap().apply { + group(bossGroup, workerGroup) + channel(NioServerSocketChannel::class.java) + } + } + + override val activeListeners: Int + get() = listeners.size + override val activeConnections: Int + get() = channels.size + + override fun listenAddresses(): List { + return listeners.values.map { + toMultiaddr(it.localAddress() as InetSocketAddress) + } + } + + override fun initialize() { + } + + override fun close(): CompletableFuture { + closed = true + + val unbindsCompleted = listeners + .map { (_, ch) -> ch } + .map { it.close().toVoidCompletableFuture() } + + val channelsClosed = channels + .toMutableList() // need a copy to avoid potential co-modification problems + .map { it.close().toVoidCompletableFuture() } + + val everythingThatNeedsToClose = unbindsCompleted.union(channelsClosed) + val allClosed = CompletableFuture.allOf(*everythingThatNeedsToClose.toTypedArray()) + + return allClosed.thenCompose { + CompletableFuture.allOf( + workerGroup.shutdownGracefully().toVoidCompletableFuture(), + bossGroup.shutdownGracefully().toVoidCompletableFuture() + ).thenApply { } + } + } // close + + override fun listen( + addr: Multiaddr, + connHandler: ConnectionHandler, + preHandler: ChannelVisitor? + ): CompletableFuture { + if (closed) throw Libp2pException("Transport is closed") + + val connectionBuilder = makeConnectionBuilder(connHandler, false, preHandler = preHandler) + val channelHandler = serverTransportBuilder(connectionBuilder, addr) ?: connectionBuilder + + val listener = server.clone() + .childHandler( + nettyInitializer { init -> + registerChannel(init.channel) + init.addLastLocal(channelHandler) + } + ) + + val bindComplete = listener.bind(fromMultiaddr(addr)) + + bindComplete.also { + synchronized(this@PlainNettyTransport) { + listeners += addr to it.channel() + it.channel().closeFuture().addListener { + synchronized(this@PlainNettyTransport) { + listeners -= addr + } + } + } + } + + return bindComplete.toVoidCompletableFuture() + } // listener + + protected abstract fun serverTransportBuilder( + connectionBuilder: ConnectionBuilder, + addr: Multiaddr + ): ChannelHandler? + + override fun unlisten(addr: Multiaddr): CompletableFuture { + return listeners[addr]?.close()?.toVoidCompletableFuture() + ?: throw Libp2pException("No listeners on address $addr") + } // unlisten + + override fun dial( + addr: Multiaddr, + connHandler: ConnectionHandler, + preHandler: ChannelVisitor? + ): CompletableFuture { + if (closed) throw Libp2pException("Transport is closed") + + val remotePeerId = addr.getPeerId() + val connectionBuilder = makeConnectionBuilder(connHandler, true, remotePeerId, preHandler) + val channelHandler = clientTransportBuilder(connectionBuilder, addr) ?: connectionBuilder + + val chanFuture = client.clone() + .handler(channelHandler) + .connect(fromMultiaddr(addr)) + .also { registerChannel(it.channel()) } + + return chanFuture.toCompletableFuture() + .thenCompose { connectionBuilder.connectionEstablished } + } // dial + + protected abstract fun clientTransportBuilder( + connectionBuilder: ConnectionBuilder, + addr: Multiaddr + ): ChannelHandler? + + private fun registerChannel(ch: Channel) { + if (closed) { + ch.close() + return + } + + synchronized(this@PlainNettyTransport) { + channels += ch + ch.closeFuture().addListener { + synchronized(this@PlainNettyTransport) { + channels -= ch + } + } + } + } // registerChannel + + private fun makeConnectionBuilder( + connHandler: ConnectionHandler, + initiator: Boolean, + remotePeerId: PeerId? = null, + preHandler: ChannelVisitor? + ) = ConnectionBuilder( + this, + upgrader, + connHandler, + initiator, + remotePeerId, + preHandler + ) + + protected fun handlesHost(addr: Multiaddr) = + addr.hasAny(Protocol.IP4, Protocol.IP6, Protocol.DNS4, Protocol.DNS6, Protocol.DNSADDR) + + protected fun hostFromMultiaddr(addr: Multiaddr): String { + val resolvedAddresses = MultiaddrDns.resolve(addr) + if (resolvedAddresses.isEmpty()) { + throw Libp2pException("Could not resolve $addr to an IP address") + } + + return resolvedAddresses[0].components.find { + it.protocol in arrayOf(Protocol.IP4, Protocol.IP6) + }?.stringValue ?: throw Libp2pException("Missing IP4/IP6 in multiaddress $addr") + } + + protected fun portFromMultiaddr(addr: Multiaddr) = + addr.components.find { p -> p.protocol == Protocol.TCP } + ?.stringValue?.toInt() ?: throw Libp2pException("Missing TCP in multiaddress $addr") + + private fun fromMultiaddr(addr: Multiaddr): InetSocketAddress { + val host = hostFromMultiaddr(addr) + val port = portFromMultiaddr(addr) + return InetSocketAddress(host, port) + } // fromMultiaddr + + override fun localAddress(nettyChannel: Channel): Multiaddr = toMultiaddr(nettyChannel.localAddress()) + override fun remoteAddress(nettyChannel: Channel): Multiaddr = toMultiaddr(nettyChannel.remoteAddress()) + + abstract fun toMultiaddr(addr: SocketAddress): Multiaddr +} diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/StreamOverNetty.kt b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/StreamOverNetty.kt index b0324a7e7..73c550ecc 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/StreamOverNetty.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/StreamOverNetty.kt @@ -8,7 +8,7 @@ import io.libp2p.etc.types.toVoidCompletableFuture import io.netty.channel.Channel import java.util.concurrent.CompletableFuture -class StreamOverNetty( +open class StreamOverNetty( ch: Channel, override val connection: Connection, initiator: Boolean diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicStream.kt b/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicStream.kt new file mode 100644 index 000000000..e2e662ff5 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicStream.kt @@ -0,0 +1,22 @@ +package io.libp2p.transport.quic + +import io.libp2p.core.Connection +import io.libp2p.etc.types.toVoidCompletableFuture +import io.libp2p.transport.implementation.StreamOverNetty +import io.netty.handler.codec.quic.QuicStreamChannel +import java.util.concurrent.CompletableFuture + +class QuicStream( + val quicStreamChannel: QuicStreamChannel, + connection: Connection, + initiator: Boolean +) : StreamOverNetty(quicStreamChannel, connection, initiator) { + + init { + pushHandler(QuicStreamReadCloseEventConverter()) + } + + override fun closeWrite(): CompletableFuture { + return quicStreamChannel.shutdownOutput().toVoidCompletableFuture() + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicStreamReadCloseEventConverter.kt b/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicStreamReadCloseEventConverter.kt new file mode 100644 index 000000000..cde8a7781 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicStreamReadCloseEventConverter.kt @@ -0,0 +1,20 @@ +package io.libp2p.transport.quic + +import io.libp2p.etc.util.netty.mux.RemoteWriteClosed +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelInboundHandlerAdapter +import io.netty.channel.socket.ChannelInputShutdownReadComplete + +/** + * Convert QUIC library specific event on remote stream close to Libp2p specific event + */ +class QuicStreamReadCloseEventConverter : ChannelInboundHandlerAdapter() { + + override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { + if (evt == ChannelInputShutdownReadComplete.INSTANCE) { + ctx.fireUserEventTriggered(RemoteWriteClosed) + } else { + super.userEventTriggered(ctx, evt) + } + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicTransport.kt b/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicTransport.kt new file mode 100644 index 000000000..7ea3c0a14 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicTransport.kt @@ -0,0 +1,442 @@ +package io.libp2p.transport.quic + +import io.libp2p.core.* +import io.libp2p.core.crypto.PrivKey +import io.libp2p.core.crypto.unmarshalPublicKey +import io.libp2p.core.multiformats.Multiaddr +import io.libp2p.core.multiformats.MultiaddrDns +import io.libp2p.core.multiformats.Multihash +import io.libp2p.core.multiformats.Protocol.* +import io.libp2p.core.multistream.MultistreamProtocol +import io.libp2p.core.multistream.MultistreamProtocolV1 +import io.libp2p.core.multistream.ProtocolBinding +import io.libp2p.core.mux.StreamMuxer +import io.libp2p.core.security.SecureChannel +import io.libp2p.crypto.keys.generateEcdsaKeyPair +import io.libp2p.crypto.keys.generateEd25519KeyPair +import io.libp2p.etc.CONNECTION +import io.libp2p.etc.STREAM +import io.libp2p.etc.types.* +import io.libp2p.etc.util.MultiaddrUtils +import io.libp2p.etc.util.netty.nettyInitializer +import io.libp2p.security.tls.Libp2pTrustManager +import io.libp2p.security.tls.buildCert +import io.libp2p.security.tls.getJavaKey +import io.libp2p.security.tls.getPublicKeyFromCert +import io.libp2p.security.tls.verifyAndExtractPeerId +import io.libp2p.transport.implementation.ConnectionOverNetty +import io.libp2p.transport.implementation.NettyTransport +import io.netty.bootstrap.Bootstrap +import io.netty.buffer.AdaptiveByteBufAllocator +import io.netty.buffer.ByteBuf +import io.netty.channel.* +import io.netty.channel.epoll.Epoll +import io.netty.channel.epoll.EpollDatagramChannel +import io.netty.channel.nio.NioIoHandler +import io.netty.channel.socket.nio.NioDatagramChannel +import io.netty.handler.codec.quic.* +import io.netty.handler.ssl.ClientAuth +import org.slf4j.LoggerFactory +import java.net.InetSocketAddress +import java.net.SocketAddress +import java.time.Duration +import java.util.* +import java.util.concurrent.CompletableFuture + +class QuicTransport( + private val localKey: PrivKey, + private val certAlgorithm: String, + private val protocols: List> +) : NettyTransport { + + private val logger = LoggerFactory.getLogger(QuicTransport::class.java) + + private var closed = false + + private val connectTimeout = Duration.ofSeconds(15) + + private val listeners = mutableMapOf() + private val channels = mutableListOf() + + private var workerGroup by lazyVar { + MultiThreadIoEventLoopGroup(NioIoHandler.newFactory()) + } + private var allocator by lazyVar { AdaptiveByteBufAllocator(true) } + private var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1 + private var incomingMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol } + + companion object { + @JvmStatic + fun Ed25519(k: PrivKey, p: List>): QuicTransport { + return QuicTransport(k, "Ed25519", p) + } + + @JvmStatic + fun ECDSA(k: PrivKey, p: List>): QuicTransport { + return QuicTransport(k, "ECDSA", p) + } + + private fun createStream(channel: QuicStreamChannel, connection: Connection, initiator: Boolean): Stream { + val stream = QuicStream(channel, connection, initiator) + channel.attr(STREAM).set(stream) + return stream + } + } + + private var client by lazyVar { + Bootstrap().group(workerGroup) + .channel( + if (Epoll.isAvailable()) { + EpollDatagramChannel::class.java + } else { + NioDatagramChannel::class.java + } + ) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout.toMillis().toInt()) + } + + private var server by lazyVar { + Bootstrap().group(workerGroup) + .channel( + if (Epoll.isAvailable()) { + EpollDatagramChannel::class.java + } else { + NioDatagramChannel::class.java + } + ) + } + + override val activeListeners: Int + get() = listeners.size + override val activeConnections: Int + get() = channels.size + + override fun listenAddresses(): List { + return listeners.values.map { + toMultiaddr(it.localAddress() as InetSocketAddress) + } + } + + override fun initialize() { + } + + override fun close(): CompletableFuture { + closed = true + + val unbindsCompleted = listeners + .map { (_, ch) -> ch } + .map { it.close().toVoidCompletableFuture() } + + val channelsClosed = channels + .toMutableList() // need a copy to avoid potential co-modification problems + .map { it.close().toVoidCompletableFuture() } + + val everythingThatNeedsToClose = unbindsCompleted.union(channelsClosed) + val allClosed = CompletableFuture.allOf(*everythingThatNeedsToClose.toTypedArray()) + + return allClosed.thenCompose { + workerGroup.shutdownGracefully().toVoidCompletableFuture() + } + } + + override fun listen( + addr: Multiaddr, + connHandler: ConnectionHandler, + preHandler: ChannelVisitor? + ): CompletableFuture { + if (closed) throw Libp2pException("Transport is closed") + + val channelHandler = serverTransportBuilder(connHandler, preHandler) + + val listener = server.clone() + .handler( + nettyInitializer { + registerChannel(it.channel) + it.addLastLocal(channelHandler) + } + ) + + val bindComplete = listener.bind(fromMultiaddr(addr)) + + bindComplete.also { + synchronized(this@QuicTransport) { + listeners += addr to it.channel() + it.channel().closeFuture().addListener { + synchronized(this@QuicTransport) { + listeners -= addr + } + } + } + } + + return bindComplete.toVoidCompletableFuture().thenApply { + logger.info("Quic server listening on {}", addr) + } + } + + override fun unlisten(addr: Multiaddr): CompletableFuture { + return listeners[addr]?.close()?.toVoidCompletableFuture() + ?: throw Libp2pException("No listeners on address $addr") + } + + override fun dial( + addr: Multiaddr, + connHandler: ConnectionHandler, + preHandler: ChannelVisitor? + ): CompletableFuture { + if (closed) throw Libp2pException("Transport is closed") + + val trustManager = Libp2pTrustManager(Optional.ofNullable(addr.getPeerId())) + val sslContext = quicSslContext(true, trustManager) + val requestsHandler = QuicClientCodecBuilder() + .sslEngineProvider { q -> sslContext.newEngine(q.alloc()) } + .sslTaskExecutor(workerGroup) + .initialMaxData(1 shl 20) + .initialMaxStreamsBidirectional(64) + .initialMaxStreamDataBidirectionalRemote(1 shl 18) + .initialMaxStreamDataBidirectionalLocal(1 shl 18) + .build() + + return client.clone() + .handler(requestsHandler) + .bind(0) + .toCompletableFuture() + .thenCompose { + QuicChannel.newBootstrap(it) + .streamOption(ChannelOption.ALLOCATOR, allocator) + .option(ChannelOption.AUTO_READ, true) + .option(ChannelOption.ALLOCATOR, allocator) + .remoteAddress(fromMultiaddr(addr)) + .streamHandler(InboundStreamHandler(multistreamProtocol, protocols)) + .connect() + .toCompletableFuture() + } + .thenApply { + registerChannel(it) + val connection = ConnectionOverNetty(it, this@QuicTransport, true) + + connection.setMuxerSession(QuicMuxerSession(it, connection)) + + val pubHash = Multihash.of(addr.getPeerId()!!.bytes.toByteBuf()) + val remotePubKey = if (pubHash.desc.digest == Multihash.Digest.Identity) { + unmarshalPublicKey(pubHash.bytes.toByteArray()) + } else { + getPublicKeyFromCert(arrayOf(trustManager.remoteCert!!)) + } + connection.setSecureSession( + SecureChannel.Session( + PeerId.fromPubKey(localKey.publicKey()), + addr.getPeerId()!!, + remotePubKey, + null + ) + ) + + preHandler?.also { visitor -> visitor.visit(connection) } + connHandler.handleConnection(connection) + + it.attr(CONNECTION).set(connection) + + connection + } + } + + private fun registerChannel(ch: Channel) { + if (closed) { + ch.close() + return + } + + synchronized(this@QuicTransport) { + channels += ch + ch.closeFuture().addListener { + synchronized(this@QuicTransport) { + channels -= ch + } + } + } + } + + private fun handlesHost(addr: Multiaddr) = + addr.hasAny(IP4, IP6, DNS4, DNS6, DNSADDR) + + private fun hostFromMultiaddr(addr: Multiaddr): String { + val resolvedAddresses = MultiaddrDns.resolve(addr) + if (resolvedAddresses.isEmpty()) { + throw Libp2pException("Could not resolve $addr to an IP address") + } + + return resolvedAddresses[0].components.find { + it.protocol in arrayOf(IP4, IP6) + }?.stringValue ?: throw Libp2pException("Missing IP4/IP6 in multiaddress $addr") + } + + override fun handles(addr: Multiaddr) = + handlesHost(addr) && + addr.has(UDP) && + addr.has(QUICV1) && + !addr.has(WS) + + fun quicSslContext(isClient: Boolean, trustManager: Libp2pTrustManager): QuicSslContext { + val connectionKeys = if (certAlgorithm == "ECDSA") generateEcdsaKeyPair() else generateEd25519KeyPair() + val javaPrivateKey = getJavaKey(connectionKeys.first) + val cert = buildCert(localKey, connectionKeys.first) + logger.trace("Building {} keys and cert for peer id {}", certAlgorithm, PeerId.fromPubKey(localKey.publicKey())) + return ( + if (isClient) { + QuicSslContextBuilder.forClient().keyManager(javaPrivateKey, null, cert) + } else { + QuicSslContextBuilder.forServer(javaPrivateKey, null, cert).clientAuth(ClientAuth.REQUIRE) + } + ) + .trustManager(trustManager) + .applicationProtocols("libp2p") + .build() + } + + fun serverTransportBuilder( + connHandler: ConnectionHandler, + preHandler: ChannelVisitor? + ): ChannelHandler { + val trustManager = Libp2pTrustManager(Optional.empty()) + val sslContext = quicSslContext(false, trustManager) + return QuicServerCodecBuilder() + .sslEngineProvider { q -> sslContext.newEngine(q.alloc()) } + .sslTaskExecutor(workerGroup) + .tokenHandler(NoTokenHandler()) + .handler( + nettyInitializer { + val connection = ConnectionOverNetty(it.channel, this@QuicTransport, false) + + connection.setMuxerSession(QuicMuxerSession(it.channel as QuicChannel, connection)) + it.channel.attr(CONNECTION).set(connection) + + // Add a handler to wait for channel activation (handshake completion) + it.channel.pipeline().addFirst( + "quic-handshake-waiter", + object : ChannelInboundHandlerAdapter() { + override fun channelActive(ctx: ChannelHandlerContext) { + // Now the handshake is complete and remoteCert should be available + val remoteCert = trustManager.remoteCert + if (remoteCert != null) { + val remotePeerId = verifyAndExtractPeerId(arrayOf(remoteCert)) + val remotePublicKey = getPublicKeyFromCert(arrayOf(remoteCert)) + + logger.info("Handshake completed with remote peer id: {}", remotePeerId) + + connection.setSecureSession( + SecureChannel.Session( + PeerId.fromPubKey(localKey.publicKey()), + remotePeerId, + remotePublicKey, + null + ) + ) + + // Remove this handler as it's no longer needed + ctx.pipeline().remove(this) + + // Now it's safe to call the connection handler + preHandler?.also { visitor -> visitor.visit(connection) } + connHandler.handleConnection(connection) + } else { + // This should not happen if channelActive is called after handshake + ctx.close() + throw IllegalStateException("Remote certificate still not available after handshake") + } + + super.channelActive(ctx) + } + + @Deprecated("Deprecated in Java") + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + logger.error("An error during handshake", cause) + ctx.close() + } + } + ) + } + ) + .initialMaxData(1 shl 20) + .initialMaxStreamsBidirectional(64) + .initialMaxStreamDataBidirectionalRemote(1 shl 18) + .initialMaxStreamDataBidirectionalLocal(1 shl 18) + .streamHandler(InboundStreamHandler(incomingMultistreamProtocol, protocols)) + .build() + } + + class QuicMuxerSession( + val ch: QuicChannel, + val connection: ConnectionOverNetty + ) : StreamMuxer.Session { + + override fun createStream(protocols: List>): StreamPromise { + val multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1 + val streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol } + val multi = streamMultistreamProtocol.createMultistream(protocols) + + val controller = CompletableFuture() + + val stream = ch.createStream( + QuicStreamType.BIDIRECTIONAL, + nettyInitializer { + val stream = createStream(it.channel as QuicStreamChannel, connection, true) + val streamHandler = multi.toStreamHandler() + streamHandler.handleStream(stream).forward(controller) + } + ).toCompletableFuture() + .thenApply { + it.attr(STREAM).get() + } + .forwardException(controller) + + return StreamPromise(stream, controller) + } + } + + class InboundStreamHandler( + val handler: MultistreamProtocol, + val protocols: List> + ) : ChannelInitializer() { + override fun initChannel(ch: QuicStreamChannel) { + val connection = ch.parent().attr(CONNECTION).get() + val stream = createStream(ch, connection, false) + val streamHandler = handler.createMultistream(protocols).toStreamHandler() + streamHandler.handleStream(stream) + } + } + + class NoTokenHandler : QuicTokenHandler { + override fun writeToken(out: ByteBuf?, dcid: ByteBuf?, address: InetSocketAddress?): Boolean { + return false + } + + override fun validateToken(token: ByteBuf?, address: InetSocketAddress?): Int { + return -1 + } + + override fun maxTokenLength(): Int { + return 0 + } + } + + fun udpPortFromMultiaddr(addr: Multiaddr) = + addr.components.find { p -> p.protocol == UDP } + ?.stringValue?.toInt() ?: throw Libp2pException("Missing UDP in multiaddress $addr") + + fun fromMultiaddr(addr: Multiaddr): SocketAddress { + val host = hostFromMultiaddr(addr) + val port = udpPortFromMultiaddr(addr) + return InetSocketAddress(host, port) + } + + override fun localAddress(nettyChannel: Channel): Multiaddr = + toMultiaddr((nettyChannel as QuicChannel).localSocketAddress()!!) + + override fun remoteAddress(nettyChannel: Channel): Multiaddr = + toMultiaddr((nettyChannel as QuicChannel).remoteSocketAddress()!!) + + fun toMultiaddr(addr: SocketAddress): Multiaddr = + MultiaddrUtils.inetSocketAddressToUdpMultiaddr(addr as InetSocketAddress) + .withComponent(QUICV1) +} diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt b/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt index a081ff67d..375c916d4 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt @@ -1,20 +1,17 @@ package io.libp2p.transport.tcp -import io.libp2p.core.InternalErrorException import io.libp2p.core.multiformats.Multiaddr import io.libp2p.core.multiformats.Protocol.DNSADDR -import io.libp2p.core.multiformats.Protocol.IP4 -import io.libp2p.core.multiformats.Protocol.IP6 import io.libp2p.core.multiformats.Protocol.P2PCIRCUIT import io.libp2p.core.multiformats.Protocol.TCP import io.libp2p.core.multiformats.Protocol.WS +import io.libp2p.etc.util.MultiaddrUtils import io.libp2p.transport.ConnectionUpgrader import io.libp2p.transport.implementation.ConnectionBuilder -import io.libp2p.transport.implementation.NettyTransport +import io.libp2p.transport.implementation.PlainNettyTransport import io.netty.channel.ChannelHandler -import java.net.Inet4Address -import java.net.Inet6Address import java.net.InetSocketAddress +import java.net.SocketAddress /** * The TCP transport can establish libp2p connections via TCP endpoints. @@ -24,7 +21,7 @@ import java.net.InetSocketAddress */ open class TcpTransport( upgrader: ConnectionUpgrader -) : NettyTransport(upgrader) { +) : PlainNettyTransport(upgrader) { override fun handles(addr: Multiaddr) = handlesHost(addr) && @@ -43,14 +40,6 @@ open class TcpTransport( addr: Multiaddr ): ChannelHandler? = null - override fun toMultiaddr(addr: InetSocketAddress): Multiaddr { - val proto = when (addr.address) { - is Inet4Address -> IP4 - is Inet6Address -> IP6 - else -> throw InternalErrorException("Unknown address type $addr") - } - return Multiaddr.empty() - .withComponent(proto, addr.address.hostAddress) - .withComponent(TCP, addr.port.toString()) - } // toMultiaddr + override fun toMultiaddr(addr: SocketAddress): Multiaddr = + MultiaddrUtils.inetSocketAddressToTcpMultiaddr(addr as InetSocketAddress) } // class TcpTransport diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/ws/WebSocketClientInitializer.kt b/libp2p/src/main/kotlin/io/libp2p/transport/ws/WebSocketClientInitializer.kt index 1fde9d5c4..d38db14ae 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/ws/WebSocketClientInitializer.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/ws/WebSocketClientInitializer.kt @@ -12,12 +12,12 @@ internal class WebSocketClientInitializer( private val url: String ) : ChannelInitializer() { - public override fun initChannel(ch: SocketChannel) { + override fun initChannel(ch: SocketChannel) { val pipeline = ch.pipeline() pipeline.addLast(HttpClientCodec()) pipeline.addLast(HttpObjectAggregator(65536)) - pipeline.addLast(WebSocketClientCompressionHandler.INSTANCE) + pipeline.addLast(WebSocketClientCompressionHandler(0)) pipeline.addLast( WebSocketClientHandshake( connectionBuilder, diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/ws/WebSocketServerInitializer.kt b/libp2p/src/main/kotlin/io/libp2p/transport/ws/WebSocketServerInitializer.kt index f1a195ef4..0665aad98 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/ws/WebSocketServerInitializer.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/ws/WebSocketServerInitializer.kt @@ -12,12 +12,12 @@ internal class WebSocketServerInitializer( private val connectionBuilder: ChannelHandler ) : ChannelInitializer() { - public override fun initChannel(ch: SocketChannel) { + override fun initChannel(ch: SocketChannel) { val pipeline = ch.pipeline() pipeline.addLast(HttpServerCodec()) pipeline.addLast(HttpObjectAggregator(65536)) - pipeline.addLast(WebSocketServerCompressionHandler()) + pipeline.addLast(WebSocketServerCompressionHandler(0)) pipeline.addLast(WebSocketServerProtocolHandler("/", null, true)) pipeline.addLast(WebSocketServerHandshakeListener(connectionBuilder)) } // initChannel diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/ws/WsTransport.kt b/libp2p/src/main/kotlin/io/libp2p/transport/ws/WsTransport.kt index 431c30d11..8afdca5d6 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/ws/WsTransport.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/ws/WsTransport.kt @@ -1,18 +1,15 @@ package io.libp2p.transport.ws -import io.libp2p.core.InternalErrorException import io.libp2p.core.multiformats.Multiaddr -import io.libp2p.core.multiformats.Protocol.IP4 -import io.libp2p.core.multiformats.Protocol.IP6 import io.libp2p.core.multiformats.Protocol.TCP import io.libp2p.core.multiformats.Protocol.WS +import io.libp2p.etc.util.MultiaddrUtils import io.libp2p.transport.ConnectionUpgrader import io.libp2p.transport.implementation.ConnectionBuilder -import io.libp2p.transport.implementation.NettyTransport +import io.libp2p.transport.implementation.PlainNettyTransport import io.netty.channel.ChannelHandler -import java.net.Inet4Address -import java.net.Inet6Address import java.net.InetSocketAddress +import java.net.SocketAddress /** * The WS transport can establish libp2p connections @@ -20,7 +17,7 @@ import java.net.InetSocketAddress */ class WsTransport( upgrader: ConnectionUpgrader -) : NettyTransport(upgrader) { +) : PlainNettyTransport(upgrader) { override fun handles(addr: Multiaddr) = handlesHost(addr) && @@ -45,15 +42,7 @@ class WsTransport( return WebSocketClientInitializer(connectionBuilder, url) } // clientTransportBuilder - override fun toMultiaddr(addr: InetSocketAddress): Multiaddr { - val proto = when (addr.address) { - is Inet4Address -> IP4 - is Inet6Address -> IP6 - else -> throw InternalErrorException("Unknown address type $addr") - } - return Multiaddr.empty() - .withComponent(proto, addr.address.hostAddress) - .withComponent(TCP, addr.port.toString()) + override fun toMultiaddr(addr: SocketAddress): Multiaddr = + MultiaddrUtils.inetSocketAddressToTcpMultiaddr(addr as InetSocketAddress) .withComponent(WS) - } // toMultiaddr } // class WsTransport diff --git a/libp2p/src/main/proto/rpc.proto b/libp2p/src/main/proto/rpc.proto index 080eef471..7ff5e2bc8 100644 --- a/libp2p/src/main/proto/rpc.proto +++ b/libp2p/src/main/proto/rpc.proto @@ -7,11 +7,21 @@ message RPC { repeated Message publish = 2; message SubOpts { - optional bool subscribe = 1; // subscribe or unsubcribe + optional bool subscribe = 1; // subscribe or unsubscribe optional string topicid = 2; + // signals to receiver that sender prefers partial messages + optional bool requestsPartial = 3; + // signals to receiver that sender supports sending partial messages + optional bool supportsSendingPartial = 4; } optional ControlMessage control = 3; + + // Canonical Extensions + optional PartialMessagesExtension partial = 10; + + // Experimental Extensions + optional TestExtension testExtension = 6492434; } message Message { @@ -29,6 +39,7 @@ message ControlMessage { repeated ControlGraft graft = 3; repeated ControlPrune prune = 4; repeated ControlIDontWant idontwant = 5; + optional ControlExtensions extensions = 6; } message ControlIHave { @@ -54,11 +65,26 @@ message ControlIDontWant { repeated bytes messageIDs = 1; } +message ControlExtensions { + optional bool partialMessages = 10; + + // Experimental extensions must use field numbers larger than 0x200000 to be + // encoded with at least 4 bytes + optional bool testExtension = 6492434; +} + message PeerInfo { optional bytes peerID = 1; optional bytes signedPeerRecord = 2; } +message PartialMessagesExtension { + optional string topicID = 1; + optional bytes groupID = 2; + optional bytes partialMessage = 3; + optional bytes partsMetadata = 4; +} + message TopicDescriptor { optional string name = 1; optional AuthOpts auth = 2; @@ -86,3 +112,5 @@ message TopicDescriptor { } } } + +message TestExtension {} diff --git a/libp2p/src/test/java/io/libp2p/core/HostTestJava.java b/libp2p/src/test/java/io/libp2p/core/HostTestJava.java index bd4f509e0..4b3b9fd37 100644 --- a/libp2p/src/test/java/io/libp2p/core/HostTestJava.java +++ b/libp2p/src/test/java/io/libp2p/core/HostTestJava.java @@ -38,14 +38,14 @@ void ping() throws Exception { Host clientHost = new HostBuilder() .transport(TcpTransport::new) - .secureChannel((k, m) -> new TlsSecureChannel(k, m, "ECDSA")) + .secureChannel(TlsSecureChannel::ECDSA) .muxer(StreamMuxerProtocol::getYamux) .build(); Host serverHost = new HostBuilder() .transport(TcpTransport::new) - .secureChannel(TlsSecureChannel::new) + .secureChannel(TlsSecureChannel::ECDSA) .muxer(StreamMuxerProtocol::getYamux) .protocol(new Ping()) .listen(localListenAddress) @@ -100,14 +100,14 @@ void largePing() throws Exception { Host clientHost = new HostBuilder() .transport(TcpTransport::new) - .secureChannel((k, m) -> new TlsSecureChannel(k, m, "ECDSA")) + .secureChannel(TlsSecureChannel::ECDSA) .muxer(StreamMuxerProtocol::getYamux) .build(); Host serverHost = new HostBuilder() .transport(TcpTransport::new) - .secureChannel(TlsSecureChannel::new) + .secureChannel(TlsSecureChannel::ECDSA) .muxer(StreamMuxerProtocol::getYamux) .protocol(new Ping(pingSize)) .listen(localListenAddress) @@ -227,14 +227,14 @@ void addPingAfterHostStart() throws Exception { Host clientHost = new HostBuilder() .transport(TcpTransport::new) - .secureChannel((k, m) -> new TlsSecureChannel(k, m, "ECDSA")) + .secureChannel(TlsSecureChannel::ECDSA) .muxer(StreamMuxerProtocol::getYamux) .build(); Host serverHost = new HostBuilder() .transport(TcpTransport::new) - .secureChannel(TlsSecureChannel::new) + .secureChannel(TlsSecureChannel::ECDSA) .muxer(StreamMuxerProtocol::getYamux) .listen(localListenAddress) .build(); diff --git a/libp2p/src/test/java/io/libp2p/transport/quic/QuicKuboTestJava.java b/libp2p/src/test/java/io/libp2p/transport/quic/QuicKuboTestJava.java new file mode 100644 index 000000000..c597eb8ba --- /dev/null +++ b/libp2p/src/test/java/io/libp2p/transport/quic/QuicKuboTestJava.java @@ -0,0 +1,84 @@ +package io.libp2p.transport.quic; + +import io.libp2p.core.Host; +import io.libp2p.core.PeerId; +import io.libp2p.core.Stream; +import io.libp2p.core.StreamPromise; +import io.libp2p.core.crypto.*; +import io.libp2p.core.dsl.*; +import io.libp2p.core.multiformats.*; +import io.libp2p.protocol.*; +import java.io.*; +import java.net.*; +import java.util.concurrent.*; +import kotlin.*; +import org.junit.jupiter.api.*; + +public class QuicKuboTestJava { + @Test + void pingKubo() throws Exception { + if (System.getProperty("os.name").toLowerCase().startsWith("windows")) return; + PeerId peerId = PeerId.fromBase58(getKuboPeerId()); + + Host clientHost = + new HostBuilder().keyType(KeyType.ED25519).secureTransport(QuicTransport::ECDSA).build(); + + CompletableFuture clientStarted = clientHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + + StreamPromise ping = + clientHost + .getNetwork() + .connect(peerId, new Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1")) + .thenApply(it -> it.muxerSession().createStream(new Ping())) + .get(5, TimeUnit.SECONDS); + + Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + + for (int i = 0; i < 1000; i++) { + long latency = pingCtr.ping().get(1, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); + } + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> pingCtr.ping().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + } + + private static String getKuboPeerId() throws IOException, URISyntaxException { + String url = "http://localhost:5001/api/v0/id"; + HttpURLConnection conn = (HttpURLConnection) new URI(url).toURL().openConnection(); + conn.setConnectTimeout(1_000); + conn.setDoInput(true); + conn.setDoOutput(true); + + DataOutputStream dout = new DataOutputStream(conn.getOutputStream()); + + dout.write(new byte[0]); + dout.flush(); + + DataInputStream din = new DataInputStream(conn.getInputStream()); + String resp = new String(din.readAllBytes()); + din.close(); + int start = resp.indexOf("ID") + 5; + int end = resp.indexOf("\"", start); + return resp.substring(start, end); + } + + @Test + void keyPairGeneration() { + Pair pair = KeyKt.generateKeyPair(KeyType.SECP256K1); + PeerId peerId = PeerId.fromPubKey(pair.component2()); + System.out.println("PeerId: " + peerId.toHex()); + } +} diff --git a/libp2p/src/test/java/io/libp2p/transport/quic/QuicServerTestJava.java b/libp2p/src/test/java/io/libp2p/transport/quic/QuicServerTestJava.java new file mode 100644 index 000000000..da5f2966d --- /dev/null +++ b/libp2p/src/test/java/io/libp2p/transport/quic/QuicServerTestJava.java @@ -0,0 +1,494 @@ +package io.libp2p.transport.quic; + +import io.libp2p.core.Connection; +import io.libp2p.core.ConnectionHandler; +import io.libp2p.core.Host; +import io.libp2p.core.PeerId; +import io.libp2p.core.Stream; +import io.libp2p.core.StreamPromise; +import io.libp2p.core.crypto.KeyKt; +import io.libp2p.core.crypto.KeyType; +import io.libp2p.core.crypto.PrivKey; +import io.libp2p.core.crypto.PubKey; +import io.libp2p.core.dsl.HostBuilder; +import io.libp2p.core.multiformats.Multiaddr; +import io.libp2p.core.mux.StreamMuxerProtocol; +import io.libp2p.protocol.Blob; +import io.libp2p.protocol.BlobController; +import io.libp2p.protocol.OneShotPing; +import io.libp2p.protocol.OneShotPingController; +import io.libp2p.protocol.Ping; +import io.libp2p.protocol.PingController; +import io.libp2p.security.noise.NoiseXXSecureChannel; +import io.libp2p.security.tls.TlsSecureChannel; +import io.libp2p.transport.tcp.TcpTransport; +import io.netty.handler.logging.LogLevel; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; +import kotlin.Pair; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +public class QuicServerTestJava { + public static int getPort() { + return new Random().nextInt(20_000) + 10_000; + } + + @Test + void pingJava() throws Exception { + String localListenAddress = "/ip4/127.0.0.1/udp/" + getPort() + "/quic-v1"; + + Host clientHost = + new HostBuilder() + .keyType(KeyType.ED25519) + .secureTransport(QuicTransport::ECDSA) + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::ECDSA) + .muxer(StreamMuxerProtocol::getYamux) + .build(); + + Host serverHost = + new HostBuilder() + .keyType(KeyType.ED25519) + .secureTransport(QuicTransport::ECDSA) + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::ECDSA) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping()) + .listen(localListenAddress) + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started " + clientHost.getPeerId()); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started " + serverHost.getPeerId()); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(1, serverHost.listenAddresses().size()); + Assertions.assertEquals( + localListenAddress + "/p2p/" + serverHost.getPeerId(), + serverHost.listenAddresses().get(0).toString()); + System.out.println("Hosts running"); + Thread.sleep(2_000); + + StreamPromise ping = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new Ping(500))) + .get(5000, TimeUnit.SECONDS); + + Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + CompletableFuture controller = ping.getController(); + PingController pingCtr = controller.get(5000, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + pingStream.getConnection().localAddress(); + Multiaddr remote = pingStream.getConnection().remoteAddress(); + Assertions.assertEquals(localListenAddress, remote.toString()); + + for (int i = 0; i < 1000; i++) { + long latency = pingCtr.ping().get(1, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); + } + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> pingCtr.ping().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void checkThatRemotePeerIdCorrectForSECP256K1() throws Exception { + String localListenAddress = "/ip4/127.0.0.1/udp/" + getPort() + "/quic-v1"; + + class TestConnectionHandler implements ConnectionHandler { + public final CompletableFuture remotePeerIdFuture = new CompletableFuture<>(); + + @Override + public void handleConnection(@NotNull Connection conn) { + remotePeerIdFuture.complete(conn.secureSession().getRemoteId()); + } + } + + TestConnectionHandler clientHandler = new TestConnectionHandler(); + TestConnectionHandler serverHandler = new TestConnectionHandler(); + + Host clientHost = + new HostBuilder() + .keyType(KeyType.SECP256K1) + .secureTransport(QuicTransport::ECDSA) + .builderModifier(b -> b.getConnectionHandlers().add(clientHandler)) + .build(); + + Host serverHost = + new HostBuilder() + .keyType(KeyType.SECP256K1) + .secureTransport(QuicTransport::ECDSA) + .transport(TcpTransport::new) + .listen(localListenAddress) + .builderModifier(b -> b.getConnectionHandlers().add(serverHandler)) + .build(); + + clientHost.start().get(5, TimeUnit.SECONDS); + serverHost.start().get(5, TimeUnit.SECONDS); + + clientHost.getNetwork().connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)); + + Assertions.assertEquals( + serverHost.getPeerId(), clientHandler.remotePeerIdFuture.get(10, TimeUnit.SECONDS)); + Assertions.assertEquals( + clientHost.getPeerId(), serverHandler.remotePeerIdFuture.get(10, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + serverHost.stop().get(5, TimeUnit.SECONDS); + } + + @Disabled("Runs too long") + @Test + void checkConnectionIsNotClosedByTimeout() throws Exception { + String localListenAddress = "/ip4/127.0.0.1/udp/" + getPort() + "/quic-v1"; + + Host clientHost = + new HostBuilder().keyType(KeyType.SECP256K1).secureTransport(QuicTransport::ECDSA).build(); + + Host serverHost = + new HostBuilder() + .keyType(KeyType.SECP256K1) + .secureTransport(QuicTransport::ECDSA) + .transport(TcpTransport::new) + .listen(localListenAddress) + .build(); + + clientHost.start().get(5, TimeUnit.SECONDS); + serverHost.start().get(5, TimeUnit.SECONDS); + + Connection connection = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)) + .get(10, TimeUnit.SECONDS); + + try { + long s = System.currentTimeMillis(); + connection.closeFuture().get(60, TimeUnit.SECONDS); + long t = System.currentTimeMillis() - s; + Assertions.fail("closeFuture complete in " + t + " ms"); + } catch (TimeoutException e) { + // expected exception: connection was not closed + } catch (Exception e) { + throw new RuntimeException("Unexpected exception", e); + } + + clientHost.stop().get(5, TimeUnit.SECONDS); + serverHost.stop().get(5, TimeUnit.SECONDS); + } + + @Test + void oneShotPingJava() throws Exception { + String localListenAddress = "/ip4/127.0.0.1/udp/" + getPort() + "/quic-v1"; + + Host clientHost = + new HostBuilder() + .keyType(KeyType.ED25519) + .secureTransport(QuicTransport::ECDSA) + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::ECDSA) + .muxer(StreamMuxerProtocol::getYamux) + .build(); + + Host serverHost = + new HostBuilder() + .keyType(KeyType.ED25519) + .secureTransport(QuicTransport::ECDSA) + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::ECDSA) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new OneShotPing()) + .listen(localListenAddress) + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started " + clientHost.getPeerId()); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started " + serverHost.getPeerId()); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(1, serverHost.listenAddresses().size()); + Assertions.assertEquals( + localListenAddress + "/p2p/" + serverHost.getPeerId(), + serverHost.listenAddresses().get(0).toString()); + System.out.println("Hosts running"); + Thread.sleep(2_000); + + StreamPromise ping = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new OneShotPing(500))) + .get(5000, TimeUnit.SECONDS); + + Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + CompletableFuture controller = ping.getController(); + OneShotPingController pingCtr = controller.get(5000, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + pingStream.getConnection().localAddress(); + Multiaddr remote = pingStream.getConnection().remoteAddress(); + Assertions.assertEquals(localListenAddress, remote.toString()); + + long s = System.currentTimeMillis(); + pingCtr.ping().get(20, TimeUnit.SECONDS); + long l = System.currentTimeMillis() - s; + System.out.println("One Shot Ping is Done in " + l + " ms"); + + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void tlsAndQuicInSameHostPing() throws Exception { + int port = getPort(); + String localQuicListenAddress = "/ip4/127.0.0.1/udp/" + port + "/quic-v1"; + String localTcpListenAddress = "/ip4/127.0.0.1/tcp/" + port; + + Host clientHost = + new HostBuilder() + .keyType(KeyType.ED25519) + .secureTransport(QuicTransport::ECDSA) + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::ECDSA) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .build(); + + Host serverHost = + new HostBuilder() + .keyType(KeyType.ED25519) + .secureTransport(QuicTransport::ECDSA) + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::ECDSA) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping()) + .listen(localQuicListenAddress, localTcpListenAddress) + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started " + clientHost.getPeerId()); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started " + serverHost.getPeerId()); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(2, serverHost.listenAddresses().size()); + Assertions.assertEquals( + Set.of( + localTcpListenAddress + "/p2p/" + serverHost.getPeerId(), + localQuicListenAddress + "/p2p/" + serverHost.getPeerId()), + serverHost.listenAddresses().stream().map(Multiaddr::toString).collect(Collectors.toSet())); + System.out.println("Hosts running"); + Thread.sleep(2_000); + + StreamPromise tcpPing = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localTcpListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new Ping(500))) + .get(5000, TimeUnit.SECONDS); + + Stream pingStream = tcpPing.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + CompletableFuture controller = tcpPing.getController(); + PingController pingCtr = controller.get(5000, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + + for (int i = 0; i < 1000; i++) { + long latency = pingCtr.ping().get(1, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); + } + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> pingCtr.ping().get(5, TimeUnit.SECONDS)); + + StreamPromise quicPing = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localQuicListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new Ping(500))) + .get(5000, TimeUnit.SECONDS); + + Stream quicPingStream = quicPing.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + CompletableFuture quicController = quicPing.getController(); + PingController quicPingCtr = quicController.get(5000, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + + for (int i = 0; i < 1000; i++) { + long latency = quicPingCtr.ping().get(1, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); + } + quicPingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> quicPingCtr.ping().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void largeBlob() throws Exception { + int blobSize = 1024 * 1024; + String localListenAddress = "/ip4/127.0.0.1/udp/" + getPort() + "/quic-v1"; + + Host clientHost = + new HostBuilder() + .keyType(KeyType.ED25519) + .secureTransport(QuicTransport::ECDSA) + .builderModifier( + b -> b.getDebug().getMuxFramesHandler().addCompactLogger(LogLevel.ERROR, "client")) + .build(); + + Host serverHost = + new HostBuilder() + .keyType(KeyType.ED25519) + .secureTransport(QuicTransport::ECDSA) + .protocol(new Blob(blobSize)) + .listen(localListenAddress) + .builderModifier( + b -> b.getDebug().getMuxFramesHandler().addCompactLogger(LogLevel.ERROR, "server")) + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(1, serverHost.listenAddresses().size()); + Assertions.assertEquals( + localListenAddress + "/p2p/" + serverHost.getPeerId(), + serverHost.listenAddresses().get(0).toString()); + + StreamPromise blob = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new Blob(blobSize))) + .join(); + + Stream blobStream = blob.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Blob stream created"); + BlobController blobCtr = blob.getController().get(5, TimeUnit.SECONDS); + System.out.println("Blob controller created"); + + for (int i = 0; i < 10; i++) { + long latency = blobCtr.blob().join(); + System.out.println("Blob round trip is " + latency); + } + blobStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Blob stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> blobCtr.blob().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void startHostAddPing() throws Exception { + String localListenAddress = "/ip4/127.0.0.1/udp/" + getPort() + "/quic-v1"; + + Host clientHost = + new HostBuilder().keyType(KeyType.ED25519).secureTransport(QuicTransport::ECDSA).build(); + + Host serverHost = + new HostBuilder() + .keyType(KeyType.ED25519) + .secureTransport(QuicTransport::ECDSA) + .listen(localListenAddress) + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(1, serverHost.listenAddresses().size()); + Assertions.assertEquals( + localListenAddress + "/p2p/" + serverHost.getPeerId(), + serverHost.listenAddresses().get(0).toString()); + + serverHost.addProtocolHandler(new Ping()); + + StreamPromise ping = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new Ping())) + .get(5, TimeUnit.SECONDS); + + Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + + for (int i = 0; i < 10; i++) { + long latency = pingCtr.ping().get(1, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); + } + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> pingCtr.ping().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void keyPairGeneration() { + Pair pair = KeyKt.generateKeyPair(KeyType.SECP256K1); + PeerId peerId = PeerId.fromPubKey(pair.component2()); + System.out.println("PeerId: " + peerId.toHex()); + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/core/multiformats/MultiaddrTest.kt b/libp2p/src/test/kotlin/io/libp2p/core/multiformats/MultiaddrTest.kt index e1c274aaf..6365114ae 100644 --- a/libp2p/src/test/kotlin/io/libp2p/core/multiformats/MultiaddrTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/core/multiformats/MultiaddrTest.kt @@ -133,6 +133,8 @@ class MultiaddrTest { "/ip4/1.2.3.4/tcp/80/unix/a/b/c/d/e/f", "/ip4/127.0.0.1/ipfs/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC/tcp/1234/unix/stdio", "/ip4/127.0.0.1/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC/tcp/1234/unix/stdio", + "/ip4/127.0.0.1/tcp/127/webrtc-direct", + "/ip4/127.0.0.1/tcp/127/webrtc", "/ip4/127.0.0.1/tcp/40001/p2p/16Uiu2HAkuqGKz8D6khfrnJnDrN5VxWWCoLU8Aq4eCFJuyXmfakB5", "/ip6/2001:6b0:30:1000:d00e:1dff:fe0b:c764/udp/4001/quic-v1/webtransport/certhash/uEiAEz_3prFf34VZff8XqA1iTdq2Ytp467ErTGr5dRFo60Q/certhash/uEiDyL7yksuIGJsYUvf0AHieLkTux5R5KBk-UsFtA1AG18A" ) diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt index bb0f21313..2a74d694e 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt @@ -6,6 +6,7 @@ import io.libp2p.core.StreamHandler import io.libp2p.etc.types.fromHex import io.libp2p.etc.types.getX import io.libp2p.etc.types.toHex +import io.libp2p.etc.util.netty.mux.MuxChannel import io.libp2p.etc.util.netty.mux.RemoteWriteClosed import io.libp2p.etc.util.netty.nettyInitializer import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.* @@ -442,6 +443,26 @@ abstract class MuxHandlerAbstractTest { } } + @Test + fun `write with localDisconnected should fail promise without throwing from doWrite`() { + val handler = openStreamLocal() + readFrameOrThrow() + + // Simulate the state between localDisconnected=true and deactivate() in doDisconnect(), + // which is when a queued WriteTask can reach doWrite with localDisconnected=true while + // the channel is still active (flush0 would take the "not-yet-connected" path otherwise). + @Suppress("UNCHECKED_CAST") + (handler.ctx.channel() as MuxChannel).localDisconnected = true + + val writeFuture = handler.ctx.writeAndFlush(allocateMessage("42")) + ech.runPendingTasks() + + assertTrue(writeFuture.isDone) + assertThrows(ConnectionClosedException::class.java) { + writeFuture.sync() + } + } + @Test fun `should throw when writing to reset stream`() { val handler = openStreamLocal() diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubProtocolTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubProtocolTest.kt new file mode 100644 index 000000000..f121d17ec --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubProtocolTest.kt @@ -0,0 +1,34 @@ +package io.libp2p.pubsub + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test + +class PubsubProtocolTest { + + @Test + fun `supportsBackoffAndPX is true for all GossipSub versions from v1_1 onwards`() { + assertThat(PubsubProtocol.Gossip_V_1_0.supportsBackoffAndPX()).isFalse() + assertThat(PubsubProtocol.Gossip_V_1_1.supportsBackoffAndPX()).isTrue() + assertThat(PubsubProtocol.Gossip_V_1_2.supportsBackoffAndPX()).isTrue() + assertThat(PubsubProtocol.Gossip_V_1_3.supportsBackoffAndPX()).isTrue() + assertThat(PubsubProtocol.Floodsub.supportsBackoffAndPX()).isFalse() + } + + @Test + fun `supportsIDontWant is true for all GossipSub versions from v1_2 onwards`() { + assertThat(PubsubProtocol.Gossip_V_1_0.supportsIDontWant()).isFalse() + assertThat(PubsubProtocol.Gossip_V_1_1.supportsIDontWant()).isFalse() + assertThat(PubsubProtocol.Gossip_V_1_2.supportsIDontWant()).isTrue() + assertThat(PubsubProtocol.Gossip_V_1_3.supportsIDontWant()).isTrue() + assertThat(PubsubProtocol.Floodsub.supportsIDontWant()).isFalse() + } + + @Test + fun `supportsExtensions is true only for GossipSub v1_3`() { + assertThat(PubsubProtocol.Gossip_V_1_0.supportsExtensions()).isFalse() + assertThat(PubsubProtocol.Gossip_V_1_1.supportsExtensions()).isFalse() + assertThat(PubsubProtocol.Gossip_V_1_2.supportsExtensions()).isFalse() + assertThat(PubsubProtocol.Gossip_V_1_3.supportsExtensions()).isTrue() + assertThat(PubsubProtocol.Floodsub.supportsExtensions()).isFalse() + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRpcLimitsDefaultTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRpcLimitsDefaultTest.kt new file mode 100644 index 000000000..99a4342a2 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRpcLimitsDefaultTest.kt @@ -0,0 +1,26 @@ +package io.libp2p.pubsub + +import io.libp2p.pubsub.flood.FloodRouter +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test + +/** + * Pins the toggle-off contract for the inbound count-validation pipeline: any + * [AbstractRouter] subclass that does not opt in must observe + * [PubsubRpcLimits.NONE], so its wire behaviour is unchanged by this defence. + * + * Uses reflection because `rpcLimits` is `protected` and [FloodRouter] is `final`. + */ +class PubsubRpcLimitsDefaultTest { + + @Test + fun `FloodRouter inherits NONE rpcLimits from AbstractRouter`() { + assertThat(FloodRouter().readRpcLimits()).isEqualTo(PubsubRpcLimits.NONE) + } + + private fun AbstractRouter.readRpcLimits(): PubsubRpcLimits { + val getter = AbstractRouter::class.java.getDeclaredMethod("getRpcLimits") + getter.isAccessible = true + return getter.invoke(this) as PubsubRpcLimits + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRpcLimitsTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRpcLimitsTest.kt new file mode 100644 index 000000000..573bddb11 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRpcLimitsTest.kt @@ -0,0 +1,42 @@ +package io.libp2p.pubsub + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource + +class PubsubRpcLimitsTest { + + @Test + fun `NONE is a noop`() { + assertThat(PubsubRpcLimits.NONE.isNoop).isTrue + } + + @ParameterizedTest(name = "non-noop when {0}") + @MethodSource("nonNoopMutations") + fun `any configured limit or reject flag makes isNoop false`( + @Suppress("UNUSED_PARAMETER") label: String, + mutated: PubsubRpcLimits, + ) { + assertThat(mutated.isNoop).isFalse + } + + companion object { + @JvmStatic + fun nonNoopMutations(): List = listOf( + Arguments.of("maxPublishedMessages set", PubsubRpcLimits.NONE.copy(maxPublishedMessages = 1)), + Arguments.of("maxTopicsPerPublishedMessage set", PubsubRpcLimits.NONE.copy(maxTopicsPerPublishedMessage = 1)), + Arguments.of("maxSubscriptions set", PubsubRpcLimits.NONE.copy(maxSubscriptions = 1)), + Arguments.of("maxIHaveMessageIds set", PubsubRpcLimits.NONE.copy(maxIHaveMessageIds = 1)), + Arguments.of("maxIWantMessageIds set", PubsubRpcLimits.NONE.copy(maxIWantMessageIds = 1)), + Arguments.of("maxGraftMessages set", PubsubRpcLimits.NONE.copy(maxGraftMessages = 1)), + Arguments.of("maxPruneMessages set", PubsubRpcLimits.NONE.copy(maxPruneMessages = 1)), + Arguments.of("maxPeersPerPruneMessage set", PubsubRpcLimits.NONE.copy(maxPeersPerPruneMessage = 1)), + Arguments.of("maxIDontWantMessages set", PubsubRpcLimits.NONE.copy(maxIDontWantMessages = 1)), + Arguments.of("maxIDontWantMessageIds set", PubsubRpcLimits.NONE.copy(maxIDontWantMessageIds = 1)), + Arguments.of("rejectEmptyPublishEntries=true", PubsubRpcLimits.NONE.copy(rejectEmptyPublishEntries = true)), + Arguments.of("rejectEmptyIDontWantEntries=true", PubsubRpcLimits.NONE.copy(rejectEmptyIDontWantEntries = true)), + ) + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcCountFrameDecoderAttackTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcCountFrameDecoderAttackTest.kt new file mode 100644 index 000000000..e6bd08a7c --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcCountFrameDecoderAttackTest.kt @@ -0,0 +1,82 @@ +package io.libp2p.pubsub + +import io.libp2p.etc.util.netty.protobuf.LimitedProtobufVarint32FrameDecoder +import io.netty.buffer.ByteBufAllocator +import io.netty.channel.embedded.EmbeddedChannel +import io.netty.handler.codec.protobuf.ProtobufDecoder +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc + +class RpcCountFrameDecoderAttackTest { + + private val limits = PubsubRpcLimits.NONE.copy( + maxPublishedMessages = 1000, + rejectEmptyPublishEntries = true, + ) + + @Test + fun `attack frame of empty publish entries is rejected before materialisation`() { + val maxMsgSize = 12_234_442 // Teku mainnet-preset + + val ch = EmbeddedChannel( + LimitedProtobufVarint32FrameDecoder(maxMsgSize), + RpcCountFrameDecoder(limits), + ProtobufDecoder(Rpc.RPC.getDefaultInstance()), + ) + + // 100_000 empty publish entries: large enough to demonstrate amplification + // would be catastrophic, small enough to keep the test cheap. + val entries = 100_000 + val body = ByteArray(entries * 2) { if (it % 2 == 0) 0x12.toByte() else 0x00.toByte() } + + // Write a length-prefixed frame manually: varint(length) || body. + val framed = ByteBufAllocator.DEFAULT.buffer(body.size + 5) + writeVarint32(framed, body.size) + framed.writeBytes(body) + + ch.writeInbound(framed) + + val received: Any? = ch.readInbound() + assertThat(received).isNull() // ProtobufDecoder never produced an Rpc.RPC + } + + @Test + fun `well-formed RPC under the same limits is still delivered`() { + val maxMsgSize = 12_234_442 + + val ch = EmbeddedChannel( + LimitedProtobufVarint32FrameDecoder(maxMsgSize), + RpcCountFrameDecoder(limits), + ProtobufDecoder(Rpc.RPC.getDefaultInstance()), + ) + + val rpc = Rpc.RPC.newBuilder() + .addPublish( + Rpc.Message.newBuilder().setData(com.google.protobuf.ByteString.copyFromUtf8("ok")) + ) + .build() + val body = rpc.toByteArray() + + val framed = ByteBufAllocator.DEFAULT.buffer(body.size + 5) + writeVarint32(framed, body.size) + framed.writeBytes(body) + + ch.writeInbound(framed) + + val received: Rpc.RPC? = ch.readInbound() + assertThat(received).isEqualTo(rpc) + } + + private fun writeVarint32(buf: io.netty.buffer.ByteBuf, value: Int) { + var v = value + while (true) { + if (v and 0x7F.inv() == 0) { + buf.writeByte(v) + return + } + buf.writeByte((v and 0x7F) or 0x80) + v = v ushr 7 + } + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcCountFrameDecoderTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcCountFrameDecoderTest.kt new file mode 100644 index 000000000..007faf247 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcCountFrameDecoderTest.kt @@ -0,0 +1,112 @@ +package io.libp2p.pubsub + +import com.google.protobuf.ByteString +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled +import io.netty.channel.embedded.EmbeddedChannel +import io.netty.handler.codec.protobuf.ProtobufDecoder +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc + +class RpcCountFrameDecoderTest { + + private val limits = PubsubRpcLimits.NONE.copy( + maxPublishedMessages = 2, + rejectEmptyPublishEntries = true, + ) + + private fun pipeline() = EmbeddedChannel( + RpcCountFrameDecoder(limits), + ProtobufDecoder(Rpc.RPC.getDefaultInstance()), + ) + + @Test + fun `forwards an accepted RPC unchanged`() { + val ch = pipeline() + val rpc = Rpc.RPC.newBuilder() + .addPublish(Rpc.Message.newBuilder().setData(ByteString.copyFromUtf8("x"))) + .build() + + ch.writeInbound(Unpooled.wrappedBuffer(rpc.toByteArray())) + + val received: Rpc.RPC? = ch.readInbound() + assertThat(received).isEqualTo(rpc) + } + + @Test + fun `drops an RPC containing an empty publish entry`() { + val ch = pipeline() + val rpc = Rpc.RPC.newBuilder() + .addPublish(Rpc.Message.getDefaultInstance()) + .build() + + ch.writeInbound(Unpooled.wrappedBuffer(rpc.toByteArray())) + + val received: Any? = ch.readInbound() + assertThat(received).isNull() + } + + @Test + fun `drops an RPC whose publish count exceeds limits`() { + val ch = pipeline() + val rpc = Rpc.RPC.newBuilder() + .apply { + repeat(3) { + addPublish(Rpc.Message.newBuilder().setData(ByteString.copyFromUtf8("x$it"))) + } + } + .build() + + ch.writeInbound(Unpooled.wrappedBuffer(rpc.toByteArray())) + + val received: Any? = ch.readInbound() + assertThat(received).isNull() + } + + /** + * Toggle-off guarantee: with [PubsubRpcLimits.NONE], a frame that would be rejected + * under tighter limits (empty publish entry plus an extra publish over the cap above) + * must pass through the decoder unchanged. + */ + @Test + fun `forwards an otherwise-rejectable RPC when limits are NONE`() { + val ch = EmbeddedChannel( + RpcCountFrameDecoder(PubsubRpcLimits.NONE), + ProtobufDecoder(Rpc.RPC.getDefaultInstance()), + ) + val rpc = Rpc.RPC.newBuilder() + .addPublish(Rpc.Message.getDefaultInstance()) + .apply { + repeat(3) { + addPublish(Rpc.Message.newBuilder().setData(ByteString.copyFromUtf8("x$it"))) + } + } + .build() + + ch.writeInbound(Unpooled.wrappedBuffer(rpc.toByteArray())) + + val received: Rpc.RPC? = ch.readInbound() + assertThat(received).isEqualTo(rpc) + } + + /** + * Fast-path proof: a truncated frame would be flagged `Malformed` by + * [RpcMessageCountValidator] and converted to a [CorruptedFrameException] + * by the decoder. With [PubsubRpcLimits.NONE] the validator must be skipped + * entirely, so the truncated bytes pass through to the next handler unchanged. + */ + @Test + fun `skips validator when limits are noop`() { + val ch = EmbeddedChannel(RpcCountFrameDecoder(PubsubRpcLimits.NONE)) + val rpc = Rpc.RPC.newBuilder() + .addPublish(Rpc.Message.newBuilder().setData(ByteString.copyFromUtf8("x"))) + .build() + val truncated = rpc.toByteArray().copyOfRange(0, rpc.toByteArray().size - 1) + + ch.writeInbound(Unpooled.wrappedBuffer(truncated)) + + val received: Any? = ch.readInbound() + assertThat(received).isInstanceOf(ByteBuf::class.java) + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcMessageCountValidatorProtoCoverageTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcMessageCountValidatorProtoCoverageTest.kt new file mode 100644 index 000000000..1387982fc --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcMessageCountValidatorProtoCoverageTest.kt @@ -0,0 +1,65 @@ +package io.libp2p.pubsub + +import com.google.protobuf.Descriptors +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc + +/** + * Bit-rot guard: any new repeated proto field added to `rpc.proto` must be + * explicitly acknowledged in [RpcMessageCountValidator.ACKNOWLEDGED_REPEATED_FIELDS], + * forcing the author to consider whether it needs a count limit. + * + * Walks every message descriptor reachable from [Rpc.RPC] and compares the + * repeated-field set against the validator's acknowledged set per descriptor. + */ +class RpcMessageCountValidatorProtoCoverageTest { + + @Test + fun `every repeated field reachable from RPC is acknowledged by the validator`() { + val reachable = reachableDescriptors(Rpc.RPC.getDescriptor()) + val expected: Map> = reachable + .associateWith { d -> d.fields.filter { it.isRepeated }.map { it.number }.toSet() } + .filterValues { it.isNotEmpty() } + + val actual = RpcMessageCountValidator.ACKNOWLEDGED_REPEATED_FIELDS + + expected.forEach { (descriptor, expectedFields) -> + val actualFields = actual[descriptor] ?: emptySet() + assertThat(actualFields) + .describedAs( + "Repeated fields in %s must be acknowledged in " + + "RpcMessageCountValidator.ACKNOWLEDGED_REPEATED_FIELDS. " + + "Add new fields (and the corresponding decode-time guard) " + + "before merging.", + descriptor.fullName, + ) + .isEqualTo(expectedFields) + } + + val stale = actual.keys - expected.keys + assertThat(stale) + .describedAs( + "Stale entries in ACKNOWLEDGED_REPEATED_FIELDS — these descriptors " + + "are no longer reachable from Rpc.RPC or no longer contain repeated " + + "fields: %s", + stale.map { it.fullName }, + ) + .isEmpty() + } + + private fun reachableDescriptors(root: Descriptors.Descriptor): Set { + val seen = LinkedHashSet() + val stack: MutableList = mutableListOf(root) + while (stack.isNotEmpty()) { + val descriptor = stack.removeAt(stack.lastIndex) + if (!seen.add(descriptor)) continue + descriptor.fields.forEach { field -> + if (field.javaType == Descriptors.FieldDescriptor.JavaType.MESSAGE) { + stack.add(field.messageType) + } + } + } + return seen + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcMessageCountValidatorTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcMessageCountValidatorTest.kt new file mode 100644 index 000000000..5ad433474 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/RpcMessageCountValidatorTest.kt @@ -0,0 +1,296 @@ +package io.libp2p.pubsub + +import com.google.protobuf.ByteString +import io.netty.buffer.Unpooled +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc + +class RpcMessageCountValidatorTest { + + private val unlimited = PubsubRpcLimits.NONE.copy(rejectEmptyPublishEntries = true) + + private fun bytesOf(rpc: Rpc.RPC) = Unpooled.wrappedBuffer(rpc.toByteArray()) + + private fun message(topics: Int = 0): Rpc.Message { + val b = Rpc.Message.newBuilder().setData(ByteString.copyFromUtf8("x")) + repeat(topics) { b.addTopicIDs("t$it") } + return b.build() + } + + private fun subOpt(topic: String) = + Rpc.RPC.SubOpts.newBuilder().setTopicid(topic).setSubscribe(true).build() + + private fun ihave(ids: Int) = Rpc.ControlIHave.newBuilder() + .setTopicID("t") + .also { repeat(ids) { i -> it.addMessageIDs(ByteString.copyFromUtf8("m$i")) } } + .build() + + private fun iwant(ids: Int) = Rpc.ControlIWant.newBuilder() + .also { repeat(ids) { i -> it.addMessageIDs(ByteString.copyFromUtf8("m$i")) } } + .build() + + private fun idontwant(ids: Int) = Rpc.ControlIDontWant.newBuilder() + .also { repeat(ids) { i -> it.addMessageIDs(ByteString.copyFromUtf8("m$i")) } } + .build() + + private fun pruneWithPeers(peers: Int) = Rpc.ControlPrune.newBuilder() + .setTopicID("t") + .also { + repeat(peers) { i -> + it.addPeers(Rpc.PeerInfo.newBuilder().setPeerID(ByteString.copyFromUtf8("p$i"))) + } + } + .build() + + @Test + fun `rejects RPC containing an empty publish entry`() { + val rpc = Rpc.RPC.newBuilder() + .addPublish(Rpc.Message.getDefaultInstance()) + .build() + + val result = RpcMessageCountValidator.validate(bytesOf(rpc), unlimited) + + assertThat(result).isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `accepts non-empty publish when allowed`() { + val rpc = Rpc.RPC.newBuilder().addPublish(message(topics = 1)).build() + val limits = PubsubRpcLimits.NONE + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isEqualTo(RpcMessageCountValidator.Result.Accepted) + } + + @Test + fun `rejects when publish count exceeds limit`() { + val rpc = Rpc.RPC.newBuilder() + .apply { repeat(3) { addPublish(message(topics = 1)) } } + .build() + val limits = PubsubRpcLimits.NONE.copy(maxPublishedMessages = 2) + val result = RpcMessageCountValidator.validate(bytesOf(rpc), limits) + assertThat(result).isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `rejects when subscriptions count exceeds limit`() { + val rpc = Rpc.RPC.newBuilder() + .apply { repeat(3) { addSubscriptions(subOpt("t$it")) } } + .build() + val limits = PubsubRpcLimits.NONE.copy(maxSubscriptions = 2) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `rejects when topicIDs per publish exceeds limit`() { + val rpc = Rpc.RPC.newBuilder().addPublish(message(topics = 5)).build() + val limits = PubsubRpcLimits.NONE.copy(maxTopicsPerPublishedMessage = 4) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `rejects when ihave messageIDs total exceeds limit`() { + val rpc = Rpc.RPC.newBuilder() + .setControl( + Rpc.ControlMessage.newBuilder() + .addIhave(ihave(ids = 4)) + .addIhave(ihave(ids = 4)) + ) + .build() + val limits = PubsubRpcLimits.NONE.copy(maxIHaveMessageIds = 7) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `rejects when iwant messageIDs total exceeds limit`() { + val rpc = Rpc.RPC.newBuilder() + .setControl(Rpc.ControlMessage.newBuilder().addIwant(iwant(ids = 10))) + .build() + val limits = PubsubRpcLimits.NONE.copy(maxIWantMessageIds = 9) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `rejects when graft count exceeds limit`() { + val rpc = Rpc.RPC.newBuilder() + .setControl( + Rpc.ControlMessage.newBuilder() + .addGraft(Rpc.ControlGraft.newBuilder().setTopicID("a")) + .addGraft(Rpc.ControlGraft.newBuilder().setTopicID("b")) + .addGraft(Rpc.ControlGraft.newBuilder().setTopicID("c")) + ) + .build() + val limits = PubsubRpcLimits.NONE.copy(maxGraftMessages = 2) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `rejects when prune count exceeds limit`() { + val rpc = Rpc.RPC.newBuilder() + .setControl( + Rpc.ControlMessage.newBuilder() + .addPrune(pruneWithPeers(peers = 0)) + .addPrune(pruneWithPeers(peers = 0)) + ) + .build() + val limits = PubsubRpcLimits.NONE.copy(maxPruneMessages = 1) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `rejects when peers per prune exceeds limit`() { + val rpc = Rpc.RPC.newBuilder() + .setControl(Rpc.ControlMessage.newBuilder().addPrune(pruneWithPeers(peers = 17))) + .build() + val limits = PubsubRpcLimits.NONE.copy(maxPeersPerPruneMessage = 16) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `rejects when idontwant messageIDs exceed limit`() { + val rpc = Rpc.RPC.newBuilder() + .setControl(Rpc.ControlMessage.newBuilder().addIdontwant(idontwant(ids = 5))) + .build() + val limits = PubsubRpcLimits.NONE.copy(maxIDontWantMessageIds = 4) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `accepts well-formed RPC under every configured limit`() { + val rpc = Rpc.RPC.newBuilder() + .addSubscriptions(subOpt("t")) + .addPublish(message(topics = 1)) + .setControl( + Rpc.ControlMessage.newBuilder() + .addIhave(ihave(ids = 2)) + .addIwant(iwant(ids = 2)) + .addGraft(Rpc.ControlGraft.newBuilder().setTopicID("t")) + .addPrune(pruneWithPeers(peers = 1)) + .addIdontwant(idontwant(ids = 1)) + ) + .build() + val limits = PubsubRpcLimits( + maxPublishedMessages = 10, + maxTopicsPerPublishedMessage = 4, + maxSubscriptions = 10, + maxIHaveMessageIds = 10, + maxIWantMessageIds = 10, + maxGraftMessages = 10, + maxPruneMessages = 10, + maxPeersPerPruneMessage = 10, + maxIDontWantMessages = 10, + maxIDontWantMessageIds = 10, + rejectEmptyPublishEntries = true, + ) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isEqualTo(RpcMessageCountValidator.Result.Accepted) + } + + @Test + fun `rejects truncated input as malformed`() { + val rpc = Rpc.RPC.newBuilder().addPublish(message(topics = 1)).build() + val full = rpc.toByteArray() + val truncated = full.copyOfRange(0, full.size - 1) + val result = RpcMessageCountValidator.validate(Unpooled.wrappedBuffer(truncated), unlimited) + assertThat(result).isInstanceOf(RpcMessageCountValidator.Result.Malformed::class.java) + } + + @Test + fun `attack payload of empty publish entries rejected on first entry`() { + // 1000 empty publish entries, which would expand to 1000 Rpc.Message objects. + val attack = ByteArray(2 * 1000) { if (it % 2 == 0) 0x12.toByte() else 0x00.toByte() } + val result = RpcMessageCountValidator.validate(Unpooled.wrappedBuffer(attack), unlimited) + assertThat(result).isEqualTo(RpcMessageCountValidator.Result.Rejected("empty publish entry")) + } + + @Test + fun `rejects when idontwant message count exceeds limit`() { + val rpc = Rpc.RPC.newBuilder() + .setControl( + Rpc.ControlMessage.newBuilder() + .addIdontwant(idontwant(ids = 1)) + .addIdontwant(idontwant(ids = 1)) + .addIdontwant(idontwant(ids = 1)) + ) + .build() + val limits = PubsubRpcLimits.NONE.copy(maxIDontWantMessages = 2) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `accepts when count equals limit exactly`() { + val rpc = Rpc.RPC.newBuilder() + .apply { repeat(3) { addPublish(message(topics = 1)) } } + .build() + val limits = PubsubRpcLimits.NONE.copy(maxPublishedMessages = 3) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isEqualTo(RpcMessageCountValidator.Result.Accepted) + } + + @Test + fun `rejects RPC containing an empty idontwant entry`() { + val rpc = Rpc.RPC.newBuilder() + .setControl( + Rpc.ControlMessage.newBuilder() + .addIdontwant(Rpc.ControlIDontWant.getDefaultInstance()) + ) + .build() + val limits = PubsubRpcLimits.NONE.copy(rejectEmptyIDontWantEntries = true) + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } + + @Test + fun `accepts empty idontwant entry when flag is off`() { + val rpc = Rpc.RPC.newBuilder() + .setControl( + Rpc.ControlMessage.newBuilder() + .addIdontwant(Rpc.ControlIDontWant.getDefaultInstance()) + ) + .build() + val limits = PubsubRpcLimits.NONE + assertThat(RpcMessageCountValidator.validate(bytesOf(rpc), limits)) + .isEqualTo(RpcMessageCountValidator.Result.Accepted) + } + + @Test + fun `rejects when graft count across split control fields exceeds limit`() { + // Build a frame with TWO top-level control fields manually, each containing + // 2 grafts. After protobuf merge the ControlMessage has 4 grafts, so a limit + // of 3 must reject. The validator must aggregate across both control fields. + val firstHalf = Rpc.RPC.newBuilder() + .setControl( + Rpc.ControlMessage.newBuilder() + .addGraft(Rpc.ControlGraft.newBuilder().setTopicID("a")) + .addGraft(Rpc.ControlGraft.newBuilder().setTopicID("b")) + ) + .build() + .toByteArray() + val secondHalf = Rpc.RPC.newBuilder() + .setControl( + Rpc.ControlMessage.newBuilder() + .addGraft(Rpc.ControlGraft.newBuilder().setTopicID("c")) + .addGraft(Rpc.ControlGraft.newBuilder().setTopicID("d")) + ) + .build() + .toByteArray() + val combined = firstHalf + secondHalf + + // Sanity: the combined bytes parse to a merged ControlMessage with 4 grafts. + val parsed = Rpc.RPC.parseFrom(combined) + assertThat(parsed.control.graftCount).isEqualTo(4) + + val limits = PubsubRpcLimits.NONE.copy(maxGraftMessages = 3) + val result = RpcMessageCountValidator.validate(Unpooled.wrappedBuffer(combined), limits) + assertThat(result).isInstanceOf(RpcMessageCountValidator.Result.Rejected::class.java) + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsStateTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsStateTest.kt new file mode 100644 index 000000000..93f7aa8b9 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsStateTest.kt @@ -0,0 +1,506 @@ +package io.libp2p.pubsub.gossip + +import io.libp2p.core.PeerId +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import pubsub.pb.Rpc +import java.util.stream.Stream + +class GossipExtensionsStateTest { + + private lateinit var extensionsState: GossipExtensionsState + private lateinit var peer1: PeerId + private lateinit var peer2: PeerId + private lateinit var peer3: PeerId + + @BeforeEach + fun setup() { + extensionsState = GossipExtensionsState( + gossipExtensionsConfig = GossipExtensionsConfig(testExtensionEnabled = true) + ) + peer1 = PeerId.random() + peer2 = PeerId.random() + peer3 = PeerId.random() + } + + @Test + fun `onControlExtensionsMessage() stores peer extensions support`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(false) + .build() + + extensionsState.onControlExtensionsMessage(extension, peer1) + + val stored = extensionsState.peerSupportedExtensions(peer1) + assertThat(stored).isNotNull + assertThat(stored!!.partialMessages).isTrue() + assertThat(stored.testExtension).isFalse() + } + + @Test + fun `hasReceivedControlExtensionsFrom() returns true after receiving extensions`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + + extensionsState.onControlExtensionsMessage(extension, peer1) + + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + } + + @Test + fun `hasReceivedControlExtensionsFrom() returns false for unknown peer`() { + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + } + + @Test + fun `peerSupportedExtensions() returns null for unknown peer`() { + val extensions = extensionsState.peerSupportedExtensions(peer1) + assertThat(extensions).isNull() + } + + /* + In practice, we should not receive more than one control message from the same peer on + the same connection, but if this ever happens, it makes sense to override the in-memory + config given it most likely has the most up-to-date data for that particular peer + */ + @Test + fun `onControlExtensionsMessage() overwrites previous extensions from same peer`() { + val extension1 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(false) + .build() + + val extension2 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(false) + .setTestExtension(true) + .build() + + extensionsState.onControlExtensionsMessage(extension1, peer1) + extensionsState.onControlExtensionsMessage(extension2, peer1) + + val stored = extensionsState.peerSupportedExtensions(peer1) + assertThat(stored).isNotNull + assertThat(stored!!.partialMessages).isFalse() + assertThat(stored.testExtension).isTrue() + } + + @Test + fun `hasSentControlExtensionsTo() returns false for unknown peer`() { + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + } + + @Test + fun `registerControlExtensionMessageSentToPeers() registers peer`() { + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + } + + @Test + fun `hasSentControlExtensionsTo() returns true after registration`() { + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + } + + @Test + fun `registerControlExtensionMessageSentToPeers() can register multiple peers`() { + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + extensionsState.registerControlExtensionMessageSentToPeers(peer2) + + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + assertThat(extensionsState.hasSentControlExtensionsTo(peer2)).isTrue() + assertThat(extensionsState.hasSentControlExtensionsTo(peer3)).isFalse() + } + + @Test + fun `sent and received extension tracking are independent`() { + // Register that we sent to peer1 + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + + // Receive from peer2 + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + extensionsState.onControlExtensionsMessage(extension, peer2) + + // Verify sent tracking + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + assertThat(extensionsState.hasSentControlExtensionsTo(peer2)).isFalse() + + // Verify received tracking + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer2)).isTrue() + } + + @Test + fun `peer can be in both sent and received tracking`() { + // Register that we sent to peer1 + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + + // Receive from peer1 + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + extensionsState.onControlExtensionsMessage(extension, peer1) + + // Both should be tracked + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + } + + @Test + fun `tracks multiple peers with different extensions`() { + val extension1 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(false) + .build() + + val extension2 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(false) + .setTestExtension(true) + .build() + + val extension3 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(true) + .build() + + extensionsState.onControlExtensionsMessage(extension1, peer1) + extensionsState.onControlExtensionsMessage(extension2, peer2) + extensionsState.onControlExtensionsMessage(extension3, peer3) + + // Verify each peer has correct extensions + val stored1 = extensionsState.peerSupportedExtensions(peer1) + assertThat(stored1!!.partialMessages).isTrue() + assertThat(stored1.testExtension).isFalse() + + val stored2 = extensionsState.peerSupportedExtensions(peer2) + assertThat(stored2!!.partialMessages).isFalse() + assertThat(stored2.testExtension).isTrue() + + val stored3 = extensionsState.peerSupportedExtensions(peer3) + assertThat(stored3!!.partialMessages).isTrue() + assertThat(stored3.testExtension).isTrue() + + // Verify all peers are tracked + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer2)).isTrue() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer3)).isTrue() + } + + @Test + fun `tracks different peer extension support for partial messages`() { + val withPartial = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + val withoutPartial = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(false) + .build() + + extensionsState.onControlExtensionsMessage(withPartial, peer1) + extensionsState.onControlExtensionsMessage(withoutPartial, peer2) + + assertThat(extensionsState.peerSupportsPartialMessages(peer1)).isTrue() + assertThat(extensionsState.peerSupportsPartialMessages(peer2)).isFalse() + } + + @Test + fun `tracks many peers simultaneously`() { + val peers = (1..10).map { PeerId.random() } + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + peers.forEach { peer -> + extensionsState.onControlExtensionsMessage(extension, peer) + } + + // Verify all peers are tracked + peers.forEach { peer -> + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer)).isTrue() + assertThat(extensionsState.peerSupportedExtensions(peer)).isNotNull + } + } + + @Test + fun `onPeerDisconnected() removes peer from received extensions map`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + extensionsState.onControlExtensionsMessage(extension, peer1) + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + + extensionsState.onPeerDisconnected(peer1) + + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.peerSupportedExtensions(peer1)).isNull() + } + + @Test + fun `onPeerDisconnected() handles unknown peer gracefully`() { + // Should not throw exception for unknown peer + extensionsState.onPeerDisconnected(peer1) + + // State should remain empty + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.peerSupportedExtensions(peer1)).isNull() + } + + @Test + fun `onPeerDisconnected() only removes specified peer`() { + val extension1 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + val extension2 = Rpc.ControlExtensions.newBuilder() + .setTestExtension(true) + .build() + + extensionsState.onControlExtensionsMessage(extension1, peer1) + extensionsState.onControlExtensionsMessage(extension2, peer2) + + // Disconnect peer1 + extensionsState.onPeerDisconnected(peer1) + + // peer1 should be removed + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.peerSupportedExtensions(peer1)).isNull() + + // peer2 should remain + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer2)).isTrue() + assertThat(extensionsState.peerSupportedExtensions(peer2)).isNotNull + assertThat(extensionsState.peerSupportedExtensions(peer2)!!.testExtension).isTrue() + } + + @Test + fun `multiple disconnects and reconnects work correctly`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + // Connect + extensionsState.onControlExtensionsMessage(extension, peer1) + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + + // Disconnect + extensionsState.onPeerDisconnected(peer1) + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + + // Reconnect with different extensions + val newExtension = Rpc.ControlExtensions.newBuilder() + .setTestExtension(true) + .build() + extensionsState.onControlExtensionsMessage(newExtension, peer1) + + val stored = extensionsState.peerSupportedExtensions(peer1) + assertThat(stored).isNotNull + assertThat(stored!!.hasPartialMessages()).isFalse() + assertThat(stored.testExtension).isTrue() + } + + @Test + fun `onPeerDisconnected() removes peer from sent extensions list`() { + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + + extensionsState.onPeerDisconnected(peer1) + + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + } + + @Test + fun `onPeerDisconnected() removes peer from both sent and received tracking`() { + // Register sent + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + + // Register received + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + extensionsState.onControlExtensionsMessage(extension, peer1) + + // Verify both tracked + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + + // Disconnect + extensionsState.onPeerDisconnected(peer1) + + // Both should be removed + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.peerSupportedExtensions(peer1)).isNull() + } + + @Test + fun `onPeerDisconnected() only removes specified peer from sent list`() { + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + extensionsState.registerControlExtensionMessageSentToPeers(peer2) + + extensionsState.onPeerDisconnected(peer1) + + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + assertThat(extensionsState.hasSentControlExtensionsTo(peer2)).isTrue() + } + + @Test + fun `reconnecting peer can have sent extension registered again`() { + // First connection - register sent + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + + // Disconnect + extensionsState.onPeerDisconnected(peer1) + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + + // Reconnect - register sent again + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + } + + @Test + fun `querying empty state returns expected values`() { + extensionsState = GossipExtensionsState() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + assertThat(extensionsState.peerSupportedExtensions(peer1)).isNull() + } + + @Test + fun `peerSupportsTestExtensions returns true when peer has extension`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setTestExtension(true) + .build() + + extensionsState.onControlExtensionsMessage(extension, peer1) + + assertThat(extensionsState.peerSupportsTestExtensions(peer1)).isTrue() + } + + @Test + fun `peerSupportsTestExtensions returns false when peer doesn't have extension`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setTestExtension(false) + .setPartialMessages(true) + .build() + + extensionsState.onControlExtensionsMessage(extension, peer1) + + assertThat(extensionsState.peerSupportsTestExtensions(peer1)).isFalse() + } + + @Test + fun `peerSupportsTestExtensions returns false for unknown peer`() { + assertThat(extensionsState.peerSupportsTestExtensions(peer1)).isFalse() + } + + @Test + fun `peerSupportsPartialMessages returns true when peer has extension`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + extensionsState.onControlExtensionsMessage(extension, peer1) + + assertThat(extensionsState.peerSupportsPartialMessages(peer1)).isTrue() + } + + @Test + fun `peerSupportsPartialMessages returns false when peer doesn't have extension`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(false) + .setTestExtension(true) + .build() + + extensionsState.onControlExtensionsMessage(extension, peer1) + + assertThat(extensionsState.peerSupportsPartialMessages(peer1)).isFalse() + } + + @Test + fun `peerSupportsPartialMessages returns false for unknown peer`() { + assertThat(extensionsState.peerSupportsPartialMessages(peer1)).isFalse() + } + + @Test + fun `default config has both extensions disabled`() { + val state = GossipExtensionsState() + + assertThat(state.testExtensionsEnabled()).isFalse() + assertThat(state.partialMessagesEnabled()).isFalse() + } + + @ParameterizedTest + @MethodSource("gossipExtensionConfigParams") + fun `config flags combinations for all extensions`( + description: String, + testExtensionsEnabled: Boolean, + partialMessagesEnabled: Boolean + ) { + val config = GossipExtensionsConfig( + testExtensionEnabled = testExtensionsEnabled, + partialMessagesEnabled = partialMessagesEnabled + ) + + assertThat(config.testExtensionEnabled).isEqualTo(testExtensionsEnabled) + .withFailMessage("expected $description") + assertThat(config.partialMessagesEnabled).isEqualTo(partialMessagesEnabled) + .withFailMessage("expected $description") + } + + companion object { + @JvmStatic + fun gossipExtensionConfigParams(): Stream { + return Stream.of( + Arguments.of("both extensions enabled", true, true), + Arguments.of("only test extensions enabled", false, true), + Arguments.of("only partial messages enabled", true, false), + Arguments.of("both extensions disabled", false, false) + ) + } + } + + @Test + fun `localExtensionSupport field reflects config`() { + val state = GossipExtensionsState( + GossipExtensionsConfig( + testExtensionEnabled = true, + partialMessagesEnabled = false + ) + ) + + val localSupport = state.localExtensionSupport + assertThat(localSupport.testExtension).isTrue() + assertThat(localSupport.partialMessages).isFalse() + } + + @Test + fun `peer extension support cleared on disconnect`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setTestExtension(true) + .setPartialMessages(true) + .build() + + extensionsState.onControlExtensionsMessage(extension, peer1) + assertThat(extensionsState.peerSupportsTestExtensions(peer1)).isTrue() + assertThat(extensionsState.peerSupportsPartialMessages(peer1)).isTrue() + + extensionsState.onPeerDisconnected(peer1) + + assertThat(extensionsState.peerSupportsTestExtensions(peer1)).isFalse() + assertThat(extensionsState.peerSupportsPartialMessages(peer1)).isFalse() + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterBuilderTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterBuilderTest.kt new file mode 100644 index 000000000..224b5a7ae --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterBuilderTest.kt @@ -0,0 +1,45 @@ +package io.libp2p.pubsub.gossip + +import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test + +class GossipRouterBuilderTest { + + @Test + fun `builds GossipRouter with both extensions disabled by default`() { + val router = GossipRouterBuilder().build() + + assertThat(router.gossipExtensionsState.testExtensionsEnabled()).isFalse() + assertThat(router.gossipExtensionsState.partialMessagesEnabled()).isFalse() + } + + @Test + fun `localExtensionSupport reflects config in built router`() { + val router = GossipRouterBuilder() + // Enabling only test extensions + .enabledGossipExtensions( + GossipExtension.TEST_EXTENSION + ) + .build() + + val localSupport = router.gossipExtensionsState.localExtensionSupport + assertThat(localSupport.testExtension).isTrue() + assertThat(localSupport.partialMessages).isFalse() + } + + @Test + fun `localExtensionSupport with all extensions enabled`() { + val router = GossipRouterBuilder() + // Enabling all extensions + .enabledGossipExtensions( + GossipExtension.TEST_EXTENSION, + GossipExtension.PARTIAL_MESSAGES, + ) + .build() + + val localSupport = router.gossipExtensionsState.localExtensionSupport + assertThat(localSupport.testExtension).isTrue() + assertThat(localSupport.partialMessages).isTrue() + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt index 5b6b35e55..087b1d0c4 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt @@ -8,6 +8,7 @@ import io.libp2p.pubsub.gossip.builders.GossipParamsBuilder import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedInvocationConstants import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.MethodSource @@ -186,7 +187,7 @@ class GossipRpcPartsQueueTest { fun mergeParams(): Stream = testCases.stream() } - @ParameterizedTest(name = "[${ParameterizedTest.INDEX_PLACEHOLDER}] {0}") + @ParameterizedTest(name = "[${ParameterizedInvocationConstants.INDEX_PLACEHOLDER}] {0}") @MethodSource("mergeParams") fun `mergeMessageParts() test various combinations`( gossipParams: GossipParams, @@ -306,4 +307,191 @@ class GossipRpcPartsQueueTest { .addMessageIDs("2222".toWBytes().toProtobuf()).build(), ) } + + @Test + fun `addControlExtensions() sets testExtension flag in control message`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + val extension = Rpc.ControlExtensions.newBuilder() + .setTestExtension(true) + .build() + + partsQueue.addControlExtensions(extension) + + val res = partsQueue.takeMerged().first() + + assertThat(res.hasControl()).isTrue() + assertThat(res.control.hasExtensions()).isTrue() + assertThat(res.control.extensions.testExtension).isTrue() + } + + @Test + fun `addControlExtensions() sets partialMessages flag in control message`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + partsQueue.addControlExtensions(extension) + + val res = partsQueue.takeMerged().first() + + assertThat(res.hasControl()).isTrue() + assertThat(res.control.hasExtensions()).isTrue() + assertThat(res.control.extensions.partialMessages).isTrue() + } + + @Test + fun `addControlExtensions() sets all extension flags`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(true) + .build() + + partsQueue.addControlExtensions(extension) + + val res = partsQueue.takeMerged().first() + + assertThat(res.hasControl()).isTrue() + assertThat(res.control.hasExtensions()).isTrue() + assertThat(res.control.extensions.partialMessages).isTrue() + assertThat(res.control.extensions.testExtension).isTrue() + } + + @Test + fun `control extensions message works with other control messages`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + // Add various control messages + partsQueue.addIHave(byteArrayOf(1).toWBytes(), "topic1") + partsQueue.addIWant(byteArrayOf(2).toWBytes()) + partsQueue.addGraft("topic2") + partsQueue.addPrune("topic3") + + // Add extension + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + partsQueue.addControlExtensions(extension) + + val res = partsQueue.takeMerged().first() + + // Verify all control messages are present + assertThat(res.hasControl()).isTrue() + assertThat(res.control.ihaveList).hasSize(1) + assertThat(res.control.iwantList).hasSize(1) + assertThat(res.control.graftList).hasSize(1) + assertThat(res.control.pruneList).hasSize(1) + + // Verify extension is present + assertThat(res.control.hasExtensions()).isTrue() + assertThat(res.control.extensions.partialMessages).isTrue() + } + + @Test + fun `control extensions message with subscriptions and publishes`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + partsQueue.addSubscribe("topic1") + partsQueue.addPublish(createRpcMessage("topic1", "data1")) + + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + partsQueue.addControlExtensions(extension) + + val res = partsQueue.takeMerged().first() + + // Verify subscriptions and publishes + assertThat(res.subscriptionsList).hasSize(1) + assertThat(res.publishList).hasSize(1) + + // Verify extension + assertThat(res.control.hasExtensions()).isTrue() + assertThat(res.control.extensions.partialMessages).isTrue() + } + + @Test + fun `control extensions message works with message splitting`() { + val partsQueue = TestGossipQueue(gossipParamsWithLimits) + + // Add enough messages to force splitting + (1..20).forEach { + partsQueue.addPublish(createRpcMessage("topic-$it", "data")) + } + + // Add extension + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + partsQueue.addControlExtensions(extension) + + val merged = partsQueue.takeMerged() + + // Should be split into multiple RPCs due to maxPublishedMessages limit + assertThat(merged.size).isGreaterThan(1) + + // Extension should be in the last RPC (since it's added last) + val lastRpc = merged.last() + assertThat(lastRpc.hasControl()).isTrue() + assertThat(lastRpc.control.hasExtensions()).isTrue() + assertThat(lastRpc.control.extensions.partialMessages).isTrue() + } + + @Test + fun `multiple control extensions messages - last one wins`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + // Add first extension + val extension1 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(false) + .build() + partsQueue.addControlExtensions(extension1) + + // Add second extension (should overwrite first) + val extension2 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(false) + .setTestExtension(true) + .build() + partsQueue.addControlExtensions(extension2) + + val res = partsQueue.takeMerged().first() + + // Verify only the last extension is present + assertThat(res.control.hasExtensions()).isTrue() + // Note: false flags may or may not be serialized depending on protobuf default behavior + // But testExtension should definitely be true + assertThat(res.control.extensions.testExtension).isTrue() + } + + @Test + fun `control extensions message does not count toward limits but may be split`() { + val partsQueue = TestGossipQueue(gossipParamsWithLimits) + + // Add exactly maxPublishedMessages messages + (1..maxPublishedMessages).forEach { + partsQueue.addPublish(createRpcMessage("topic-$it", "data")) + } + + // Add extension + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + partsQueue.addControlExtensions(extension) + + val merged = partsQueue.takeMerged() + + // Extension doesn't count toward limits, but it may end up in a separate RPC + // if it comes after parts that exhaust a limit + assertThat(merged).hasSize(2) + assertThat(merged[0].publishList).hasSize(maxPublishedMessages) + + // Extension should be in the second RPC + assertThat(merged[1].control.hasExtensions()).isTrue() + assertThat(merged[1].control.extensions.partialMessages).isTrue() + } } diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt index ecc912256..1917310e9 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt @@ -62,10 +62,19 @@ abstract class GossipTestsBase { val coreParams: GossipParams = GossipParams(), val scoreParams: GossipScoreParams = GossipScoreParams(), val mockRouterFactory: DeterministicFuzzRouterFactory = createMockFuzzRouterFactory(), - val protocol: PubsubProtocol = PubsubProtocol.Gossip_V_1_1 + val protocol: PubsubProtocol = PubsubProtocol.Gossip_V_1_1, + val enabledGossipExtensions: List = listOf(GossipExtension.TEST_EXTENSION) + ) { val fuzz = DeterministicFuzz() - val gossipRouterBuilderFactory = { GossipRouterBuilder(protocol = protocol, params = coreParams, scoreParams = scoreParams) } + val gossipRouterBuilderFactory = { + GossipRouterBuilder( + protocol = protocol, + params = coreParams, + scoreParams = scoreParams, + enabledGossipExtensions = enabledGossipExtensions + ) + } val router1 = fuzz.createTestRouter(createGossipFuzzRouterFactory(gossipRouterBuilderFactory)) val router2 = fuzz.createTestRouter(mockRouterFactory) val gossipRouter = router1.router as GossipRouter diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt index 905a6b489..856c560ed 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt @@ -258,6 +258,57 @@ class GossipV1_1Tests : GossipTestsBase() { assertEquals(0, test.mockRouter.inboundMessages.size) } + @Test + fun testSubscribeRespectsBackoff() { + // Regression test for the subscribe()-bypasses-backoff bug. + // + // Reproduces the production failure mode where Teku, after being PRUNEd by a peer, + // re-subscribes to a topic (e.g., attestation subnet rotation) and immediately + // GRAFTs back onto peers that are still within their backoff window — accumulating + // P7 behaviour penalties on the remote scorer until disconnect. + // + // The heartbeat-driven mesh maintenance paths correctly filter by isBackOff; + // the subscribe() path historically did not. + val test = TwoRoutersTest() + + test.mockRouter.subscribe("topic1") + + // Let the mock peer's subscription propagate so it appears in topicPeers. + test.fuzz.timeController.addTime(1.seconds) + + // Mock peer pre-emptively PRUNEs us with a 30-second backoff before we subscribe. + val pruneMsg = Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().addPrune( + Rpc.ControlPrune.newBuilder() + .setTopicID("topic1") + .setBackoff(30) + ) + ).build() + test.mockRouter.sendToSingle(pruneMsg) + test.fuzz.timeController.addTime(100.millis) + test.mockRouter.inboundMessages.clear() + + // Now subscribe locally — this is the path that historically grafted without + // checking backoff. + test.gossipRouter.subscribe("topic1") + test.fuzz.timeController.addTime(15.seconds) + + // No GRAFT should have been sent while the backoff is active. + assertEquals( + 0, + test.mockRouter.inboundMessages + .count { it.hasControl() && it.control.graftCount > 0 }, + "subscribe() must not GRAFT a peer that is in backoff" + ) + + // After the backoff expires, the next heartbeat is allowed to GRAFT. + test.fuzz.timeController.addTime(20.seconds) + test.mockRouter.waitForMessage { + it.hasControl() && + it.control.graftCount > 0 && it.control.getGraft(0).topicID == "topic1" + } + } + @Test fun testGraftFloodPenalty() { val test = TwoRoutersTest() diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_3Tests.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_3Tests.kt new file mode 100644 index 000000000..891e48cf7 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_3Tests.kt @@ -0,0 +1,63 @@ +@file:Suppress("ktlint:standard:class-naming") + +package io.libp2p.pubsub.gossip + +import io.libp2p.etc.types.seconds +import io.libp2p.pubsub.PubsubProtocol +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc + +class GossipV1_3Tests : GossipTestsBase() { + + @Test + fun selfSanityTest() { + val test = TwoRoutersTest(protocol = PubsubProtocol.Gossip_V_1_3) + + test.mockRouter.subscribe("topic1") + val msg = newMessage("topic1", 0L, "Hello".toByteArray()) + test.gossipRouter.publish(msg) + test.mockRouter.waitForMessage { it.publishCount > 0 } + } + + @Test + fun testBackoffTimeoutOnV1_3() { + // Regression test: v1.3 must honor PRUNE backoff (inherited from v1.1). + // Previously `supportsBackoffAndPX()` only returned true for v1.1/v1.2, + // causing v1.3 routers to ignore the Backoff field and immediately re-GRAFT. + val test = TwoRoutersTest(protocol = PubsubProtocol.Gossip_V_1_3) + + test.mockRouter.subscribe("topic1") + test.gossipRouter.subscribe("topic1") + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + test.mockRouter.waitForMessage { it.hasControl() && it.control.graftCount > 0 } + test.mockRouter.inboundMessages.clear() + + val pruneMsg = Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().addPrune( + Rpc.ControlPrune.newBuilder() + .setTopicID("topic1") + .setBackoff(30) + ) + ).build() + test.mockRouter.sendToSingle(pruneMsg) + + // No GRAFT should be sent during the backoff window + test.fuzz.timeController.addTime(15.seconds) + assertEquals( + 0, + test.mockRouter.inboundMessages + .count { it.hasControl() && it.control.graftCount > 0 } + ) + test.mockRouter.inboundMessages.clear() + + // Expecting GRAFT after backoff expires + test.fuzz.timeController.addTime(20.seconds) + test.mockRouter.waitForMessage { + it.hasControl() && + it.control.graftCount > 0 && it.control.getGraft(0).topicID == "topic1" + } + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt new file mode 100644 index 000000000..c39f4d51b --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt @@ -0,0 +1,267 @@ +package io.libp2p.pubsub.gossip.extensions + +import io.libp2p.pubsub.PubsubProtocol +import io.libp2p.pubsub.gossip.GossipExtension +import io.libp2p.pubsub.gossip.GossipPeerScoreParams +import io.libp2p.pubsub.gossip.GossipScoreParams +import io.libp2p.pubsub.gossip.GossipTestsBase +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import pubsub.pb.Rpc +import java.util.concurrent.TimeoutException + +private const val DEFAULT_WAIT_TIMEOUT_IN_MILLIS = 500L + +class GossipExtensionsMessageHandlingTest : GossipTestsBase() { + + @Test + fun `extension messages sent to peer prior to gossip v1_3 are ignored`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_2 + ) + + test.mockRouter.sendToSingle(rpcMsgWithCtrlExtensionsAndTestExtension) + assertNoResponseFromTestExtension(test) + } + + @Test + fun `extension messages sent to peer prior to sending control extensions messages are ignored`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3 + ) + + test.mockRouter.sendToSingle(rpcMessageWithTestExtension) + assertNoResponseFromTestExtension(test) + } + + @Test + fun `extension message flow with control extensions message before actual extension message`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3 + ) + + test.mockRouter.sendToSingle(rpcMessageWithControlExtensions) + assertThat(test.gossipRouter.gossipExtensionsState.peerSupportedExtensions(test.router2.peerId)).isEqualTo( + rpcMessageWithControlExtensions.control.extensions + ) + + test.mockRouter.sendToSingle(rpcMessageWithTestExtension) + test.mockRouter.waitForMessage { it.hasTestExtension() } + } + + @Test + fun `extension message flow with control extensions and extension message in the same rpc message`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3 + ) + + test.mockRouter.sendToSingle(rpcMsgWithCtrlExtensionsAndTestExtension) + test.mockRouter.waitForMessage { it.hasTestExtension() } + } + + @Test + fun `remove peer control extensions map when disconnecting`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3 + ) + + test.mockRouter.sendToSingle(rpcMsgWithCtrlExtensionsAndTestExtension) + + assertThat(test.gossipRouter.gossipExtensionsState.peerSupportedExtensions(test.router2.peerId)).isEqualTo( + rpcMsgWithCtrlExtensionsAndTestExtension.control.extensions + ) + + test.mockRouter.waitForMessage { it.hasTestExtension() } + + // Successfully registered peer2 extensions support + + assertThat(test.gossipRouter.gossipExtensionsState.peerSupportedExtensions(test.router2.peerId)).isNotNull() + + test.connection.disconnect() + + // After disconnecting removes peer2 from extensions support map + assertThat(test.gossipRouter.gossipExtensionsState.peerSupportedExtensions(test.router2.peerId)).isNull() + } + + @ParameterizedTest + @MethodSource("protocolVersionsWithExtensionSupport") + fun `control extension message sent to peer on connection with extension support`(protocol: PubsubProtocol) { + val test = TwoRoutersTest(protocol = protocol) + + val receivedMessage = test.mockRouter.waitForMessage( + { it.hasControl() && it.control.hasExtensions() }, + DEFAULT_WAIT_TIMEOUT_IN_MILLIS + ) + + assertThat(receivedMessage.control.extensions.testExtension).isTrue() + } + + @ParameterizedTest + @MethodSource("protocolVersionsWithoutExtensionSupport") + fun `control extension message not sent to peer on connection without extension support`( + protocol: PubsubProtocol + ) { + val test = TwoRoutersTest(protocol = protocol) + + // Should not receive control extension message on versions without extension support + assertThrows { + test.mockRouter.waitForMessage( + { it.hasControl() && it.control.hasExtensions() }, + DEFAULT_WAIT_TIMEOUT_IN_MILLIS + ) + } + } + + @Test + fun `local peer ignores test extension messages when they are disabled in config`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3, + enabledGossipExtensions = listOf() + ) + + test.mockRouter.sendToSingle(rpcMsgWithCtrlExtensionsAndTestExtension) + assertThrows { + test.mockRouter.waitForMessage( + { it.hasTestExtension() }, + DEFAULT_WAIT_TIMEOUT_IN_MILLIS + ) + } + } + + @Test + fun `control extension message contains all supported extensions flags`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3, + enabledGossipExtensions = listOf( + GossipExtension.TEST_EXTENSION, + GossipExtension.PARTIAL_MESSAGES + ) + ) + + val receivedMessage = test.mockRouter.waitForMessage( + { it.hasControl() && it.control.hasExtensions() }, + 2000L + ) + + val extensions = receivedMessage.control.extensions + + // Verify both extension flags are set + assertThat(extensions.hasPartialMessages()).isTrue() + assertThat(extensions.partialMessages).isTrue() + assertThat(extensions.hasTestExtension()).isTrue() + assertThat(extensions.testExtension).isTrue() + } + + @Test + fun `extension state tracks that we sent control extension to peer`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3 + ) + + // Wait for control extension message to be sent + test.mockRouter.waitForMessage( + { it.hasControl() && it.control.hasExtensions() }, + DEFAULT_WAIT_TIMEOUT_IN_MILLIS + ) + + // Should be tracked in state + assertThat(test.gossipRouter.gossipExtensionsState.hasSentControlExtensionsTo(test.router2.peerId)).isTrue() + } + + @Test + fun `control extension sent state cleared on peer disconnect`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3 + ) + + // Wait for control extension message + test.mockRouter.waitForMessage( + { it.hasControl() && it.control.hasExtensions() }, + DEFAULT_WAIT_TIMEOUT_IN_MILLIS + ) + + // Verify it's tracked + assertThat(test.gossipRouter.gossipExtensionsState.hasSentControlExtensionsTo(test.router2.peerId)).isTrue() + + // Disconnect + test.connection.disconnect() + + // Should be cleared from sent tracking + assertThat(test.gossipRouter.gossipExtensionsState.hasSentControlExtensionsTo(test.router2.peerId)).isFalse() + } + + @Test + fun `peer sending multiple control extension messages are downscored`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3, + enabledGossipExtensions = listOf(GossipExtension.PARTIAL_MESSAGES), + // Creating GossipScoreParams with behaviourPenaltyWeight (peer bad behavior affecting + // score). Here we are not interested if the weight is "correct". What we want to see if + // that a peer is penalized for sending more than one ControlExtensions message. + scoreParams = GossipScoreParams( + peerScoreParams = GossipPeerScoreParams( + behaviourPenaltyWeight = -1.0 + ) + ) + ) + + val offendingPeer = test.gossipRouter.peers[0].peerId + val initialScore = test.gossipRouter.score.score(offendingPeer) + + // first ControlExtensions message, no downscoring + test.mockRouter.sendToSingle(rpcMessageWithControlExtensions) + assertThat(test.gossipRouter.score.score(offendingPeer)).isEqualTo(initialScore) + + // second ControlExtensions message, peer downscored + test.mockRouter.sendToSingle(rpcMessageWithControlExtensions) + assertThat(test.gossipRouter.score.score(offendingPeer)).isLessThan(initialScore) + } + + companion object { + @JvmStatic + fun protocolVersionsWithExtensionSupport() = listOf( + PubsubProtocol.Gossip_V_1_3 + ) + + @JvmStatic + fun protocolVersionsWithoutExtensionSupport() = listOf( + PubsubProtocol.Gossip_V_1_1, + PubsubProtocol.Gossip_V_1_2 + ) + + val testExtensionMessage: Rpc.TestExtension = Rpc.TestExtension.newBuilder().build() + + val rpcMessageWithControlExtensions = Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().setExtensions(controlExtensionMessage()) + ).build()!! + + val rpcMessageWithTestExtension = + Rpc.RPC.newBuilder().setTestExtension(testExtensionMessage).build()!! + + // An RPC message with both ControlExtensions and TestExtension message (test extension enabled on control) + val rpcMsgWithCtrlExtensionsAndTestExtension = Rpc.RPC.newBuilder() + .setControl( + Rpc.ControlMessage.newBuilder() + .setExtensions(Rpc.ControlExtensions.newBuilder().setTestExtension(true)) + .build() + ) + .setTestExtension(Rpc.TestExtension.newBuilder().build()) + .build()!! + + fun controlExtensionMessage(testExtensionEnabled: Boolean = true): Rpc.ControlExtensions { + return Rpc.ControlExtensions.newBuilder().setTestExtension(testExtensionEnabled).build() + } + + fun assertNoResponseFromTestExtension(test: TwoRoutersTest) { + assertThrows { + test.mockRouter.waitForMessage( + { it.hasTestExtension() }, + DEFAULT_WAIT_TIMEOUT_IN_MILLIS + ) + } + } + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/security/tls/CertificatesTest.kt b/libp2p/src/test/kotlin/io/libp2p/security/tls/CertificatesTest.kt index feb2e6a98..940a3afa4 100644 --- a/libp2p/src/test/kotlin/io/libp2p/security/tls/CertificatesTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/security/tls/CertificatesTest.kt @@ -18,6 +18,7 @@ class CertificatesTest { val certBytes = Hex.decode(hex) val certHolder = X509CertificateHolder(certBytes) val cert = JcaX509CertificateConverter().setProvider(BouncyCastleProvider()).getCertificate(certHolder) + getPublicKeyFromCert(arrayOf(cert)) val peerIdFromCert = verifyAndExtractPeerId(arrayOf(cert)) val expectedPeerId = PeerId.fromBase58("12D3KooWJRSrypvnpHgc6ZAgyCni4KcSmbV7uGRaMw5LgMKT18fq") assertEquals(peerIdFromCert, expectedPeerId) @@ -29,6 +30,7 @@ class CertificatesTest { val certBytes = Hex.decode(hex) val certHolder = X509CertificateHolder(certBytes) val cert = JcaX509CertificateConverter().setProvider(BouncyCastleProvider()).getCertificate(certHolder) + getPublicKeyFromCert(arrayOf(cert)) val peerIdFromCert = verifyAndExtractPeerId(arrayOf(cert)) val expectedPeerId = PeerId.fromBase58("QmZcrvr3r4S3QvwFdae3c2EWTfo792Y14UpzCZurhmiWeX") assertEquals(peerIdFromCert, expectedPeerId) @@ -40,11 +42,24 @@ class CertificatesTest { val certBytes = Hex.decode(hex) val certHolder = X509CertificateHolder(certBytes) val cert = JcaX509CertificateConverter().setProvider(BouncyCastleProvider()).getCertificate(certHolder) + getPublicKeyFromCert(arrayOf(cert)) val peerIdFromCert = verifyAndExtractPeerId(arrayOf(cert)) val expectedPeerId = PeerId.fromBase58("16Uiu2HAm2dSCBFxuge46aEt7U1oejtYuBUZXxASHqmcfVmk4gsbx") assertEquals(peerIdFromCert, expectedPeerId) } + @Test + fun rustCert() { + val hex = "3082018230820129a00302010202144d1178a3bb828459ce1e266baa234ed8f0615c06300a06082a8648ce3d04030230003020170d3735303130313030303030305a180f34303936303130313030303030305a30003059301306072a8648ce3d020106082a8648ce3d03010703420004089ff3ab6e4b42cb2252a41aff3b8cb7c6f71f7050f6604ff138219f35652de1f6a006f487cd15d88db31e12dcd3b080cd53aa5869a649a13762b6193029f61ca37f307d307b060a2b0601040183a25a01010101ff046a30680424080112207f249e77411a3fa0c3f6305a8446cd45f9fb73ae2412f230f21943cf15dabc3d044025544b48ff50963b5f26b277906a08ba3f231d2d80f399801f856e21e3d9ec2b84c51f8063eb4ae70e52cd940ff82a5aa29b82f3f82b5fb2ae67a9d5bba75c0b300a06082a8648ce3d0403020347003044022031580479526dd6a38a3cc1e90122ac9437d3633aa63f697165099e3d3c4cb3b70220525a60d13802089a9cbb0752646a2801df74d06d6f7785ff21931dca4e188e16" + val certBytes = Hex.decode(hex) + val certHolder = X509CertificateHolder(certBytes) + val cert = JcaX509CertificateConverter().setProvider(BouncyCastleProvider()).getCertificate(certHolder) + getPublicKeyFromCert(arrayOf(cert)) + val peerIdFromCert = verifyAndExtractPeerId(arrayOf(cert)) + val expectedPeerId = PeerId.fromBase58("12D3KooWJNgLEeuYt54A58gcnsggjHhVt6YBsrK71QRXTzK9WABn") + assertEquals(peerIdFromCert, expectedPeerId) + } + @Test fun invalidCert() { val hex = "308201773082011da003020102020830a73c5d896a1109300a06082a8648ce3d04030230003020170d3735303130313030303030305a180f34303936303130313030303030305a30003059301306072a8648ce3d020106082a8648ce3d03010703420004bbe62df9a7c1c46b7f1f21d556deec5382a36df146fb29c7f1240e60d7d5328570e3b71d99602b77a65c9b3655f62837f8d66b59f1763b8c9beba3be07778043a37f307d307b060a2b0601040183a25a01010101ff046a3068042408011220ec8094573afb9728088860864f7bcea2d4fd412fef09a8e2d24d482377c20db60440ecabae8354afa2f0af4b8d2ad871e865cb5a7c0c8d3dbdbf42de577f92461a0ebb0a28703e33581af7d2a4f2270fc37aec6261fcc95f8af08f3f4806581c730a300a06082a8648ce3d040302034800304502202dfb17a6fa0f94ee0e2e6a3b9fb6e986f311dee27392058016464bd130930a61022100ba4b937a11c8d3172b81e7cd04aedb79b978c4379c2b5b24d565dd5d67d3cb3c" @@ -59,6 +74,7 @@ class CertificatesTest { val host = generateEd25519KeyPair() val conn = generateEd25519KeyPair() val cert = buildCert(host.first, conn.first) + getPublicKeyFromCert(arrayOf(cert)) val peerIdFromCert = verifyAndExtractPeerId(arrayOf(cert)) val expectedPeerId = PeerId.fromPubKey(host.second) assertEquals(peerIdFromCert, expectedPeerId) diff --git a/libp2p/src/test/kotlin/io/libp2p/protocol/Blob.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/protocol/Blob.kt similarity index 100% rename from libp2p/src/test/kotlin/io/libp2p/protocol/Blob.kt rename to libp2p/src/testFixtures/kotlin/io/libp2p/protocol/Blob.kt diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/protocol/OneShotPing.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/protocol/OneShotPing.kt new file mode 100644 index 000000000..1c85f912f --- /dev/null +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/protocol/OneShotPing.kt @@ -0,0 +1,113 @@ +package io.libp2p.protocol + +import io.libp2p.core.ConnectionClosedException +import io.libp2p.core.Libp2pException +import io.libp2p.core.Stream +import io.libp2p.core.multistream.StrictProtocolBinding +import io.libp2p.etc.types.toByteBuf +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled +import java.util.concurrent.CompletableFuture + +interface OneShotPingController { + fun ping(): CompletableFuture +} + +/** + * Ping responder responds only once when initiator closes the stream for write + */ +class OneShotPing(pingSize: Int) : OneShotPingBinding(OneShotPingProtocol(pingSize)) { + constructor() : this(32) +} + +open class OneShotPingBinding(ping: OneShotPingProtocol) : + StrictProtocolBinding("/ipfs/one-shot-ping/1.0.0", ping) + +open class OneShotPingProtocol(var pingSize: Int) : ProtocolHandler(Long.MAX_VALUE, Long.MAX_VALUE) { + + constructor() : this(32) + + override fun onStartInitiator(stream: Stream): CompletableFuture { + val handler = OneShotPingInitiator() + stream.pushHandler(handler) + return handler.activeFuture + } + + override fun onStartResponder(stream: Stream): CompletableFuture { + val handler = OneShotPingResponder() + stream.pushHandler(handler) + return CompletableFuture.completedFuture(handler) + } + + open inner class OneShotPingResponder : ProtocolMessageHandler, OneShotPingController { + lateinit var stream: Stream + val outBuf = Unpooled.buffer() + + override fun onActivated(stream: Stream) { + println("OneShotPingResponder: onActivated") + this.stream = stream + } + + override fun onMessage(stream: Stream, msg: ByteBuf) { + println("OneShotPingResponder: onMessage $msg") + outBuf.writeBytes(msg) + } + + override fun onReadClosed(stream: Stream) { + println("OneShotPingResponder: onReadClosed") + stream.writeAndFlush(outBuf) + stream.closeWrite() + } + + override fun onClosed(stream: Stream) { + println("OneShotPingResponder: onClosed") + } + + override fun onException(cause: Throwable?) { + println("OneShotPingResponder: onException: $cause") + } + + override fun ping(): CompletableFuture { + throw Libp2pException("This is ping responder only") + } + } + + open inner class OneShotPingInitiator : ProtocolMessageHandler, OneShotPingController { + val activeFuture = CompletableFuture() + val responseFuture = CompletableFuture() + lateinit var stream: Stream + var closed = false + + override fun onActivated(stream: Stream) { + println("OneShotPingInitiator: onActivated") + this.stream = stream + activeFuture.complete(this) + } + + override fun onMessage(stream: Stream, msg: ByteBuf) { + println("OneShotPingInitiator: onMessage $msg") + responseFuture.complete(null) + } + + override fun onReadClosed(stream: Stream) { + println("OneShotPingInitiator: onReadClosed") + } + + override fun onClosed(stream: Stream) { + println("OneShotPingInitiator: onClosed") + activeFuture.completeExceptionally(ConnectionClosedException()) + } + + override fun onException(cause: Throwable?) { + println("OneShotPingInitiator: onException: $cause") + } + + override fun ping(): CompletableFuture { + println("OneShotPingInitiator: ping") + val data = ByteArray(pingSize) + stream.writeAndFlush(data.toByteBuf()) + stream.closeWrite() + return responseFuture + } + } +} diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/DeterministicFuzz.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/DeterministicFuzz.kt index 646ee5c5c..e5050b20e 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/DeterministicFuzz.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/DeterministicFuzz.kt @@ -47,7 +47,7 @@ class DeterministicFuzz { { executor, curTime, random -> routerBuilderFactory().also { it.scheduledAsyncExecutor = executor - it.currentTimeSuppluer = curTime + it.currentTimeSupplier = curTime it.random = random }.build() } diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/MockRouter.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/MockRouter.kt index d214fd7bb..9df88819b 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/MockRouter.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/MockRouter.kt @@ -8,6 +8,8 @@ import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.TimeUnit import java.util.concurrent.TimeoutException +private const val DEFAULT_WAIT_FOR_MESSAGE_TIMEOUT_IN_MILLIS = 5000L + open class MockRouter(executor: ScheduledExecutorService) : AbstractRouter( protocol = PubsubProtocol.Floodsub, executor = executor, @@ -26,9 +28,16 @@ open class MockRouter(executor: ScheduledExecutorService) : AbstractRouter( } fun waitForMessage(predicate: (Rpc.RPC) -> Boolean): Rpc.RPC { + return waitForMessage(predicate, DEFAULT_WAIT_FOR_MESSAGE_TIMEOUT_IN_MILLIS) + } + + fun waitForMessage( + predicate: (Rpc.RPC) -> Boolean, + timeoutInMillis: Long = DEFAULT_WAIT_FOR_MESSAGE_TIMEOUT_IN_MILLIS + ): Rpc.RPC { var cnt = 0 while (true) { - val msg = inboundMessages.poll(5, TimeUnit.SECONDS) + val msg = inboundMessages.poll(timeoutInMillis, TimeUnit.MILLISECONDS) ?: throw TimeoutException("No matching message received among $cnt") if (predicate(msg)) return msg cnt++ @@ -47,4 +56,5 @@ open class MockRouter(executor: ScheduledExecutorService) : AbstractRouter( override fun broadcastOutbound(msg: PubsubMessage): CompletableFuture = CompletableFuture.completedFuture(null) override fun broadcastInbound(msgs: List, receivedFrom: PeerHandler) {} override fun processControl(ctrl: Rpc.ControlMessage, receivedFrom: PeerHandler) {} + override fun processExtensions(msg: Rpc.RPC, receivedFrom: PeerHandler) {} } diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/HostFactory.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/HostFactory.kt index 5cf216865..efb48db64 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/HostFactory.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/HostFactory.kt @@ -20,12 +20,13 @@ import io.libp2p.protocol.Ping import io.libp2p.security.noise.NoiseXXSecureChannel import io.libp2p.transport.tcp.TcpTransport import io.netty.handler.logging.LogLevel +import java.util.* import java.util.concurrent.TimeUnit class HostFactory { var keyType = KeyType.ECDSA - var tcpPort = 5000 + var tcpPort = Random().nextInt(10_000) + 6000 var transportCtor = ::TcpTransport var secureCtor: SecureChannelCtor = ::NoiseXXSecureChannel var mplexCtor = ::MplexStreamMuxer diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/NullTransport.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/NullTransport.kt index f50b4e201..81c10441f 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/NullTransport.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/NullTransport.kt @@ -5,10 +5,13 @@ import io.libp2p.core.Connection import io.libp2p.core.ConnectionHandler import io.libp2p.core.P2PChannel import io.libp2p.core.multiformats.Multiaddr -import io.libp2p.core.transport.Transport +import io.libp2p.etc.util.MultiaddrUtils +import io.libp2p.transport.implementation.NettyTransport +import io.netty.channel.Channel +import java.net.InetSocketAddress import java.util.concurrent.CompletableFuture -class NullTransport : Transport { +class NullTransport : NettyTransport { override val activeConnections: Int get() = stub() override val activeListeners: Int @@ -22,14 +25,22 @@ class NullTransport : Transport { connHandler: ConnectionHandler, preHandler: ChannelVisitor? ): CompletableFuture = stub() + override fun unlisten(addr: Multiaddr): CompletableFuture = stub() override fun dial( addr: Multiaddr, connHandler: ConnectionHandler, preHandler: ChannelVisitor? ): CompletableFuture = stub() + override fun handles(addr: Multiaddr): Boolean = stub() + override fun localAddress(nettyChannel: Channel): Multiaddr = + MultiaddrUtils.inetSocketAddressToTcpMultiaddr(nettyChannel.localAddress() as InetSocketAddress) + + override fun remoteAddress(nettyChannel: Channel): Multiaddr = + MultiaddrUtils.inetSocketAddressToTcpMultiaddr(nettyChannel.remoteAddress() as InetSocketAddress) + private fun stub(): Nothing { throw NotImplementedError("Test stub") } diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/TCPProxy.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/TCPProxy.kt index eb77980d7..6cb5c2d7b 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/TCPProxy.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/TCPProxy.kt @@ -7,7 +7,8 @@ import io.netty.channel.ChannelFuture import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInboundHandlerAdapter import io.netty.channel.ChannelOption -import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.MultiThreadIoEventLoopGroup +import io.netty.channel.nio.NioIoHandler import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.nio.NioSocketChannel import io.netty.handler.logging.LogLevel @@ -19,7 +20,7 @@ class TCPProxy { fun start(listenPort: Int, dialHost: String, dialPort: Int): ChannelFuture { val future = ServerBootstrap().apply { - group(NioEventLoopGroup()) + group(MultiThreadIoEventLoopGroup(NioIoHandler.newFactory())) channel(NioServerSocketChannel::class.java) childHandler( nettyInitializer { @@ -29,7 +30,7 @@ class TCPProxy { serverCtx.channel().pipeline().addFirst(LoggingHandler("server", LogLevel.INFO)) Bootstrap().apply { - group(NioEventLoopGroup()) + group(MultiThreadIoEventLoopGroup(NioIoHandler.newFactory())) channel(NioSocketChannel::class.java) option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5 * 1000) handler(object : ChannelInboundHandlerAdapter() { @@ -38,7 +39,6 @@ class TCPProxy { } override fun channelActive(ctx: ChannelHandlerContext) { -// serverCtx.channel().pipeline().addFirst(LoggingHandler("client", LogLevel.INFO)) client.complete(ctx) } diff --git a/settings.gradle b/settings.gradle index 3b2d71fae..158428ac4 100644 --- a/settings.gradle +++ b/settings.gradle @@ -23,6 +23,7 @@ include ':tools:simulator' include ':examples:chatter' include ':examples:cli-chatter' include ':examples:pinger' +include 'interop-test-client' def getAndroidSdkDir() { def localPropertiesSdkDir = null diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimNetwork.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimNetwork.kt index e2aa4f0a9..3e874e1b3 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimNetwork.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimNetwork.kt @@ -45,7 +45,7 @@ class GossipSimNetwork( protected fun createSimPeer(number: Int): GossipSimPeer { val router = routerFactory(number).also { - it.currentTimeSuppluer = { timeController.time } + it.currentTimeSupplier = { timeController.time } it.serializeMessagesToBytes = false } diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimPeer.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimPeer.kt index 08eadfaaa..018a4b632 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimPeer.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimPeer.kt @@ -23,7 +23,7 @@ class GossipSimPeer( routerBuilder.also { it.name = name it.scheduledAsyncExecutor = simExecutor - it.currentTimeSuppluer = { currentTime() } + it.currentTimeSupplier = { currentTime() } it.random = random }.build() } diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/router/SimGossipRouter.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/router/SimGossipRouter.kt index b6f830f91..2acc1569d 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/router/SimGossipRouter.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/router/SimGossipRouter.kt @@ -33,6 +33,7 @@ class SimGossipRouter( name, mCache, score, + gossipExtensionsConfig = GossipExtensionsConfig(), subscriptionTopicSubscriptionFilter, protocol, executor, diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/router/SimGossipRouterBuilder.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/router/SimGossipRouterBuilder.kt index f1ee3f82c..7096b2fdc 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/router/SimGossipRouterBuilder.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/router/SimGossipRouterBuilder.kt @@ -10,12 +10,12 @@ class SimGossipRouterBuilder : GossipRouterBuilder() { override fun createGossipRouter(): GossipRouter { val gossipScore = - scoreFactory(scoreParams, scheduledAsyncExecutor, currentTimeSuppluer) { gossipRouterEventListeners += it } + scoreFactory(scoreParams, scheduledAsyncExecutor, currentTimeSupplier) { gossipRouterEventListeners += it } val router = SimGossipRouter( params = params, scoreParams = scoreParams, - currentTimeSupplier = currentTimeSuppluer, + currentTimeSupplier = currentTimeSupplier, random = random, name = name, mCache = mCache, diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/util/NullTransport.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/util/NullTransport.kt index 538a9c273..91a51699d 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/util/NullTransport.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/util/NullTransport.kt @@ -5,10 +5,13 @@ import io.libp2p.core.Connection import io.libp2p.core.ConnectionHandler import io.libp2p.core.P2PChannel import io.libp2p.core.multiformats.Multiaddr -import io.libp2p.core.transport.Transport +import io.libp2p.etc.util.MultiaddrUtils +import io.libp2p.transport.implementation.NettyTransport +import io.netty.channel.Channel +import java.net.InetSocketAddress import java.util.concurrent.CompletableFuture -class NullTransport : Transport { +class NullTransport : NettyTransport { override val activeConnections: Int get() = stub() override val activeListeners: Int @@ -30,6 +33,12 @@ class NullTransport : Transport { ): CompletableFuture = stub() override fun handles(addr: Multiaddr): Boolean = stub() + override fun localAddress(nettyChannel: Channel): Multiaddr = + MultiaddrUtils.inetSocketAddressToTcpMultiaddr(nettyChannel.localAddress() as InetSocketAddress) + + override fun remoteAddress(nettyChannel: Channel): Multiaddr = + MultiaddrUtils.inetSocketAddressToTcpMultiaddr(nettyChannel.remoteAddress() as InetSocketAddress) + private fun stub(): Nothing { throw NotImplementedError("Test stub") } diff --git a/versions.gradle b/versions.gradle index 94fa371b0..c6e59e148 100644 --- a/versions.gradle +++ b/versions.gradle @@ -9,36 +9,39 @@ dependencyManagement { dependency "com.google.guava:guava:33.3.1-jre" dependency "org.slf4j:slf4j-api:2.0.9" - dependencySet(group: 'org.apache.logging.log4j', version: '2.24.1') { + dependencySet(group: 'org.apache.logging.log4j', version: '2.25.4') { entry 'log4j-core' entry 'log4j-slf4j2-impl' } - dependencySet(group: 'org.junit.jupiter', version: '5.11.3') { + dependencySet(group: 'org.junit.jupiter', version: '5.13.4') { entry 'junit-jupiter-api' entry 'junit-jupiter-engine' entry 'junit-jupiter-params' } - dependency "io.mockk:mockk:1.13.3" - dependency "org.assertj:assertj-core:3.26.3" + dependencySet(group: 'org.junit.platform', version: '1.13.4') { + entry 'junit-platform-launcher' + entry 'junit-platform-engine' - dependencySet(group: "org.openjdk.jmh", version: "1.37") { - entry 'jmh-core' - entry 'jmh-generator-annprocess' } + dependency "io.mockk:mockk:1.13.3" + dependency "org.assertj:assertj-core:3.27.4" dependencySet(group: "com.google.protobuf", version: "3.25.5") { entry 'protobuf-java' entry 'protoc' } - dependencySet(group: "io.netty", version: "4.1.115.Final") { + dependencySet(group: "io.netty", version: "4.2.10.Final") { entry 'netty-common' entry 'netty-handler' entry 'netty-transport' entry 'netty-buffer' entry 'netty-codec-http' + entry 'netty-codec-protobuf' + entry 'netty-codec-native-quic' entry 'netty-transport-classes-epoll' } + dependency "io.netty:netty-tcnative-boringssl-static:2.0.76.Final" dependency "com.github.multiformats:java-multibase:v1.1.1" dependency "tech.pegasys:noise-java:22.1.0" dependencySet(group: "org.bouncycastle", version: "1.78.1") { @@ -47,4 +50,4 @@ dependencyManagement { entry 'bctls-jdk18on' } } -} \ No newline at end of file +}