diff --git a/.github/deps.json b/.github/deps.json new file mode 100644 index 00000000..12de2df9 --- /dev/null +++ b/.github/deps.json @@ -0,0 +1,174 @@ +{ + "$schema": "https://pilotprotocol.network/.well-known/deps.schema.json", + "version": 1, + "description": "Canonical dependency graph of every Pilot Protocol repo. Read by .github/workflows/orchestrator.yml to compute which downstream nodes need a bump when any package ships a release. This is the SINGLE source of truth — keep it in sync whenever a repo gains or drops an upstream dependency.", + "nodes": { + "web4": { + "repo": "TeoSlayer/pilotprotocol", + "type": "hub", + "go_module": "github.com/TeoSlayer/pilotprotocol", + "depends_on": [] + }, + + "beacon": { + "repo": "pilot-protocol/beacon", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/beacon", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "dataexchange": { + "repo": "pilot-protocol/dataexchange", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/dataexchange", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "eventstream": { + "repo": "pilot-protocol/eventstream", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/eventstream", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "examples": { + "repo": "pilot-protocol/examples", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/examples", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "gateway": { + "repo": "pilot-protocol/gateway", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/gateway", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "handshake": { + "repo": "pilot-protocol/handshake", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/handshake", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "nameserver": { + "repo": "pilot-protocol/nameserver", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/nameserver", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "policy": { + "repo": "pilot-protocol/policy", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/policy", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "rendezvous": { + "repo": "pilot-protocol/rendezvous", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/rendezvous", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "runtime": { + "repo": "pilot-protocol/runtime", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/runtime", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "skillinject": { + "repo": "pilot-protocol/skillinject", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/skillinject", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "trustedagents": { + "repo": "pilot-protocol/trustedagents", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/trustedagents", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "updater": { + "repo": "pilot-protocol/updater", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/updater", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "webhook": { + "repo": "pilot-protocol/webhook", + "type": "go-sibling", + "go_module": "github.com/TeoSlayer/pilotprotocol/webhook", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "app-store": { + "repo": "pilot-protocol/app-store", + "type": "go-app", + "go_module": "github.com/TeoSlayer/pilotprotocol/app-store/integration", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + + "libpilot": { + "repo": "pilot-protocol/libpilot", + "type": "ffi-fan-in", + "go_module": "github.com/TeoSlayer/pilotprotocol/libpilot", + "depends_on": ["web4", "handshake", "policy", "runtime", "skillinject"], + "bump_policy": { + "stable_only": true, + "open_pr_instead_of_main": true, + "_note": "FFI fan-in feeds all 3 SDKs at build time — a bad bump cascades to every SDK. Use PR + CI gate." + } + }, + + "sdk-node": { + "repo": "pilot-protocol/sdk-node", + "type": "sdk", + "package_manager": "npm", + "depends_on": ["libpilot"], + "bump_policy": { "stable_only": true, "open_pr_instead_of_main": true } + }, + "sdk-python": { + "repo": "pilot-protocol/sdk-python", + "type": "sdk", + "package_manager": "pypi", + "depends_on": ["libpilot"], + "bump_policy": { "stable_only": true, "open_pr_instead_of_main": true } + }, + "sdk-swift": { + "repo": "pilot-protocol/sdk-swift", + "type": "sdk", + "package_manager": "swiftpm", + "depends_on": ["libpilot"], + "bump_policy": { "stable_only": true, "open_pr_instead_of_main": true } + }, + + "homebrew-pilot": { + "repo": "TeoSlayer/homebrew-pilot", + "type": "package-tap", + "depends_on": ["web4"], + "bump_policy": { "stable_only": true, "auto_commit_main": true } + }, + "website": { + "repo": "pilot-protocol/website", + "type": "surface", + "depends_on": ["web4"], + "bump_policy": { + "stable_only": false, + "auto_commit_main": true, + "_note": "Edge channel also publishes; manifest preserves latest_stable when upstream is prerelease." + } + }, + + "cosift": { "repo": "pilot-protocol/cosift", "type": "freeloader", "depends_on": [] }, + "pilot-ca": { "repo": "pilot-protocol/pilot-ca", "type": "freeloader", "depends_on": [] }, + "wallet": { "repo": "pilot-protocol/wallet", "type": "indirect", "depends_on": ["app-store"] } + } +} diff --git a/.github/workflows/_template-bump-upstream.yml b/.github/workflows/_template-bump-upstream.yml new file mode 100644 index 00000000..c98a0818 --- /dev/null +++ b/.github/workflows/_template-bump-upstream.yml @@ -0,0 +1,139 @@ +name: bump-upstream + +# TEMPLATE — copy this file into a Go-sibling repo's `.github/workflows/` +# and rename to `bump-upstream.yml`. It receives a `bump-upstream` dispatch +# from web4's orchestrator and rewrites the local `go.mod` to require the +# new version of the upstream that just shipped. +# +# Behavior: +# - If client_payload.is_prerelease == true and this repo's policy is +# stable_only (the orchestrator already filtered, but we re-check), skip. +# - `go mod edit -require=@` +# - `go mod tidy` +# - Commit to main OR open a PR depending on BUMP_MODE below. +# +# Tunables (set as repo variables or hardcode): +# BUMP_MODE auto-commit | pr (default: auto-commit) +# UPSTREAM_MODULE_PATH the import path of the upstream — e.g. +# "github.com/TeoSlayer/pilotprotocol". Each +# sibling has the same hub path; FFI nodes like +# libpilot may receive multiple upstreams (the +# orchestrator sends one dispatch per upstream). +# +# This template targets Go modules. SDK repos (sdk-node, sdk-python, +# sdk-swift) need their own variants — see the per-SDK bump templates. + +on: + repository_dispatch: + types: [bump-upstream] + workflow_dispatch: + inputs: + upstream: + description: 'Upstream node name (informational)' + required: true + version: + description: 'Upstream version (e.g. v1.10.6)' + required: true + +permissions: + contents: write + pull-requests: write + +env: + BUMP_MODE: auto-commit + UPSTREAM_MODULE_PATH: github.com/TeoSlayer/pilotprotocol + +jobs: + bump: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Resolve incoming dispatch + id: input + env: + DISPATCH_UP: ${{ github.event.client_payload.upstream }} + DISPATCH_VER: ${{ github.event.client_payload.version }} + DISPATCH_PRE: ${{ github.event.client_payload.is_prerelease }} + MANUAL_UP: ${{ inputs.upstream }} + MANUAL_VER: ${{ inputs.version }} + run: | + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + UP="$MANUAL_UP"; VER="$MANUAL_VER"; PRE="false" + else + UP="$DISPATCH_UP"; VER="$DISPATCH_VER"; PRE="$DISPATCH_PRE" + fi + # Re-check the prerelease gate — defense in depth in case a + # mis-configured orchestrator forgets to filter. + if [ "$PRE" = "true" ]; then + echo "::notice::upstream ${UP} ${VER} is a prerelease; skipping (stable_only policy)" + echo "skip=1" >> "$GITHUB_OUTPUT" + exit 0 + fi + echo "upstream=$UP" >> "$GITHUB_OUTPUT" + echo "version=$VER" >> "$GITHUB_OUTPUT" + echo "skip=0" >> "$GITHUB_OUTPUT" + + - name: Bump go.mod + if: steps.input.outputs.skip != '1' + env: + VER: ${{ steps.input.outputs.version }} + run: | + # The upstream module path is the hub module — siblings require + # it directly. Sub-package nodes (handshake, policy, ...) are + # currently part of the same Go module, so the require line + # always points at the hub path. If a sibling later becomes its + # own top-level Go module, set UPSTREAM_MODULE_PATH per-repo. + go mod edit -require="${UPSTREAM_MODULE_PATH}@${VER}" + go mod tidy + echo "=== updated go.mod ===" + grep "${UPSTREAM_MODULE_PATH}" go.mod || true + + - name: Verify build still works + if: steps.input.outputs.skip != '1' + run: go build ./... || (echo "::error::build broke after bump"; exit 1) + + - name: Commit (auto-commit mode) + if: steps.input.outputs.skip != '1' && env.BUMP_MODE == 'auto-commit' + env: + UP: ${{ steps.input.outputs.upstream }} + VER: ${{ steps.input.outputs.version }} + run: | + if git diff --quiet go.mod go.sum; then + echo "go.mod unchanged — nothing to commit." + exit 0 + fi + git config user.name "pilot-release-bot" + git config user.email "release-bot@pilotprotocol.network" + git add go.mod go.sum + git commit -m "deps: bump ${UP} to ${VER}" \ + -m "Triggered by orchestrator dispatch from upstream release." + git push origin HEAD:main + + - name: Open PR (pr mode) + if: steps.input.outputs.skip != '1' && env.BUMP_MODE == 'pr' + env: + GH_TOKEN: ${{ github.token }} + UP: ${{ steps.input.outputs.upstream }} + VER: ${{ steps.input.outputs.version }} + run: | + if git diff --quiet go.mod go.sum; then + echo "go.mod unchanged — nothing to commit." + exit 0 + fi + BRANCH="bump-${UP}-${VER}" + git config user.name "pilot-release-bot" + git config user.email "release-bot@pilotprotocol.network" + git checkout -b "$BRANCH" + git add go.mod go.sum + git commit -m "deps: bump ${UP} to ${VER}" + git push origin "$BRANCH" + gh pr create \ + --title "deps: bump ${UP} to ${VER}" \ + --body "Orchestrator dispatch from upstream release ${UP} ${VER}." \ + --base main \ + --head "$BRANCH" diff --git a/.github/workflows/_template-emit-release.yml b/.github/workflows/_template-emit-release.yml new file mode 100644 index 00000000..19127c36 --- /dev/null +++ b/.github/workflows/_template-emit-release.yml @@ -0,0 +1,85 @@ +name: notify-orchestrator + +# TEMPLATE — copy this file into any sibling repo's `.github/workflows/` and +# rename to `notify-orchestrator.yml`. The repo's `release.yml` (or whatever +# workflow tags releases) gains this step at its end. Whenever the repo tags +# a release, this workflow notifies the central orchestrator at +# `web4/.github/workflows/orchestrator.yml`, which then computes the set of +# downstream nodes that need a bump and dispatches to each. +# +# Required secret on this repo: +# ORCHESTRATOR_DISPATCH_TOKEN — repository_dispatch scope on web4. Prefer a +# GitHub App token; one App installation can cover every sibling so the +# secret is the same value everywhere. +# +# Placeholders to fill before adopting: +# — the package name as it appears in deps.json (e.g. "policy", +# "handshake", "libpilot"). Must match exactly. +# +# Example wiring at the end of a sibling's release.yml: +# +# notify: +# needs: release +# uses: ./.github/workflows/notify-orchestrator.yml +# with: +# node: policy +# secrets: +# ORCHESTRATOR_DISPATCH_TOKEN: ${{ secrets.ORCHESTRATOR_DISPATCH_TOKEN }} + +on: + workflow_call: + inputs: + node: + description: 'Node name in web4/.github/deps.json' + required: true + type: string + secrets: + ORCHESTRATOR_DISPATCH_TOKEN: + required: true + workflow_dispatch: + inputs: + node: + description: 'Node name in web4/.github/deps.json' + required: true + version: + description: 'Tag to notify about (e.g. v1.0.5). Defaults to github.ref_name when called from a release workflow.' + required: false + +jobs: + notify: + runs-on: ubuntu-latest + steps: + - name: Emit package-released + env: + TOKEN: ${{ secrets.ORCHESTRATOR_DISPATCH_TOKEN }} + NODE: ${{ inputs.node }} + VERSION: ${{ inputs.version || github.ref_name }} + run: | + if [ -z "$TOKEN" ]; then + echo "::warning::ORCHESTRATOR_DISPATCH_TOKEN unset — skipping notification." + exit 0 + fi + IS_PRE=false + case "$VERSION" in + *-rc*|*-beta*|*-alpha*) IS_PRE=true ;; + esac + + payload=$(jq -nc \ + --arg pkg "$NODE" \ + --arg ver "$VERSION" \ + --argjson pre "$IS_PRE" \ + '{event_type:"package-released", client_payload:{package:$pkg, version:$ver, is_prerelease:$pre}}') + + HTTP_CODE=$(curl -sS -o /tmp/resp -w '%{http_code}' \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${TOKEN}" \ + "https://api.github.com/repos/TeoSlayer/pilotprotocol/dispatches" \ + -d "$payload") + + if [ "$HTTP_CODE" != "204" ]; then + echo "orchestrator notify failed (HTTP $HTTP_CODE):" + cat /tmp/resp + exit 1 + fi + echo "orchestrator notified: ${NODE} ${VERSION} (is_prerelease=${IS_PRE})" diff --git a/.github/workflows/orchestrator.yml b/.github/workflows/orchestrator.yml new file mode 100644 index 00000000..35924c5f --- /dev/null +++ b/.github/workflows/orchestrator.yml @@ -0,0 +1,204 @@ +name: orchestrator + +# Central release orchestrator. Any repo in the constellation that ships a +# release `repository_dispatch`es to this workflow with: +# +# { event_type: "package-released", +# client_payload: { package: "", version: "v1.2.3", is_prerelease: false } } +# +# The orchestrator reads `.github/deps.json` (the canonical dependency graph), +# computes the reverse-transitive closure of `package` — i.e. every node that +# imports the released one, directly or through a chain — and emits a +# `bump-upstream` dispatch to each one. The receivers (one per repo) decide +# whether to commit-to-main, open a PR, or skip based on their own bump_policy. +# +# Why centralize: +# - The dependency graph lives in exactly one place; no repo needs to know +# who its dependents are. Add or remove an edge by editing deps.json here. +# - Topological order is computed once, not negotiated by receivers. +# - One token (SHOCKWAVE_DISPATCH_TOKEN) covers every fan-out hop. +# +# Why deps.json over a fancier graph store: +# - Reviewable: changes show up as PR diffs. +# - Auditable: `git log` tells you when an edge was added. +# - No runtime infra: this workflow is the whole system. + +on: + repository_dispatch: + types: [package-released] + workflow_dispatch: + inputs: + package: + description: 'Upstream package node name (must exist in deps.json)' + required: true + version: + description: 'Upstream version (e.g. v1.10.6)' + required: true + is_prerelease: + description: 'true if the upstream tag is a prerelease' + required: false + default: 'false' + dry_run: + description: 'When true, print the closure but do not dispatch' + required: false + default: 'false' + +permissions: + contents: read + +jobs: + fan-out: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Resolve inputs + id: input + env: + DISPATCH_PKG: ${{ github.event.client_payload.package }} + DISPATCH_VER: ${{ github.event.client_payload.version }} + DISPATCH_PRE: ${{ github.event.client_payload.is_prerelease }} + MANUAL_PKG: ${{ inputs.package }} + MANUAL_VER: ${{ inputs.version }} + MANUAL_PRE: ${{ inputs.is_prerelease }} + MANUAL_DRYRUN: ${{ inputs.dry_run }} + run: | + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + echo "package=$MANUAL_PKG" >> "$GITHUB_OUTPUT" + echo "version=$MANUAL_VER" >> "$GITHUB_OUTPUT" + echo "is_pre=$MANUAL_PRE" >> "$GITHUB_OUTPUT" + echo "dry_run=$MANUAL_DRYRUN" >> "$GITHUB_OUTPUT" + else + echo "package=$DISPATCH_PKG" >> "$GITHUB_OUTPUT" + echo "version=$DISPATCH_VER" >> "$GITHUB_OUTPUT" + echo "is_pre=$DISPATCH_PRE" >> "$GITHUB_OUTPUT" + echo "dry_run=false" >> "$GITHUB_OUTPUT" + fi + + - name: Compute reverse-transitive closure + id: closure + env: + PKG: ${{ steps.input.outputs.package }} + run: | + python3 - <<'PY' > /tmp/closure.json + import json, sys, os + from collections import defaultdict, deque + + deps = json.load(open(".github/deps.json")) + pkg = os.environ["PKG"] + if pkg not in deps["nodes"]: + print(f"::error::package '{pkg}' not in deps.json", file=sys.stderr) + sys.exit(1) + + # Build child map: children[upstream] = [downstream, ...] + children = defaultdict(list) + for node, info in deps["nodes"].items(): + for up in info.get("depends_on", []): + children[up].append(node) + + # BFS by depth so we can emit a topologically-ordered list. + depth = {pkg: 0} + order = [] + q = deque([pkg]) + while q: + n = q.popleft() + for c in children.get(n, []): + # Use max depth across all paths to root — guarantees a + # node's upstream peers in the closure are all visited + # before it is. + new_d = depth[n] + 1 + if c not in depth or new_d > depth[c]: + depth[c] = new_d + q.append(c) + + targets = sorted( + (c for c in depth if c != pkg), + key=lambda c: (depth[c], c), + ) + + out = [] + for t in targets: + info = deps["nodes"][t] + out.append({ + "node": t, + "repo": info["repo"], + "depth": depth[t], + "bump_policy": info.get("bump_policy", {}), + "type": info.get("type", "unknown"), + }) + + json.dump({"upstream": pkg, "targets": out}, sys.stdout, indent=2) + PY + + cat /tmp/closure.json + echo "payload=$(jq -c . /tmp/closure.json)" >> "$GITHUB_OUTPUT" + + - name: Fan-out dispatch (or dry-run report) + if: steps.input.outputs.dry_run != 'true' + env: + TOKEN: ${{ secrets.SHOCKWAVE_DISPATCH_TOKEN }} + UPSTREAM: ${{ steps.input.outputs.package }} + VERSION: ${{ steps.input.outputs.version }} + IS_PRE: ${{ steps.input.outputs.is_pre }} + CLOSURE: ${{ steps.closure.outputs.payload }} + run: | + if [ -z "$TOKEN" ]; then + echo "::warning::SHOCKWAVE_DISPATCH_TOKEN unset — skipping fan-out." + echo "Action required: set the secret and re-run with workflow_dispatch." + exit 0 + fi + + summary="| node | depth | type | result |\n|---|---|---|---|\n" + # Iterate the closure JSON's targets array. + echo "$CLOSURE" | jq -c '.targets[]' | while read -r t; do + REPO=$(echo "$t" | jq -r '.repo') + NODE=$(echo "$t" | jq -r '.node') + DEPTH=$(echo "$t" | jq -r '.depth') + TYPE=$(echo "$t" | jq -r '.type') + STABLE_ONLY=$(echo "$t" | jq -r '.bump_policy.stable_only // false') + + # Per-receiver stable-only gate: orchestrator could skip the + # dispatch entirely when the upstream is a prerelease AND the + # receiver says stable_only. Saves the receiver an API call. + if [ "$IS_PRE" = "true" ] && [ "$STABLE_ONLY" = "true" ]; then + printf " ~ skip %-22s depth=%s (stable_only and upstream is prerelease)\n" "$NODE" "$DEPTH" + summary="${summary}| ${NODE} | ${DEPTH} | ${TYPE} | skipped (prerelease) |\n" + continue + fi + + payload=$(jq -nc \ + --arg upstream "$UPSTREAM" \ + --arg version "$VERSION" \ + --argjson pre "$IS_PRE" \ + '{event_type:"bump-upstream", client_payload:{upstream:$upstream, version:$version, is_prerelease:$pre}}') + + HTTP_CODE=$(curl -sS -o /tmp/resp -w '%{http_code}' \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${TOKEN}" \ + "https://api.github.com/repos/${REPO}/dispatches" \ + -d "$payload") + + if [ "$HTTP_CODE" = "204" ]; then + printf " ✓ dispatched to %-25s depth=%s\n" "$REPO" "$DEPTH" + summary="${summary}| ${NODE} | ${DEPTH} | ${TYPE} | ✓ dispatched |\n" + else + printf " ✗ %-25s HTTP=%s\n" "$REPO" "$HTTP_CODE" + cat /tmp/resp + summary="${summary}| ${NODE} | ${DEPTH} | ${TYPE} | ✗ HTTP ${HTTP_CODE} |\n" + fi + done + + printf "## Orchestrator fan-out for %s %s\n\nUpstream is_prerelease: %s\n\n%s" \ + "$UPSTREAM" "$VERSION" "$IS_PRE" "$summary" >> "$GITHUB_STEP_SUMMARY" + + - name: Dry-run summary only + if: steps.input.outputs.dry_run == 'true' + env: + CLOSURE: ${{ steps.closure.outputs.payload }} + run: | + echo "## Dry-run closure for ${{ steps.input.outputs.package }}" \ + >> "$GITHUB_STEP_SUMMARY" + echo '```json' >> "$GITHUB_STEP_SUMMARY" + echo "$CLOSURE" | jq . >> "$GITHUB_STEP_SUMMARY" + echo '```' >> "$GITHUB_STEP_SUMMARY" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index deabf21e..d0c991d7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,3 +1,40 @@ +# ---------------------------------------------------------------------------- +# Releases continue from this repo as normal — both before and after the +# planned move to `pilot-protocol/`. This comment is just a reminder for +# migration day; it does NOT gate any release. +# +# Why the reminder: GitHub repo transfers do not carry secrets across orgs. +# The current setup only needs GITHUB_TOKEN (auto-issued), so a transfer +# today would not break anything. But future work will re-introduce +# secrets, and they must be re-created on the destination org BEFORE the +# transfer flips DNS — otherwise the first release after the move silently +# falls back to no-op or fails. +# +# Pending (not blocking releases now — will block IF the linked work +# lands before the org migration): +# +# HOMEBREW_TAP_TOKEN - re-introduced when Homebrew auto-publish +# returns. Prefer a GitHub App over a PAT +# via `actions/create-github-app-token@v1`. +# NPM_TOKEN - if PILOT-203 lands sdk-node auto-publish. +# PYPI_TOKEN - if PILOT-203 lands sdk-python auto-publish. +# COSIGN_KEY / COSIGN_PASS - if PILOT-114 lands updater binary signing. +# +# Migration steps (when the day comes): +# 1. List secrets on the source org with `gh secret list --repo `. +# 2. For each non-auto-issued secret, recreate it on the destination +# using the original cleartext value (GitHub never reveals existing +# secret values). +# 3. Transfer the repo via Settings → "Transfer ownership" or +# `gh api repos//transfer -f new_owner=`. +# 4. Re-verify a release tag triggers this workflow successfully. +# +# Track the migration in the org-move runbook. Do NOT delete this comment +# until either: (a) the migration has completed and every reintroduced +# secret is wired against the destination org, or (b) auto-publish and +# binary signing have been formally retired. +# ---------------------------------------------------------------------------- + name: Release on: @@ -240,3 +277,166 @@ jobs: generate_release_notes: true draft: false prerelease: ${{ contains(github.ref_name, '-rc') || contains(github.ref_name, '-beta') }} + + # ---------------------------------------------------------------------------- + # publish-manifest + # + # Regenerates `pilotprotocol.network/.well-known/latest.json` from the tag we + # just shipped, then hands it off to `pilot-protocol/website` via + # `repository_dispatch`. The website side commits the JSON to main, which + # triggers the Cloudflare Pages deploy. + # + # The manifest is the single source of truth that every install surface + # (install.sh, Homebrew formula, SDK release helpers) reads to decide which + # version is current. Failing here does NOT roll the release back — the + # GitHub release is already live — but it does mean install.sh will keep + # serving the old tag until the manifest is republished. The step is best- + # effort and prints a clear hint when the dispatch token is missing. + # ---------------------------------------------------------------------------- + publish-manifest: + name: Publish version manifest + needs: release + runs-on: ubuntu-latest + steps: + - name: Build manifest JSON from this release + id: build + env: + TAG: ${{ github.ref_name }} + IS_PRE: ${{ contains(github.ref_name, '-rc') || contains(github.ref_name, '-beta') }} + run: | + # Pull the just-published checksums.txt directly from the GitHub + # release (it landed there in the previous job). + curl -fsSL -o checksums.txt \ + "https://github.com/${GITHUB_REPOSITORY}/releases/download/${TAG}/checksums.txt" + + sha_for() { + grep " $1\$" checksums.txt | awk '{print $1}' + } + + DARWIN_AMD64=$(sha_for "pilot-darwin-amd64.tar.gz") + DARWIN_ARM64=$(sha_for "pilot-darwin-arm64.tar.gz") + LINUX_AMD64=$(sha_for "pilot-linux-amd64.tar.gz") + LINUX_ARM64=$(sha_for "pilot-linux-arm64.tar.gz") + + # Refuse to publish a manifest with missing checksums — install.sh + # would silently skip verification. + for v in "$DARWIN_AMD64" "$DARWIN_ARM64" "$LINUX_AMD64" "$LINUX_ARM64"; do + if [ -z "$v" ]; then + echo "error: checksums.txt missing one or more platform entries" + cat checksums.txt + exit 1 + fi + done + + # When the new tag is a prerelease, leave latest_stable alone and + # bump only latest_prerelease + channels.edge. The website receiver + # merges into the existing manifest. + if [ "$IS_PRE" = "true" ]; then + STABLE="" ; EDGE="$TAG" + else + STABLE="$TAG" ; EDGE="$TAG" + fi + + UPDATED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ) + OWNER_REPO="${GITHUB_REPOSITORY}" + + cat > manifest.json <> "$GITHUB_OUTPUT" + + - name: Dispatch to website + env: + # SHOCKWAVE_DISPATCH_TOKEN must have `repository_dispatch` scope on + # pilot-protocol/website. Prefer a GitHub App token over a PAT — + # see actions/create-github-app-token@v1. + SHOCKWAVE_TOKEN: ${{ secrets.SHOCKWAVE_DISPATCH_TOKEN }} + run: | + if [ -z "$SHOCKWAVE_TOKEN" ]; then + echo "::warning::SHOCKWAVE_DISPATCH_TOKEN secret unset — skipping manifest dispatch." + echo "Action required: set the secret and re-run this workflow to publish ${{ github.ref_name }}." + exit 0 + fi + HTTP_CODE=$(curl -sS -o /tmp/resp -w '%{http_code}' \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${SHOCKWAVE_TOKEN}" \ + "https://api.github.com/repos/pilot-protocol/website/dispatches" \ + -d '${{ steps.build.outputs.payload }}') + if [ "$HTTP_CODE" != "204" ]; then + echo "dispatch failed (HTTP $HTTP_CODE):" + cat /tmp/resp + exit 1 + fi + echo "manifest dispatch accepted (HTTP 204)" + + # ---------------------------------------------------------------------------- + # shockwave + # + # Notify every package that derives from web4 (Homebrew formula + SDKs) that + # a new release exists. Each downstream repo runs its own bump workflow on + # receiving `repository_dispatch` event_type=upstream-release. + # + # Receivers (each must have a workflow listening for `upstream-release`): + # - pilot-protocol/homebrew-pilot → bump Formula/pilot.rb + # - pilot-protocol/sdk-node → bump pkg version + npm publish + # - pilot-protocol/sdk-python → bump pyproject + PyPI publish + # - pilot-protocol/sdk-swift → bump Package.swift binaryTarget + # + # Soft-fail per receiver: a missing token or a 404 on one repo MUST NOT block + # the others. The job summary at the end lists which targets succeeded so a + # missed dispatch is visible without grepping logs. + # ---------------------------------------------------------------------------- + shockwave: + name: Shockwave fan-out + needs: release + runs-on: ubuntu-latest + steps: + - name: Dispatch to downstream consumers + env: + SHOCKWAVE_TOKEN: ${{ secrets.SHOCKWAVE_DISPATCH_TOKEN }} + TAG: ${{ github.ref_name }} + IS_PRE: ${{ contains(github.ref_name, '-rc') || contains(github.ref_name, '-beta') }} + run: | + if [ -z "$SHOCKWAVE_TOKEN" ]; then + echo "::warning::SHOCKWAVE_DISPATCH_TOKEN secret unset — skipping fan-out." + exit 0 + fi + summary="" + for repo in homebrew-pilot sdk-node sdk-python sdk-swift; do + payload=$(jq -nc --arg tag "$TAG" --argjson pre "$IS_PRE" \ + '{event_type:"upstream-release", client_payload:{tag:$tag, is_prerelease:$pre}}') + HTTP_CODE=$(curl -sS -o /tmp/resp -w '%{http_code}' \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${SHOCKWAVE_TOKEN}" \ + "https://api.github.com/repos/pilot-protocol/${repo}/dispatches" \ + -d "$payload") + if [ "$HTTP_CODE" = "204" ]; then + summary="${summary} ✓ ${repo}\n" + else + summary="${summary} ✗ ${repo} (HTTP ${HTTP_CODE})\n" + echo "::warning::shockwave dispatch failed for ${repo}: HTTP ${HTTP_CODE}" + cat /tmp/resp + fi + done + printf "Shockwave fan-out summary:\n${summary}" >> "$GITHUB_STEP_SUMMARY" diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go index fe1658de..6c387468 100644 --- a/cmd/daemon/main.go +++ b/cmd/daemon/main.go @@ -15,10 +15,10 @@ import ( "syscall" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/config" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - "github.com/TeoSlayer/pilotprotocol/pkg/logging" + "github.com/pilot-protocol/common/config" + "github.com/pilot-protocol/common/driver" + "github.com/pilot-protocol/common/logging" // L11 plugin imports — cmd/daemon (L12) is the only place these // are allowed. The daemon proper imports only pkg/coreapi diff --git a/cmd/pilotctl/main.go b/cmd/pilotctl/main.go index 9bc8ecb7..f9c3d6c5 100644 --- a/cmd/pilotctl/main.go +++ b/cmd/pilotctl/main.go @@ -21,9 +21,9 @@ import ( "syscall" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + "github.com/pilot-protocol/common/driver" + "github.com/pilot-protocol/common/protocol" + registry "github.com/pilot-protocol/common/registry/client" "github.com/pilot-protocol/dataexchange" "github.com/pilot-protocol/eventstream" "github.com/pilot-protocol/policy/policylang" @@ -189,6 +189,100 @@ func fatal(format string, args ...interface{}) { fatalCode("internal", format, args...) } +// nearestCommand returns the closest match from candidates by case-folded +// Levenshtein distance, but only if it's within a tolerance proportional +// to the input length. Returns "" when no suggestion is useful — typing +// "potato" should not suggest "ping". Empty input or empty candidates +// also produce "". +func nearestCommand(input string, candidates []string) string { + if input == "" || len(candidates) == 0 { + return "" + } + in := strings.ToLower(input) + best := "" + bestDist := -1 + for _, c := range candidates { + d := levenshteinDistance(in, strings.ToLower(c)) + if bestDist == -1 || d < bestDist { + best = c + bestDist = d + } + } + // Only suggest if distance is small. Tolerate one typo for short + // inputs (≤3 chars) and up to two for longer. + threshold := 2 + if len(in) <= 3 { + threshold = 1 + } + if bestDist > threshold { + return "" + } + return best +} + +// levenshteinDistance returns the edit distance between a and b. Used by +// nearestCommand to surface "did you mean" suggestions for typo'd +// subcommand names. +func levenshteinDistance(a, b string) int { + if a == b { + return 0 + } + if len(a) == 0 { + return len(b) + } + if len(b) == 0 { + return len(a) + } + prev := make([]int, len(b)+1) + curr := make([]int, len(b)+1) + for j := range prev { + prev[j] = j + } + for i := 1; i <= len(a); i++ { + curr[0] = i + for j := 1; j <= len(b); j++ { + cost := 1 + if a[i-1] == b[j-1] { + cost = 0 + } + del := prev[j] + 1 + ins := curr[j-1] + 1 + sub := prev[j-1] + cost + m := del + if ins < m { + m = ins + } + if sub < m { + m = sub + } + curr[j] = m + } + prev, curr = curr, prev + } + return prev[len(b)] +} + +// knownTopLevelCommands returns the set of top-level command tokens +// pilotctl recognizes, derived from the commandHelp registry. Compound +// keys like "daemon start" contribute only their first token, so the +// suggestion engine matches what the user types at position 0. +func knownTopLevelCommands() []string { + seen := make(map[string]struct{}, len(commandHelp)) + out := make([]string, 0, len(commandHelp)) + for k := range commandHelp { + first := k + if i := strings.IndexByte(k, ' '); i > 0 { + first = k[:i] + } + if _, ok := seen[first]; ok { + continue + } + seen[first] = struct{}{} + out = append(out, first) + } + return out +} + // parseNodeID parses a string as a uint32 node ID or exits with an error (M18 fix). func parseNodeID(s string) uint32 { v, err := strconv.ParseUint(s, 10, 32) @@ -1556,12 +1650,17 @@ dispatch: runDaemonInternal(cmdArgs) default: + hint := "run 'pilotctl' for the full command list" if jsonOutput { - fatalHint("invalid_argument", - "run 'pilotctl context' for the full command list", - "unknown command: %s", cmd) + hint = "run 'pilotctl context' for the full command list" + } + if suggestion := nearestCommand(cmd, knownTopLevelCommands()); suggestion != "" { + hint = "did you mean 'pilotctl " + suggestion + "'? " + hint + } + if jsonOutput { + fatalHint("invalid_argument", hint, "unknown command: %s", cmd) } - fmt.Fprintf(os.Stderr, "unknown command: %s\n\n", cmd) + fmt.Fprintf(os.Stderr, "unknown command: %s\nhint: %s\n\n", cmd, hint) usage() } } diff --git a/cmd/pilotctl/zz_fake_daemon_test.go b/cmd/pilotctl/zz_fake_daemon_test.go index 89615bcb..6dfb154d 100644 --- a/cmd/pilotctl/zz_fake_daemon_test.go +++ b/cmd/pilotctl/zz_fake_daemon_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" + "github.com/pilot-protocol/common/ipcutil" ) // IPC cmd codes — must match pkg/driver/ipc.go and cmd/daemon/ipc.go. diff --git a/cmd/pilotctl/zz_lifecycle_test.go b/cmd/pilotctl/zz_lifecycle_test.go index 27d4a347..5db73327 100644 --- a/cmd/pilotctl/zz_lifecycle_test.go +++ b/cmd/pilotctl/zz_lifecycle_test.go @@ -9,7 +9,7 @@ import ( "strings" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // withTempHomeFull isolates HOME so config/socket/registry helpers don't diff --git a/cmd/pilotctl/zz_parsers_test.go b/cmd/pilotctl/zz_parsers_test.go index 07948c31..51732753 100644 --- a/cmd/pilotctl/zz_parsers_test.go +++ b/cmd/pilotctl/zz_parsers_test.go @@ -300,3 +300,50 @@ func TestClassifyDaemonError(t *testing.T) { type simpleErr struct{ s string } func (e *simpleErr) Error() string { return e.s } + +func TestNearestCommandSuggestion(t *testing.T) { + t.Parallel() + cands := []string{"ping", "peers", "info", "init", "handshake", "send-message", "trust"} + cases := []struct { + in string + want string + }{ + {"pin", "ping"}, // one deletion: ping → pin + {"pings", "ping"}, // one insertion + {"Peers", "peers"}, // case-insensitive + {"handshak", "handshake"}, + {"init", "init"}, // exact match + {"send-mesage", "send-message"}, + {"potato", ""}, // too far from anything + {"x", ""}, // too short to match anything in cands + {"", ""}, // empty input + } + for _, tc := range cases { + if got := nearestCommand(tc.in, cands); got != tc.want { + t.Errorf("nearestCommand(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestKnownTopLevelCommandsDedupesCompounds(t *testing.T) { + t.Parallel() + out := knownTopLevelCommands() + seen := make(map[string]int) + for _, c := range out { + seen[c]++ + if strings.Contains(c, " ") { + t.Errorf("compound key leaked into top-level list: %q", c) + } + } + for c, n := range seen { + if n > 1 { + t.Errorf("duplicate top-level token %q (count=%d)", c, n) + } + } + // Sanity: a few well-known commands must surface. + for _, must := range []string{"daemon", "send-message", "ping", "appstore"} { + if _, ok := seen[must]; !ok { + t.Errorf("expected %q in knownTopLevelCommands(), got %v", must, out) + } + } +} diff --git a/cmd/pilotctl/zz_stream_daemon_test.go b/cmd/pilotctl/zz_stream_daemon_test.go index ef5818e7..26903d41 100644 --- a/cmd/pilotctl/zz_stream_daemon_test.go +++ b/cmd/pilotctl/zz_stream_daemon_test.go @@ -13,8 +13,8 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/ipcutil" + "github.com/pilot-protocol/common/protocol" ) // Round-3 coverage push: drive cmdConnect/cmdSend/cmdRecv/cmdDgram/ diff --git a/go.mod b/go.mod index 2aa2003a..ef403784 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,9 @@ require ( github.com/coder/websocket v1.8.14 github.com/pilot-protocol/app-store v0.1.0 github.com/pilot-protocol/beacon v0.1.0 - github.com/pilot-protocol/common v0.1.0 + github.com/pilot-protocol/common v0.2.0 github.com/pilot-protocol/dataexchange v0.1.0 github.com/pilot-protocol/eventstream v0.1.0 - github.com/pilot-protocol/gateway v0.1.0 github.com/pilot-protocol/handshake v0.1.0 github.com/pilot-protocol/nameserver v0.1.0 github.com/pilot-protocol/policy v0.1.0 @@ -25,3 +24,35 @@ require ( golang.org/x/net v0.55.0 // indirect golang.org/x/sys v0.45.0 // indirect ) + +replace github.com/pilot-protocol/common => ../common + +replace github.com/pilot-protocol/beacon => ../beacon + +replace github.com/pilot-protocol/dataexchange => ../dataexchange + +replace github.com/pilot-protocol/eventstream => ../eventstream + +replace github.com/pilot-protocol/gateway => ../gateway + +replace github.com/pilot-protocol/nameserver => ../nameserver + +replace github.com/pilot-protocol/policy => ../policy + +replace github.com/pilot-protocol/rendezvous => ../rendezvous + +replace github.com/pilot-protocol/skillinject => ../skillinject + +replace github.com/pilot-protocol/trustedagents => ../trustedagents + +replace github.com/pilot-protocol/webhook => ../webhook + +replace github.com/pilot-protocol/app-store => ../app-store + +replace github.com/pilot-protocol/updater => ../updater + +replace github.com/pilot-protocol/handshake => ../handshake + +replace github.com/pilot-protocol/runtime => ../runtime + +replace github.com/pilot-protocol/libpilot => ../libpilot diff --git a/go.sum b/go.sum index b19df09c..0cedd3a5 100644 --- a/go.sum +++ b/go.sum @@ -2,34 +2,6 @@ github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9 github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/expr-lang/expr v1.17.8 h1:W1loDTT+0PQf5YteHSTpju2qfUfNoBt4yw9+wOEU9VM= github.com/expr-lang/expr v1.17.8/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= -github.com/pilot-protocol/app-store v0.1.0 h1:mMEbr04GURXWuFd4kQBONZZK+AMrXxdVt+IujeySfo8= -github.com/pilot-protocol/app-store v0.1.0/go.mod h1:0fo1XjzzLHmRMGuTc22aOLAseQzms7qM4QXfGilmMWY= -github.com/pilot-protocol/beacon v0.1.0 h1:jXO8duAzzpB8K+9It0QwR9BRupgKZ8IQhuwqy7rqtmk= -github.com/pilot-protocol/beacon v0.1.0/go.mod h1:PejZP5sZ4s5Lrtc0wdYHSEJVc7cn6E8yqo0R4CV+iUo= -github.com/pilot-protocol/common v0.1.0 h1:m8mZZATgeBiFoqhWXPnskw2u0lNkWxHp0IagZK35V1g= -github.com/pilot-protocol/common v0.1.0/go.mod h1:4YZWHK5nhM+4RLmYTspLxxAFbyBII7yzQDAHq3Ul2ck= -github.com/pilot-protocol/dataexchange v0.1.0 h1:JJ29lL/LxDd1+szKFoEVxakzT93Tid3zoaAUBVGNKV4= -github.com/pilot-protocol/dataexchange v0.1.0/go.mod h1:mXD17Vh0Eup+M//YhCm4j6D/DyFaZM7JzN5xexBdfLs= -github.com/pilot-protocol/eventstream v0.1.0 h1:uHNTNTMA9MasBBpi2nCRkvFmFvTxbKCk7azPfuILkvY= -github.com/pilot-protocol/eventstream v0.1.0/go.mod h1:sFhEh/YP76Sjhn8kNz7SOvqgk0vaJEIFkXMXAKlPqnY= -github.com/pilot-protocol/gateway v0.1.0 h1:mCbMMO4N7hkkuFELHCaBKGwmGXIeWW6fR8YMS9aEBgk= -github.com/pilot-protocol/gateway v0.1.0/go.mod h1:NvzDHFgPDno8ftTMDeFUnU7yZxXEx/3ds1cX3IN5YV0= -github.com/pilot-protocol/handshake v0.1.0 h1:TmqIglsimTynKtE5hLpCt/SZmmBYs8OCn4qn755fmew= -github.com/pilot-protocol/handshake v0.1.0/go.mod h1:FIIMTgRcMIMEim/1d7F5f6YenJC+3xl53QMEmnnJY+0= -github.com/pilot-protocol/nameserver v0.1.0 h1:91R7g36eIXMKX7Ld1YObtNsDh72smw+eD1rIVxvutQM= -github.com/pilot-protocol/nameserver v0.1.0/go.mod h1:6o01gsjvw4LqYIAxI5sD8OHfaiWv78jC8aPtoxV7nJA= -github.com/pilot-protocol/policy v0.1.0 h1:Eh0CfCZDEX8UCkMPi2MrNrhCe8c15a/Bqf4eaKUhyis= -github.com/pilot-protocol/policy v0.1.0/go.mod h1:IMVm7IQhgLtH/iXow2AWFuLl+sKrxJ4mGs4EKLoHop8= -github.com/pilot-protocol/rendezvous v0.1.0 h1:vOBD7CnRY8uU8vma0Vfcr0aPSQ54qNuxppNUiljzk9Y= -github.com/pilot-protocol/rendezvous v0.1.0/go.mod h1:g3/IYBykbU5m9jeprSCrmuoDpaqROO4Lu/+ecKVIF3A= -github.com/pilot-protocol/runtime v0.1.0 h1:TyerRWKVN38WM2RAPR5bhCdY5cR7d3UYg5neUK0pdZk= -github.com/pilot-protocol/runtime v0.1.0/go.mod h1:X1sImTG8xu6HkvKimu8Eq91HmDKQt6GHEWju7HxofEQ= -github.com/pilot-protocol/skillinject v0.1.0 h1:gs912gqmxl0ifIvswefjx8BzPCmoBWS77mSp/RC1YEY= -github.com/pilot-protocol/skillinject v0.1.0/go.mod h1:303GIB6j95ZhnoYeTlYzlBDhUbO01PB/6KGohm4DcJs= -github.com/pilot-protocol/trustedagents v0.1.0 h1:rCX0IQxfZ84Q4dSgw01WJgjHUODRnI3iAon1t+NuFGE= -github.com/pilot-protocol/trustedagents v0.1.0/go.mod h1:uVySmuMPb6N7AOCnvLHN2I9C9ggqEpfBmAZwVuP5Xaw= -github.com/pilot-protocol/webhook v0.1.0 h1:SnIcn+IdHvzoJt+OFzGJbkBHurjNBD0xSc6HAUvExAg= -github.com/pilot-protocol/webhook v0.1.0/go.mod h1:c0du05MMy8FYnlc2YGqUWxRgTP8pRWTrtIdNuRhf6uk= golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= diff --git a/internal/ipcutil/ipcutil.go b/internal/ipcutil/ipcutil.go deleted file mode 100644 index 5939d86a..00000000 --- a/internal/ipcutil/ipcutil.go +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package ipcutil - -import ( - "encoding/binary" - "fmt" - "io" -) - -// MaxMessageSize is the maximum IPC message size (1MB). -const MaxMessageSize = 1 << 20 - -// Read reads a length-prefixed IPC message from r. -func Read(r io.Reader) ([]byte, error) { - var lenBuf [4]byte - if _, err := io.ReadFull(r, lenBuf[:]); err != nil { - return nil, err - } - length := binary.BigEndian.Uint32(lenBuf[:]) - if length > MaxMessageSize { - return nil, fmt.Errorf("ipc message too large: %d bytes (max %d)", length, MaxMessageSize) - } - buf := make([]byte, length) - if _, err := io.ReadFull(r, buf); err != nil { - return nil, err - } - return buf, nil -} - -// Write writes a length-prefixed IPC message to w. -func Write(w io.Writer, data []byte) error { - var lenBuf [4]byte - binary.BigEndian.PutUint32(lenBuf[:], uint32(len(data))) - if _, err := w.Write(lenBuf[:]); err != nil { - return err - } - _, err := w.Write(data) - return err -} diff --git a/internal/ipcutil/zz_test.go b/internal/ipcutil/zz_test.go deleted file mode 100644 index 6df60be4..00000000 --- a/internal/ipcutil/zz_test.go +++ /dev/null @@ -1,137 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package ipcutil - -import ( - "bytes" - "encoding/binary" - "io" - "strings" - "testing" -) - -func TestReadWriteRoundTrip(t *testing.T) { - t.Parallel() - for _, payload := range [][]byte{ - nil, - {}, - []byte("hello"), - bytes.Repeat([]byte{0xAB}, 10000), - } { - var buf bytes.Buffer - if err := Write(&buf, payload); err != nil { - t.Fatalf("Write(%d bytes): %v", len(payload), err) - } - got, err := Read(&buf) - if err != nil { - t.Fatalf("Read: %v", err) - } - if !bytes.Equal(got, payload) { - t.Fatalf("round-trip mismatch: got %d bytes, want %d bytes", len(got), len(payload)) - } - } -} - -func TestWriteLengthPrefix(t *testing.T) { - t.Parallel() - var buf bytes.Buffer - payload := []byte("abcde") - if err := Write(&buf, payload); err != nil { - t.Fatal(err) - } - if buf.Len() != 4+len(payload) { - t.Fatalf("buf len = %d, want %d", buf.Len(), 4+len(payload)) - } - length := binary.BigEndian.Uint32(buf.Bytes()[:4]) - if length != uint32(len(payload)) { - t.Fatalf("length prefix = %d, want %d", length, len(payload)) - } -} - -func TestReadTooLargeRejected(t *testing.T) { - t.Parallel() - var buf bytes.Buffer - // Write length prefix claiming > MaxMessageSize - lenBuf := make([]byte, 4) - binary.BigEndian.PutUint32(lenBuf, MaxMessageSize+1) - buf.Write(lenBuf) - - _, err := Read(&buf) - if err == nil { - t.Fatal("expected too-large error") - } - if !strings.Contains(err.Error(), "too large") { - t.Fatalf("error %q missing 'too large'", err) - } -} - -func TestReadExactlyMaxSizeAccepted(t *testing.T) { - t.Parallel() - var buf bytes.Buffer - // Length exactly == max, followed by that many zero bytes - data := make([]byte, MaxMessageSize) - if err := Write(&buf, data); err != nil { - t.Fatal(err) - } - got, err := Read(&buf) - if err != nil { - t.Fatalf("max-size read should succeed: %v", err) - } - if len(got) != MaxMessageSize { - t.Fatalf("len = %d, want %d", len(got), MaxMessageSize) - } -} - -func TestReadTruncatedLength(t *testing.T) { - t.Parallel() - buf := bytes.NewReader([]byte{0x00, 0x00}) // only 2 bytes of length prefix - _, err := Read(buf) - if err == nil { - t.Fatal("expected error on truncated length prefix") - } -} - -func TestReadTruncatedPayload(t *testing.T) { - t.Parallel() - var buf bytes.Buffer - lenBuf := make([]byte, 4) - binary.BigEndian.PutUint32(lenBuf, 100) - buf.Write(lenBuf) - buf.Write([]byte("only 20 bytes here..")) // payload truncated - _, err := Read(&buf) - if err == nil { - t.Fatal("expected truncation error") - } -} - -// errWriter fails every write — exercises the Write error paths. -type errWriter struct { - failAfter int - calls int -} - -func (w *errWriter) Write(p []byte) (int, error) { - w.calls++ - if w.calls > w.failAfter { - return 0, io.ErrShortWrite - } - return len(p), nil -} - -func TestWriteErrorOnLengthPrefix(t *testing.T) { - t.Parallel() - w := &errWriter{failAfter: 0} - err := Write(w, []byte("data")) - if err == nil { - t.Fatal("expected error from failing writer on length prefix") - } -} - -func TestWriteErrorOnPayload(t *testing.T) { - t.Parallel() - w := &errWriter{failAfter: 1} // first write (length) succeeds, second (payload) fails - err := Write(w, []byte("data")) - if err == nil { - t.Fatal("expected error from failing writer on payload") - } -} diff --git a/pkg/config/config.go b/pkg/config/config.go deleted file mode 100644 index 5a58118f..00000000 --- a/pkg/config/config.go +++ /dev/null @@ -1,59 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package config - -import ( - "encoding/json" - "flag" - "fmt" - "os" - "strings" -) - -// Load reads a JSON config file and returns it as a map. -func Load(path string) (map[string]interface{}, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - - var cfg map[string]interface{} - if err := json.NewDecoder(f).Decode(&cfg); err != nil { - return nil, err - } - return cfg, nil -} - -// ApplyToFlags overrides flag defaults from config for any flag not -// explicitly set on the command line. Call this AFTER flag.Parse(). -// Keys in the config can use either hyphens or underscores (e.g. -// "log-level" or "log_level" both match the -log-level flag). -func ApplyToFlags(cfg map[string]interface{}) { - explicit := make(map[string]bool) - flag.Visit(func(f *flag.Flag) { - explicit[f.Name] = true - }) - - flag.VisitAll(func(f *flag.Flag) { - if explicit[f.Name] { - return - } - val, ok := cfg[f.Name] - if !ok { - // Try underscore variant: log-level → log_level - val, ok = cfg[strings.ReplaceAll(f.Name, "-", "_")] - } - if !ok { - return - } - switch v := val.(type) { - case string: - f.Value.Set(v) - case float64: - f.Value.Set(fmt.Sprintf("%v", v)) - case bool: - f.Value.Set(fmt.Sprintf("%v", v)) - } - }) -} diff --git a/pkg/config/zz_config_test.go b/pkg/config/zz_config_test.go deleted file mode 100644 index 89a74b81..00000000 --- a/pkg/config/zz_config_test.go +++ /dev/null @@ -1,180 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package config_test - -import ( - "flag" - "os" - "path/filepath" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/config" -) - -func TestLoadValidJSON(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "cfg.json") - body := `{"log_level":"debug","port":8080,"verbose":true}` - if err := os.WriteFile(path, []byte(body), 0644); err != nil { - t.Fatal(err) - } - cfg, err := config.Load(path) - if err != nil { - t.Fatalf("Load: %v", err) - } - if cfg["log_level"] != "debug" { - t.Errorf("log_level = %v, want debug", cfg["log_level"]) - } - if cfg["port"].(float64) != 8080 { - t.Errorf("port = %v, want 8080", cfg["port"]) - } - if cfg["verbose"] != true { - t.Errorf("verbose = %v, want true", cfg["verbose"]) - } -} - -func TestLoadMissingFile(t *testing.T) { - _, err := config.Load("/nonexistent/path/cfg.json") - if err == nil { - t.Fatal("expected error for missing file") - } -} - -func TestLoadMalformedJSON(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "bad.json") - if err := os.WriteFile(path, []byte("{not json"), 0644); err != nil { - t.Fatal(err) - } - _, err := config.Load(path) - if err == nil { - t.Fatal("expected parse error") - } -} - -// ApplyToFlags tests must serialize because flag package has global state. -// We use a dedicated FlagSet per test, but package-level flag.Visit reads -// flag.CommandLine — so we temporarily swap it. -func withFreshCommandLine(t *testing.T) *flag.FlagSet { - t.Helper() - saved := flag.CommandLine - flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) - t.Cleanup(func() { flag.CommandLine = saved }) - return flag.CommandLine -} - -func TestApplyToFlagsSetsUnsetFlags(t *testing.T) { - fs := withFreshCommandLine(t) - var level string - var port int - var verbose bool - fs.StringVar(&level, "log-level", "info", "") - fs.IntVar(&port, "port", 9000, "") - fs.BoolVar(&verbose, "verbose", false, "") - - // Parse with no args so nothing is explicitly set - if err := fs.Parse(nil); err != nil { - t.Fatalf("Parse: %v", err) - } - - cfg := map[string]interface{}{ - "log-level": "debug", - "port": float64(8080), - "verbose": true, - } - config.ApplyToFlags(cfg) - - if level != "debug" { - t.Errorf("log-level = %q, want debug", level) - } - if port != 8080 { - t.Errorf("port = %d, want 8080", port) - } - if verbose != true { - t.Errorf("verbose = %v, want true", verbose) - } -} - -func TestApplyToFlagsPreservesExplicitlySetFlags(t *testing.T) { - fs := withFreshCommandLine(t) - var level string - fs.StringVar(&level, "log-level", "info", "") - - // Explicitly set on the command line — config must NOT override. - if err := fs.Parse([]string{"-log-level=warn"}); err != nil { - t.Fatalf("Parse: %v", err) - } - - cfg := map[string]interface{}{"log-level": "debug"} - config.ApplyToFlags(cfg) - - if level != "warn" { - t.Errorf("log-level = %q, want warn (explicit flag must win over config)", level) - } -} - -func TestApplyToFlagsUnderscoreVariantMatches(t *testing.T) { - fs := withFreshCommandLine(t) - var level string - fs.StringVar(&level, "log-level", "info", "") - if err := fs.Parse(nil); err != nil { - t.Fatal(err) - } - - // Config uses underscore; flag uses hyphen. ApplyToFlags should match them. - cfg := map[string]interface{}{"log_level": "debug"} - config.ApplyToFlags(cfg) - - if level != "debug" { - t.Errorf("log-level = %q, want debug (underscore→hyphen match)", level) - } -} - -func TestApplyToFlagsHyphenVariantTakesPrecedenceOverUnderscore(t *testing.T) { - fs := withFreshCommandLine(t) - var level string - fs.StringVar(&level, "log-level", "info", "") - if err := fs.Parse(nil); err != nil { - t.Fatal(err) - } - - // If both keys present, the exact flag-name match (log-level) must win. - cfg := map[string]interface{}{ - "log-level": "debug", - "log_level": "warn", - } - config.ApplyToFlags(cfg) - - if level != "debug" { - t.Errorf("log-level = %q, want debug (exact match wins)", level) - } -} - -func TestApplyToFlagsIgnoresUnknownKeys(t *testing.T) { - fs := withFreshCommandLine(t) - var level string - fs.StringVar(&level, "log-level", "info", "") - if err := fs.Parse(nil); err != nil { - t.Fatal(err) - } - config.ApplyToFlags(map[string]interface{}{"unrelated-flag": "xyz"}) - if level != "info" { - t.Errorf("log-level changed unexpectedly: %q", level) - } -} - -func TestApplyToFlagsSkipsUnsupportedTypes(t *testing.T) { - fs := withFreshCommandLine(t) - var level string - fs.StringVar(&level, "log-level", "info", "") - if err := fs.Parse(nil); err != nil { - t.Fatal(err) - } - // Nested map / array — should be silently skipped (not panic) - config.ApplyToFlags(map[string]interface{}{ - "log-level": map[string]interface{}{"nested": "value"}, - }) - if level != "info" { - t.Errorf("log-level changed from nested map: %q (unsupported type should skip)", level) - } -} diff --git a/pkg/coreapi/doc.go b/pkg/coreapi/doc.go deleted file mode 100644 index 65245564..00000000 --- a/pkg/coreapi/doc.go +++ /dev/null @@ -1,19 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -// Package coreapi defines the L10 plugin runtime contract. -// -// The interfaces in this package are the only surface a plugin -// (L11) ever sees of the daemon. Plugins import coreapi; the daemon -// implements coreapi; the bridge happens at lifecycle bootstrap -// (cmd/daemon/main.go registers concrete plugins against the -// daemon's coreapi implementations). -// -// See docs/architecture/01-LAYERS.md §10 for the layer's role, -// docs/architecture/03-INVARIANTS.md for the principles this -// package enforces, and docs/architecture/06-CHANGES.md §2 for -// the rationale of each interface signature. -// -// Stability contract: every exported identifier in this package is -// part of the daemon-plugin ABI. Removing or renaming any of them -// breaks every plugin. Additions are forward-compatible. -package coreapi diff --git a/pkg/coreapi/errors.go b/pkg/coreapi/errors.go deleted file mode 100644 index 214313ac..00000000 --- a/pkg/coreapi/errors.go +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi - -import "errors" - -// Sentinel errors returned by the L10 surface. -var ( - // ErrRegistryStarted is returned by ServiceRegistry.Register and - // ServiceRegistry.StartAll when StartAll has already been called. - // Plugins must register before bootstrap. - ErrRegistryStarted = errors.New("coreapi: service registry already started") - - // ErrServiceNotReady indicates a Service.Start call was made on a - // dependency that itself hasn't completed Start. Surface only — - // Service implementations shouldn't return this; the registry will. - ErrServiceNotReady = errors.New("coreapi: dependency service not ready") - - // ErrPeerNotFound is the canonical "directory has no record" error - // from PeerResolver. Plugins should match on errors.Is. - ErrPeerNotFound = errors.New("coreapi: peer not found") -) diff --git a/pkg/coreapi/events.go b/pkg/coreapi/events.go deleted file mode 100644 index 495bcf6d..00000000 --- a/pkg/coreapi/events.go +++ /dev/null @@ -1,31 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi - -import "time" - -// Event is one item published to the EventBus. Topics are -// dot-namespaced (e.g., "tunnel.established", "security.nonce_replay"). -// Payload keys/values are plugin-defined; subscribers parse them. -type Event struct { - Topic string - NodeID uint32 - Time time.Time - Payload map[string]any -} - -// EventBus is the publish/subscribe channel that replaces inline -// webhook.Emit calls inside core layers. Core (L2-L7) publishes; -// the webhook plugin (and any other observability plugin) subscribes. -// -// Publish is non-blocking. If the bus is over capacity, the event is -// dropped (and a metric counter is incremented inside the daemon -// implementation). This keeps L2 readLoop / L6 decrypt latency bounded. -// -// Subscribe returns a buffered channel and an unsubscribe func. Pattern -// is a glob: "tunnel.*" matches "tunnel.established" but not -// "security.nonce_replay". -type EventBus interface { - Publish(topic string, payload map[string]any) - Subscribe(pattern string) (<-chan Event, func()) -} diff --git a/pkg/coreapi/identity.go b/pkg/coreapi/identity.go deleted file mode 100644 index 741a3a3d..00000000 --- a/pkg/coreapi/identity.go +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi - -import "crypto/ed25519" - -// Identity is the daemon's own identity — its Ed25519 keypair, its -// stable nodeID, its 48-bit address. Plugins may sign arbitrary bytes -// (e.g., for plugin-level auth proofs) but cannot replace the identity. -type Identity interface { - NodeID() uint32 - Address() Addr - PublicKey() ed25519.PublicKey - Sign(msg []byte) ([]byte, error) -} diff --git a/pkg/coreapi/lifecycle.go b/pkg/coreapi/lifecycle.go deleted file mode 100644 index 5482accd..00000000 --- a/pkg/coreapi/lifecycle.go +++ /dev/null @@ -1,149 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi - -import ( - "context" - "fmt" - "log/slog" - "sort" - "sync" -) - -// Service is the lifecycle contract every L11 plugin implements. -// -// Order determines the start sequence. Lower numbers start first; -// higher numbers stop first. Suggested ranges: -// -// 10-49 Foundation (none today) -// 50-79 Trust / identity-adjacent (trustedagents) -// 80-99 Observability (webhook) -// 100-199 Application services (dataexchange, eventstream, tasks) -// 200-249 Sidecars (skillinject) -// 250+ Tooling-bound (updater) -// -// Start receives Deps (the L10 surface). Implementations must NOT -// retain references to anything outside Deps — that's the whole -// extraction contract. -// -// Stop should drain in-flight work, close listeners, and signal -// background goroutines to exit. It must return within 5 seconds -// or the daemon shutdown gate will fail. -type Service interface { - Name() string - Order() int - Start(ctx context.Context, deps Deps) error - Stop(ctx context.Context) error -} - -// Deps is the bag of capabilities a plugin can use. Optional fields -// may be nil if the corresponding plugin isn't loaded; plugins that -// hard-depend on them should error in Start(). -type Deps struct { - Streams Streams - Identity Identity - Resolver PeerResolver - Events EventBus - Logger *slog.Logger - - // Optional — nil if the plugin providing them isn't registered. - Trust TrustChecker -} - -// ServiceRegistry coordinates plugin lifecycle. cmd/daemon/main.go -// constructs one, registers each plugin, and hands it to the daemon. -// The daemon calls StartAll during bootstrap and StopAll during -// shutdown. -type ServiceRegistry struct { - mu sync.Mutex - services []Service - started []Service // start order, used to stop in reverse -} - -// Register adds a service. Must be called before StartAll. After -// StartAll runs, Register is a no-op error. -func (sr *ServiceRegistry) Register(s Service) error { - sr.mu.Lock() - defer sr.mu.Unlock() - if len(sr.started) > 0 { - return ErrRegistryStarted - } - sr.services = append(sr.services, s) - return nil -} - -// StartAll sorts by Order and starts every service in sequence. -// The first failing Start aborts and returns its error; previously- -// started services are NOT auto-stopped (the caller's job, via Stop() -// or by passing a context that cancels). -func (sr *ServiceRegistry) StartAll(ctx context.Context, deps Deps) error { - sr.mu.Lock() - if len(sr.started) > 0 { - sr.mu.Unlock() - return ErrRegistryStarted - } - sort.SliceStable(sr.services, func(i, j int) bool { - return sr.services[i].Order() < sr.services[j].Order() - }) - queue := append([]Service(nil), sr.services...) - sr.mu.Unlock() - - for _, s := range queue { - if err := startWithPanicRecovery(ctx, s, deps); err != nil { - return err - } - sr.mu.Lock() - sr.started = append(sr.started, s) - sr.mu.Unlock() - } - return nil -} - -// startWithPanicRecovery calls s.Start(ctx, deps) inside a defer -// recover() so a buggy plugin panicking during initialization (nil -// deref, index OOB, channel-send on nil, etc.) surfaces as a normal -// Start error rather than crashing the entire daemon process. -// -// Without this wrapper, every plugin's Init bug becomes a single- -// point-of-failure for the host: the whole daemon dies, every OTHER -// plugin goes offline with it, and the operator's only signal is a -// stack trace. -// -// Behaviour preserved on normal error returns: the surrounding -// StartAll loop still aborts on first failure and leaves earlier -// services running for the caller's Stop() to drain. -func startWithPanicRecovery(ctx context.Context, s Service, deps Deps) (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("plugin %q Start panicked: %v", s.Name(), r) - } - }() - return s.Start(ctx, deps) -} - -// StopAll stops every started service in reverse order. Errors from -// individual Stop calls are collected; the first one is returned but -// every service still gets its Stop call invoked. -func (sr *ServiceRegistry) StopAll(ctx context.Context) error { - sr.mu.Lock() - queue := append([]Service(nil), sr.started...) - sr.started = nil - sr.mu.Unlock() - - var firstErr error - for i := len(queue) - 1; i >= 0; i-- { - if err := queue[i].Stop(ctx); err != nil && firstErr == nil { - firstErr = err - } - } - return firstErr -} - -// All returns a snapshot of the registered services in start order. -func (sr *ServiceRegistry) All() []Service { - sr.mu.Lock() - defer sr.mu.Unlock() - out := make([]Service, len(sr.services)) - copy(out, sr.services) - return out -} diff --git a/pkg/coreapi/peers.go b/pkg/coreapi/peers.go deleted file mode 100644 index 0017a9b1..00000000 --- a/pkg/coreapi/peers.go +++ /dev/null @@ -1,30 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi - -import ( - "context" - "crypto/ed25519" - "net" -) - -// PeerInfo is the directory record for a remote node. Returned by -// PeerResolver.Resolve and PeerResolver.ListByNetwork. -type PeerInfo struct { - NodeID uint32 - Addr Addr - Endpoint *net.UDPAddr // best-known reachable endpoint, or nil - PubKey ed25519.PublicKey - Public bool - Hostname string - RelayOnly bool -} - -// PeerResolver is the L8 directory surface. The daemon's -// implementation talks to the registry over the bootstrap TCP -// side-channel (see 01-LAYERS §L8). -type PeerResolver interface { - Resolve(ctx context.Context, nodeID uint32) (PeerInfo, error) - ResolveHostname(ctx context.Context, name string) (uint32, error) - ListByNetwork(ctx context.Context, networkID uint32) ([]PeerInfo, error) -} diff --git a/pkg/coreapi/policy.go b/pkg/coreapi/policy.go deleted file mode 100644 index 9d409801..00000000 --- a/pkg/coreapi/policy.go +++ /dev/null @@ -1,78 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi - -// PolicyEventType is the kind of protocol event a policy is evaluated -// against. Type alias to string so daemon-local primitive interfaces -// can satisfy plugin signatures via structural typing without importing -// this package (T7.1). -type PolicyEventType = string - -const ( - PolicyEventConnect = "connect" - PolicyEventDial = "dial" - PolicyEventDatagram = "datagram" - PolicyEventJoin = "join" - PolicyEventLeave = "leave" - PolicyEventCycle = "cycle" -) - -// PolicyRunner is the daemon-facing surface of a single network's -// running policy. The plugin's concrete *PolicyRunner type implements -// this. The daemon never holds the concrete type — only this interface. -type PolicyRunner interface { - NetworkID() uint16 - - // HasMember returns true if peerNodeID is in this runner's - // per-peer state. The daemon iterates all runners to consult - // every network the peer belongs to (deny wins across networks). - HasMember(peerNodeID uint32) bool - - // EvaluatePortGate is the daemon-facing gate API for inbound SYN - // (Connect), outbound SYN (Dial), and datagram (in/out) events. - // The plugin builds the per-peer ctx internally (peer_age_s, - // peer_tags, members) using its peer state and the - // daemon-supplied localTags + nodeInfoTags. Returns the - // allow/deny verdict (default allow on no explicit deny). - EvaluatePortGate(eventType PolicyEventType, port uint16, peerNodeID uint32, payloadSize int, direction string, localTags, nodeInfoTags []string) bool - - // EvaluateActions runs an action-event (cycle/join/leave) with a - // caller-built ctx. Side-effect-only: no return value. - EvaluateActions(eventType PolicyEventType, ctx map[string]any) - - Status() map[string]any - PeerList() []map[string]any - ForceCycle() map[string]any - ReconcileNow() - - // PolicyJSON returns the marshaled policy document. Used by IPC - // handlers that read the current policy back to admin tools. - PolicyJSON() ([]byte, error) - - Stop() -} - -// PolicyManager owns the per-network registry of policy runners. The -// daemon holds it as an interface field; cmd/daemon (L12) constructs -// the concrete plugin and calls Daemon.RegisterPolicyManager. -type PolicyManager interface { - // Start compiles a policy JSON for the given network and registers - // a runner. Returns the runner handle; existing runners for the - // same network are stopped first. - Start(netID uint16, policyJSON []byte) (PolicyRunner, error) - - // Stop stops the runner for netID (no-op if absent). - Stop(netID uint16) - - // Get returns the runner for netID or nil. - Get(netID uint16) PolicyRunner - - // All returns a snapshot of all running runners. - All() []PolicyRunner - - // StopAll stops every runner. Called during daemon shutdown. - StopAll() - - // LoadPersisted runs at daemon-Start to restore runners from disk. - LoadPersisted() error -} diff --git a/pkg/coreapi/recover.go b/pkg/coreapi/recover.go deleted file mode 100644 index 7c771d4d..00000000 --- a/pkg/coreapi/recover.go +++ /dev/null @@ -1,83 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi - -import ( - "fmt" - "log/slog" - "runtime/debug" - "sync/atomic" -) - -// pluginRecoveredPanicCount is the L11 counterpart to the daemon's -// internal recoveredPanicCount. Tracks how many panics have been -// caught at plugin entry points (acceptLoop, handleConn, Service.Start -// goroutines). Exposed via PluginRecoveredPanicCount. -var pluginRecoveredPanicCount atomic.Uint64 - -// PluginRecoveredPanicCount returns the total number of panics -// swallowed by RecoverPlugin since process start. -func PluginRecoveredPanicCount() uint64 { - return pluginRecoveredPanicCount.Load() -} - -// ResetPluginRecoveredPanicCountForTest is test-only. -func ResetPluginRecoveredPanicCountForTest() { - pluginRecoveredPanicCount.Store(0) -} - -// RecoverPlugin is the L11 panic-recovery shim used at the top of -// every plugin entrypoint goroutine: Service.Start helper goroutines, -// acceptLoop, and per-connection handlers. Usage: -// -// defer coreapi.RecoverPlugin("eventstream", "acceptLoop", events, nil) -// -// On panic it: -// 1. Recovers (caller goroutine continues / loop iteration is dropped) -// 2. Logs at ERROR with structured plugin/op fields, panic value, and -// full goroutine stack trace -// 3. Increments PluginRecoveredPanicCount -// 4. Publishes a "plugin..panic" event on the bus (if -// events != nil) so observability subscribers see the recovery -// 5. Calls onPanic(r) if non-nil — typical use is per-conn close, -// or signaling a future per-plugin supervisor for restart -// -// TODO(03-INVARIANTS.md §8): per-plugin supervisor not yet implemented. -// Today the boundary just survives + logs. A future tier will signal a -// restart of the panicked plugin via the onPanic callback. -// -// This must be the OUTERMOST defer in the goroutine: defers run LIFO, -// so other defers (conn.Close, mu.Unlock, removeSub) run first. -func RecoverPlugin(plugin, op string, events EventBus, onPanic func(any)) { - r := recover() - if r == nil { - return - } - count := pluginRecoveredPanicCount.Add(1) - slog.Error("plugin panic recovered", - "layer", "L11", - "plugin", plugin, - "op", op, - "panic", r, - "recovered_total", count, - "stack", string(debug.Stack()), - ) - if events != nil { - // Defensive: a publisher that itself panics must not propagate. - func() { - defer func() { _ = recover() }() - events.Publish("plugin."+plugin+".panic", map[string]any{ - "plugin": plugin, - "op": op, - "panic": fmt.Sprintf("%v", r), - "recovered_total": count, - }) - }() - } - if onPanic != nil { - func() { - defer func() { _ = recover() }() - onPanic(r) - }() - } -} diff --git a/pkg/coreapi/streams.go b/pkg/coreapi/streams.go deleted file mode 100644 index 86da40cd..00000000 --- a/pkg/coreapi/streams.go +++ /dev/null @@ -1,49 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi - -import ( - "context" - "io" - - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" -) - -// Addr is the 48-bit virtual address used throughout the protocol. -// Re-exported here so plugins can stay free of pkg/protocol if they want. -type Addr = protocol.Addr - -// Stream is one bidirectional ordered byte stream between two -// (Addr, port) endpoints. It satisfies io.ReadWriteCloser with -// Pilot Protocol addressing extensions. Deadline methods are -// intentionally excluded — the runtime currently cannot honor -// them, and removing them from the interface forces callers to -// get a compile-time signal rather than a silent no-op. -type Stream interface { - io.ReadWriteCloser - - LocalAddr() Addr - LocalPort() uint16 - RemoteAddr() Addr - RemotePort() uint16 -} - -// Listener accepts inbound streams on a single well-known or ephemeral -// port. Returned by Streams.Listen. -type Listener interface { - Accept() (Stream, error) - Close() error - Addr() Addr - Port() uint16 -} - -// Streams is the L7 surface plugins consume. The daemon-side -// implementation routes through L7 → L6 → L5 → L4 → L2. -// -// SendDatagram is the connectionless variant (one packet, no ACK, -// no retransmit). Used by plugins that don't need stream semantics. -type Streams interface { - Dial(ctx context.Context, dst Addr, port uint16) (Stream, error) - Listen(port uint16) (Listener, error) - SendDatagram(ctx context.Context, dst Addr, port uint16, data []byte) error -} diff --git a/pkg/coreapi/trust.go b/pkg/coreapi/trust.go deleted file mode 100644 index c96f8328..00000000 --- a/pkg/coreapi/trust.go +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi - -// TrustChecker is the trusted-agents gate consumed by L11/tasks (and -// any other plugin that gates on peer reputation). -// -// IsTrusted: returns true if the peer is on the auto-approve allowlist -// (loaded from the trusted-agents JSON, refreshed hourly). -type TrustChecker interface { - // IsTrusted reports whether the peer is on the auto-approve allowlist. - // Returns the agent's display name when known. Both return values are - // zero on miss. - IsTrusted(nodeID uint32) (name string, ok bool) -} diff --git a/pkg/coreapi/zz_lifecycle_edge_test.go b/pkg/coreapi/zz_lifecycle_edge_test.go deleted file mode 100644 index 567231cc..00000000 --- a/pkg/coreapi/zz_lifecycle_edge_test.go +++ /dev/null @@ -1,75 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi_test - -import ( - "context" - "errors" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/coreapi" -) - -func TestServiceRegistry_StartAllTwiceReturnsErrRegistryStarted(t *testing.T) { - t.Parallel() - sr := &coreapi.ServiceRegistry{} - _ = sr.Register(&fakeService{name: "a", order: 1}) - if err := sr.StartAll(context.Background(), coreapi.Deps{}); err != nil { - t.Fatalf("first StartAll: %v", err) - } - err := sr.StartAll(context.Background(), coreapi.Deps{}) - if !errors.Is(err, coreapi.ErrRegistryStarted) { - t.Errorf("second StartAll = %v, want ErrRegistryStarted", err) - } -} - -func TestServiceRegistry_StopAllSurfacesFirstError(t *testing.T) { - t.Parallel() - sr := &coreapi.ServiceRegistry{} - a := &fakeService{name: "a", order: 1, stopErr: errors.New("stop-a-failed")} - b := &fakeService{name: "b", order: 2, stopErr: errors.New("stop-b-failed")} - _ = sr.Register(a) - _ = sr.Register(b) - if err := sr.StartAll(context.Background(), coreapi.Deps{}); err != nil { - t.Fatalf("StartAll: %v", err) - } - // b stops first (reverse order), so its error is "first" returned. - err := sr.StopAll(context.Background()) - if err == nil || err.Error() != "stop-b-failed" { - t.Errorf("StopAll = %v, want stop-b-failed", err) - } -} - -func TestServiceRegistry_StopAllStopsAllEvenAfterError(t *testing.T) { - t.Parallel() - sr := &coreapi.ServiceRegistry{} - aStopped := false - bStopped := false - a := &recordingStopWithErr{name: "a", order: 1, stopped: &aStopped} - b := &recordingStopWithErr{name: "b", order: 2, stopped: &bStopped, err: errors.New("b-failed")} - _ = sr.Register(a) - _ = sr.Register(b) - _ = sr.StartAll(context.Background(), coreapi.Deps{}) - _ = sr.StopAll(context.Background()) - if !aStopped { - t.Error("service a was not stopped despite b's error") - } - if !bStopped { - t.Error("service b was not stopped") - } -} - -type recordingStopWithErr struct { - name string - order int - stopped *bool - err error -} - -func (r *recordingStopWithErr) Name() string { return r.name } -func (r *recordingStopWithErr) Order() int { return r.order } -func (r *recordingStopWithErr) Start(ctx context.Context, deps coreapi.Deps) error { return nil } -func (r *recordingStopWithErr) Stop(ctx context.Context) error { - *r.stopped = true - return r.err -} diff --git a/pkg/coreapi/zz_lifecycle_test.go b/pkg/coreapi/zz_lifecycle_test.go deleted file mode 100644 index b3bd9952..00000000 --- a/pkg/coreapi/zz_lifecycle_test.go +++ /dev/null @@ -1,114 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi_test - -import ( - "context" - "errors" - "fmt" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/coreapi" -) - -type fakeService struct { - name string - order int - startErr error - stopErr error - startedAt int // sequence number, set by harness -} - -func (f *fakeService) Name() string { return f.name } -func (f *fakeService) Order() int { return f.order } -func (f *fakeService) Start(ctx context.Context, deps coreapi.Deps) error { - return f.startErr -} -func (f *fakeService) Stop(ctx context.Context) error { return f.stopErr } - -func TestServiceRegistry_StartOrder(t *testing.T) { - t.Parallel() - sr := &coreapi.ServiceRegistry{} - a := &fakeService{name: "a", order: 200} - b := &fakeService{name: "b", order: 100} - c := &fakeService{name: "c", order: 50} - for _, s := range []coreapi.Service{a, b, c} { - if err := sr.Register(s); err != nil { - t.Fatalf("register %s: %v", s.Name(), err) - } - } - if err := sr.StartAll(context.Background(), coreapi.Deps{}); err != nil { - t.Fatalf("StartAll: %v", err) - } - got := sr.All() - want := []string{"c", "b", "a"} - for i, s := range got { - if s.Name() != want[i] { - t.Errorf("position %d: got %s, want %s", i, s.Name(), want[i]) - } - } -} - -func TestServiceRegistry_StartFailureAborts(t *testing.T) { - t.Parallel() - sr := &coreapi.ServiceRegistry{} - a := &fakeService{name: "a", order: 10} - boom := &fakeService{name: "boom", order: 20, startErr: errors.New("boom")} - c := &fakeService{name: "c", order: 30} - for _, s := range []coreapi.Service{a, boom, c} { - _ = sr.Register(s) - } - err := sr.StartAll(context.Background(), coreapi.Deps{}) - if err == nil || err.Error() != "boom" { - t.Fatalf("want boom, got %v", err) - } - // `c` should NOT have been started after boom failed; verify by - // calling StopAll and checking only the started ones rolled back. - // (We can't directly observe started state, but the registry should - // not crash and Stop should return nil on the un-started services.) - if err := sr.StopAll(context.Background()); err != nil { - t.Errorf("StopAll after partial start: %v", err) - } -} - -func TestServiceRegistry_StopReverseOrder(t *testing.T) { - t.Parallel() - sr := &coreapi.ServiceRegistry{} - stops := []string{} - a := &recordingStop{name: "a", order: 10, stops: &stops} - b := &recordingStop{name: "b", order: 20, stops: &stops} - c := &recordingStop{name: "c", order: 30, stops: &stops} - for _, s := range []coreapi.Service{a, b, c} { - _ = sr.Register(s) - } - _ = sr.StartAll(context.Background(), coreapi.Deps{}) - _ = sr.StopAll(context.Background()) - want := []string{"c", "b", "a"} - if fmt.Sprint(stops) != fmt.Sprint(want) { - t.Errorf("stop order: got %v, want %v", stops, want) - } -} - -func TestServiceRegistry_RegisterAfterStart(t *testing.T) { - t.Parallel() - sr := &coreapi.ServiceRegistry{} - _ = sr.Register(&fakeService{name: "a", order: 10}) - _ = sr.StartAll(context.Background(), coreapi.Deps{}) - if err := sr.Register(&fakeService{name: "late", order: 50}); !errors.Is(err, coreapi.ErrRegistryStarted) { - t.Errorf("want ErrRegistryStarted, got %v", err) - } -} - -type recordingStop struct { - name string - order int - stops *[]string -} - -func (r *recordingStop) Name() string { return r.name } -func (r *recordingStop) Order() int { return r.order } -func (r *recordingStop) Start(ctx context.Context, deps coreapi.Deps) error { return nil } -func (r *recordingStop) Stop(ctx context.Context) error { - *r.stops = append(*r.stops, r.name) - return nil -} diff --git a/pkg/coreapi/zz_panic_recovery_test.go b/pkg/coreapi/zz_panic_recovery_test.go deleted file mode 100644 index 62e3d33e..00000000 --- a/pkg/coreapi/zz_panic_recovery_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi_test - -// Regression for P1 plugin-crash DoS: StartAll invokes each plugin's -// Start() directly with no recover wrapper. A plugin that panics -// during Start (nil-deref, index OOB, etc.) crashes the entire daemon -// process — operator's recourse is to find the buggy plugin via -// stack trace and disable it, while every other plugin is offline. -// -// Fix: StartAll wraps each plugin Start() in defer recover(), converts -// the panic to an error like any other Start failure. The error path -// (return on first failure, previously-started plugins NOT auto- -// stopped) is preserved — the caller's Stop() handles cleanup. - -import ( - "context" - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/coreapi" -) - -// panickingService panics during Start with the given message. -type panickingService struct{ msg string } - -func (p *panickingService) Name() string { return "panicker" } -func (p *panickingService) Order() int { return 100 } -func (p *panickingService) Start(_ context.Context, _ coreapi.Deps) error { panic(p.msg) } -func (p *panickingService) Stop(_ context.Context) error { return nil } - -func TestServiceRegistry_StartAllRecoversFromPluginPanic(t *testing.T) { - t.Parallel() - - sr := &coreapi.ServiceRegistry{} - if err := sr.Register(&panickingService{msg: "boom from a buggy plugin"}); err != nil { - t.Fatalf("Register: %v", err) - } - - // Without the recover wrapper, this CRASHES the test process. - err := sr.StartAll(context.Background(), coreapi.Deps{}) - - if err == nil { - t.Fatal("StartAll returned nil for panicking plugin — recover wrapper missing") - } - if !strings.Contains(err.Error(), "panic") { - t.Errorf("expected error to mention 'panic'; got %q", err.Error()) - } - if !strings.Contains(err.Error(), "boom from a buggy plugin") { - t.Errorf("expected error to include the panic message; got %q", err.Error()) - } -} diff --git a/pkg/coreapi/zz_recover_edge_test.go b/pkg/coreapi/zz_recover_edge_test.go deleted file mode 100644 index 78e279e9..00000000 --- a/pkg/coreapi/zz_recover_edge_test.go +++ /dev/null @@ -1,98 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi_test - -import ( - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/coreapi" -) - -func TestPluginRecoveredPanicCountAndReset(t *testing.T) { - // Not parallel — touches a package-level counter. - coreapi.ResetPluginRecoveredPanicCountForTest() - if got := coreapi.PluginRecoveredPanicCount(); got != 0 { - t.Fatalf("after reset = %d, want 0", got) - } - - // Induce a panic and let RecoverPlugin swallow it. - func() { - defer coreapi.RecoverPlugin("test-plugin", "test-op", nil, nil) - panic("synthetic") - }() - if got := coreapi.PluginRecoveredPanicCount(); got != 1 { - t.Errorf("after one panic = %d, want 1", got) - } - - // Another with onPanic callback exercised. - called := false - func() { - defer coreapi.RecoverPlugin("p2", "op", nil, func(_ any) { called = true }) - panic("two") - }() - if !called { - t.Errorf("onPanic callback not invoked") - } - if got := coreapi.PluginRecoveredPanicCount(); got != 2 { - t.Errorf("after two panics = %d, want 2", got) - } - - // Reset works after non-zero count. - coreapi.ResetPluginRecoveredPanicCountForTest() - if got := coreapi.PluginRecoveredPanicCount(); got != 0 { - t.Errorf("second reset = %d, want 0", got) - } -} - -func TestRecoverPlugin_NoPanicIsNoOp(t *testing.T) { - t.Parallel() - // The early-return path when recover() returns nil. No counter bump. - before := coreapi.PluginRecoveredPanicCount() - func() { - defer coreapi.RecoverPlugin("clean", "op", nil, nil) - }() - if got := coreapi.PluginRecoveredPanicCount(); got != before { - t.Errorf("counter changed without a panic: %d → %d", before, got) - } -} - -// fakeBusPanics publishes that itself panics — RecoverPlugin must -// shield itself from a nested publisher panic. -type fakeBusPanics struct{} - -func (fakeBusPanics) Publish(string, map[string]any) { panic("nested-publish-panic") } -func (fakeBusPanics) Subscribe(string) (<-chan coreapi.Event, func()) { return nil, func() {} } - -func TestRecoverPlugin_NestedPublishPanicSwallowed(t *testing.T) { - // Not parallel — touches counter. - coreapi.ResetPluginRecoveredPanicCountForTest() - defer func() { - if r := recover(); r != nil { - t.Fatalf("nested publish panic escaped: %v", r) - } - }() - func() { - defer coreapi.RecoverPlugin("p", "op", fakeBusPanics{}, nil) - panic("trigger") - }() - if got := coreapi.PluginRecoveredPanicCount(); got != 1 { - t.Errorf("counter = %d, want 1", got) - } -} - -func TestRecoverPlugin_NestedOnPanicSwallowed(t *testing.T) { - // Not parallel — touches counter. - coreapi.ResetPluginRecoveredPanicCountForTest() - defer func() { - if r := recover(); r != nil { - t.Fatalf("nested onPanic panic escaped: %v", r) - } - }() - func() { - defer coreapi.RecoverPlugin("p", "op", nil, func(_ any) { panic("nested-cb-panic") }) - panic("trigger") - }() - if got := coreapi.PluginRecoveredPanicCount(); got != 1 { - t.Errorf("counter = %d, want 1", got) - } -} diff --git a/pkg/coreapi/zz_recover_test.go b/pkg/coreapi/zz_recover_test.go deleted file mode 100644 index b8a40adf..00000000 --- a/pkg/coreapi/zz_recover_test.go +++ /dev/null @@ -1,130 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package coreapi - -import ( - "sync" - "testing" -) - -// fakeBus implements EventBus for the panic-survival test. Records -// every published topic so the test can assert the boundary emitted -// the expected event. -type fakeBus struct { - mu sync.Mutex - topics []string -} - -func (b *fakeBus) Publish(topic string, _ map[string]any) { - b.mu.Lock() - defer b.mu.Unlock() - b.topics = append(b.topics, topic) -} - -func (b *fakeBus) Subscribe(_ string) (<-chan Event, func()) { - ch := make(chan Event) - return ch, func() {} -} - -func (b *fakeBus) latest() []string { - b.mu.Lock() - defer b.mu.Unlock() - out := make([]string, len(b.topics)) - copy(out, b.topics) - return out -} - -// TestL11PluginPanicSurvival exercises the L11 boundary -// (RecoverPlugin) by inducing a panic in a goroutine guarded by it. -// Verifies: -// 1. process survives -// 2. PluginRecoveredPanicCount increments -// 3. a "plugin..panic" event lands on the bus -// 4. the onPanic callback fires with the panic value -func TestL11PluginPanicSurvival(t *testing.T) { - t.Parallel() - before := PluginRecoveredPanicCount() - bus := &fakeBus{} - - var ( - gotPanicValue any - callbackCount int - mu sync.Mutex - ) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - defer RecoverPlugin("testplugin", "acceptLoop", bus, func(r any) { - mu.Lock() - defer mu.Unlock() - gotPanicValue = r - callbackCount++ - }) - panic("L11 boundary test panic") - }() - wg.Wait() - - if PluginRecoveredPanicCount() <= before { - t.Fatal("L11 boundary did not record the panic") - } - - mu.Lock() - defer mu.Unlock() - if callbackCount != 1 { - t.Fatalf("onPanic callback fired %d times, want 1", callbackCount) - } - if s, ok := gotPanicValue.(string); !ok || s != "L11 boundary test panic" { - t.Fatalf("onPanic got %v (%T), want string 'L11 boundary test panic'", gotPanicValue, gotPanicValue) - } - - // Bus event should be "plugin.testplugin.panic". - found := false - for _, topic := range bus.latest() { - if topic == "plugin.testplugin.panic" { - found = true - break - } - } - if !found { - t.Fatalf("plugin.testplugin.panic event not on bus: got %v", bus.latest()) - } -} - -// TestL11PluginPanicNilBus confirms the boundary is nil-safe when no -// bus is provided (e.g., the standalone nameserver binary). -func TestL11PluginPanicNilBus(t *testing.T) { - t.Parallel() - before := PluginRecoveredPanicCount() - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - defer RecoverPlugin("nullbus", "op", nil, nil) - panic("nil-bus panic") - }() - wg.Wait() - if PluginRecoveredPanicCount() <= before { - t.Fatal("L11 boundary did not record nil-bus panic") - } -} - -// TestL11PluginPanicCallbackPanicSwallowed checks the defensive guard -// against a panicking onPanic callback. -func TestL11PluginPanicCallbackPanicSwallowed(t *testing.T) { - t.Parallel() - before := PluginRecoveredPanicCount() - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - defer RecoverPlugin("buggy", "op", nil, func(_ any) { - panic("callback-itself-panics") - }) - panic("primary panic") - }() - wg.Wait() - if PluginRecoveredPanicCount() <= before { - t.Fatal("L11 boundary did not record the primary panic") - } -} diff --git a/pkg/daemon/daemon.go b/pkg/daemon/daemon.go index 798c45c8..b527dd8c 100644 --- a/pkg/daemon/daemon.go +++ b/pkg/daemon/daemon.go @@ -25,11 +25,11 @@ import ( "github.com/TeoSlayer/pilotprotocol/internal/account" "github.com/TeoSlayer/pilotprotocol/internal/transport/compat" "github.com/TeoSlayer/pilotprotocol/internal/validate" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" - registrywire "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" "github.com/pilot-protocol/common/crypto" "github.com/pilot-protocol/common/fsutil" + "github.com/pilot-protocol/common/protocol" + registry "github.com/pilot-protocol/common/registry/client" + registrywire "github.com/pilot-protocol/common/registry/wire" "github.com/pilot-protocol/trustedagents" ) diff --git a/pkg/daemon/envelope/envelope.go b/pkg/daemon/envelope/envelope.go index 7b6cbfda..b176dfa3 100644 --- a/pkg/daemon/envelope/envelope.go +++ b/pkg/daemon/envelope/envelope.go @@ -33,7 +33,7 @@ import ( "sync/atomic" "github.com/TeoSlayer/pilotprotocol/pkg/daemon/keyexchange" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // DecryptResult is the outcome of DecryptFrame. diff --git a/pkg/daemon/envelope/zz_envelope_test.go b/pkg/daemon/envelope/zz_envelope_test.go index 69aa5394..fc57f7ae 100644 --- a/pkg/daemon/envelope/zz_envelope_test.go +++ b/pkg/daemon/envelope/zz_envelope_test.go @@ -15,7 +15,7 @@ import ( "github.com/TeoSlayer/pilotprotocol/pkg/daemon/envelope" "github.com/TeoSlayer/pilotprotocol/pkg/daemon/keyexchange" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // peerSetup holds two L5 Managers + Stores with mutually derived Crypto diff --git a/pkg/daemon/ipc.go b/pkg/daemon/ipc.go index 6bf8d04d..5513fc1d 100644 --- a/pkg/daemon/ipc.go +++ b/pkg/daemon/ipc.go @@ -15,8 +15,8 @@ import ( "sync" "time" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/ipcutil" + "github.com/pilot-protocol/common/protocol" ) // IPC commands (daemon ↔ driver) diff --git a/pkg/daemon/keyexchange/frame.go b/pkg/daemon/keyexchange/frame.go index ca6bb2f3..70a99edf 100644 --- a/pkg/daemon/keyexchange/frame.go +++ b/pkg/daemon/keyexchange/frame.go @@ -5,7 +5,7 @@ package keyexchange import ( "encoding/binary" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // Wire-format constants. Frozen — golden-corpus tested. diff --git a/pkg/daemon/keyexchange/zz_asymmetric_recovery_test.go b/pkg/daemon/keyexchange/zz_asymmetric_recovery_test.go index 73175e49..3bf21362 100644 --- a/pkg/daemon/keyexchange/zz_asymmetric_recovery_test.go +++ b/pkg/daemon/keyexchange/zz_asymmetric_recovery_test.go @@ -24,7 +24,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // frameRecorder captures frames sent via Manager.SetSender so tests can diff --git a/pkg/daemon/keyexchange/zz_keyexchange_test.go b/pkg/daemon/keyexchange/zz_keyexchange_test.go index 671ce7b1..6bfbe9a3 100644 --- a/pkg/daemon/keyexchange/zz_keyexchange_test.go +++ b/pkg/daemon/keyexchange/zz_keyexchange_test.go @@ -35,8 +35,8 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon/keyexchange" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" icrypto "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/protocol" ) // ---------- helpers ---------- diff --git a/pkg/daemon/managed.go b/pkg/daemon/managed.go index 2093bd8d..ce0db7d4 100644 --- a/pkg/daemon/managed.go +++ b/pkg/daemon/managed.go @@ -13,8 +13,8 @@ import ( "sync" "time" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" "github.com/pilot-protocol/common/fsutil" + registry "github.com/pilot-protocol/common/registry/wire" ) // ManagedEngine runs the managed network cycle for a single network. diff --git a/pkg/daemon/ports.go b/pkg/daemon/ports.go index 29132dda..06148098 100644 --- a/pkg/daemon/ports.go +++ b/pkg/daemon/ports.go @@ -11,7 +11,7 @@ import ( "sync" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // ErrEphemeralExhausted is returned by callers of AllocEphemeralPort when diff --git a/pkg/daemon/routing/beacon.go b/pkg/daemon/routing/beacon.go index af70921a..0b7c90a1 100644 --- a/pkg/daemon/routing/beacon.go +++ b/pkg/daemon/routing/beacon.go @@ -7,7 +7,7 @@ import ( "log/slog" "net" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // RegisterWithBeacon sends a MsgDiscover to the beacon from the tunnel diff --git a/pkg/daemon/routing/discover.go b/pkg/daemon/routing/discover.go index 8238c5c0..ca38bd2e 100644 --- a/pkg/daemon/routing/discover.go +++ b/pkg/daemon/routing/discover.go @@ -8,7 +8,7 @@ import ( "net" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // DiscoverEndpoint sends a STUN-style discover to the beacon over the diff --git a/pkg/daemon/routing/writeframe.go b/pkg/daemon/routing/writeframe.go index efcdcf23..60b27d21 100644 --- a/pkg/daemon/routing/writeframe.go +++ b/pkg/daemon/routing/writeframe.go @@ -8,7 +8,7 @@ import ( "net" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // WriteFrame ships a raw UDP frame to a peer. Routes through the beacon diff --git a/pkg/daemon/routing/zz_routing_test.go b/pkg/daemon/routing/zz_routing_test.go index fce2ccf1..d0589d9e 100644 --- a/pkg/daemon/routing/zz_routing_test.go +++ b/pkg/daemon/routing/zz_routing_test.go @@ -14,7 +14,7 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon/routing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // ---------------------------------------------------------------------------- diff --git a/pkg/daemon/services.go b/pkg/daemon/services.go index 7bc9c651..eaa7b316 100644 --- a/pkg/daemon/services.go +++ b/pkg/daemon/services.go @@ -12,7 +12,7 @@ import ( "sync/atomic" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // connAdapter wraps a daemon *Connection as a net.Conn so that existing diff --git a/pkg/daemon/tunnel.go b/pkg/daemon/tunnel.go index 81dc7610..34b99976 100644 --- a/pkg/daemon/tunnel.go +++ b/pkg/daemon/tunnel.go @@ -24,8 +24,8 @@ import ( "github.com/TeoSlayer/pilotprotocol/pkg/daemon/transport" wssTransport "github.com/TeoSlayer/pilotprotocol/pkg/daemon/transport/wss" "github.com/TeoSlayer/pilotprotocol/pkg/daemon/udpio" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/protocol" ) // Type aliases letting existing pkg/daemon code (tests + L5/L7) refer diff --git a/pkg/daemon/webhook_url.go b/pkg/daemon/webhook_url.go index a79f73f6..4d3da04d 100644 --- a/pkg/daemon/webhook_url.go +++ b/pkg/daemon/webhook_url.go @@ -2,7 +2,7 @@ package daemon -import "github.com/TeoSlayer/pilotprotocol/pkg/urlvalidate" +import "github.com/pilot-protocol/common/urlvalidate" // ValidateWebhookURL is a thin shim over urlvalidate.Validate. Lives // in pkg/daemon (rather than plugins/webhook) so the IPC handler in diff --git a/pkg/daemon/zz_accept_queue_bug_test.go b/pkg/daemon/zz_accept_queue_bug_test.go index 259084ce..9b794526 100644 --- a/pkg/daemon/zz_accept_queue_bug_test.go +++ b/pkg/daemon/zz_accept_queue_bug_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSYNDropOnFullAcceptCqueueIsInvisible reproduces the diff --git a/pkg/daemon/zz_accessors_test.go b/pkg/daemon/zz_accessors_test.go index 7e45b777..ed175fe9 100644 --- a/pkg/daemon/zz_accessors_test.go +++ b/pkg/daemon/zz_accessors_test.go @@ -8,8 +8,8 @@ import ( "reflect" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" + "github.com/pilot-protocol/common/protocol" + registry "github.com/pilot-protocol/common/registry/wire" ) // --- SetMemberTags / GetMemberTags --- diff --git a/pkg/daemon/zz_addrfamily_portinuse_test.go b/pkg/daemon/zz_addrfamily_portinuse_test.go index d6caecad..f38aead7 100644 --- a/pkg/daemon/zz_addrfamily_portinuse_test.go +++ b/pkg/daemon/zz_addrfamily_portinuse_test.go @@ -5,7 +5,7 @@ package daemon import ( "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --------------------------------------------------------------------------- diff --git a/pkg/daemon/zz_alloc_ephemeral_exhausted_bug_test.go b/pkg/daemon/zz_alloc_ephemeral_exhausted_bug_test.go index 5a2d0155..43cf5180 100644 --- a/pkg/daemon/zz_alloc_ephemeral_exhausted_bug_test.go +++ b/pkg/daemon/zz_alloc_ephemeral_exhausted_bug_test.go @@ -5,7 +5,7 @@ package daemon import ( "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestAllocEphemeralPortReturnsZeroOnUInt16Overflow verifies that diff --git a/pkg/daemon/zz_broadcast_test.go b/pkg/daemon/zz_broadcast_test.go index 3c91e8c6..b6e46050 100644 --- a/pkg/daemon/zz_broadcast_test.go +++ b/pkg/daemon/zz_broadcast_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/protocol" ) // broadcastFixture sets up a registry with a network that has the daemon's diff --git a/pkg/daemon/zz_close_connection_seq_race_test.go b/pkg/daemon/zz_close_connection_seq_race_test.go index eaf84350..003693d2 100644 --- a/pkg/daemon/zz_close_connection_seq_race_test.go +++ b/pkg/daemon/zz_close_connection_seq_race_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestConcurrentSendAndCloseConnectionDuplicateSeq verifies that a concurrent diff --git a/pkg/daemon/zz_coverage_pkg_daemon_round2_test.go b/pkg/daemon/zz_coverage_pkg_daemon_round2_test.go index 25a7ce97..7f142a25 100644 --- a/pkg/daemon/zz_coverage_pkg_daemon_round2_test.go +++ b/pkg/daemon/zz_coverage_pkg_daemon_round2_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - registrywire "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" "github.com/pilot-protocol/common/crypto" + registrywire "github.com/pilot-protocol/common/registry/wire" ) // Round-2 coverage push for pkg/daemon. Round-1 (zz_coverage_pkg_daemon_test.go) diff --git a/pkg/daemon/zz_coverage_pkg_daemon_test.go b/pkg/daemon/zz_coverage_pkg_daemon_test.go index 96848d0b..921c73b1 100644 --- a/pkg/daemon/zz_coverage_pkg_daemon_test.go +++ b/pkg/daemon/zz_coverage_pkg_daemon_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - registrywire "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" + registrywire "github.com/pilot-protocol/common/registry/wire" ) // This file targets ~32% of pkg/daemon statements that were uncovered at diff --git a/pkg/daemon/zz_daemon_packet_dispatch_test.go b/pkg/daemon/zz_daemon_packet_dispatch_test.go index 9be9cd07..ff8d10e0 100644 --- a/pkg/daemon/zz_daemon_packet_dispatch_test.go +++ b/pkg/daemon/zz_daemon_packet_dispatch_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + "github.com/pilot-protocol/common/protocol" + registry "github.com/pilot-protocol/common/registry/client" ) // newPacketDaemon fully wires a Daemon for packet-dispatch tests: real UDP diff --git a/pkg/daemon/zz_daemon_retx_test.go b/pkg/daemon/zz_daemon_retx_test.go index 311e0c0e..d670726f 100644 --- a/pkg/daemon/zz_daemon_retx_test.go +++ b/pkg/daemon/zz_daemon_retx_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // newDaemonRetxConn returns a Connection prepared for retransmission tests with diff --git a/pkg/daemon/zz_daemon_senddata_test.go b/pkg/daemon/zz_daemon_senddata_test.go index 384af9dd..41aae1a8 100644 --- a/pkg/daemon/zz_daemon_senddata_test.go +++ b/pkg/daemon/zz_daemon_senddata_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // readLargeFrame reads one UDP datagram with a 64KB buffer (sufficient for diff --git a/pkg/daemon/zz_dial_dedup_bug_test.go b/pkg/daemon/zz_dial_dedup_bug_test.go index 095e83e8..6fcddbca 100644 --- a/pkg/daemon/zz_dial_dedup_bug_test.go +++ b/pkg/daemon/zz_dial_dedup_bug_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestDialConcurrentToSamePeerCreatesIndependentSynSent pins the diff --git a/pkg/daemon/zz_dial_orphan_syn_sent_test.go b/pkg/daemon/zz_dial_orphan_syn_sent_test.go index 6425321f..9e0317d8 100644 --- a/pkg/daemon/zz_dial_orphan_syn_sent_test.go +++ b/pkg/daemon/zz_dial_orphan_syn_sent_test.go @@ -20,7 +20,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestDialRetryBudgetCleansUpOrphanSynSent dials an unresponsive peer diff --git a/pkg/daemon/zz_dial_precancelled_ctx_bug_test.go b/pkg/daemon/zz_dial_precancelled_ctx_bug_test.go index 5d647d9c..fe4cd8f0 100644 --- a/pkg/daemon/zz_dial_precancelled_ctx_bug_test.go +++ b/pkg/daemon/zz_dial_precancelled_ctx_bug_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestDialContextPreCancelledStillSendsSYN reproduces the diff --git a/pkg/daemon/zz_dup_ack_empty_unacked_recovery_exit_bug_test.go b/pkg/daemon/zz_dup_ack_empty_unacked_recovery_exit_bug_test.go index dd53b2b6..3d94724c 100644 --- a/pkg/daemon/zz_dup_ack_empty_unacked_recovery_exit_bug_test.go +++ b/pkg/daemon/zz_dup_ack_empty_unacked_recovery_exit_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRecoveryExitDeflatesWhenNoRecoveryEntered verifies that the fast-recovery diff --git a/pkg/daemon/zz_dup_ack_in_recovery_ssthresh_halving_bug_test.go b/pkg/daemon/zz_dup_ack_in_recovery_ssthresh_halving_bug_test.go index 6f1cf018..2197f54e 100644 --- a/pkg/daemon/zz_dup_ack_in_recovery_ssthresh_halving_bug_test.go +++ b/pkg/daemon/zz_dup_ack_in_recovery_ssthresh_halving_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestDupAckFastRetransmitInRecoveryDoesNotRehalveSSThresh verifies that when diff --git a/pkg/daemon/zz_dup_ack_in_timeout_recovery_additional_inflate_bug_test.go b/pkg/daemon/zz_dup_ack_in_timeout_recovery_additional_inflate_bug_test.go index d1d37ec3..74f4c4cc 100644 --- a/pkg/daemon/zz_dup_ack_in_timeout_recovery_additional_inflate_bug_test.go +++ b/pkg/daemon/zz_dup_ack_in_timeout_recovery_additional_inflate_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestAdditionalDupAckInTimeoutRecoveryDoesNotInflateConn verifies that when diff --git a/pkg/daemon/zz_dup_ack_in_timeout_recovery_cwnd_reinflation_bug_test.go b/pkg/daemon/zz_dup_ack_in_timeout_recovery_cwnd_reinflation_bug_test.go index d07a766c..a11711a6 100644 --- a/pkg/daemon/zz_dup_ack_in_timeout_recovery_cwnd_reinflation_bug_test.go +++ b/pkg/daemon/zz_dup_ack_in_timeout_recovery_cwnd_reinflation_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestDupAckInTimeoutRecoveryDoesNotReinflateConn verifies that when exactly diff --git a/pkg/daemon/zz_dup_ack_new_episode_in_recovery_ssthresh_bug_test.go b/pkg/daemon/zz_dup_ack_new_episode_in_recovery_ssthresh_bug_test.go index 01e006ab..4e61af59 100644 --- a/pkg/daemon/zz_dup_ack_new_episode_in_recovery_ssthresh_bug_test.go +++ b/pkg/daemon/zz_dup_ack_new_episode_in_recovery_ssthresh_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestNewEpisodeDupAcksInRecoveryHalveSSThresh verifies that when three dup ACKs diff --git a/pkg/daemon/zz_dup_ack_timeout_recovery_fast_recovery_flag_bug_test.go b/pkg/daemon/zz_dup_ack_timeout_recovery_fast_recovery_flag_bug_test.go index d276d160..5036ac75 100644 --- a/pkg/daemon/zz_dup_ack_timeout_recovery_fast_recovery_flag_bug_test.go +++ b/pkg/daemon/zz_dup_ack_timeout_recovery_fast_recovery_flag_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSameEpisodeDupAcksDoNotSetFastRecoveryFlag verifies that same-episode diff --git a/pkg/daemon/zz_ensuretunnel_test.go b/pkg/daemon/zz_ensuretunnel_test.go index 01f9c971..3ea20522 100644 --- a/pkg/daemon/zz_ensuretunnel_test.go +++ b/pkg/daemon/zz_ensuretunnel_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --- ensureTunnel --- diff --git a/pkg/daemon/zz_fast_recovery_cwnd_inflation_windowch_bug_test.go b/pkg/daemon/zz_fast_recovery_cwnd_inflation_windowch_bug_test.go index 9d236c93..b6e9dd2b 100644 --- a/pkg/daemon/zz_fast_recovery_cwnd_inflation_windowch_bug_test.go +++ b/pkg/daemon/zz_fast_recovery_cwnd_inflation_windowch_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRecoveryExtraACKInflationSignalsWindowCh verifies that the CongWin diff --git a/pkg/daemon/zz_fast_recovery_exit_cwnd_bug_test.go b/pkg/daemon/zz_fast_recovery_exit_cwnd_bug_test.go index ba76f360..62fa99f3 100644 --- a/pkg/daemon/zz_fast_recovery_exit_cwnd_bug_test.go +++ b/pkg/daemon/zz_fast_recovery_exit_cwnd_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRecoveryExitDeflatesCongWin verifies that when the first new ACK diff --git a/pkg/daemon/zz_fast_recovery_exit_deflation_noop_bug_test.go b/pkg/daemon/zz_fast_recovery_exit_deflation_noop_bug_test.go index 265d5b0d..c71e139c 100644 --- a/pkg/daemon/zz_fast_recovery_exit_deflation_noop_bug_test.go +++ b/pkg/daemon/zz_fast_recovery_exit_deflation_noop_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRecoveryExitDeflationRequiresActualRecovery verifies that the diff --git a/pkg/daemon/zz_fast_recovery_inflation_without_recovery_bug_test.go b/pkg/daemon/zz_fast_recovery_inflation_without_recovery_bug_test.go index 923cc589..5bc8ffd7 100644 --- a/pkg/daemon/zz_fast_recovery_inflation_without_recovery_bug_test.go +++ b/pkg/daemon/zz_fast_recovery_inflation_without_recovery_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRecoveryInflationRequiresInRecovery verifies that the DupAckCount > 3 diff --git a/pkg/daemon/zz_fast_recovery_partial_ack_aimd_inflation_bug_test.go b/pkg/daemon/zz_fast_recovery_partial_ack_aimd_inflation_bug_test.go index 4b399b98..6d5516a1 100644 --- a/pkg/daemon/zz_fast_recovery_partial_ack_aimd_inflation_bug_test.go +++ b/pkg/daemon/zz_fast_recovery_partial_ack_aimd_inflation_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRecoveryPartialAckNoAIMDInflation verifies that a partial ACK during diff --git a/pkg/daemon/zz_fast_recovery_post_partial_ack_dup_inflation_bug_test.go b/pkg/daemon/zz_fast_recovery_post_partial_ack_dup_inflation_bug_test.go index e2c70c64..2a076339 100644 --- a/pkg/daemon/zz_fast_recovery_post_partial_ack_dup_inflation_bug_test.go +++ b/pkg/daemon/zz_fast_recovery_post_partial_ack_dup_inflation_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRecoveryPostPartialAckDupAckInflation verifies that each duplicate diff --git a/pkg/daemon/zz_fast_recovery_third_dup_ack_same_episode_inflation_bug_test.go b/pkg/daemon/zz_fast_recovery_third_dup_ack_same_episode_inflation_bug_test.go index c1cf8987..2d092fdb 100644 --- a/pkg/daemon/zz_fast_recovery_third_dup_ack_same_episode_inflation_bug_test.go +++ b/pkg/daemon/zz_fast_recovery_third_dup_ack_same_episode_inflation_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRecoveryThirdDupAckSameEpisodeInflation verifies that the 3rd diff --git a/pkg/daemon/zz_fast_retransmit_entry_windowch_bug_test.go b/pkg/daemon/zz_fast_retransmit_entry_windowch_bug_test.go index 277530c0..5a855080 100644 --- a/pkg/daemon/zz_fast_retransmit_entry_windowch_bug_test.go +++ b/pkg/daemon/zz_fast_retransmit_entry_windowch_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRetransmitEntryInflatesWindowAndSignalsWindowCh verifies that when diff --git a/pkg/daemon/zz_fast_retransmit_fin_as_data_bug_test.go b/pkg/daemon/zz_fast_retransmit_fin_as_data_bug_test.go index a3984930..34a1baf7 100644 --- a/pkg/daemon/zz_fast_retransmit_fin_as_data_bug_test.go +++ b/pkg/daemon/zz_fast_retransmit_fin_as_data_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRetransmitFINEntryAsFlagACK verifies that when 3 dup ACKs fire fast diff --git a/pkg/daemon/zz_fast_retransmit_max_attempts_bug_test.go b/pkg/daemon/zz_fast_retransmit_max_attempts_bug_test.go index e185211c..5ff19799 100644 --- a/pkg/daemon/zz_fast_retransmit_max_attempts_bug_test.go +++ b/pkg/daemon/zz_fast_retransmit_max_attempts_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRetransmitStopsAtMaxAttempts verifies that fastRetransmit does NOT diff --git a/pkg/daemon/zz_fast_retransmit_noop_congestion_state_bug_test.go b/pkg/daemon/zz_fast_retransmit_noop_congestion_state_bug_test.go index f5e8646b..97cb9926 100644 --- a/pkg/daemon/zz_fast_retransmit_noop_congestion_state_bug_test.go +++ b/pkg/daemon/zz_fast_retransmit_noop_congestion_state_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRetransmitNoopDoesNotAdjustCongestionState verifies that when diff --git a/pkg/daemon/zz_fast_retransmit_sets_in_recovery_bug_test.go b/pkg/daemon/zz_fast_retransmit_sets_in_recovery_bug_test.go index da0ebf1f..2f6528ed 100644 --- a/pkg/daemon/zz_fast_retransmit_sets_in_recovery_bug_test.go +++ b/pkg/daemon/zz_fast_retransmit_sets_in_recovery_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRetransmitSetsInRecoveryPreventsDoubleSSThreshHalving verifies that diff --git a/pkg/daemon/zz_fin_ack_aimd_sentinel_byte_bug_test.go b/pkg/daemon/zz_fin_ack_aimd_sentinel_byte_bug_test.go index 8b606d04..c7af9a97 100644 --- a/pkg/daemon/zz_fin_ack_aimd_sentinel_byte_bug_test.go +++ b/pkg/daemon/zz_fin_ack_aimd_sentinel_byte_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFINAckDoesNotDriveAIMD verifies that when a FIN-ACK arrives and the diff --git a/pkg/daemon/zz_fin_retransmit_data_as_fin_bug_test.go b/pkg/daemon/zz_fin_retransmit_data_as_fin_bug_test.go index d60cb024..d306a43d 100644 --- a/pkg/daemon/zz_fin_retransmit_data_as_fin_bug_test.go +++ b/pkg/daemon/zz_fin_retransmit_data_as_fin_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestRetransmitUnackedFinWaitDataSentAsFIN verifies that when a data segment diff --git a/pkg/daemon/zz_handlers_test.go b/pkg/daemon/zz_handlers_test.go index 310ac388..df6a77ff 100644 --- a/pkg/daemon/zz_handlers_test.go +++ b/pkg/daemon/zz_handlers_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // setupDaemonWithPeer builds a Daemon that is wired to send packets to `peerConn`. diff --git a/pkg/daemon/zz_helpers_test.go b/pkg/daemon/zz_helpers_test.go index 977d2f09..5433b3f9 100644 --- a/pkg/daemon/zz_helpers_test.go +++ b/pkg/daemon/zz_helpers_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --------------------------------------------------------------------------- @@ -405,9 +405,9 @@ func TestProcessSACKPartialOverlap(t *testing.T) { // A segment at [200, 205) should be marked sacked even when // the SACK block only partially overlaps (e.g. [195, 203)). c := &Connection{} - c.TrackSend(100, []byte("hello")) // [100, 105) - c.TrackSend(200, []byte("world")) // [200, 205) - c.TrackSend(300, []byte("!!!")) // [300, 303) + c.TrackSend(100, []byte("hello")) // [100, 105) + c.TrackSend(200, []byte("world")) // [200, 205) + c.TrackSend(300, []byte("!!!")) // [300, 303) // SACK block partially overlaps [200, 205) from the left c.ProcessSACK([]SACKBlock{{Left: 195, Right: 203}}) diff --git a/pkg/daemon/zz_ipc_async_write_test.go b/pkg/daemon/zz_ipc_async_write_test.go index c8f7b812..a7943510 100644 --- a/pkg/daemon/zz_ipc_async_write_test.go +++ b/pkg/daemon/zz_ipc_async_write_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" + "github.com/pilot-protocol/common/ipcutil" ) // pairedConn returns two ends of a Unix-style pipe wrapped as net.Conn so we diff --git a/pkg/daemon/zz_ipc_bind_accept_test.go b/pkg/daemon/zz_ipc_bind_accept_test.go index 53841f3b..f4b0d181 100644 --- a/pkg/daemon/zz_ipc_bind_accept_test.go +++ b/pkg/daemon/zz_ipc_bind_accept_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/ipcutil" + "github.com/pilot-protocol/common/protocol" ) // Iter-110 coverage for the remaining gaps in ipc.go handleBind (53.6%), diff --git a/pkg/daemon/zz_ipc_conncount_stale_bug_test.go b/pkg/daemon/zz_ipc_conncount_stale_bug_test.go index e5dcf00b..28bdb014 100644 --- a/pkg/daemon/zz_ipc_conncount_stale_bug_test.go +++ b/pkg/daemon/zz_ipc_conncount_stale_bug_test.go @@ -6,7 +6,7 @@ import ( "net" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestIPCConnCountIncludesClosedConns verifies that connCount() reflects diff --git a/pkg/daemon/zz_ipc_helpers_test.go b/pkg/daemon/zz_ipc_helpers_test.go index 72598e9d..eba0f24e 100644 --- a/pkg/daemon/zz_ipc_helpers_test.go +++ b/pkg/daemon/zz_ipc_helpers_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/ipcutil" + "github.com/pilot-protocol/common/protocol" ) // newPipePair returns a connected pair. The server side is wrapped in ipcConn diff --git a/pkg/daemon/zz_ipc_recv_pusher_test.go b/pkg/daemon/zz_ipc_recv_pusher_test.go index d2f8b129..87aef5f0 100644 --- a/pkg/daemon/zz_ipc_recv_pusher_test.go +++ b/pkg/daemon/zz_ipc_recv_pusher_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" + "github.com/pilot-protocol/common/ipcutil" ) // Iter-109 coverage for startRecvPusher (0% baseline) — the goroutine diff --git a/pkg/daemon/zz_ipc_simple_handlers_test.go b/pkg/daemon/zz_ipc_simple_handlers_test.go index b044e90b..f524ca3c 100644 --- a/pkg/daemon/zz_ipc_simple_handlers_test.go +++ b/pkg/daemon/zz_ipc_simple_handlers_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registry "github.com/pilot-protocol/common/registry/client" ) // newSimpleHandlerDaemon wires the minimum set of fields needed to exercise diff --git a/pkg/daemon/zz_ipc_socket_lifecycle_test.go b/pkg/daemon/zz_ipc_socket_lifecycle_test.go index e4e6d93a..f7e9b575 100644 --- a/pkg/daemon/zz_ipc_socket_lifecycle_test.go +++ b/pkg/daemon/zz_ipc_socket_lifecycle_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" + "github.com/pilot-protocol/common/ipcutil" ) // macOS sun_path is ~104 bytes; Go test temp dirs blow past it. Mint short diff --git a/pkg/daemon/zz_ipc_test.go b/pkg/daemon/zz_ipc_test.go index 71d05420..82f8ad52 100644 --- a/pkg/daemon/zz_ipc_test.go +++ b/pkg/daemon/zz_ipc_test.go @@ -10,9 +10,9 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/ipcutil" + "github.com/pilot-protocol/common/protocol" ) // newIPCTestConn returns (serverConn, clientConn). Wrap serverConn in ipcConn diff --git a/pkg/daemon/zz_keepalive_fin_not_rst_test.go b/pkg/daemon/zz_keepalive_fin_not_rst_test.go index 346deb2d..17aafd71 100644 --- a/pkg/daemon/zz_keepalive_fin_not_rst_test.go +++ b/pkg/daemon/zz_keepalive_fin_not_rst_test.go @@ -27,7 +27,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestKeepaliveDeadPeerSendsFinNotRst is the smoking-gun regression: diff --git a/pkg/daemon/zz_keepalive_zero_window_probe_bug_test.go b/pkg/daemon/zz_keepalive_zero_window_probe_bug_test.go index e71c5489..7a274dbb 100644 --- a/pkg/daemon/zz_keepalive_zero_window_probe_bug_test.go +++ b/pkg/daemon/zz_keepalive_zero_window_probe_bug_test.go @@ -5,7 +5,7 @@ package daemon import ( "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestKeepaliveProbeWithWindowDoesNotStallPeer verifies that a keepalive probe diff --git a/pkg/daemon/zz_managed_test.go b/pkg/daemon/zz_managed_test.go index 47f36dc3..6a199bec 100644 --- a/pkg/daemon/zz_managed_test.go +++ b/pkg/daemon/zz_managed_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" + registry "github.com/pilot-protocol/common/registry/wire" ) func testRules() *registry.NetworkRules { diff --git a/pkg/daemon/zz_nagle_all_sacked_hasunacked_bug_test.go b/pkg/daemon/zz_nagle_all_sacked_hasunacked_bug_test.go index 9cce08b8..19c69001 100644 --- a/pkg/daemon/zz_nagle_all_sacked_hasunacked_bug_test.go +++ b/pkg/daemon/zz_nagle_all_sacked_hasunacked_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestNagleFlushAllSackedSendsImmediately verifies that nagleFlush does NOT diff --git a/pkg/daemon/zz_nagle_buf_bug_test.go b/pkg/daemon/zz_nagle_buf_bug_test.go index 6e0647f1..a3a95a44 100644 --- a/pkg/daemon/zz_nagle_buf_bug_test.go +++ b/pkg/daemon/zz_nagle_buf_bug_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSendDataNagleBufGrowsUnbounded reproduces the NagleBuf-OOM bug. diff --git a/pkg/daemon/zz_nat_keepalive_bug_test.go b/pkg/daemon/zz_nat_keepalive_bug_test.go index 18aac69c..eac8c59d 100644 --- a/pkg/daemon/zz_nat_keepalive_bug_test.go +++ b/pkg/daemon/zz_nat_keepalive_bug_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestKeepaliveAbsentForIdlePeer reproduces the "NAT mapping idle-times diff --git a/pkg/daemon/zz_panic_survival_test.go b/pkg/daemon/zz_panic_survival_test.go index 257852da..506e07c5 100644 --- a/pkg/daemon/zz_panic_survival_test.go +++ b/pkg/daemon/zz_panic_survival_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestL1UnmarshalPanicSurvival drives protocol.Unmarshal through its diff --git a/pkg/daemon/zz_peer_recv_win_growth_windowch_bug_test.go b/pkg/daemon/zz_peer_recv_win_growth_windowch_bug_test.go index 520060bd..9ba1b85e 100644 --- a/pkg/daemon/zz_peer_recv_win_growth_windowch_bug_test.go +++ b/pkg/daemon/zz_peer_recv_win_growth_windowch_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestPeerRecvWinGrowthSignalsWindowCh verifies that when a pure window diff --git a/pkg/daemon/zz_ports_logic_test.go b/pkg/daemon/zz_ports_logic_test.go index 0918d7cf..0f53d5a0 100644 --- a/pkg/daemon/zz_ports_logic_test.go +++ b/pkg/daemon/zz_ports_logic_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --- DecodeSACK --- diff --git a/pkg/daemon/zz_process_ack_partial_nagle_ch_bug_test.go b/pkg/daemon/zz_process_ack_partial_nagle_ch_bug_test.go index 102d5d9a..76436751 100644 --- a/pkg/daemon/zz_process_ack_partial_nagle_ch_bug_test.go +++ b/pkg/daemon/zz_process_ack_partial_nagle_ch_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestProcessAckPartialRemovesLastUnsackedSignalsNagleCh verifies that when diff --git a/pkg/daemon/zz_process_ack_resets_sack_state_bug_test.go b/pkg/daemon/zz_process_ack_resets_sack_state_bug_test.go index 936b6dc8..4a1f4f0f 100644 --- a/pkg/daemon/zz_process_ack_resets_sack_state_bug_test.go +++ b/pkg/daemon/zz_process_ack_resets_sack_state_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestProcessAckPartialDoesNotResetSACKedState verifies that when a partial diff --git a/pkg/daemon/zz_process_ack_zero_guard_bug_test.go b/pkg/daemon/zz_process_ack_zero_guard_bug_test.go index 1e3d8227..8f7eef17 100644 --- a/pkg/daemon/zz_process_ack_zero_guard_bug_test.go +++ b/pkg/daemon/zz_process_ack_zero_guard_bug_test.go @@ -5,7 +5,7 @@ package daemon import ( "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestHandleStreamPacketZeroAckSkipsWraparound verifies that an ACK packet diff --git a/pkg/daemon/zz_process_sack_wraparound_bug_test.go b/pkg/daemon/zz_process_sack_wraparound_bug_test.go index 9c48d72d..378ee66e 100644 --- a/pkg/daemon/zz_process_sack_wraparound_bug_test.go +++ b/pkg/daemon/zz_process_sack_wraparound_bug_test.go @@ -5,7 +5,7 @@ package daemon import ( "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestProcessSACKWraparound verifies that ProcessSACK correctly marks diff --git a/pkg/daemon/zz_retransmit_resets_dup_ack_count_bug_test.go b/pkg/daemon/zz_retransmit_resets_dup_ack_count_bug_test.go index 8aeef4a2..c5aa2dab 100644 --- a/pkg/daemon/zz_retransmit_resets_dup_ack_count_bug_test.go +++ b/pkg/daemon/zz_retransmit_resets_dup_ack_count_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestRetransmitUnackedResetsDupAckCount verifies that retransmitUnacked resets diff --git a/pkg/daemon/zz_retransmit_timeout_ssthresh_flightsize_bug_test.go b/pkg/daemon/zz_retransmit_timeout_ssthresh_flightsize_bug_test.go index 0dc68d85..46d2e0ad 100644 --- a/pkg/daemon/zz_retransmit_timeout_ssthresh_flightsize_bug_test.go +++ b/pkg/daemon/zz_retransmit_timeout_ssthresh_flightsize_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestTimeoutSSThreshUsesFlightSizeNotCongWin verifies that when the RTO diff --git a/pkg/daemon/zz_retx_test.go b/pkg/daemon/zz_retx_test.go index 4d872c83..3b096cfe 100644 --- a/pkg/daemon/zz_retx_test.go +++ b/pkg/daemon/zz_retx_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --- SetWebhookURL --- diff --git a/pkg/daemon/zz_rto_backoff_in_recovery_bug_test.go b/pkg/daemon/zz_rto_backoff_in_recovery_bug_test.go index 6f7fc9f4..1abd7239 100644 --- a/pkg/daemon/zz_rto_backoff_in_recovery_bug_test.go +++ b/pkg/daemon/zz_rto_backoff_in_recovery_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestRetransmitUnackedDoublesRTOOnEachTimeout verifies that the diff --git a/pkg/daemon/zz_rtt_multiple_samples_per_ack_bug_test.go b/pkg/daemon/zz_rtt_multiple_samples_per_ack_bug_test.go index 707d5b33..38101fb3 100644 --- a/pkg/daemon/zz_rtt_multiple_samples_per_ack_bug_test.go +++ b/pkg/daemon/zz_rtt_multiple_samples_per_ack_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestProcessAckTakesOnlyOneRTTSamplePerACK verifies that a single cumulative diff --git a/pkg/daemon/zz_rtt_sacked_segment_skipped_bug_test.go b/pkg/daemon/zz_rtt_sacked_segment_skipped_bug_test.go index 0fdc959a..215cac51 100644 --- a/pkg/daemon/zz_rtt_sacked_segment_skipped_bug_test.go +++ b/pkg/daemon/zz_rtt_sacked_segment_skipped_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestRTTUpdateSkippedForSackedSegments verifies that a once-sent segment diff --git a/pkg/daemon/zz_sack_all_sacked_cumulative_ack_no_aimd_bug_test.go b/pkg/daemon/zz_sack_all_sacked_cumulative_ack_no_aimd_bug_test.go index b1242e18..397b726b 100644 --- a/pkg/daemon/zz_sack_all_sacked_cumulative_ack_no_aimd_bug_test.go +++ b/pkg/daemon/zz_sack_all_sacked_cumulative_ack_no_aimd_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSACKAllSackedCumulativeAckDoesNotGrowCongWin verifies that when a diff --git a/pkg/daemon/zz_sack_all_sacked_nagle_ch_bug_test.go b/pkg/daemon/zz_sack_all_sacked_nagle_ch_bug_test.go index d89e6f97..871b73da 100644 --- a/pkg/daemon/zz_sack_all_sacked_nagle_ch_bug_test.go +++ b/pkg/daemon/zz_sack_all_sacked_nagle_ch_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestProcessSACKAllSackedSignalsNagleCh verifies that when ProcessSACK marks diff --git a/pkg/daemon/zz_sack_all_sacked_spurious_fast_retx_bug_test.go b/pkg/daemon/zz_sack_all_sacked_spurious_fast_retx_bug_test.go index 133b7fa8..6fa980af 100644 --- a/pkg/daemon/zz_sack_all_sacked_spurious_fast_retx_bug_test.go +++ b/pkg/daemon/zz_sack_all_sacked_spurious_fast_retx_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestAllSACKedUnackedDoesNotTriggerSpuriousFastRetransmit verifies that diff --git a/pkg/daemon/zz_sack_blocks_truncation_bug_test.go b/pkg/daemon/zz_sack_blocks_truncation_bug_test.go index 6a2b4d56..89f98ca1 100644 --- a/pkg/daemon/zz_sack_blocks_truncation_bug_test.go +++ b/pkg/daemon/zz_sack_blocks_truncation_bug_test.go @@ -5,7 +5,7 @@ package daemon import ( "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSACKBlocksTruncationKeepsHighestSeq verifies that when OOOBuf holds diff --git a/pkg/daemon/zz_sack_blocks_wraparound_bug_test.go b/pkg/daemon/zz_sack_blocks_wraparound_bug_test.go index 94d96c10..fd01a336 100644 --- a/pkg/daemon/zz_sack_blocks_wraparound_bug_test.go +++ b/pkg/daemon/zz_sack_blocks_wraparound_bug_test.go @@ -5,7 +5,7 @@ package daemon import ( "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSACKBlocksWraparound verifies that SACKBlocks correctly builds separate diff --git a/pkg/daemon/zz_sack_bytes_in_flight_windowch_bug_test.go b/pkg/daemon/zz_sack_bytes_in_flight_windowch_bug_test.go index 19f875fd..b7d2ba18 100644 --- a/pkg/daemon/zz_sack_bytes_in_flight_windowch_bug_test.go +++ b/pkg/daemon/zz_sack_bytes_in_flight_windowch_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSACKReducesBytesInFlightAndSignalsWindowCh verifies that after diff --git a/pkg/daemon/zz_sack_cumulative_ack_aimd_overcounting_bug_test.go b/pkg/daemon/zz_sack_cumulative_ack_aimd_overcounting_bug_test.go index c2ae7065..b2d07b07 100644 --- a/pkg/daemon/zz_sack_cumulative_ack_aimd_overcounting_bug_test.go +++ b/pkg/daemon/zz_sack_cumulative_ack_aimd_overcounting_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSACKCumulativeAckDoesNotInflateBytesAcked verifies that when a diff --git a/pkg/daemon/zz_sack_dup_ack_pureack_bug_test.go b/pkg/daemon/zz_sack_dup_ack_pureack_bug_test.go index da87d246..240f0880 100644 --- a/pkg/daemon/zz_sack_dup_ack_pureack_bug_test.go +++ b/pkg/daemon/zz_sack_dup_ack_pureack_bug_test.go @@ -5,7 +5,7 @@ package daemon import ( "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSACKPayloadSuppressesDupACKCounting verifies that a SACK-carrying ACK diff --git a/pkg/daemon/zz_send_segment_error_tracksend_bug_test.go b/pkg/daemon/zz_send_segment_error_tracksend_bug_test.go index ce1021ed..c6f5d9d9 100644 --- a/pkg/daemon/zz_send_segment_error_tracksend_bug_test.go +++ b/pkg/daemon/zz_send_segment_error_tracksend_bug_test.go @@ -5,7 +5,7 @@ package daemon import ( "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSendSegmentTunnelErrorLosesData verifies that a segment whose tunnel diff --git a/pkg/daemon/zz_send_segment_seq_race_test.go b/pkg/daemon/zz_send_segment_seq_race_test.go index c94778e5..220479bc 100644 --- a/pkg/daemon/zz_send_segment_seq_race_test.go +++ b/pkg/daemon/zz_send_segment_seq_race_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestConcurrentNoDelaySendDuplicateSeq verifies that two goroutines calling diff --git a/pkg/daemon/zz_send_test.go b/pkg/daemon/zz_send_test.go index 07dc2cb9..e40a96e2 100644 --- a/pkg/daemon/zz_send_test.go +++ b/pkg/daemon/zz_send_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --- routeLoop --- diff --git a/pkg/daemon/zz_sendbuf_caller_bug_test.go b/pkg/daemon/zz_sendbuf_caller_bug_test.go index 4aa2b484..65f72950 100644 --- a/pkg/daemon/zz_sendbuf_caller_bug_test.go +++ b/pkg/daemon/zz_sendbuf_caller_bug_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestConnAdapterWriteSurfacesErrSendBufFullToCaller reproduces the diff --git a/pkg/daemon/zz_senddata_test.go b/pkg/daemon/zz_senddata_test.go index 0505f28f..9cf82771 100644 --- a/pkg/daemon/zz_senddata_test.go +++ b/pkg/daemon/zz_senddata_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --- helpers --- diff --git a/pkg/daemon/zz_sendpath_test.go b/pkg/daemon/zz_sendpath_test.go index e1b322cb..a35761d1 100644 --- a/pkg/daemon/zz_sendpath_test.go +++ b/pkg/daemon/zz_sendpath_test.go @@ -7,9 +7,9 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/TeoSlayer/pilotprotocol/tests/regtestutil" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/protocol" ) // startTestRegistry is a thin alias for regtestutil.StartTestRegistry so diff --git a/pkg/daemon/zz_services_bootstrap_test.go b/pkg/daemon/zz_services_bootstrap_test.go index 71df3c3a..a80482d7 100644 --- a/pkg/daemon/zz_services_bootstrap_test.go +++ b/pkg/daemon/zz_services_bootstrap_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // Iter-112 coverage for the built-in service bootstrap path at services.go: diff --git a/pkg/daemon/zz_services_builtins_test.go b/pkg/daemon/zz_services_builtins_test.go index eeb7b319..c2f015a7 100644 --- a/pkg/daemon/zz_services_builtins_test.go +++ b/pkg/daemon/zz_services_builtins_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // stopDaemonOnce closes d.stopCh safely from t.Cleanup and waits briefly so diff --git a/pkg/daemon/zz_services_logic_test.go b/pkg/daemon/zz_services_logic_test.go index 7a4d41bd..7cb32f63 100644 --- a/pkg/daemon/zz_services_logic_test.go +++ b/pkg/daemon/zz_services_logic_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --- pilotAddr --- diff --git a/pkg/daemon/zz_services_test.go b/pkg/daemon/zz_services_test.go index 04259e0d..f12c5513 100644 --- a/pkg/daemon/zz_services_test.go +++ b/pkg/daemon/zz_services_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // ---------- pilotAddr ---------- diff --git a/pkg/daemon/zz_ssthresh_congwin_vs_flightsize_bug_test.go b/pkg/daemon/zz_ssthresh_congwin_vs_flightsize_bug_test.go index 7fca1eda..fb3f2084 100644 --- a/pkg/daemon/zz_ssthresh_congwin_vs_flightsize_bug_test.go +++ b/pkg/daemon/zz_ssthresh_congwin_vs_flightsize_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRetransmitSSThreshUsesFlightSizeNotCongWin verifies that when 3 dup diff --git a/pkg/daemon/zz_ssthresh_floor_two_mss_bug_test.go b/pkg/daemon/zz_ssthresh_floor_two_mss_bug_test.go index 33486c44..3f4e8b6b 100644 --- a/pkg/daemon/zz_ssthresh_floor_two_mss_bug_test.go +++ b/pkg/daemon/zz_ssthresh_floor_two_mss_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFastRetransmitSSThreshFloorTwoSMSS verifies that when 3 dup ACKs trigger diff --git a/pkg/daemon/zz_startmanaged_test.go b/pkg/daemon/zz_startmanaged_test.go index 83eac48e..23f22ee9 100644 --- a/pkg/daemon/zz_startmanaged_test.go +++ b/pkg/daemon/zz_startmanaged_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" "github.com/pilot-protocol/common/crypto" + registry "github.com/pilot-protocol/common/registry/wire" ) // --- startManaged happy-path --- diff --git a/pkg/daemon/zz_streampacket_test.go b/pkg/daemon/zz_streampacket_test.go index 337dd964..a4708d17 100644 --- a/pkg/daemon/zz_streampacket_test.go +++ b/pkg/daemon/zz_streampacket_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --- helpers --- diff --git a/pkg/daemon/zz_syn_timewait_port_reuse_bug_test.go b/pkg/daemon/zz_syn_timewait_port_reuse_bug_test.go index d1d812a1..7cc91f08 100644 --- a/pkg/daemon/zz_syn_timewait_port_reuse_bug_test.go +++ b/pkg/daemon/zz_syn_timewait_port_reuse_bug_test.go @@ -5,7 +5,7 @@ package daemon import ( "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestSYNToTimeWaitConnBlocksNewConnection verifies that a fresh SYN on a diff --git a/pkg/daemon/zz_test_helpers_registry_test.go b/pkg/daemon/zz_test_helpers_registry_test.go index 7b023987..25f99af1 100644 --- a/pkg/daemon/zz_test_helpers_registry_test.go +++ b/pkg/daemon/zz_test_helpers_registry_test.go @@ -8,8 +8,8 @@ import ( "sync" "testing" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + "github.com/pilot-protocol/common/ipcutil" + registry "github.com/pilot-protocol/common/registry/client" ) // startFakeRegistry stands up an in-process TCP listener that speaks the diff --git a/pkg/daemon/zz_throughput_bench_test.go b/pkg/daemon/zz_throughput_bench_test.go index ab2f7c87..25236f0f 100644 --- a/pkg/daemon/zz_throughput_bench_test.go +++ b/pkg/daemon/zz_throughput_bench_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // Throughput benchmarks for the Pilot congestion control stack. diff --git a/pkg/daemon/zz_timeout_cwnd_reset_bug_test.go b/pkg/daemon/zz_timeout_cwnd_reset_bug_test.go index de3b8780..5aeaaf02 100644 --- a/pkg/daemon/zz_timeout_cwnd_reset_bug_test.go +++ b/pkg/daemon/zz_timeout_cwnd_reset_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestTimeoutResetsCongWinTo1SMSS verifies that when the retransmission timer diff --git a/pkg/daemon/zz_timeout_ssthresh_in_recovery_bug_test.go b/pkg/daemon/zz_timeout_ssthresh_in_recovery_bug_test.go index db41cd1b..a626d5ba 100644 --- a/pkg/daemon/zz_timeout_ssthresh_in_recovery_bug_test.go +++ b/pkg/daemon/zz_timeout_ssthresh_in_recovery_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestTimeoutDuringFastRecoveryRecomputesSSThresh verifies that when the diff --git a/pkg/daemon/zz_tunnel_beacon_test.go b/pkg/daemon/zz_tunnel_beacon_test.go index 5f11a5f2..ac6e9b3e 100644 --- a/pkg/daemon/zz_tunnel_beacon_test.go +++ b/pkg/daemon/zz_tunnel_beacon_test.go @@ -7,7 +7,7 @@ import ( "net" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) func TestRegisterWithBeaconNilBeaconAddrIsNoop(t *testing.T) { diff --git a/pkg/daemon/zz_tunnel_frames_test.go b/pkg/daemon/zz_tunnel_frames_test.go index df604420..81fd9ae2 100644 --- a/pkg/daemon/zz_tunnel_frames_test.go +++ b/pkg/daemon/zz_tunnel_frames_test.go @@ -10,8 +10,8 @@ import ( "encoding/binary" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/protocol" ) func TestBuildKeyExchangeFrameWithoutPubKeyReturnsNil(t *testing.T) { diff --git a/pkg/daemon/zz_tunnel_handle_test.go b/pkg/daemon/zz_tunnel_handle_test.go index 05216d18..299c997c 100644 --- a/pkg/daemon/zz_tunnel_handle_test.go +++ b/pkg/daemon/zz_tunnel_handle_test.go @@ -14,8 +14,8 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon/udpio" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/protocol" ) // --- handleKeyExchange --- diff --git a/pkg/daemon/zz_tunnel_send_test.go b/pkg/daemon/zz_tunnel_send_test.go index a131b85e..ce7c0dc3 100644 --- a/pkg/daemon/zz_tunnel_send_test.go +++ b/pkg/daemon/zz_tunnel_send_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) func newPacket(payload string) *protocol.Packet { diff --git a/pkg/daemon/zz_window_update_dup_ack_count_bug_test.go b/pkg/daemon/zz_window_update_dup_ack_count_bug_test.go index f1178a97..7f68aa77 100644 --- a/pkg/daemon/zz_window_update_dup_ack_count_bug_test.go +++ b/pkg/daemon/zz_window_update_dup_ack_count_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestWindowUpdateDoesNotIncrementDupAckCount verifies that a pure window diff --git a/pkg/daemon/zz_window_update_wakeup_bug_test.go b/pkg/daemon/zz_window_update_wakeup_bug_test.go index ec5eb9f8..00b96613 100644 --- a/pkg/daemon/zz_window_update_wakeup_bug_test.go +++ b/pkg/daemon/zz_window_update_wakeup_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestWindowUpdateDoesNotWakeSender verifies that when the peer transitions diff --git a/pkg/daemon/zz_wire_helpers_test.go b/pkg/daemon/zz_wire_helpers_test.go index 038524b3..b2b5a5c3 100644 --- a/pkg/daemon/zz_wire_helpers_test.go +++ b/pkg/daemon/zz_wire_helpers_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + "github.com/pilot-protocol/common/protocol" + registry "github.com/pilot-protocol/common/registry/client" ) // newWireDaemon wires a Daemon with a tunnel bound to a real 127.0.0.1 UDP diff --git a/pkg/driver/conn.go b/pkg/driver/conn.go deleted file mode 100644 index d8a267b6..00000000 --- a/pkg/driver/conn.go +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package driver - -import ( - "encoding/binary" - "io" - "net" - "os" - "sync" - "time" - - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" -) - -// maxSendChunk is the largest payload we will pack into one cmdSend IPC -// message. IPC messages are capped at ipcutil.MaxMessageSize; we reserve -// 5 bytes for the cmdSend+conn_id header and leave a small safety margin. -const maxSendChunk = ipcutil.MaxMessageSize - 64 - -// Conn implements net.Conn over a Pilot Protocol stream. -type Conn struct { - id uint32 - localAddr protocol.SocketAddr - remoteAddr protocol.SocketAddr - ipc *ipcClient - recvCh chan []byte - recvBuf []byte // leftover from previous read - closed bool - - mu sync.Mutex - readDeadline time.Time - deadlineCh chan struct{} // closed when deadline is set/changed -} - -func (c *Conn) Read(b []byte) (int, error) { - // Drain leftover first - if len(c.recvBuf) > 0 { - n := copy(b, c.recvBuf) - c.recvBuf = c.recvBuf[n:] - return n, nil - } - - c.mu.Lock() - dl := c.readDeadline - dch := c.deadlineCh - c.mu.Unlock() - - // Check if deadline already passed - if !dl.IsZero() && !time.Now().Before(dl) { - return 0, os.ErrDeadlineExceeded - } - - // Set up timer if deadline is set - var timer <-chan time.Time - if !dl.IsZero() { - t := time.NewTimer(time.Until(dl)) - defer t.Stop() - timer = t.C - } - - select { - case data, ok := <-c.recvCh: - if !ok { - return 0, io.EOF - } - n := copy(b, data) - if n < len(data) { - c.recvBuf = data[n:] - } - return n, nil - case <-timer: - return 0, os.ErrDeadlineExceeded - case <-dch: - // Deadline was changed, re-check - return 0, os.ErrDeadlineExceeded - } -} - -func (c *Conn) Write(b []byte) (int, error) { - c.mu.Lock() - if c.closed { - c.mu.Unlock() - return 0, protocol.ErrConnClosed - } - c.mu.Unlock() - - total := len(b) - written := 0 - for written < total { - chunk := total - written - if chunk > maxSendChunk { - chunk = maxSendChunk - } - msg := make([]byte, 1+4+chunk) - msg[0] = cmdSend - binary.BigEndian.PutUint32(msg[1:5], c.id) - copy(msg[5:], b[written:written+chunk]) - if err := c.ipc.send(msg); err != nil { - return written, err - } - written += chunk - } - return written, nil -} - -func (c *Conn) Close() error { - c.mu.Lock() - if c.closed { - c.mu.Unlock() - return nil - } - c.closed = true - c.mu.Unlock() - c.ipc.unregisterRecvCh(c.id) - - msg := make([]byte, 5) - msg[0] = cmdClose - binary.BigEndian.PutUint32(msg[1:5], c.id) - return c.ipc.send(msg) -} - -func (c *Conn) LocalAddr() net.Addr { return pilotAddr(c.localAddr) } -func (c *Conn) RemoteAddr() net.Addr { return pilotAddr(c.remoteAddr) } - -func (c *Conn) SetDeadline(t time.Time) error { - c.SetReadDeadline(t) - return nil -} - -func (c *Conn) SetReadDeadline(t time.Time) error { - c.mu.Lock() - c.readDeadline = t - // Signal any blocked Read to re-check - if c.deadlineCh != nil { - close(c.deadlineCh) - } - c.deadlineCh = make(chan struct{}) - c.mu.Unlock() - return nil -} - -func (c *Conn) SetWriteDeadline(t time.Time) error { return nil } - -// pilotAddr wraps SocketAddr to satisfy net.Addr. -type pilotAddr protocol.SocketAddr - -func (a pilotAddr) Network() string { return "pilot" } -func (a pilotAddr) String() string { return protocol.SocketAddr(a).String() } diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go deleted file mode 100644 index 5d2b6208..00000000 --- a/pkg/driver/driver.go +++ /dev/null @@ -1,495 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package driver - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "os" - "path/filepath" - "runtime" - "time" - - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" -) - -// DefaultSocketPath returns the default Unix socket path for IPC. -// On Linux it prefers $XDG_RUNTIME_DIR (typically /run/user/, -// which is private to the user); falls back to /tmp/pilot.sock. -// On macOS /tmp is already per-user via SIP, so /tmp/pilot.sock is safe. -func DefaultSocketPath() string { - if runtime.GOOS == "linux" { - if xdg := os.Getenv("XDG_RUNTIME_DIR"); xdg != "" { - return filepath.Join(xdg, "pilot.sock") - } - } - return "/tmp/pilot.sock" -} - -// Handshake sub-commands (must match daemon SubHandshake* constants) -const ( - subHandshakeSend byte = 0x01 - subHandshakeApprove byte = 0x02 - subHandshakeReject byte = 0x03 - subHandshakePending byte = 0x04 - subHandshakeTrusted byte = 0x05 - subHandshakeRevoke byte = 0x06 - subHandshakeWait byte = 0x07 -) - -// jsonRPC sends an IPC message, waits for the expected response, and -// unmarshals the JSON payload. Most driver methods follow this pattern. -func (d *Driver) jsonRPC(msg []byte, expectCmd byte, label string) (map[string]interface{}, error) { - resp, err := d.ipc.sendAndWait(msg, expectCmd) - if err != nil { - return nil, fmt.Errorf("%s: %w", label, err) - } - var result map[string]interface{} - if err := json.Unmarshal(resp, &result); err != nil { - return nil, fmt.Errorf("%s unmarshal: %w", label, err) - } - return result, nil -} - -// Driver is the main entry point for the Pilot Protocol SDK. -type Driver struct { - ipc *ipcClient - socketPath string -} - -// Connect creates a new driver connected to the local daemon. -func Connect(socketPath string) (*Driver, error) { - if socketPath == "" { - socketPath = DefaultSocketPath() - } - - ipc, err := newIPCClient(socketPath) - if err != nil { - return nil, err - } - - return &Driver{ipc: ipc, socketPath: socketPath}, nil -} - -// Dial opens a stream connection to a remote address:port. -// addr format: "N:XXXX.YYYY.YYYY:PORT" -func (d *Driver) Dial(addr string) (*Conn, error) { - sa, err := protocol.ParseSocketAddr(addr) - if err != nil { - return nil, fmt.Errorf("parse address: %w", err) - } - - return d.DialAddr(sa.Addr, sa.Port) -} - -// DialAddr opens a stream connection to a remote Addr + port. -func (d *Driver) DialAddr(dst protocol.Addr, port uint16) (*Conn, error) { - msg := make([]byte, 1+protocol.AddrSize+2) - msg[0] = cmdDial - dst.MarshalTo(msg, 1) - binary.BigEndian.PutUint16(msg[1+protocol.AddrSize:], port) - - resp, err := d.ipc.sendAndWait(msg, cmdDialOK) - if err != nil { - return nil, fmt.Errorf("dial: %w", err) - } - - if len(resp) < 4 { - return nil, fmt.Errorf("invalid dial response") - } - - connID := binary.BigEndian.Uint32(resp[0:4]) - recvCh := d.ipc.registerRecvCh(connID) - - return &Conn{ - id: connID, - remoteAddr: protocol.SocketAddr{Addr: dst, Port: port}, - ipc: d.ipc, - recvCh: recvCh, - deadlineCh: make(chan struct{}), - }, nil -} - -// DialAddrTimeout opens a stream connection with a client-side timeout. -// If the daemon does not respond within the timeout, the dial is cancelled. -func (d *Driver) DialAddrTimeout(dst protocol.Addr, port uint16, timeout time.Duration) (*Conn, error) { - msg := make([]byte, 1+protocol.AddrSize+2) - msg[0] = cmdDial - dst.MarshalTo(msg, 1) - binary.BigEndian.PutUint16(msg[1+protocol.AddrSize:], port) - - resp, err := d.ipc.sendAndWaitTimeout(msg, cmdDialOK, timeout) - if err != nil { - return nil, fmt.Errorf("dial: %w", err) - } - - if len(resp) < 4 { - return nil, fmt.Errorf("invalid dial response") - } - - connID := binary.BigEndian.Uint32(resp[0:4]) - recvCh := d.ipc.registerRecvCh(connID) - - return &Conn{ - id: connID, - remoteAddr: protocol.SocketAddr{Addr: dst, Port: port}, - ipc: d.ipc, - recvCh: recvCh, - deadlineCh: make(chan struct{}), - }, nil -} - -// Listen binds a port and returns a Listener that accepts connections. -func (d *Driver) Listen(port uint16) (*Listener, error) { - msg := make([]byte, 3) - msg[0] = cmdBind - binary.BigEndian.PutUint16(msg[1:3], port) - - resp, err := d.ipc.sendAndWait(msg, cmdBindOK) - if err != nil { - return nil, fmt.Errorf("bind: %w", err) - } - - boundPort := binary.BigEndian.Uint16(resp[0:2]) - - // H12 fix: register per-port accept channel - acceptCh := d.ipc.registerAcceptCh(boundPort) - - return &Listener{ - port: boundPort, - ipc: d.ipc, - acceptCh: acceptCh, - done: make(chan struct{}), - }, nil -} - -// SendTo sends an unreliable unicast datagram to the given address:port. -// Broadcast addresses (Node=0xFFFFFFFF) are not accepted on this path; use -// Broadcast, which requires the daemon's admin token. -func (d *Driver) SendTo(dst protocol.Addr, port uint16, data []byte) error { - if dst.IsBroadcast() { - return fmt.Errorf("broadcast address requires admin token: use Driver.Broadcast") - } - msg := make([]byte, 1+protocol.AddrSize+2+len(data)) - msg[0] = cmdSendTo - dst.MarshalTo(msg, 1) - binary.BigEndian.PutUint16(msg[1+protocol.AddrSize:], port) - copy(msg[1+protocol.AddrSize+2:], data) - return d.ipc.send(msg) -} - -// Broadcast fans an unreliable datagram out to every member of a network. -// The admin token must match the daemon's configured Config.AdminToken; an -// empty token or mismatched token is rejected. Permitted on every network -// including network 0 (backbone). Sender membership is not required. -func (d *Driver) Broadcast(netID uint16, port uint16, data []byte, adminToken string) error { - tokenBytes := []byte(adminToken) - msg := make([]byte, 1+2+2+2+len(tokenBytes)+len(data)) - msg[0] = cmdBroadcast - binary.BigEndian.PutUint16(msg[1:3], netID) - binary.BigEndian.PutUint16(msg[3:5], port) - binary.BigEndian.PutUint16(msg[5:7], uint16(len(tokenBytes))) - copy(msg[7:7+len(tokenBytes)], tokenBytes) - copy(msg[7+len(tokenBytes):], data) - if _, err := d.ipc.sendAndWait(msg, cmdBroadcastOK); err != nil { - return err - } - return nil -} - -// RecvFrom receives the next incoming datagram. -func (d *Driver) RecvFrom() (*Datagram, error) { - dg, ok := <-d.ipc.dgCh - if !ok { - return nil, fmt.Errorf("driver closed") - } - return dg, nil -} - -// Info returns the daemon's status information. -func (d *Driver) Info() (map[string]interface{}, error) { - return d.jsonRPC([]byte{cmdInfo}, cmdInfoOK, "info") -} - -// Health returns a lightweight health check from the daemon. -func (d *Driver) Health() (map[string]interface{}, error) { - return d.jsonRPC([]byte{cmdHealth}, cmdHealthOK, "health") -} - -// Handshake sends a trust handshake request to a remote node. -func (d *Driver) Handshake(nodeID uint32, justification string) (map[string]interface{}, error) { - msg := make([]byte, 1+1+4+len(justification)) - msg[0] = cmdHandshake - msg[1] = subHandshakeSend - binary.BigEndian.PutUint32(msg[2:6], nodeID) - copy(msg[6:], justification) - return d.jsonRPC(msg, cmdHandshakeOK, "handshake") -} - -// ApproveHandshake approves a pending trust handshake request. -func (d *Driver) ApproveHandshake(nodeID uint32) (map[string]interface{}, error) { - msg := make([]byte, 6) - msg[0] = cmdHandshake - msg[1] = subHandshakeApprove - binary.BigEndian.PutUint32(msg[2:6], nodeID) - return d.jsonRPC(msg, cmdHandshakeOK, "approve") -} - -// RejectHandshake rejects a pending trust handshake request. -func (d *Driver) RejectHandshake(nodeID uint32, reason string) (map[string]interface{}, error) { - msg := make([]byte, 1+1+4+len(reason)) - msg[0] = cmdHandshake - msg[1] = subHandshakeReject - binary.BigEndian.PutUint32(msg[2:6], nodeID) - copy(msg[6:], reason) - return d.jsonRPC(msg, cmdHandshakeOK, "reject") -} - -// PendingHandshakes returns pending trust handshake requests. -func (d *Driver) PendingHandshakes() (map[string]interface{}, error) { - return d.jsonRPC([]byte{cmdHandshake, subHandshakePending}, cmdHandshakeOK, "pending") -} - -// WaitForTrust blocks (in the daemon) until the peer transitions to trusted -// or the timeout elapses. Single IPC roundtrip — the daemon-side -// HandshakeService.WaitForTrust waits on a per-node channel that is closed -// the moment trust is granted, so wakeup latency is sub-millisecond. -// -// Backward compatibility: an old daemon (no SubHandshakeWait) returns an -// "unknown sub-command" error; callers should treat that as "wait skipped" -// and proceed. -func (d *Driver) WaitForTrust(nodeID uint32, timeoutMs uint32) (map[string]interface{}, error) { - msg := make([]byte, 1+1+4+4) - msg[0] = cmdHandshake - msg[1] = subHandshakeWait - binary.BigEndian.PutUint32(msg[2:6], nodeID) - binary.BigEndian.PutUint32(msg[6:10], timeoutMs) - return d.jsonRPC(msg, cmdHandshakeOK, "wait") -} - -// TrustedPeers returns all trusted peers from the handshake protocol. -func (d *Driver) TrustedPeers() (map[string]interface{}, error) { - return d.jsonRPC([]byte{cmdHandshake, subHandshakeTrusted}, cmdHandshakeOK, "trusted") -} - -// RevokeTrust removes a peer from the trusted set and notifies the registry. -func (d *Driver) RevokeTrust(nodeID uint32) (map[string]interface{}, error) { - msg := make([]byte, 6) - msg[0] = cmdHandshake - msg[1] = subHandshakeRevoke - binary.BigEndian.PutUint32(msg[2:6], nodeID) - return d.jsonRPC(msg, cmdHandshakeOK, "revoke") -} - -// ResolveHostname resolves a hostname to node info via the daemon. -func (d *Driver) ResolveHostname(hostname string) (map[string]interface{}, error) { - msg := make([]byte, 1+len(hostname)) - msg[0] = cmdResolveHostname - copy(msg[1:], hostname) - return d.jsonRPC(msg, cmdResolveHostnameOK, "resolve_hostname") -} - -// SetHostname sets or clears the daemon's hostname via the registry. -func (d *Driver) SetHostname(hostname string) (map[string]interface{}, error) { - msg := make([]byte, 1+len(hostname)) - msg[0] = cmdSetHostname - copy(msg[1:], hostname) - return d.jsonRPC(msg, cmdSetHostnameOK, "set_hostname") -} - -// SetVisibility sets the daemon's visibility on the registry. -func (d *Driver) SetVisibility(public bool) (map[string]interface{}, error) { - msg := make([]byte, 2) - msg[0] = cmdSetVisibility - if public { - msg[1] = 1 - } - return d.jsonRPC(msg, cmdSetVisibilityOK, "set_visibility") -} - -// Deregister removes the daemon from the registry. -func (d *Driver) Deregister() (map[string]interface{}, error) { - return d.jsonRPC([]byte{cmdDeregister}, cmdDeregisterOK, "deregister") -} - -// SetTags sets the capability tags for this daemon's node. -func (d *Driver) SetTags(tags []string) (map[string]interface{}, error) { - data, _ := json.Marshal(tags) - msg := make([]byte, 1+len(data)) - msg[0] = cmdSetTags - copy(msg[1:], data) - return d.jsonRPC(msg, cmdSetTagsOK, "set_tags") -} - -// SetWebhook sets or clears the daemon's webhook URL at runtime. -// An empty URL disables the webhook. -func (d *Driver) SetWebhook(url string) (map[string]interface{}, error) { - msg := make([]byte, 1+len(url)) - msg[0] = cmdSetWebhook - copy(msg[1:], url) - return d.jsonRPC(msg, cmdSetWebhookOK, "set_webhook") -} - -// RotateKey asks the daemon to rotate its Ed25519 identity at the registry. -// The daemon generates a new keypair, signs proof of the current key, calls -// registry.RotateKey, then atomically swaps and persists the new identity. -func (d *Driver) RotateKey() (map[string]interface{}, error) { - return d.jsonRPC([]byte{cmdRotateKey}, cmdRotateKeyOK, "rotate_key") -} - -// Disconnect closes a connection by ID. Used by administrative tools. -// Fire-and-forget: the daemon always responds CmdCloseOK regardless of -// whether the connID exists, so there is no error to propagate. Using -// sendAndWait here would corrupt a concurrent sendAndWait for a different -// command if a server-pushed cmdCloseOK (remote FIN) arrived simultaneously. -func (d *Driver) Disconnect(connID uint32) error { - msg := make([]byte, 5) - msg[0] = cmdClose - binary.BigEndian.PutUint32(msg[1:5], connID) - return d.ipc.send(msg) -} - -// NetworkList returns all networks known to the registry. -func (d *Driver) NetworkList() (map[string]interface{}, error) { - return d.jsonRPC([]byte{cmdNetwork, subNetworkList}, cmdNetworkOK, "network list") -} - -// NetworkJoin joins a network by ID, optionally using a token for token-gated networks. -func (d *Driver) NetworkJoin(networkID uint16, token string) (map[string]interface{}, error) { - msg := make([]byte, 1+1+2+len(token)) - msg[0] = cmdNetwork - msg[1] = subNetworkJoin - binary.BigEndian.PutUint16(msg[2:4], networkID) - copy(msg[4:], token) - return d.jsonRPC(msg, cmdNetworkOK, "network join") -} - -// NetworkLeave leaves a network by ID. -func (d *Driver) NetworkLeave(networkID uint16) (map[string]interface{}, error) { - msg := make([]byte, 4) - msg[0] = cmdNetwork - msg[1] = subNetworkLeave - binary.BigEndian.PutUint16(msg[2:4], networkID) - return d.jsonRPC(msg, cmdNetworkOK, "network leave") -} - -// NetworkMembers lists all members of a network. -func (d *Driver) NetworkMembers(networkID uint16) (map[string]interface{}, error) { - msg := make([]byte, 4) - msg[0] = cmdNetwork - msg[1] = subNetworkMembers - binary.BigEndian.PutUint16(msg[2:4], networkID) - return d.jsonRPC(msg, cmdNetworkOK, "network members") -} - -// NetworkInvite invites a target node to a network (requires admin token on daemon). -func (d *Driver) NetworkInvite(networkID uint16, targetNodeID uint32) (map[string]interface{}, error) { - msg := make([]byte, 8) - msg[0] = cmdNetwork - msg[1] = subNetworkInvite - binary.BigEndian.PutUint16(msg[2:4], networkID) - binary.BigEndian.PutUint32(msg[4:8], targetNodeID) - return d.jsonRPC(msg, cmdNetworkOK, "network invite") -} - -// NetworkPollInvites returns pending network invites for this node. -func (d *Driver) NetworkPollInvites() (map[string]interface{}, error) { - return d.jsonRPC([]byte{cmdNetwork, subNetworkPollInvites}, cmdNetworkOK, "network poll-invites") -} - -// NetworkRespondInvite accepts or rejects a pending network invite. -func (d *Driver) NetworkRespondInvite(networkID uint16, accept bool) (map[string]interface{}, error) { - msg := make([]byte, 5) - msg[0] = cmdNetwork - msg[1] = subNetworkRespondInvite - binary.BigEndian.PutUint16(msg[2:4], networkID) - if accept { - msg[4] = 1 - } - return d.jsonRPC(msg, cmdNetworkOK, "network respond-invite") -} - -// ManagedStatus returns the status of a managed network engine. -func (d *Driver) ManagedStatus(networkID uint16) (map[string]interface{}, error) { - msg := make([]byte, 4) - msg[0] = cmdManaged - msg[1] = subManagedStatus - binary.BigEndian.PutUint16(msg[2:4], networkID) - return d.jsonRPC(msg, cmdManagedOK, "managed status") -} - -// ManagedForceCycle forces a prune/fill cycle in a managed network. -func (d *Driver) ManagedForceCycle(networkID uint16) (map[string]interface{}, error) { - msg := make([]byte, 4) - msg[0] = cmdManaged - msg[1] = subManagedCycle - binary.BigEndian.PutUint16(msg[2:4], networkID) - return d.jsonRPC(msg, cmdManagedOK, "managed cycle") -} - -// ManagedReconcile asks the daemon's policy runner for networkID to -// poll the registry and refresh its peer set — without running a -// policy cycle. Returns {network_id, peers}. -func (d *Driver) ManagedReconcile(networkID uint16) (map[string]interface{}, error) { - msg := make([]byte, 4) - msg[0] = cmdManaged - msg[1] = subManagedReconcile - binary.BigEndian.PutUint16(msg[2:4], networkID) - return d.jsonRPC(msg, cmdManagedOK, "managed reconcile") -} - -// PolicyGet retrieves the active policy for a network from the daemon. -func (d *Driver) PolicyGet(networkID uint16) (map[string]interface{}, error) { - msg := make([]byte, 4) - msg[0] = cmdManaged - msg[1] = subManagedPolicy - msg[2] = 0x00 // get - // Shift: need [cmd][sub][action][netID_hi][netID_lo] - msg = make([]byte, 5) - msg[0] = cmdManaged - msg[1] = subManagedPolicy - msg[2] = 0x00 // get - binary.BigEndian.PutUint16(msg[3:5], networkID) - return d.jsonRPC(msg, cmdManagedOK, "policy get") -} - -// PolicySet sends a policy document to the daemon for immediate application. -func (d *Driver) PolicySet(networkID uint16, policyJSON []byte) (map[string]interface{}, error) { - msg := make([]byte, 5+len(policyJSON)) - msg[0] = cmdManaged - msg[1] = subManagedPolicy - msg[2] = 0x01 // set - binary.BigEndian.PutUint16(msg[3:5], networkID) - copy(msg[5:], policyJSON) - return d.jsonRPC(msg, cmdManagedOK, "policy set") -} - -// MemberTagsGet retrieves admin-assigned member tags for a node in a network. -func (d *Driver) MemberTagsGet(networkID uint16, nodeID uint32) (map[string]interface{}, error) { - msg := make([]byte, 9) - msg[0] = cmdManaged - msg[1] = subManagedMemberTags - msg[2] = 0x00 // get - binary.BigEndian.PutUint16(msg[3:5], networkID) - binary.BigEndian.PutUint32(msg[5:9], nodeID) - return d.jsonRPC(msg, cmdManagedOK, "member-tags get") -} - -// MemberTagsSet sets admin-assigned member tags for a node in a network. -func (d *Driver) MemberTagsSet(networkID uint16, nodeID uint32, tags []string) (map[string]interface{}, error) { - tagsJSON, _ := json.Marshal(tags) - msg := make([]byte, 9+len(tagsJSON)) - msg[0] = cmdManaged - msg[1] = subManagedMemberTags - msg[2] = 0x01 // set - binary.BigEndian.PutUint16(msg[3:5], networkID) - binary.BigEndian.PutUint32(msg[5:9], nodeID) - copy(msg[9:], tagsJSON) - return d.jsonRPC(msg, cmdManagedOK, "member-tags set") -} - -// Close disconnects from the daemon. -func (d *Driver) Close() error { - return d.ipc.close() -} diff --git a/pkg/driver/ipc.go b/pkg/driver/ipc.go deleted file mode 100644 index 5419ddf2..00000000 --- a/pkg/driver/ipc.go +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package driver - -import ( - "encoding/binary" - "fmt" - "net" - "sync" - "time" - - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" -) - -// IPC commands (must match daemon/ipc.go) -const ( - cmdBind byte = 0x01 - cmdBindOK byte = 0x02 - cmdDial byte = 0x03 - cmdDialOK byte = 0x04 - cmdAccept byte = 0x05 - cmdSend byte = 0x06 - cmdRecv byte = 0x07 - cmdClose byte = 0x08 - cmdCloseOK byte = 0x09 - cmdError byte = 0x0A - cmdSendTo byte = 0x0B - cmdRecvFrom byte = 0x0C - cmdInfo byte = 0x0D - cmdInfoOK byte = 0x0E - cmdHandshake byte = 0x0F - cmdHandshakeOK byte = 0x10 - cmdResolveHostname byte = 0x11 - cmdResolveHostnameOK byte = 0x12 - cmdSetHostname byte = 0x13 - cmdSetHostnameOK byte = 0x14 - cmdSetVisibility byte = 0x15 - cmdSetVisibilityOK byte = 0x16 - cmdDeregister byte = 0x17 - cmdDeregisterOK byte = 0x18 - cmdSetTags byte = 0x19 - cmdSetTagsOK byte = 0x1A - cmdSetWebhook byte = 0x1B - cmdSetWebhookOK byte = 0x1C - cmdNetwork byte = 0x1F - cmdNetworkOK byte = 0x20 - cmdHealth byte = 0x21 - cmdHealthOK byte = 0x22 - cmdManaged byte = 0x23 - cmdManagedOK byte = 0x24 - cmdRotateKey byte = 0x25 - cmdRotateKeyOK byte = 0x26 - cmdBroadcast byte = 0x29 - cmdBroadcastOK byte = 0x2A -) - -// Network sub-commands (must match daemon SubNetwork* constants) -const ( - subNetworkList byte = 0x01 - subNetworkJoin byte = 0x02 - subNetworkLeave byte = 0x03 - subNetworkMembers byte = 0x04 - subNetworkInvite byte = 0x05 - subNetworkPollInvites byte = 0x06 - subNetworkRespondInvite byte = 0x07 -) - -// Managed sub-commands (must match daemon SubManaged* constants) -const ( - subManagedStatus byte = 0x02 - subManagedCycle byte = 0x04 - subManagedPolicy byte = 0x05 - subManagedMemberTags byte = 0x06 - subManagedReconcile byte = 0x07 -) - -// ipcEnvelopeHeaderSize matches daemon.IPCEnvelopeHeaderSize: 1 byte cmd. -const ipcEnvelopeHeaderSize = 1 - -// Datagram represents a received unreliable datagram. -type Datagram struct { - SrcAddr protocol.Addr - SrcPort uint16 - DstPort uint16 - Data []byte -} - -// pendingResponse carries the response to a sendAndWait waiter — either -// the cmd-OK payload (ok=true) or the error text from cmdError. -type pendingResponse struct { - cmd byte - payload []byte -} - -type ipcClient struct { - conn net.Conn - - // writeMu serializes frame writes so concurrent goroutines don't - // interleave bytes on the wire. Held only for the write itself. - writeMu sync.Mutex - - // waitSem is a channel-based semaphore (capacity 1) that ensures at - // most one request/reply pair is in-flight at a time. Using a channel - // instead of sync.Mutex lets goroutines waiting for the semaphore be - // woken on doneCh close, preventing a deadlock when the daemon closes - // while many goroutines are queued behind a slow sendAndWait. - waitSem chan struct{} // capacity 1 - pending chan *pendingResponse // capacity 16; buffers reply frames from readLoop - - recvMu sync.Mutex - recvChs map[uint32]chan []byte // conn_id → data channel - pendRecv map[uint32][][]byte // conn_id → buffered data before recvCh registered - pendAccept map[uint16][][]byte // port → buffered cmdAccept payloads before acceptCh registered (post-#99 race fix) - - acceptMu sync.Mutex - acceptChs map[uint16]chan []byte // H12 fix: per-port accept channels - - dgCh chan *Datagram // incoming datagrams - doneCh chan struct{} // closed when readLoop exits - - closeOnce sync.Once -} - -func newIPCClient(socketPath string) (*ipcClient, error) { - conn, err := net.Dial("unix", socketPath) - if err != nil { - return nil, fmt.Errorf("connect to daemon: %w", err) - } - - c := &ipcClient{ - conn: conn, - waitSem: make(chan struct{}, 1), - pending: make(chan *pendingResponse, 16), - recvChs: make(map[uint32]chan []byte), - pendRecv: make(map[uint32][][]byte), - pendAccept: make(map[uint16][][]byte), - acceptChs: make(map[uint16]chan []byte), - dgCh: make(chan *Datagram, 256), - doneCh: make(chan struct{}), - } - - go c.readLoop() - return c, nil -} - -func (c *ipcClient) close() error { - var err error - c.closeOnce.Do(func() { - err = c.conn.Close() - }) - return err -} - -// readLoop demultiplexes incoming envelopes. Wire format: -// -// [uint32-len][uint8-cmd][payload...] -// -// Server-pushed frames (cmdRecv, cmdCloseOK, cmdRecvFrom, cmdAccept) are -// routed by cmd to their per-connection channels. cmdCloseOK is always -// a server-push (remote FIN); Driver.Disconnect uses send() not -// sendAndWait() so it never waits for cmdCloseOK in pending. -// Known response cmds are forwarded to c.pending for sendAndWait. -// Unknown cmds are silently dropped — they never reach pending, so -// sendAndWaitTimeout can use a single read without a discard loop. -func (c *ipcClient) readLoop() { - defer c.cleanup() - for { - msg, err := ipcutil.Read(c.conn) - if err != nil { - return - } - if len(msg) < ipcEnvelopeHeaderSize { - continue - } - - cmd := msg[0] - payload := msg[ipcEnvelopeHeaderSize:] - - switch cmd { - case cmdRecv, cmdRecvFrom, cmdAccept, cmdCloseOK: - // Server-pushed frames: route to per-connection channels. - c.dispatchPush(cmd, payload) - case cmdBindOK, cmdDialOK, cmdError, cmdInfoOK, cmdHandshakeOK, - cmdResolveHostnameOK, cmdSetHostnameOK, cmdSetVisibilityOK, - cmdDeregisterOK, cmdSetTagsOK, cmdSetWebhookOK, cmdNetworkOK, - cmdHealthOK, cmdManagedOK, cmdRotateKeyOK, cmdBroadcastOK: - // Known response cmds: route to pending for the in-flight sendAndWait. - select { - case c.pending <- &pendingResponse{cmd: cmd, payload: append([]byte(nil), payload...)}: - default: - } - // default: unknown cmd — silently drop (version mismatch, test injection, etc.) - } - } -} - -// dispatchPush routes server-pushed (reqID==0) frames to their per-cmd -// destination. CmdRecv and CmdCloseOK route by conn ID; CmdAccept by -// listener port; CmdRecvFrom into the global datagram channel. -func (c *ipcClient) dispatchPush(cmd byte, payload []byte) { - switch cmd { - case cmdRecv: - if len(payload) >= 4 { - connID := binary.BigEndian.Uint32(payload[0:4]) - data := append([]byte(nil), payload[4:]...) - c.recvMu.Lock() - ch, ok := c.recvChs[connID] - if ok { - c.recvMu.Unlock() - // Drop the recvMu BEFORE blocking on the channel send - // so Conn.Close() / unregisterRecvCh can take the lock - // while readLoop is parked. Without this, a slow Conn - // holds recvMu indirectly (through readLoop) and other - // IPC operations stall. - ch <- data - } else { - c.pendRecv[connID] = append(c.pendRecv[connID], data) - c.recvMu.Unlock() - } - } - case cmdCloseOK: - // Server-pushed CmdCloseOK fires from recvPusher when the remote - // FINs. Close the per-conn recv channel so blocked reads see EOF. - if len(payload) >= 4 { - connID := binary.BigEndian.Uint32(payload[0:4]) - c.recvMu.Lock() - ch, ok := c.recvChs[connID] - if ok { - delete(c.recvChs, connID) - close(ch) - } - c.recvMu.Unlock() - } - case cmdRecvFrom: - if len(payload) >= protocol.AddrSize+4 { - srcAddr := protocol.UnmarshalAddr(payload[0:protocol.AddrSize]) - srcPort := binary.BigEndian.Uint16(payload[protocol.AddrSize:]) - dstPort := binary.BigEndian.Uint16(payload[protocol.AddrSize+2:]) - data := append([]byte(nil), payload[protocol.AddrSize+4:]...) - select { - case c.dgCh <- &Datagram{SrcAddr: srcAddr, SrcPort: srcPort, DstPort: dstPort, Data: data}: - default: - } - } - case cmdAccept: - if len(payload) >= 2 { - port := binary.BigEndian.Uint16(payload[0:2]) - rest := append([]byte(nil), payload[2:]...) - c.acceptMu.Lock() - ch, ok := c.acceptChs[port] - if ok { - c.acceptMu.Unlock() - select { - case ch <- rest: - default: - } - } else { - // Buffer until registerAcceptCh is called. The race - // (post-#99): with concurrent daemon dispatch, the - // daemon can push cmdAccept BEFORE the driver registers - // acceptChs[port] — Listen() registers AFTER the - // cmdBind reply, but a peer's dial can race the bind - // reply through different worker goroutines on the - // daemon side. Same pattern as pendRecv for cmdRecv. - c.pendAccept[port] = append(c.pendAccept[port], rest) - c.acceptMu.Unlock() - } - } - default: - // Unknown unsolicited cmd — drop. The daemon should never send - // reqID=0 with a cmd outside this set; if a test or future - // addition does, dropping is the safe default. - } -} - -// cleanup closes channels when readLoop exits (daemon disconnect). -func (c *ipcClient) cleanup() { - close(c.doneCh) - - // Drain all buffered responses. - for { - select { - case <-c.pending: - default: - goto drained - } - } -drained: - - // Close all receive channels - c.recvMu.Lock() - for id, ch := range c.recvChs { - close(ch) - delete(c.recvChs, id) - } - c.recvMu.Unlock() - - // Close all accept channels (H12 fix) - c.acceptMu.Lock() - for port, ch := range c.acceptChs { - close(ch) - delete(c.acceptChs, port) - } - c.acceptMu.Unlock() -} - -// writeFrame builds a `[cmd][body...]` envelope and writes it under -// writeMu so frames don't interleave on the wire. -func (c *ipcClient) writeFrame(cmd byte, body []byte) error { - buf := make([]byte, ipcEnvelopeHeaderSize+len(body)) - buf[0] = cmd - copy(buf[1:], body) - c.writeMu.Lock() - defer c.writeMu.Unlock() - return ipcutil.Write(c.conn, buf) -} - -// send is a fire-and-forget write — used for cmdSend/cmdSendTo where -// the daemon does not reply. Acquires only writeMu (not waitMu), so -// concurrent fire-and-forget sends are never blocked behind a reply wait. -func (c *ipcClient) send(data []byte) error { - if len(data) < 1 { - return fmt.Errorf("ipc: empty message") - } - return c.writeFrame(data[0], data[1:]) -} - -// sendAndWait sends a request and waits for the reply. -func (c *ipcClient) sendAndWait(data []byte, expectCmd byte) ([]byte, error) { - return c.sendAndWaitTimeout(data, expectCmd, 0) -} - -// sendAndWaitTimeout serialises at most one request/reply pair at a time -// via waitSem. timeout=0 means wait forever. The timer is started BEFORE -// acquiring the semaphore so the timeout applies to queue wait + reply -// wait combined — without this, goroutines queued behind the semaphore -// can't time out and pile up indefinitely under high concurrency. -func (c *ipcClient) sendAndWaitTimeout(data []byte, expectCmd byte, timeout time.Duration) ([]byte, error) { - if len(data) < 1 { - return nil, fmt.Errorf("ipc: empty request") - } - - // Start the timer before acquiring the semaphore so queued goroutines - // can bail out instead of waiting forever. - var timer <-chan time.Time - if timeout > 0 { - t := time.NewTimer(timeout) - defer t.Stop() - timer = t.C - } - - // Acquire the serialisation semaphore. Channel-based (not sync.Mutex) - // so goroutines blocked here are woken by doneCh or timer. - select { - case c.waitSem <- struct{}{}: - case <-c.doneCh: - return nil, fmt.Errorf("daemon disconnected") - case <-timer: - return nil, fmt.Errorf("dial timeout") - } - defer func() { <-c.waitSem }() - - // Drain all stale replies buffered before this request was sent. - for { - select { - case <-c.pending: - default: - goto drained - } - } -drained: - - if err := c.writeFrame(data[0], data[1:]); err != nil { - return nil, err - } - - // Unknown cmds are dropped in readLoop, so the first frame in pending - // is always either the expected response or cmdError. - select { - case resp := <-c.pending: - if resp.cmd == cmdError { - if len(resp.payload) >= 2 { - return nil, fmt.Errorf("daemon: %s", string(resp.payload[2:])) - } - return nil, fmt.Errorf("daemon error") - } - if resp.cmd != expectCmd { - return nil, fmt.Errorf("ipc: unexpected reply 0x%02X (want 0x%02X)", resp.cmd, expectCmd) - } - return resp.payload, nil - case <-c.doneCh: - return nil, fmt.Errorf("daemon disconnected") - case <-timer: - return nil, fmt.Errorf("dial timeout") - } -} - -// H12 fix: per-port accept channel management. -// Drains any cmdAccept frames buffered in pendAccept (the post-#99 -// race window between cmdBind reply and acceptChs registration). -func (c *ipcClient) registerAcceptCh(port uint16) chan []byte { - ch := make(chan []byte, 64) - c.acceptMu.Lock() - c.acceptChs[port] = ch - pending := c.pendAccept[port] - delete(c.pendAccept, port) - c.acceptMu.Unlock() - for _, data := range pending { - select { - case ch <- data: - default: - } - } - return ch -} - -func (c *ipcClient) registerRecvCh(connID uint32) chan []byte { - ch := make(chan []byte, 256) - c.recvMu.Lock() - c.recvChs[connID] = ch - // Drain any data that arrived before registration. Hold recvMu - // across the drain so a concurrent dispatchPush(cmdCloseOK) for the - // same connID can't race with these sends — without this guard, the - // FIN handler at dispatchPush:250 closes the channel mid-drain and - // chansend1 panics on a closed channel (issue #105 §4.8 race). - // The drain is bounded by len(pendRecv[connID]) which is small — - // data only buffers in pendRecv during the brief window between - // the daemon dispatching cmdRecv and the driver's Accept calling - // registerRecvCh, and never exceeds a single slow-path frame batch. - pending := c.pendRecv[connID] - delete(c.pendRecv, connID) - for _, data := range pending { - ch <- data - } - c.recvMu.Unlock() - return ch -} - -func (c *ipcClient) unregisterRecvCh(connID uint32) { - c.recvMu.Lock() - defer c.recvMu.Unlock() - delete(c.recvChs, connID) -} diff --git a/pkg/driver/listener.go b/pkg/driver/listener.go deleted file mode 100644 index 73fb5573..00000000 --- a/pkg/driver/listener.go +++ /dev/null @@ -1,79 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package driver - -import ( - "encoding/binary" - "fmt" - "net" - "sync" - - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" -) - -// Listener implements net.Listener over a Pilot Protocol port. -type Listener struct { - port uint16 - ipc *ipcClient - acceptCh chan []byte // H12 fix: per-port accept channel - mu sync.Mutex - closed bool - done chan struct{} // closed on Close() to unblock Accept (H13 fix) -} - -func (l *Listener) Accept() (net.Conn, error) { - l.mu.Lock() - if l.closed { - l.mu.Unlock() - return nil, fmt.Errorf("listener closed") - } - l.mu.Unlock() - - // H12 fix: wait on per-port accept channel - var payload []byte - var ok bool - select { - case payload, ok = <-l.acceptCh: - if !ok { - return nil, fmt.Errorf("listener closed") - } - case <-l.done: - return nil, fmt.Errorf("listener closed") - } - - // Parse: [4 bytes conn_id][6 bytes remote addr][2 bytes remote port] - if len(payload) < 4+protocol.AddrSize+2 { - return nil, fmt.Errorf("invalid accept payload") - } - - connID := binary.BigEndian.Uint32(payload[0:4]) - remoteAddr := protocol.UnmarshalAddr(payload[4 : 4+protocol.AddrSize]) - remotePort := binary.BigEndian.Uint16(payload[4+protocol.AddrSize:]) - - recvCh := l.ipc.registerRecvCh(connID) - - conn := &Conn{ - id: connID, - localAddr: protocol.SocketAddr{Port: l.port}, - remoteAddr: protocol.SocketAddr{Addr: remoteAddr, Port: remotePort}, - ipc: l.ipc, - recvCh: recvCh, - deadlineCh: make(chan struct{}), - } - - return conn, nil -} - -func (l *Listener) Close() error { - l.mu.Lock() - if !l.closed { - l.closed = true - close(l.done) // unblock Accept() (H13 fix) - } - l.mu.Unlock() - return nil -} - -func (l *Listener) Addr() net.Addr { - return pilotAddr(protocol.SocketAddr{Port: l.port}) -} diff --git a/pkg/driver/zz_conn_test.go b/pkg/driver/zz_conn_test.go deleted file mode 100644 index 4bcc66c2..00000000 --- a/pkg/driver/zz_conn_test.go +++ /dev/null @@ -1,299 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package driver - -import ( - "errors" - "io" - "net" - "os" - "testing" - "time" - - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" -) - -// --------------------------------------------------------------------------- -// pilotAddr — net.Addr implementation -// --------------------------------------------------------------------------- - -func TestPilotAddrNetwork(t *testing.T) { - t.Parallel() - a := pilotAddr(protocol.SocketAddr{Port: 80}) - if got := a.Network(); got != "pilot" { - t.Errorf("Network() = %q, want %q", got, "pilot") - } -} - -func TestPilotAddrString(t *testing.T) { - t.Parallel() - addr, _ := protocol.ParseAddr("1:0001.0002.0003") - a := pilotAddr(protocol.SocketAddr{Addr: addr, Port: 7}) - got := a.String() - want := protocol.SocketAddr{Addr: addr, Port: 7}.String() - if got != want { - t.Errorf("String() = %q, want %q", got, want) - } -} - -// --------------------------------------------------------------------------- -// Conn — read leftover and deadline behaviour exercise the in-memory -// branches that don't require live IPC. -// --------------------------------------------------------------------------- - -func TestConnReadDrainsLeftover(t *testing.T) { - t.Parallel() - c := &Conn{ - recvBuf: []byte("hello world"), - recvCh: make(chan []byte), - deadlineCh: make(chan struct{}), - } - got := make([]byte, 5) - n, err := c.Read(got) - if err != nil { - t.Fatal(err) - } - if n != 5 || string(got) != "hello" { - t.Fatalf("first read: n=%d got=%q", n, got) - } - // Second read drains the rest of the leftover (no IPC needed). - got2 := make([]byte, 6) - n2, err := c.Read(got2) - if err != nil { - t.Fatal(err) - } - if n2 != 6 || string(got2) != " world" { - t.Fatalf("second read: n=%d got=%q", n2, got2) - } -} - -func TestConnReadDeadlineAlreadyPassed(t *testing.T) { - t.Parallel() - c := &Conn{ - recvCh: make(chan []byte), - deadlineCh: make(chan struct{}), - readDeadline: time.Now().Add(-time.Second), // already in past - } - _, err := c.Read(make([]byte, 1)) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("got %v, want ErrDeadlineExceeded", err) - } -} - -func TestConnReadEOFOnClosedRecvCh(t *testing.T) { - t.Parallel() - ch := make(chan []byte) - close(ch) - c := &Conn{ - recvCh: ch, - deadlineCh: make(chan struct{}), - } - _, err := c.Read(make([]byte, 1)) - if !errors.Is(err, io.EOF) { - t.Errorf("got %v, want io.EOF", err) - } -} - -func TestConnReadDelivers(t *testing.T) { - t.Parallel() - ch := make(chan []byte, 1) - ch <- []byte("xy") - c := &Conn{ - recvCh: ch, - deadlineCh: make(chan struct{}), - } - buf := make([]byte, 2) - n, err := c.Read(buf) - if err != nil || n != 2 || string(buf) != "xy" { - t.Fatalf("got n=%d err=%v buf=%q", n, err, buf) - } -} - -func TestConnReadStoresLeftoverWhenBufferTooSmall(t *testing.T) { - t.Parallel() - ch := make(chan []byte, 1) - ch <- []byte("12345") - c := &Conn{ - recvCh: ch, - deadlineCh: make(chan struct{}), - } - buf := make([]byte, 2) - n, err := c.Read(buf) - if err != nil || n != 2 || string(buf) != "12" { - t.Fatalf("first read got n=%d err=%v buf=%q", n, err, buf) - } - // Remaining 3 bytes should be in recvBuf - rest := make([]byte, 3) - n2, err := c.Read(rest) - if err != nil || n2 != 3 || string(rest) != "345" { - t.Fatalf("leftover read got n=%d err=%v buf=%q", n2, err, rest) - } -} - -func TestConnReadTimerExpires(t *testing.T) { - t.Parallel() - c := &Conn{ - recvCh: make(chan []byte), - deadlineCh: make(chan struct{}), - readDeadline: time.Now().Add(20 * time.Millisecond), - } - start := time.Now() - _, err := c.Read(make([]byte, 1)) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("got %v, want ErrDeadlineExceeded", err) - } - if elapsed := time.Since(start); elapsed < 20*time.Millisecond { - t.Errorf("returned too early: %v", elapsed) - } -} - -func TestSetReadDeadlineUnblocksReader(t *testing.T) { - t.Parallel() - c := &Conn{ - recvCh: make(chan []byte), - deadlineCh: make(chan struct{}), - } - done := make(chan error, 1) - go func() { - _, err := c.Read(make([]byte, 1)) - done <- err - }() - // Give Read a moment to enter the select. - time.Sleep(10 * time.Millisecond) - c.SetReadDeadline(time.Now().Add(time.Hour)) // closes the old deadlineCh - select { - case err := <-done: - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("got %v, want ErrDeadlineExceeded", err) - } - case <-time.After(time.Second): - t.Fatal("Read did not unblock after SetReadDeadline") - } -} - -func TestSetDeadlineDelegatesToRead(t *testing.T) { - t.Parallel() - c := &Conn{ - recvCh: make(chan []byte), - deadlineCh: make(chan struct{}), - } - dl := time.Now().Add(time.Hour) - if err := c.SetDeadline(dl); err != nil { - t.Fatal(err) - } - if !c.readDeadline.Equal(dl) { - t.Errorf("readDeadline = %v, want %v", c.readDeadline, dl) - } -} - -func TestSetWriteDeadlineNoop(t *testing.T) { - t.Parallel() - c := &Conn{} - if err := c.SetWriteDeadline(time.Now()); err != nil { - t.Errorf("expected nil, got %v", err) - } -} - -func TestConnAddrs(t *testing.T) { - t.Parallel() - addr, _ := protocol.ParseAddr("1:0001.0002.0003") - c := &Conn{ - localAddr: protocol.SocketAddr{Port: 80}, - remoteAddr: protocol.SocketAddr{Addr: addr, Port: 7}, - } - if c.LocalAddr().Network() != "pilot" { - t.Errorf("LocalAddr().Network() unexpected") - } - if c.RemoteAddr().Network() != "pilot" { - t.Errorf("RemoteAddr().Network() unexpected") - } -} - -// --------------------------------------------------------------------------- -// Listener — Accept payload parsing and Close behavior -// --------------------------------------------------------------------------- - -func TestListenerCloseUnblocksAccept(t *testing.T) { - t.Parallel() - l := &Listener{ - port: 80, - acceptCh: make(chan []byte), - done: make(chan struct{}), - } - type r struct{ err error } - ch := make(chan r, 1) - go func() { - _, err := l.Accept() - ch <- r{err} - }() - time.Sleep(10 * time.Millisecond) - if err := l.Close(); err != nil { - t.Fatal(err) - } - select { - case got := <-ch: - if got.err == nil { - t.Fatal("expected error after close") - } - case <-time.After(time.Second): - t.Fatal("Accept did not unblock after Close") - } -} - -func TestListenerAcceptOnAlreadyClosed(t *testing.T) { - t.Parallel() - l := &Listener{ - port: 80, - acceptCh: make(chan []byte), - done: make(chan struct{}), - } - if err := l.Close(); err != nil { - t.Fatal(err) - } - _, err := l.Accept() - if err == nil { - t.Fatal("expected closed error") - } -} - -func TestListenerCloseIdempotent(t *testing.T) { - t.Parallel() - l := &Listener{ - port: 80, - acceptCh: make(chan []byte), - done: make(chan struct{}), - } - if err := l.Close(); err != nil { - t.Fatal(err) - } - // Second Close must not panic on closed channel - if err := l.Close(); err != nil { - t.Errorf("second Close: %v", err) - } -} - -func TestListenerAddr(t *testing.T) { - t.Parallel() - l := &Listener{port: 8080} - a := l.Addr() - if a.Network() != "pilot" { - t.Errorf("Network() = %q", a.Network()) - } -} - -func TestListenerAcceptInvalidPayload(t *testing.T) { - t.Parallel() - l := &Listener{ - port: 80, - acceptCh: make(chan []byte, 1), - done: make(chan struct{}), - } - l.acceptCh <- []byte{0x01, 0x02} // way too short - _, err := l.Accept() - if err == nil { - t.Fatal("expected invalid-payload error") - } -} - -// satisfy unused import detector if SDK isn't otherwise used here -var _ net.Listener = (*Listener)(nil) diff --git a/pkg/driver/zz_conn_write_test.go b/pkg/driver/zz_conn_write_test.go deleted file mode 100644 index 0e77ceea..00000000 --- a/pkg/driver/zz_conn_write_test.go +++ /dev/null @@ -1,167 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package driver - -import ( - "encoding/binary" - "io" - "net" - "sync" - "testing" - "time" - - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" -) - -// TestConnWriteChunksLargePayload verifies that Conn.Write splits payloads -// larger than the IPC message cap into multiple cmdSend messages so the -// daemon side never rejects oversized frames. -func TestConnWriteChunksLargePayload(t *testing.T) { - t.Parallel() - clientSide, serverSide := net.Pipe() - defer clientSide.Close() - defer serverSide.Close() - - ipc := &ipcClient{ - conn: clientSide, - waitSem: make(chan struct{}, 1), - pending: make(chan *pendingResponse, 16), - recvChs: make(map[uint32]chan []byte), - pendRecv: make(map[uint32][][]byte), - acceptChs: make(map[uint16]chan []byte), - dgCh: make(chan *Datagram, 1), - doneCh: make(chan struct{}), - } - - const connID uint32 = 42 - c := &Conn{id: connID, ipc: ipc, deadlineCh: make(chan struct{})} - - const payloadSize = 5 * 1024 * 1024 // 5 MB - payload := make([]byte, payloadSize) - for i := range payload { - payload[i] = byte(i) - } - - // Wire format (issue #99): [cmd(1)][reqID(8)][connID(4)][data...]. - // Each cmdSend frame carries 13 bytes of header before the payload. - const sendHdr = ipcEnvelopeHeaderSize + 4 - - var got []byte - var chunks int - var readErr error - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - _ = serverSide.SetReadDeadline(time.Now().Add(5 * time.Second)) - for len(got) < payloadSize { - msg, err := ipcutil.Read(serverSide) - if err != nil { - if err != io.EOF { - readErr = err - } - return - } - if len(msg) < sendHdr { - readErr = io.ErrShortBuffer - return - } - if msg[0] != cmdSend { - readErr = io.ErrUnexpectedEOF - return - } - gotID := binary.BigEndian.Uint32(msg[ipcEnvelopeHeaderSize : ipcEnvelopeHeaderSize+4]) - if gotID != connID { - readErr = io.ErrUnexpectedEOF - return - } - if len(msg) > ipcutil.MaxMessageSize { - readErr = io.ErrShortBuffer - return - } - chunks++ - got = append(got, msg[sendHdr:]...) - } - }() - - n, err := c.Write(payload) - if err != nil { - t.Fatalf("Write returned err: %v", err) - } - if n != payloadSize { - t.Fatalf("Write returned n=%d, want %d", n, payloadSize) - } - - // Close to unblock reader if it got everything already. - _ = clientSide.Close() - wg.Wait() - - if readErr != nil { - t.Fatalf("reader err: %v", readErr) - } - if len(got) != payloadSize { - t.Fatalf("reader got %d bytes, want %d", len(got), payloadSize) - } - if chunks < 2 { - t.Fatalf("expected >=2 chunks for 5MB payload, got %d", chunks) - } - for i, b := range got { - if b != byte(i) { - t.Fatalf("byte %d: got %d, want %d", i, b, byte(i)) - } - } -} - -// TestConnWriteSinglePayloadNotSplit verifies that payloads that fit in one -// IPC message are still sent as a single cmdSend message. -func TestConnWriteSinglePayloadNotSplit(t *testing.T) { - t.Parallel() - clientSide, serverSide := net.Pipe() - defer clientSide.Close() - defer serverSide.Close() - - ipc := &ipcClient{ - conn: clientSide, - waitSem: make(chan struct{}, 1), - pending: make(chan *pendingResponse, 16), - recvChs: make(map[uint32]chan []byte), - pendRecv: make(map[uint32][][]byte), - acceptChs: make(map[uint16]chan []byte), - dgCh: make(chan *Datagram, 1), - doneCh: make(chan struct{}), - } - - const connID uint32 = 7 - c := &Conn{id: connID, ipc: ipc, deadlineCh: make(chan struct{})} - - payload := []byte("hello world") - - // Wire format: [cmd(1)][connID(4)][data...] - const sendHdr = ipcEnvelopeHeaderSize + 4 - - var got []byte - var chunks int - done := make(chan struct{}) - go func() { - defer close(done) - _ = serverSide.SetReadDeadline(time.Now().Add(2 * time.Second)) - msg, err := ipcutil.Read(serverSide) - if err != nil { - return - } - chunks++ - got = append(got, msg[sendHdr:]...) - }() - - if _, err := c.Write(payload); err != nil { - t.Fatalf("Write err: %v", err) - } - <-done - - if chunks != 1 { - t.Fatalf("expected 1 chunk, got %d", chunks) - } - if string(got) != string(payload) { - t.Fatalf("got %q, want %q", got, payload) - } -} diff --git a/pkg/driver/zz_driver_simple_ops_test.go b/pkg/driver/zz_driver_simple_ops_test.go deleted file mode 100644 index d2e82b58..00000000 --- a/pkg/driver/zz_driver_simple_ops_test.go +++ /dev/null @@ -1,134 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package driver - -import ( - "testing" -) - -// TestDriverClose covers the trivial Close() forwarder. -func TestDriverClose(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("Connect: %v", err) - } - if err := drv.Close(); err != nil { - t.Errorf("Close: %v", err) - } -} - -// TestDriverBroadcast covers the happy-path Broadcast (network + port + -// admin token + data → cmdBroadcastOK). -func TestDriverBroadcast(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - - d.onCmd(cmdBroadcast, func(frame []byte) [][]byte { - return [][]byte{{cmdBroadcastOK}} - }) - - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("Connect: %v", err) - } - defer drv.Close() - - if err := drv.Broadcast(1, 8080, []byte("hello"), "admin-token"); err != nil { - t.Fatalf("Broadcast: %v", err) - } -} - -// TestConnClose covers Conn.Close (cmdClose fire-and-forget) and its -// idempotency — second Close is a no-op. -func TestConnClose(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - - // Dial to get a Conn back. - d.onCmd(cmdDial, func(frame []byte) [][]byte { - resp := make([]byte, 5) - resp[0] = cmdDialOK - resp[1] = 0x00 - resp[2] = 0x00 - resp[3] = 0x00 - resp[4] = 0x42 - return [][]byte{resp} - }) - - drv, _ := Connect(d.path) - defer drv.Close() - - conn, err := drv.Dial("0:0000.0000.0001:80") - if err != nil { - t.Fatalf("Dial: %v", err) - } - - if err := conn.Close(); err != nil { - t.Errorf("Close: %v", err) - } - // Second Close is idempotent. - if err := conn.Close(); err != nil { - t.Errorf("second Close: %v", err) - } -} - -// TestDriverWaitForTrust covers the handshake-wait JSON-RPC. -func TestDriverWaitForTrust(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - - d.onCmd(cmdHandshake, func(frame []byte) [][]byte { - // Verify the sub-command byte (0x07 = subHandshakeWait). - if len(frame) < 2 || frame[1] != subHandshakeWait { - return [][]byte{{cmdError, 'b', 'a', 'd'}} - } - body := []byte(`{"trusted":true}`) - return [][]byte{append([]byte{cmdHandshakeOK}, body...)} - }) - - drv, _ := Connect(d.path) - defer drv.Close() - - result, err := drv.WaitForTrust(0xCAFE, 5000) - if err != nil { - t.Fatalf("WaitForTrust: %v", err) - } - if result == nil { - t.Errorf("result is nil") - } -} - -// TestDriverRotateKey covers RotateKey's JSON-RPC roundtrip. -func TestDriverRotateKey(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - - d.onCmd(cmdRotateKey, func(frame []byte) [][]byte { - // jsonRPC expects [cmdRotateKeyOK][JSON body] - body := []byte(`{"old_node_id":1,"new_node_id":2}`) - resp := append([]byte{cmdRotateKeyOK}, body...) - return [][]byte{resp} - }) - - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("Connect: %v", err) - } - defer drv.Close() - - result, err := drv.RotateKey() - if err != nil { - t.Fatalf("RotateKey: %v", err) - } - if result == nil { - t.Errorf("RotateKey result is nil") - } -} diff --git a/pkg/driver/zz_driver_test.go b/pkg/driver/zz_driver_test.go deleted file mode 100644 index 7d2070f9..00000000 --- a/pkg/driver/zz_driver_test.go +++ /dev/null @@ -1,739 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package driver - -import ( - "crypto/rand" - "encoding/binary" - "encoding/hex" - "net" - "os" - "path/filepath" - "strings" - "sync" - "testing" - "time" - - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" -) - -// shortSocketPath returns a /tmp path short enough for macOS unix socket -// length limit (~104 chars). t.TempDir() paths exceed this on darwin. -func shortSocketPath(t *testing.T) string { - t.Helper() - var b [6]byte - if _, err := rand.Read(b[:]); err != nil { - t.Fatal(err) - } - p := filepath.Join("/tmp", "ps-"+hex.EncodeToString(b[:])+".sock") - t.Cleanup(func() { _ = os.Remove(p) }) - return p -} - -// fakeDaemon is a minimal test harness that simulates the Pilot daemon's -// IPC wire protocol. It listens on a unix socket, records incoming frames, -// and replies with configured responses. Sufficient for verifying each -// driver.* method's request encoding and response decoding end-to-end. -type fakeDaemon struct { - t *testing.T - ln net.Listener - path string - conn net.Conn - connSet chan struct{} // closed once conn is stored in acceptLoop - mu sync.Mutex - received [][]byte // all frames received - handlers map[byte]func(frame []byte) [][]byte -} - -func newFakeDaemon(t *testing.T) *fakeDaemon { - t.Helper() - path := shortSocketPath(t) - ln, err := net.Listen("unix", path) - if err != nil { - t.Fatalf("listen unix: %v", err) - } - d := &fakeDaemon{ - t: t, - ln: ln, - path: path, - connSet: make(chan struct{}), - handlers: make(map[byte]func(frame []byte) [][]byte), - } - go d.acceptLoop() - return d -} - -func (d *fakeDaemon) acceptLoop() { - conn, err := d.ln.Accept() - if err != nil { - return - } - d.mu.Lock() - d.conn = conn - d.mu.Unlock() - close(d.connSet) // signal that conn is stored and ready to be closed - - // Wire format: [cmd(1)][payload...] — matches driver.ipcEnvelopeHeaderSize. - for { - frame, err := ipcutil.Read(conn) - if err != nil { - return - } - d.mu.Lock() - var resp [][]byte - if len(frame) >= 1 { - cmd := frame[0] - d.received = append(d.received, frame) - if h, ok := d.handlers[cmd]; ok { - resp = h(frame) - } - } - d.mu.Unlock() - for _, r := range resp { - _ = ipcutil.Write(conn, r) - } - } -} - -func (d *fakeDaemon) onCmd(cmd byte, respond func(frame []byte) [][]byte) { - d.mu.Lock() - defer d.mu.Unlock() - d.handlers[cmd] = respond -} - -func (d *fakeDaemon) lastFrame() []byte { - d.mu.Lock() - defer d.mu.Unlock() - if len(d.received) == 0 { - return nil - } - return d.received[len(d.received)-1] -} - -func (d *fakeDaemon) allFrames() [][]byte { - d.mu.Lock() - defer d.mu.Unlock() - out := make([][]byte, len(d.received)) - copy(out, d.received) - return out -} - -func (d *fakeDaemon) closeConn() { - d.mu.Lock() - c := d.conn - d.mu.Unlock() - if c != nil { - _ = c.Close() - } -} - -func (d *fakeDaemon) close() { - d.ln.Close() - // Wait for acceptLoop to store d.conn before closing it. - // Without this, close() races with acceptLoop: d.conn may still be - // nil when closeConn() runs, leaving the accepted socket open and - // blocking the driver's readLoop indefinitely. - select { - case <-d.connSet: - case <-time.After(100 * time.Millisecond): - } - d.closeConn() -} - -// waitFor polls until cond is true or deadline is reached. -func waitFor(t *testing.T, max time.Duration, cond func() bool, what string) { - t.Helper() - deadline := time.Now().Add(max) - for time.Now().Before(deadline) { - if cond() { - return - } - time.Sleep(5 * time.Millisecond) - } - t.Fatalf("timeout waiting for %s", what) -} - -// jsonOK returns a [cmd][json-body] frame. -func jsonOK(cmd byte, body string) []byte { - out := make([]byte, 1+len(body)) - out[0] = cmd - copy(out[1:], body) - return out -} - -// ---------- Connect / Close ---------- - -func TestConnectNonExistentSocketReturnsError(t *testing.T) { - t.Parallel() - _, err := Connect("/tmp/definitely-not-a-real-pilot-socket-xxx.sock") - if err == nil { - t.Fatal("expected error") - } -} - -func TestConnectEmptySocketFallsBackToDefault(t *testing.T) { - t.Parallel() - // DefaultSocketPath is /tmp/pilot.sock — almost certainly not present - // in a test env. We just assert the fall-through path is taken and - // returns an error (no panic on empty input). - _, err := Connect("") - if err == nil { - t.Log("Connect(\"\") succeeded — a daemon is running on default path; not an error") - return - } -} - -func TestConnectAndCloseSuccess(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("Connect: %v", err) - } - if drv.socketPath != d.path { - t.Errorf("socketPath = %q, want %q", drv.socketPath, d.path) - } - if err := drv.Close(); err != nil { - t.Errorf("Close: %v", err) - } -} - -// ---------- DialAddr / Dial ---------- - -func TestDialAddrHappyPath(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - - d.onCmd(cmdDial, func(frame []byte) [][]byte { - resp := make([]byte, 1+4) - resp[0] = cmdDialOK - binary.BigEndian.PutUint32(resp[1:5], 0xDEADBEEF) - return [][]byte{resp} - }) - - drv, err := Connect(d.path) - if err != nil { - t.Fatal(err) - } - defer drv.Close() - - dst := protocol.Addr{Network: 1, Node: 0x0102_0304} - conn, err := drv.DialAddr(dst, 7) - if err != nil { - t.Fatalf("DialAddr: %v", err) - } - if conn.id != 0xDEADBEEF { - t.Errorf("conn.id = %#x, want 0xDEADBEEF", conn.id) - } - if conn.remoteAddr.Addr != dst || conn.remoteAddr.Port != 7 { - t.Errorf("remoteAddr = %+v, want {%+v, 7}", conn.remoteAddr, dst) - } -} - -func TestDialParsesAddressString(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - d.onCmd(cmdDial, func(frame []byte) [][]byte { - resp := make([]byte, 5) - resp[0] = cmdDialOK - binary.BigEndian.PutUint32(resp[1:5], 42) - return [][]byte{resp} - }) - - drv, _ := Connect(d.path) - defer drv.Close() - - conn, err := drv.Dial("1:0001.AAAA.BBBB:80") - if err != nil { - t.Fatalf("Dial: %v", err) - } - if conn.id != 42 { - t.Errorf("id = %d, want 42", conn.id) - } -} - -func TestDialBadAddressReturnsParseError(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, _ := Connect(d.path) - defer drv.Close() - if _, err := drv.Dial("not-a-valid-addr"); err == nil { - t.Fatal("expected parse error") - } -} - -func TestDialAddrTimeoutFiresWhenDaemonSilent(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - // No handler for cmdDial → daemon never responds - drv, _ := Connect(d.path) - defer drv.Close() - - start := time.Now() - _, err := drv.DialAddrTimeout(protocol.Addr{Network: 1, Node: 1}, 1, 100*time.Millisecond) - elapsed := time.Since(start) - if err == nil { - t.Fatal("expected timeout error") - } - if elapsed < 80*time.Millisecond || elapsed > 500*time.Millisecond { - t.Errorf("elapsed = %v (expected ~100ms)", elapsed) - } -} - -// ---------- Listen ---------- - -func TestListenHappyPath(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - - d.onCmd(cmdBind, func(frame []byte) [][]byte { - resp := make([]byte, 3) - resp[0] = cmdBindOK - binary.BigEndian.PutUint16(resp[1:3], 7) // echoed port - return [][]byte{resp} - }) - - drv, _ := Connect(d.path) - defer drv.Close() - - ln, err := drv.Listen(7) - if err != nil { - t.Fatalf("Listen: %v", err) - } - if ln.port != 7 { - t.Errorf("port = %d, want 7", ln.port) - } - _ = ln.Close() -} - -// ---------- SendTo / RecvFrom ---------- - -func TestSendToWritesFrame(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, _ := Connect(d.path) - defer drv.Close() - - dst := protocol.Addr{Network: 2, Node: 0x0A0B_0C0D} - if err := drv.SendTo(dst, 100, []byte("hi")); err != nil { - t.Fatalf("SendTo: %v", err) - } - - waitFor(t, time.Second, func() bool { - return d.lastFrame() != nil - }, "daemon to receive frame") - frame := d.lastFrame() - // d.received stores frames as-is: [cmd(1)][body...]. - if frame[0] != cmdSendTo { - t.Errorf("cmd = %#x, want %#x", frame[0], cmdSendTo) - } - if len(frame) != 1+protocol.AddrSize+2+2 { - t.Errorf("len = %d", len(frame)) - } - gotPort := binary.BigEndian.Uint16(frame[1+protocol.AddrSize:]) - if gotPort != 100 { - t.Errorf("port = %d, want 100", gotPort) - } - if string(frame[1+protocol.AddrSize+2:]) != "hi" { - t.Errorf("payload = %q", frame[1+protocol.AddrSize+2:]) - } -} - -func TestRecvFromDeliversDatagram(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, _ := Connect(d.path) - defer drv.Close() - - // Inject a cmdRecvFrom frame from the daemon - src := protocol.Addr{Network: 1, Node: 0x1122_3344} - payload := make([]byte, protocol.AddrSize+4+3) - src.MarshalTo(payload, 0) - binary.BigEndian.PutUint16(payload[protocol.AddrSize:], 200) - binary.BigEndian.PutUint16(payload[protocol.AddrSize+2:], 300) - copy(payload[protocol.AddrSize+4:], "abc") - frame := append([]byte{cmdRecvFrom}, payload...) - - // Use pushFromDaemon to write the frame through the daemon-side conn. - pushFromDaemon(t, d, frame) - - dg, err := drv.RecvFrom() - if err != nil { - t.Fatalf("RecvFrom: %v", err) - } - if dg.SrcAddr != src || dg.SrcPort != 200 || dg.DstPort != 300 || string(dg.Data) != "abc" { - t.Errorf("got %+v", dg) - } -} - -func TestRecvFromErrorAfterClose(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - drv, _ := Connect(d.path) - - // Close the daemon so readLoop exits and drains dgCh - d.close() - - // Give the readLoop time to exit - waitFor(t, time.Second, func() bool { - select { - case <-drv.ipc.doneCh: - return true - default: - return false - } - }, "readLoop exit") - - // dgCh is not explicitly closed; RecvFrom blocks until dgCh closes OR - // until we push. Since it's buffered but not closed, this would hang. - // Instead we verify the doneCh path by calling Close on the driver. - _ = drv.Close() -} - -// ---------- Info / Health ---------- - -func TestInfoAndHealthReturnParsedJSON(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - d.onCmd(cmdInfo, func(_ []byte) [][]byte { - return [][]byte{jsonOK(cmdInfoOK, `{"node_id": 42, "addr": "1:0001.0002.0003"}`)} - }) - d.onCmd(cmdHealth, func(_ []byte) [][]byte { - return [][]byte{jsonOK(cmdHealthOK, `{"ok": true}`)} - }) - drv, _ := Connect(d.path) - defer drv.Close() - - info, err := drv.Info() - if err != nil { - t.Fatalf("Info: %v", err) - } - if info["node_id"].(float64) != 42 { - t.Errorf("node_id = %v", info["node_id"]) - } - - h, err := drv.Health() - if err != nil { - t.Fatalf("Health: %v", err) - } - if h["ok"] != true { - t.Errorf("ok = %v", h["ok"]) - } -} - -func TestJsonRPCUnmarshalErrorSurfaced(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - d.onCmd(cmdInfo, func(_ []byte) [][]byte { - return [][]byte{jsonOK(cmdInfoOK, `not-json`)} - }) - drv, _ := Connect(d.path) - defer drv.Close() - - if _, err := drv.Info(); err == nil { - t.Fatal("expected unmarshal error") - } -} - -func TestSendAndWaitSurfacesDaemonErrorFrame(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - d.onCmd(cmdInfo, func(_ []byte) [][]byte { - // cmdError frame: first byte cmdError, then 2 bytes code, then msg - body := []byte{cmdError, 0, 0} - body = append(body, []byte("boom")...) - return [][]byte{body} - }) - drv, _ := Connect(d.path) - defer drv.Close() - - _, err := drv.Info() - if err == nil || !strings.Contains(err.Error(), "boom") { - t.Fatalf("err = %v, want boom", err) - } -} - -// ---------- Handshake family ---------- - -func TestHandshakeFamilyRoundTrips(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - d.onCmd(cmdHandshake, func(frame []byte) [][]byte { - return [][]byte{jsonOK(cmdHandshakeOK, `{"ok": true}`)} - }) - drv, _ := Connect(d.path) - defer drv.Close() - - if _, err := drv.Handshake(99, "please"); err != nil { - t.Fatalf("Handshake: %v", err) - } - if _, err := drv.ApproveHandshake(100); err != nil { - t.Fatalf("Approve: %v", err) - } - if _, err := drv.RejectHandshake(101, "no"); err != nil { - t.Fatalf("Reject: %v", err) - } - if _, err := drv.PendingHandshakes(); err != nil { - t.Fatalf("Pending: %v", err) - } - if _, err := drv.TrustedPeers(); err != nil { - t.Fatalf("Trusted: %v", err) - } - if _, err := drv.RevokeTrust(102); err != nil { - t.Fatalf("Revoke: %v", err) - } - - frames := d.allFrames() - if len(frames) != 6 { - t.Fatalf("expected 6 handshake frames, got %d", len(frames)) - } - expectSub := []byte{subHandshakeSend, subHandshakeApprove, subHandshakeReject, - subHandshakePending, subHandshakeTrusted, subHandshakeRevoke} - for i, want := range expectSub { - if frames[i][0] != cmdHandshake || frames[i][1] != want { - t.Errorf("frame[%d] = %v, want cmd=%#x sub=%#x", i, frames[i][:2], cmdHandshake, want) - } - } -} - -// ---------- Registry-modifying wrappers ---------- - -func TestRegistryWrappersEncodeCorrectly(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - - okCommands := map[byte]byte{ - cmdResolveHostname: cmdResolveHostnameOK, - cmdSetHostname: cmdSetHostnameOK, - cmdSetVisibility: cmdSetVisibilityOK, - cmdDeregister: cmdDeregisterOK, - cmdSetTags: cmdSetTagsOK, - cmdSetWebhook: cmdSetWebhookOK, - } - for req, ok := range okCommands { - req, ok := req, ok - d.onCmd(req, func(_ []byte) [][]byte { - return [][]byte{jsonOK(ok, `{"ok":true}`)} - }) - } - - drv, _ := Connect(d.path) - defer drv.Close() - - if _, err := drv.ResolveHostname("myhost"); err != nil { - t.Fatalf("ResolveHostname: %v", err) - } - if _, err := drv.SetHostname("myhost"); err != nil { - t.Fatalf("SetHostname: %v", err) - } - if _, err := drv.SetVisibility(true); err != nil { - t.Fatalf("SetVisibility: %v", err) - } - if _, err := drv.Deregister(); err != nil { - t.Fatalf("Deregister: %v", err) - } - if _, err := drv.SetTags([]string{"a", "b"}); err != nil { - t.Fatalf("SetTags: %v", err) - } - if _, err := drv.SetWebhook("https://x/y"); err != nil { - t.Fatalf("SetWebhook: %v", err) - } - - // Check visibility byte=1 for enabled - for _, f := range d.allFrames() { - switch f[0] { - case cmdSetVisibility: - if f[1] != 1 { - t.Errorf("visibility byte = %d, want 1", f[1]) - } - case cmdResolveHostname: - if string(f[1:]) != "myhost" { - t.Errorf("ResolveHostname host = %q", f[1:]) - } - case cmdSetWebhook: - if string(f[1:]) != "https://x/y" { - t.Errorf("SetWebhook url = %q", f[1:]) - } - } - } -} - -func TestSetVisibilityFalsePath(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - d.onCmd(cmdSetVisibility, func(_ []byte) [][]byte { - return [][]byte{jsonOK(cmdSetVisibilityOK, `{}`)} - }) - drv, _ := Connect(d.path) - defer drv.Close() - - if _, err := drv.SetVisibility(false); err != nil { - t.Fatal(err) - } - - frames := d.allFrames() - if frames[0][1] != 0 { - t.Errorf("visibility false byte = %d, want 0", frames[0][1]) - } -} - -// ---------- Disconnect / cmdClose ---------- - -func TestDisconnectSendsCmdClose(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - d.onCmd(cmdClose, func(frame []byte) [][]byte { - resp := make([]byte, 5) - resp[0] = cmdCloseOK - binary.BigEndian.PutUint32(resp[1:5], 77) - return [][]byte{resp} - }) - - drv, _ := Connect(d.path) - defer drv.Close() - - if err := drv.Disconnect(77); err != nil { - t.Fatalf("Disconnect: %v", err) - } - // Disconnect is fire-and-forget; wait for the daemon to receive the frame. - waitFor(t, time.Second, func() bool { return d.lastFrame() != nil }, "daemon to receive cmdClose") - frame := d.lastFrame() - if frame[0] != cmdClose { - t.Errorf("cmd = %#x, want %#x", frame[0], cmdClose) - } - if connID := binary.BigEndian.Uint32(frame[1:5]); connID != 77 { - t.Errorf("connID = %d", connID) - } -} - -// ---------- Network family ---------- - -func TestNetworkFamilyRoundTrips(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - d.onCmd(cmdNetwork, func(_ []byte) [][]byte { - return [][]byte{jsonOK(cmdNetworkOK, `{"ok":true}`)} - }) - - drv, _ := Connect(d.path) - defer drv.Close() - - if _, err := drv.NetworkList(); err != nil { - t.Fatal(err) - } - if _, err := drv.NetworkJoin(5, "token"); err != nil { - t.Fatal(err) - } - if _, err := drv.NetworkLeave(5); err != nil { - t.Fatal(err) - } - if _, err := drv.NetworkMembers(5); err != nil { - t.Fatal(err) - } - if _, err := drv.NetworkInvite(5, 100); err != nil { - t.Fatal(err) - } - if _, err := drv.NetworkPollInvites(); err != nil { - t.Fatal(err) - } - if _, err := drv.NetworkRespondInvite(5, true); err != nil { - t.Fatal(err) - } - if _, err := drv.NetworkRespondInvite(5, false); err != nil { - t.Fatal(err) - } - - frames := d.allFrames() - wantSubs := []byte{subNetworkList, subNetworkJoin, subNetworkLeave, subNetworkMembers, - subNetworkInvite, subNetworkPollInvites, subNetworkRespondInvite, subNetworkRespondInvite} - if len(frames) != len(wantSubs) { - t.Fatalf("got %d frames, want %d", len(frames), len(wantSubs)) - } - for i, want := range wantSubs { - if frames[i][1] != want { - t.Errorf("frame[%d] sub = %#x, want %#x", i, frames[i][1], want) - } - } - // Respond-invite accept vs reject byte - // Accept frame is 7th (index 6), reject is 8th (index 7) - if frames[6][4] != 1 { - t.Errorf("accept byte = %d, want 1", frames[6][4]) - } - if frames[7][4] != 0 { - t.Errorf("reject byte = %d, want 0", frames[7][4]) - } -} - -// ---------- Managed family ---------- - -func TestManagedFamilyRoundTrips(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - d.onCmd(cmdManaged, func(_ []byte) [][]byte { - return [][]byte{jsonOK(cmdManagedOK, `{"ok":true}`)} - }) - - drv, _ := Connect(d.path) - defer drv.Close() - - if _, err := drv.ManagedStatus(5); err != nil { - t.Fatal(err) - } - if _, err := drv.ManagedForceCycle(5); err != nil { - t.Fatal(err) - } - if _, err := drv.PolicyGet(5); err != nil { - t.Fatal(err) - } - if _, err := drv.PolicySet(5, []byte(`{"version":1}`)); err != nil { - t.Fatal(err) - } - if _, err := drv.MemberTagsGet(5, 99); err != nil { - t.Fatal(err) - } - if _, err := drv.MemberTagsSet(5, 99, []string{"x", "y"}); err != nil { - t.Fatal(err) - } - - frames := d.allFrames() - wantSubs := []byte{subManagedStatus, subManagedCycle, - subManagedPolicy, subManagedPolicy, subManagedMemberTags, subManagedMemberTags} - for i, want := range wantSubs { - if frames[i][1] != want { - t.Errorf("frame[%d] sub = %#x, want %#x", i, frames[i][1], want) - } - } - // PolicyGet action byte is 0x00, PolicySet 0x01 - if frames[2][2] != 0x00 { - t.Errorf("PolicyGet action byte = %#x, want 0x00", frames[2][2]) - } - if frames[3][2] != 0x01 { - t.Errorf("PolicySet action byte = %#x, want 0x01", frames[3][2]) - } - // MemberTagsGet 0x00, Set 0x01 - if frames[4][2] != 0x00 { - t.Errorf("MemberTagsGet action byte = %#x", frames[4][2]) - } - if frames[5][2] != 0x01 { - t.Errorf("MemberTagsSet action byte = %#x", frames[5][2]) - } -} diff --git a/pkg/driver/zz_ipc_listener_test.go b/pkg/driver/zz_ipc_listener_test.go deleted file mode 100644 index 78958086..00000000 --- a/pkg/driver/zz_ipc_listener_test.go +++ /dev/null @@ -1,510 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package driver - -import ( - "encoding/binary" - "testing" - "time" - - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" -) - -// pushFromDaemon writes an unsolicited frame from the fakeDaemon side to -// exercise driver.readLoop dispatch. Waits briefly for the daemon conn to -// be accepted first. -// -// frame must be [cmd][payload...]. Wire format is [cmd][payload...] with no reqID. -func pushFromDaemon(t *testing.T, d *fakeDaemon, frame []byte) { - t.Helper() - waitFor(t, 2*time.Second, func() bool { - d.mu.Lock() - defer d.mu.Unlock() - return d.conn != nil - }, "daemon accept") - d.mu.Lock() - conn := d.conn - d.mu.Unlock() - if err := ipcutil.Write(conn, frame); err != nil { - t.Fatalf("write from daemon: %v", err) - } -} - -// ---------- readLoop dispatch ---------- - -func TestReadLoopRecvDeliversToRegisteredChannel(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - connID := uint32(42) - ch := drv.ipc.registerRecvCh(connID) - - frame := make([]byte, 1+4+5) - frame[0] = cmdRecv - binary.BigEndian.PutUint32(frame[1:5], connID) - copy(frame[5:], "hello") - pushFromDaemon(t, d, frame) - - select { - case data := <-ch: - if string(data) != "hello" { - t.Errorf("got %q, want hello", data) - } - case <-time.After(2 * time.Second): - t.Fatal("no data delivered to recvCh") - } -} - -func TestReadLoopRecvBuffersWhenChannelNotRegistered(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - connID := uint32(99) - frame := make([]byte, 1+4+3) - frame[0] = cmdRecv - binary.BigEndian.PutUint32(frame[1:5], connID) - copy(frame[5:], "buf") - pushFromDaemon(t, d, frame) - - // Wait for readLoop to process the frame (no recvCh registered yet). - waitFor(t, 2*time.Second, func() bool { - drv.ipc.recvMu.Lock() - defer drv.ipc.recvMu.Unlock() - return len(drv.ipc.pendRecv[connID]) == 1 - }, "pendRecv buffered") - - // Registering now should drain the buffered data. - ch := drv.ipc.registerRecvCh(connID) - select { - case data := <-ch: - if string(data) != "buf" { - t.Errorf("got %q, want buf", data) - } - case <-time.After(1 * time.Second): - t.Fatal("registerRecvCh did not drain pendRecv") - } -} - -func TestReadLoopRecvShortPayloadDropped(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - // <4 bytes after cmd byte → dropped. - pushFromDaemon(t, d, []byte{cmdRecv, 0x01}) - - // Sanity: no crash, no data buffered, no recv channels created. - time.Sleep(50 * time.Millisecond) - drv.ipc.recvMu.Lock() - defer drv.ipc.recvMu.Unlock() - if len(drv.ipc.pendRecv) != 0 { - t.Errorf("short cmdRecv should not buffer, got %d entries", len(drv.ipc.pendRecv)) - } -} - -func TestReadLoopCloseOKClosesRegisteredChannel(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - connID := uint32(7) - ch := drv.ipc.registerRecvCh(connID) - - frame := make([]byte, 1+4) - frame[0] = cmdCloseOK - binary.BigEndian.PutUint32(frame[1:], connID) - pushFromDaemon(t, d, frame) - - select { - case _, ok := <-ch: - if ok { - t.Error("expected channel closed, got value") - } - case <-time.After(2 * time.Second): - t.Fatal("channel not closed") - } - - drv.ipc.recvMu.Lock() - _, stillThere := drv.ipc.recvChs[connID] - drv.ipc.recvMu.Unlock() - if stillThere { - t.Error("recvCh entry should be deleted") - } -} - -func TestReadLoopCloseOKShortPayloadDropped(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - // payload < 4 — must not panic, must not disturb recvChs. - connID := uint32(8) - ch := drv.ipc.registerRecvCh(connID) - pushFromDaemon(t, d, []byte{cmdCloseOK, 0x00}) - - time.Sleep(50 * time.Millisecond) - select { - case <-ch: - t.Error("channel should not close on short CloseOK") - default: - } -} - -func TestReadLoopRecvFromDeliversDatagram(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - srcAddr, err := protocol.ParseAddr("1:0001.AAAA.BBBB") - if err != nil { - t.Fatal(err) - } - frame := make([]byte, 1+protocol.AddrSize+4+5) - frame[0] = cmdRecvFrom - srcAddr.MarshalTo(frame, 1) - binary.BigEndian.PutUint16(frame[1+protocol.AddrSize:], 111) - binary.BigEndian.PutUint16(frame[1+protocol.AddrSize+2:], 222) - copy(frame[1+protocol.AddrSize+4:], "ping!") - pushFromDaemon(t, d, frame) - - dg, err := drv.RecvFrom() - if err != nil { - t.Fatalf("RecvFrom: %v", err) - } - if dg.SrcPort != 111 || dg.DstPort != 222 || string(dg.Data) != "ping!" { - t.Errorf("datagram = %+v, data=%q", dg, string(dg.Data)) - } - if dg.SrcAddr != srcAddr { - t.Errorf("src addr = %v, want %v", dg.SrcAddr, srcAddr) - } -} - -func TestReadLoopRecvFromShortPayloadDropped(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - // AddrSize=6, need +4 for ports — send just 5 bytes payload → drop. - pushFromDaemon(t, d, append([]byte{cmdRecvFrom}, make([]byte, 5)...)) - - // If it were dispatched, it'd land on dgCh. Confirm nothing arrives. - select { - case dg := <-drv.ipc.dgCh: - t.Errorf("unexpected datagram: %+v", dg) - case <-time.After(100 * time.Millisecond): - } -} - -func TestReadLoopAcceptShortPayloadDropped(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - // <2 bytes after cmd byte → dropped. - pushFromDaemon(t, d, []byte{cmdAccept, 0x01}) - time.Sleep(50 * time.Millisecond) // assertion: no crash -} - -func TestReadLoopEmptyFrameContinues(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - // Zero-length frame is skipped; readLoop must keep running. - pushFromDaemon(t, d, []byte{}) - - // Follow it with a valid cmdRecvFrom to prove readLoop is still alive. - srcAddr, err := protocol.ParseAddr("1:0001.CCCC.DDDD") - if err != nil { - t.Fatal(err) - } - frame := make([]byte, 1+protocol.AddrSize+4+2) - frame[0] = cmdRecvFrom - srcAddr.MarshalTo(frame, 1) - copy(frame[1+protocol.AddrSize+4:], "ok") - pushFromDaemon(t, d, frame) - - if _, err := drv.RecvFrom(); err != nil { - t.Fatalf("readLoop died after empty frame: %v", err) - } -} - -func TestReadLoopUnknownCmdWithNoHandlerDropped(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - // cmd 0xFE not in any handler map — readLoop default branch, no waiter, drop. - pushFromDaemon(t, d, []byte{0xFE, 0x01, 0x02}) - - // Prove readLoop still alive by exchanging Info. - d.onCmd(cmdInfo, func(_ []byte) [][]byte { - return [][]byte{jsonOK(cmdInfoOK, `{"ok":true}`)} - }) - if _, err := drv.Info(); err != nil { - t.Fatalf("Info after unknown cmd: %v", err) - } -} - -// ---------- Listener.Accept branches ---------- - -func TestListenerAcceptDeliversConn(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - // Bind port 5000. - d.onCmd(cmdBind, func(frame []byte) [][]byte { - return [][]byte{{cmdBindOK, frame[1], frame[2]}} - }) - ln, err := drv.Listen(5000) - if err != nil { - t.Fatalf("Listen: %v", err) - } - defer ln.Close() - - // Push an unsolicited cmdAccept: [port][connID][addr][port] - remoteAddr, _ := protocol.ParseAddr("1:0002.1111.2222") - payload := make([]byte, 2+4+protocol.AddrSize+2) - binary.BigEndian.PutUint16(payload[0:2], 5000) - binary.BigEndian.PutUint32(payload[2:6], 1234) - remoteAddr.MarshalTo(payload, 6) - binary.BigEndian.PutUint16(payload[6+protocol.AddrSize:], 99) - pushFromDaemon(t, d, append([]byte{cmdAccept}, payload...)) - - // Accept should return a Conn with the parsed fields. - done := make(chan *Conn, 1) - errCh := make(chan error, 1) - go func() { - c, err := ln.Accept() - if err != nil { - errCh <- err - return - } - done <- c.(*Conn) - }() - - select { - case c := <-done: - if c.id != 1234 { - t.Errorf("conn.id = %d, want 1234", c.id) - } - if c.remoteAddr.Port != 99 { - t.Errorf("remote port = %d, want 99", c.remoteAddr.Port) - } - if c.remoteAddr.Addr != remoteAddr { - t.Errorf("remote addr = %v, want %v", c.remoteAddr.Addr, remoteAddr) - } - case err := <-errCh: - t.Fatalf("Accept: %v", err) - case <-time.After(2 * time.Second): - t.Fatal("Accept did not complete") - } -} - -func TestListenerAcceptInvalidPayloadReturnsError(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - d.onCmd(cmdBind, func(frame []byte) [][]byte { - return [][]byte{{cmdBindOK, frame[1], frame[2]}} - }) - ln, err := drv.Listen(5001) - if err != nil { - t.Fatalf("Listen: %v", err) - } - defer ln.Close() - - // cmdAccept with port=5001 but truncated tail. - pushFromDaemon(t, d, []byte{cmdAccept, 0x13, 0x89, 0x00}) - - _, err = ln.Accept() - if err == nil { - t.Fatal("expected invalid payload error") - } -} - -func TestListenerAcceptUnblocksOnClose(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - d.onCmd(cmdBind, func(frame []byte) [][]byte { - return [][]byte{{cmdBindOK, frame[1], frame[2]}} - }) - ln, err := drv.Listen(5002) - if err != nil { - t.Fatalf("Listen: %v", err) - } - - errCh := make(chan error, 1) - go func() { - _, err := ln.Accept() - errCh <- err - }() - - // Give Accept a moment to enter the select, then close. - time.Sleep(50 * time.Millisecond) - _ = ln.Close() - - select { - case err := <-errCh: - if err == nil { - t.Fatal("expected error after Close") - } - case <-time.After(2 * time.Second): - t.Fatal("Accept did not unblock on Close") - } -} - -// ---------- ipcClient helpers ---------- - -func TestUnregisterRecvChRemovesEntry(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - connID := uint32(55) - _ = drv.ipc.registerRecvCh(connID) - - drv.ipc.recvMu.Lock() - if _, ok := drv.ipc.recvChs[connID]; !ok { - drv.ipc.recvMu.Unlock() - t.Fatal("recvCh not present after registerRecvCh") - } - drv.ipc.recvMu.Unlock() - - drv.ipc.unregisterRecvCh(connID) - - drv.ipc.recvMu.Lock() - _, ok := drv.ipc.recvChs[connID] - drv.ipc.recvMu.Unlock() - if ok { - t.Error("recvCh still present after unregisterRecvCh") - } -} - -func TestSendAndWaitTimeoutFires(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - defer d.close() - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - // No handler for cmdInfo — daemon accepts the frame and never replies. - start := time.Now() - _, err = drv.ipc.sendAndWaitTimeout([]byte{cmdInfo}, cmdInfoOK, 80*time.Millisecond) - elapsed := time.Since(start) - if err == nil { - t.Fatal("expected timeout error") - } - if elapsed < 50*time.Millisecond || elapsed > 500*time.Millisecond { - t.Errorf("unexpected elapsed %v (want ~80ms)", elapsed) - } -} - -func TestSendAndWaitReturnsWhenDaemonDisconnects(t *testing.T) { - t.Parallel() - d := newFakeDaemon(t) - drv, err := Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - defer drv.Close() - - errCh := make(chan error, 1) - go func() { - _, err := drv.ipc.sendAndWait([]byte{cmdInfo}, cmdInfoOK) - errCh <- err - }() - - // Give the request time to enqueue its handler, then yank the daemon. - time.Sleep(50 * time.Millisecond) - d.close() - - select { - case err := <-errCh: - if err == nil { - t.Fatal("expected error on daemon disconnect") - } - case <-time.After(2 * time.Second): - t.Fatal("sendAndWait did not unblock on disconnect") - } -} diff --git a/pkg/logging/logging.go b/pkg/logging/logging.go deleted file mode 100644 index 7d92d8c7..00000000 --- a/pkg/logging/logging.go +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package logging - -import ( - "io" - "log/slog" - "os" - "strings" -) - -// Setup configures the default slog logger with the given level and format. -// format can be "text" (human-readable) or "json" (machine-parseable). -// level can be "debug", "info", "warn", "error". -func Setup(level, format string) { - SetupWriter(os.Stderr, level, format) -} - -// SetupWriter configures the default slog logger writing to w. -func SetupWriter(w io.Writer, level, format string) { - var lvl slog.Level - switch strings.ToLower(level) { - case "debug": - lvl = slog.LevelDebug - case "warn", "warning": - lvl = slog.LevelWarn - case "error": - lvl = slog.LevelError - default: - lvl = slog.LevelInfo - } - - opts := &slog.HandlerOptions{Level: lvl} - - var handler slog.Handler - switch strings.ToLower(format) { - case "json": - handler = slog.NewJSONHandler(w, opts) - default: - handler = slog.NewTextHandler(w, opts) - } - - slog.SetDefault(slog.New(handler)) -} diff --git a/pkg/logging/zz_logging_test.go b/pkg/logging/zz_logging_test.go deleted file mode 100644 index 01a49504..00000000 --- a/pkg/logging/zz_logging_test.go +++ /dev/null @@ -1,161 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package logging_test - -import ( - "bytes" - "encoding/json" - "log/slog" - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/logging" -) - -func TestSetupWriterJSONFormat(t *testing.T) { - // Cannot run in parallel because this mutates the package-global default logger. - saved := slog.Default() - defer slog.SetDefault(saved) - - var buf bytes.Buffer - logging.SetupWriter(&buf, "info", "json") - slog.Info("hello", "k", "v") - - line := strings.TrimSpace(buf.String()) - if line == "" { - t.Fatal("no output emitted") - } - var m map[string]interface{} - if err := json.Unmarshal([]byte(line), &m); err != nil { - t.Fatalf("output not valid JSON: %v\n%s", err, line) - } - if m["msg"] != "hello" { - t.Errorf("msg = %v, want hello", m["msg"]) - } - if m["k"] != "v" { - t.Errorf("attr k = %v, want v", m["k"]) - } - if m["level"] != "INFO" { - t.Errorf("level = %v, want INFO", m["level"]) - } -} - -func TestSetupWriterTextFormat(t *testing.T) { - saved := slog.Default() - defer slog.SetDefault(saved) - - var buf bytes.Buffer - logging.SetupWriter(&buf, "info", "text") - slog.Info("hello", "k", "v") - - out := buf.String() - if out == "" { - t.Fatal("no output") - } - // text format output should NOT parse as JSON - if err := json.Unmarshal([]byte(strings.TrimSpace(out)), &map[string]interface{}{}); err == nil { - t.Fatalf("text output unexpectedly parsed as JSON: %s", out) - } - if !strings.Contains(out, "hello") || !strings.Contains(out, "k=v") { - t.Errorf("missing expected content: %s", out) - } -} - -func TestSetupWriterDefaultFormatIsText(t *testing.T) { - saved := slog.Default() - defer slog.SetDefault(saved) - - var buf bytes.Buffer - logging.SetupWriter(&buf, "info", "unknown-format") - slog.Info("msg") - - out := strings.TrimSpace(buf.String()) - // default (unknown format) → text handler, not JSON - if strings.HasPrefix(out, "{") { - t.Fatalf("unknown format should default to text, got JSON: %s", out) - } -} - -func TestSetupWriterLevelsGateOutput(t *testing.T) { - cases := []struct { - level string - wantDebug bool - wantInfo bool - wantWarn bool - wantError bool - }{ - {"debug", true, true, true, true}, - {"info", false, true, true, true}, - {"warn", false, false, true, true}, - {"warning", false, false, true, true}, - {"error", false, false, false, true}, - {"unknown", false, true, true, true}, // default → info - } - for _, tc := range cases { - t.Run(tc.level, func(t *testing.T) { - saved := slog.Default() - defer slog.SetDefault(saved) - - var buf bytes.Buffer - logging.SetupWriter(&buf, tc.level, "text") - - slog.Debug("D") - slog.Info("I") - slog.Warn("W") - slog.Error("E") - - out := buf.String() - if strings.Contains(out, "\"D\"") != tc.wantDebug && strings.Contains(out, "D") != tc.wantDebug { - hasD := strings.Contains(out, "msg=D") - if hasD != tc.wantDebug { - t.Errorf("debug output present=%v, want=%v\n%s", hasD, tc.wantDebug, out) - } - } - hasInfo := strings.Contains(out, "msg=I") - if hasInfo != tc.wantInfo { - t.Errorf("info output present=%v, want=%v\n%s", hasInfo, tc.wantInfo, out) - } - hasWarn := strings.Contains(out, "msg=W") - if hasWarn != tc.wantWarn { - t.Errorf("warn output present=%v, want=%v\n%s", hasWarn, tc.wantWarn, out) - } - hasError := strings.Contains(out, "msg=E") - if hasError != tc.wantError { - t.Errorf("error output present=%v, want=%v\n%s", hasError, tc.wantError, out) - } - }) - } -} - -func TestSetupCaseInsensitive(t *testing.T) { - saved := slog.Default() - defer slog.SetDefault(saved) - - var buf bytes.Buffer - logging.SetupWriter(&buf, "DEBUG", "JSON") - slog.Debug("dbg-msg") - line := strings.TrimSpace(buf.String()) - if !strings.HasPrefix(line, "{") { - t.Fatalf("uppercase JSON should produce JSON, got: %s", line) - } - var m map[string]interface{} - if err := json.Unmarshal([]byte(line), &m); err != nil { - t.Fatalf("not JSON: %v", err) - } - if m["level"] != "DEBUG" { - t.Errorf("level = %v, want DEBUG", m["level"]) - } -} - -func TestSetupUsesStderrByDefault(t *testing.T) { - // Smoke test: Setup() picks the stderr path. We don't capture stderr (too - // invasive) but we verify the call doesn't panic and leaves slog.Default - // non-nil. - saved := slog.Default() - defer slog.SetDefault(saved) - - logging.Setup("info", "text") - if slog.Default() == nil { - t.Fatal("slog.Default() is nil after Setup") - } -} diff --git a/pkg/protocol/address.go b/pkg/protocol/address.go deleted file mode 100644 index 74cd57f4..00000000 --- a/pkg/protocol/address.go +++ /dev/null @@ -1,151 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package protocol - -import ( - "encoding/binary" - "fmt" - "strconv" - "strings" -) - -const AddrSize = 6 // 48 bits: 2 bytes network + 4 bytes node - -// Addr is a 48-bit Pilot Protocol virtual address. -// Layout: [16-bit Network ID][32-bit Node ID] -// Text format: N:NNNN.HHHH.LLLL -// -// N = network ID in decimal -// NNNN = network ID in hex (redundant, for readability) -// HHHH = node ID high 16 bits in hex -// LLLL = node ID low 16 bits in hex -type Addr struct { - Network uint16 - Node uint32 -} - -var ( - AddrRegistry = Addr{0, 1} - AddrBeacon = Addr{0, 2} - AddrNameserver = Addr{0, 3} -) - -// ZeroAddr returns the zero-value address ({0, 0}). It exists as a -// function rather than a package-level var so callers cannot mutate a -// shared sentinel (P3 — no cross-layer mutable globals). The returned -// value is freshly constructed on each call. -func ZeroAddr() Addr { return Addr{} } - -// BroadcastAddr returns the broadcast address for a given network. -func BroadcastAddr(network uint16) Addr { - return Addr{Network: network, Node: 0xFFFFFFFF} -} - -func (a Addr) IsZero() bool { return a.Network == 0 && a.Node == 0 } -func (a Addr) IsBroadcast() bool { return a.Node == 0xFFFFFFFF } - -// Marshal writes the address as 6 bytes (big-endian). -func (a Addr) Marshal() []byte { - b := make([]byte, AddrSize) - a.MarshalTo(b, 0) - return b -} - -// MarshalTo writes the address into buf at the given offset. -func (a Addr) MarshalTo(buf []byte, offset int) { - binary.BigEndian.PutUint16(buf[offset:], a.Network) - binary.BigEndian.PutUint32(buf[offset+2:], a.Node) -} - -// UnmarshalAddr reads a 6-byte address from buf. -// Returns a zero address if buf is shorter than AddrSize (6 bytes), -// rather than panicking on the out-of-bounds slice (PILOT-133). -func UnmarshalAddr(buf []byte) Addr { - if len(buf) < AddrSize { - return Addr{} - } - return Addr{ - Network: binary.BigEndian.Uint16(buf[0:2]), - Node: binary.BigEndian.Uint32(buf[2:6]), - } -} - -// String returns the text representation: N:NNNN.HHHH.LLLL -func (a Addr) String() string { - return fmt.Sprintf("%d:%04X.%04X.%04X", a.Network, a.Network, (a.Node>>16)&0xFFFF, a.Node&0xFFFF) -} - -// ParseAddr parses "0:0000.0000.0001" or "1:00A3.F291.0004" into an Addr. -func ParseAddr(s string) (Addr, error) { - parts := strings.SplitN(s, ":", 2) - if len(parts) != 2 { - return Addr{}, fmt.Errorf("invalid address: %q (expected N:XXXX.YYYY.YYYY)", s) - } - - networkDec, err := strconv.ParseUint(parts[0], 10, 16) - if err != nil { - return Addr{}, fmt.Errorf("invalid network ID: %q: %w", parts[0], err) - } - - hexGroups := strings.Split(parts[1], ".") - if len(hexGroups) != 3 { - return Addr{}, fmt.Errorf("invalid address: %q (expected 3 dot-separated hex groups)", parts[1]) - } - for _, h := range hexGroups { - if len(h) != 4 { - return Addr{}, fmt.Errorf("invalid hex group: %q (expected 4 digits)", h) - } - } - - netHex, err := strconv.ParseUint(hexGroups[0], 16, 16) - if err != nil { - return Addr{}, fmt.Errorf("invalid hex group: %q: %w", hexGroups[0], err) - } - if netHex != networkDec { - return Addr{}, fmt.Errorf("network mismatch: decimal %d != hex 0x%04X", networkDec, netHex) - } - - nodeHigh, err := strconv.ParseUint(hexGroups[1], 16, 16) - if err != nil { - return Addr{}, fmt.Errorf("invalid hex group: %q: %w", hexGroups[1], err) - } - nodeLow, err := strconv.ParseUint(hexGroups[2], 16, 16) - if err != nil { - return Addr{}, fmt.Errorf("invalid hex group: %q: %w", hexGroups[2], err) - } - - return Addr{ - Network: uint16(networkDec), - Node: uint32(nodeHigh)<<16 | uint32(nodeLow), - }, nil -} - -// SocketAddr is a full endpoint: virtual address + port. -type SocketAddr struct { - Addr Addr - Port uint16 -} - -func (sa SocketAddr) String() string { - return fmt.Sprintf("%s:%d", sa.Addr.String(), sa.Port) -} - -// ParseSocketAddr parses "N:XXXX.YYYY.YYYY:PORT". -func ParseSocketAddr(s string) (SocketAddr, error) { - lastColon := strings.LastIndex(s, ":") - if lastColon == -1 { - return SocketAddr{}, fmt.Errorf("invalid socket address: %q (no port)", s) - } - - addr, err := ParseAddr(s[:lastColon]) - if err != nil { - return SocketAddr{}, err - } - - port, err := strconv.ParseUint(s[lastColon+1:], 10, 16) - if err != nil { - return SocketAddr{}, fmt.Errorf("invalid port: %q: %w", s[lastColon+1:], err) - } - - return SocketAddr{Addr: addr, Port: uint16(port)}, nil -} diff --git a/pkg/protocol/checksum.go b/pkg/protocol/checksum.go deleted file mode 100644 index 39a3ca8d..00000000 --- a/pkg/protocol/checksum.go +++ /dev/null @@ -1,12 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package protocol - -import "hash/crc32" - -var crcTable = crc32.MakeTable(crc32.IEEE) - -// Checksum computes CRC32 (IEEE) over the given data. -func Checksum(data []byte) uint32 { - return crc32.Checksum(data, crcTable) -} diff --git a/pkg/protocol/header.go b/pkg/protocol/header.go deleted file mode 100644 index edc3dd14..00000000 --- a/pkg/protocol/header.go +++ /dev/null @@ -1,87 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package protocol - -import "errors" - -// Protocol version -const Version uint8 = 1 - -// Sentinel errors shared across packages. -var ( - ErrNodeNotFound = errors.New("node not found") - ErrNetworkNotFound = errors.New("network not found") - ErrConnClosed = errors.New("connection closed") - ErrConnRefused = errors.New("connection refused") - ErrDialTimeout = errors.New("dial timeout") - ErrChecksumMismatch = errors.New("checksum mismatch") - // ErrMalformedPacket is returned by Marshal/Unmarshal's L1 panic - // boundary when a panic is recovered during wire-format decode/encode. - // Wraps the original panic value via fmt.Errorf("%w: %v", ...). - ErrMalformedPacket = errors.New("malformed packet") -) - -// Flags (4 bits, stored in lower nibble of first byte alongside version) -const ( - FlagSYN uint8 = 0x1 - FlagACK uint8 = 0x2 - FlagFIN uint8 = 0x4 - FlagRST uint8 = 0x8 -) - -// Protocol types -const ( - ProtoStream uint8 = 0x01 // Reliable, ordered (TCP-like) - ProtoDatagram uint8 = 0x02 // Unreliable, unordered (UDP-like) - ProtoControl uint8 = 0x03 // Internal control -) - -// Well-known ports -const ( - PortPing uint16 = 0 - PortControl uint16 = 1 - PortEcho uint16 = 7 - PortNameserver uint16 = 53 - PortHTTP uint16 = 80 - PortSecure uint16 = 443 - PortStdIO uint16 = 1000 - PortDataExchange uint16 = 1001 - PortEventStream uint16 = 1002 -) - -// Port ranges -const ( - PortReservedMax uint16 = 1023 - PortRegisteredMax uint16 = 49151 - PortEphemeralMin uint16 = 49152 - PortEphemeralMax uint16 = 65535 -) - -// Tunnel magic bytes: "PILT" (0x50494C54) -var TunnelMagic = [4]byte{0x50, 0x49, 0x4C, 0x54} - -// Tunnel magic bytes for encrypted packets: "PILS" (0x50494C53) -var TunnelMagicSecure = [4]byte{0x50, 0x49, 0x4C, 0x53} - -// Tunnel magic bytes for key exchange: "PILK" (0x50494C4B) -var TunnelMagicKeyEx = [4]byte{0x50, 0x49, 0x4C, 0x4B} - -// Tunnel magic bytes for authenticated key exchange: "PILA" (0x50494C41) -var TunnelMagicAuthEx = [4]byte{0x50, 0x49, 0x4C, 0x41} - -// Tunnel magic bytes for NAT punch packet: "PILP" (0x50494C50) -var TunnelMagicPunch = [4]byte{0x50, 0x49, 0x4C, 0x50} - -// Well-known port for handshake requests -const PortHandshake uint16 = 444 - -// Beacon message types (single-byte codes, all < 0x10 to avoid collision with tunnel magic) -const ( - BeaconMsgDiscover byte = 0x01 - BeaconMsgDiscoverReply byte = 0x02 - BeaconMsgPunchRequest byte = 0x03 - BeaconMsgPunchCommand byte = 0x04 - BeaconMsgRelay byte = 0x05 - BeaconMsgRelayDeliver byte = 0x06 - BeaconMsgSync byte = 0x07 // gossip: beacon-to-beacon node list exchange -) diff --git a/pkg/protocol/packet.go b/pkg/protocol/packet.go deleted file mode 100644 index 190f5010..00000000 --- a/pkg/protocol/packet.go +++ /dev/null @@ -1,158 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package protocol - -import ( - "encoding/binary" - "fmt" -) - -// Wire layout (34 bytes): -// -// Byte 0: [Version:4][Flags:4] -// Byte 1: Protocol -// Byte 2-3: Payload Length -// Byte 4-9: Source Address (6 bytes) -// Byte 10-15: Destination Address (6 bytes) -// Byte 16-17: Source Port -// Byte 18-19: Destination Port -// Byte 20-23: Sequence Number -// Byte 24-27: Acknowledgment Number -// Byte 28-29: Window (receive window in segments, 0 = no flow control) -// Byte 30-33: Checksum (CRC32) -const packetHeaderSize = 34 - -type Packet struct { - Version uint8 - Flags uint8 - Protocol uint8 - - Src Addr - Dst Addr - SrcPort uint16 - DstPort uint16 - - Seq uint32 - Ack uint32 - Window uint16 // advertised receive window (in segments; 0 = no limit) - - Payload []byte -} - -func (p *Packet) HasFlag(f uint8) bool { return p.Flags&f != 0 } -func (p *Packet) SetFlag(f uint8) { p.Flags |= f } -func (p *Packet) ClearFlag(f uint8) { p.Flags &^= f } - -// Marshal serializes the packet to wire format with checksum. -// -// L1 panic boundary (architecture-notes/03-INVARIANTS.md §8): -// the explicit length-check below covers the only known caller-induced -// failure (oversize payload), but a nil-pointer Packet receiver or -// future bug could trigger a panic mid-encode. The deferred recover -// converts any panic into ErrMalformedPacket so callers (Send paths) -// drop the frame instead of crashing the daemon. -func (p *Packet) Marshal() (out []byte, err error) { - defer func() { - if r := recover(); r != nil { - out = nil - err = fmt.Errorf("%w: panic during encode: %v", ErrMalformedPacket, r) - } - }() - - payloadLen := len(p.Payload) - if payloadLen > 0xFFFF { - return nil, fmt.Errorf("payload too large: %d bytes (max 65535)", payloadLen) - } - - totalLen := packetHeaderSize + payloadLen // safe: payloadLen ≤ 0xFFFF (checked above) - buf := make([]byte, totalLen) - - buf[0] = (p.Version << 4) | (p.Flags & 0x0F) - buf[1] = p.Protocol - binary.BigEndian.PutUint16(buf[2:4], uint16(payloadLen)) - p.Src.MarshalTo(buf, 4) - p.Dst.MarshalTo(buf, 10) - binary.BigEndian.PutUint16(buf[16:18], p.SrcPort) - binary.BigEndian.PutUint16(buf[18:20], p.DstPort) - binary.BigEndian.PutUint32(buf[20:24], p.Seq) - binary.BigEndian.PutUint32(buf[24:28], p.Ack) - binary.BigEndian.PutUint16(buf[28:30], p.Window) - - if payloadLen > 0 { - copy(buf[packetHeaderSize:], p.Payload) - } - - // Checksum: CRC32 over header (with checksum field zeroed) + payload. - // Field is already zero from make(). - binary.BigEndian.PutUint32(buf[30:34], Checksum(buf)) - - return buf, nil -} - -// Unmarshal deserializes a packet from wire bytes. -// -// L1 panic boundary (architecture-notes/03-INVARIANTS.md §8): -// the explicit length-checks below cover all *known* malformed inputs, -// but a future caller could pass a slice that aliases a buffer being -// concurrently mutated, or a malformed input not yet enumerated, causing -// an out-of-bounds slice expression to panic. The deferred recover -// converts any such panic into a structured error so callers (the -// tunnel readLoop, relay path) drop the frame instead of taking down -// the whole daemon. Returns ErrMalformedPacket on panic; the original -// panic value is wrapped via fmt.Errorf for diagnostics. -func Unmarshal(data []byte) (p *Packet, err error) { - defer func() { - if r := recover(); r != nil { - p = nil - err = fmt.Errorf("%w: panic during decode: %v", ErrMalformedPacket, r) - } - }() - - if len(data) < packetHeaderSize { - return nil, fmt.Errorf("packet too short: %d bytes (min %d)", len(data), packetHeaderSize) - } - - payloadLen := binary.BigEndian.Uint16(data[2:4]) - total := packetHeaderSize + int(payloadLen) - if len(data) < total { - return nil, fmt.Errorf("packet truncated: have %d bytes, need %d", len(data), total) - } - - // Verify checksum before parsing. - wireChecksum := binary.BigEndian.Uint32(data[30:34]) - binary.BigEndian.PutUint32(data[30:34], 0) // zero for computation - computed := Checksum(data[:total]) - binary.BigEndian.PutUint32(data[30:34], wireChecksum) // restore - - if computed != wireChecksum { - return nil, ErrChecksumMismatch - } - - // Validate protocol version. - wireVersion := (data[0] >> 4) & 0x0F - if wireVersion != Version { - return nil, fmt.Errorf("unsupported protocol version %d (expected %d)", wireVersion, Version) - } - - p = &Packet{ - Version: (data[0] >> 4) & 0x0F, - Flags: data[0] & 0x0F, - Protocol: data[1], - Src: UnmarshalAddr(data[4:10]), - Dst: UnmarshalAddr(data[10:16]), - SrcPort: binary.BigEndian.Uint16(data[16:18]), - DstPort: binary.BigEndian.Uint16(data[18:20]), - Seq: binary.BigEndian.Uint32(data[20:24]), - Ack: binary.BigEndian.Uint32(data[24:28]), - Window: binary.BigEndian.Uint16(data[28:30]), - } - - if payloadLen > 0 { - p.Payload = make([]byte, payloadLen) - copy(p.Payload, data[packetHeaderSize:total]) - } - - return p, nil -} - -func PacketHeaderSize() int { return packetHeaderSize } diff --git a/pkg/protocol/zz_fuzz_packet_test.go b/pkg/protocol/zz_fuzz_packet_test.go deleted file mode 100644 index 99818ffa..00000000 --- a/pkg/protocol/zz_fuzz_packet_test.go +++ /dev/null @@ -1,200 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package protocol - -import ( - "bytes" - "testing" -) - -// FuzzUnmarshalPacket exercises the packet decoder with arbitrary bytes. -// The decoder is documented as having an L1 panic boundary (deferred -// recover → ErrMalformedPacket), so a literal panic should never escape -// to the fuzz harness — any escape is a real bug. PILOT-133 noted a -// latent unmarshal panic in UnmarshalAddr; this target should reproduce -// such issues quickly. -// -// Seeds include valid frames at multiple sizes, header-only inputs, all -// flag combinations, malformed length fields, and a few adversarial -// envelopes (length field claims much more than the buffer holds). -func FuzzUnmarshalPacket(f *testing.F) { - // Seed 1: minimal valid (no payload) packet. - { - p := &Packet{Version: Version, Protocol: ProtoStream} - b, err := p.Marshal() - if err == nil { - f.Add(b) - } - } - // Seed 2: valid packet with small payload. - { - p := &Packet{ - Version: Version, Flags: FlagSYN | FlagACK, Protocol: ProtoStream, - Src: Addr{1, 0xDEADBEEF}, Dst: Addr{2, 0xCAFEBABE}, - SrcPort: 1234, DstPort: 5678, - Seq: 0x11223344, Ack: 0x55667788, Window: 16, - Payload: []byte("hello"), - } - b, err := p.Marshal() - if err == nil { - f.Add(b) - } - } - // Seed 3: control proto, all flags. - { - p := &Packet{ - Version: Version, Flags: FlagSYN | FlagACK | FlagFIN | FlagRST, - Protocol: ProtoControl, Payload: bytes.Repeat([]byte{0xAB}, 64), - } - b, err := p.Marshal() - if err == nil { - f.Add(b) - } - } - // Seed 4: datagram with binary payload. - { - p := &Packet{ - Version: Version, Protocol: ProtoDatagram, - Payload: []byte{0x00, 0xFF, 0x7F, 0x80, 0x01, 0xFE}, - } - b, err := p.Marshal() - if err == nil { - f.Add(b) - } - } - // Seed 5: broadcast destination. - { - p := &Packet{ - Version: Version, Protocol: ProtoDatagram, - Dst: BroadcastAddr(7), DstPort: PortPing, - } - b, err := p.Marshal() - if err == nil { - f.Add(b) - } - } - // Seed 6: exactly header-sized (34 bytes of zeros). - f.Add(make([]byte, packetHeaderSize)) - // Seed 7: shorter than header. - f.Add(make([]byte, packetHeaderSize-1)) - // Seed 8: empty. - f.Add([]byte{}) - // Seed 9: header claims big payload but buffer truncated. - { - b := make([]byte, packetHeaderSize) - b[0] = Version << 4 // valid version - b[2], b[3] = 0xFF, 0xFF - f.Add(b) - } - // Seed 10: header with unsupported version. - { - b := make([]byte, packetHeaderSize) - b[0] = 0xF0 // version 0x0F (unsupported) - f.Add(b) - } - - f.Fuzz(func(t *testing.T, data []byte) { - // Defensive: literal panic out of Unmarshal is the find. - defer func() { - if r := recover(); r != nil { - t.Errorf("panic on input %x: %v", data, r) - } - }() - - // Unmarshal mutates data[30:34] briefly during checksum verify; - // pass a copy so the fuzzer's input slice is not aliased. - buf := make([]byte, len(data)) - copy(buf, data) - - p, err := Unmarshal(buf) - if err != nil { - return // expected on most random input - } - - // Round-trip property: a successfully decoded packet should - // re-encode to bytes that decode back to an equivalent struct. - re, err := p.Marshal() - if err != nil { - t.Errorf("decode-then-encode failed: %v (orig=%x)", err, data) - return - } - p2, err := Unmarshal(re) - if err != nil { - t.Errorf("re-decode failed: %v (re=%x)", err, re) - return - } - if p.Seq != p2.Seq || p.Ack != p2.Ack || p.SrcPort != p2.SrcPort || - p.DstPort != p2.DstPort || p.Window != p2.Window || - p.Protocol != p2.Protocol || p.Flags != p2.Flags || - p.Version != p2.Version || p.Src != p2.Src || p.Dst != p2.Dst { - t.Errorf("round-trip header mismatch: %+v vs %+v", p, p2) - } - if !bytes.Equal(p.Payload, p2.Payload) { - t.Errorf("round-trip payload mismatch: %x vs %x", p.Payload, p2.Payload) - } - }) -} - -// FuzzUnmarshalAddr targets the 6-byte address decoder directly. -// PILOT-133 specifically flagged this function; UnmarshalAddr does NOT -// have a defer/recover, so out-of-bounds slicing would propagate as a -// real panic. A naive call with len(buf) < AddrSize panics, so the -// fuzzer must include the bounds check the harness uses in practice. -func FuzzUnmarshalAddr(f *testing.F) { - f.Add(make([]byte, AddrSize)) - f.Add([]byte{0x00, 0x01, 0xDE, 0xAD, 0xBE, 0xEF}) - f.Add([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) - - f.Fuzz(func(t *testing.T, data []byte) { - defer func() { - if r := recover(); r != nil { - t.Errorf("panic on input %x: %v", data, r) - } - }() - - // UnmarshalAddr's contract is "exactly 6 bytes". Callers that - // pass less are the bug shape PILOT-133 was concerned with — - // fuzz both the contract-respecting path and the over-long path. - if len(data) >= AddrSize { - a := UnmarshalAddr(data[:AddrSize]) - b := a.Marshal() - a2 := UnmarshalAddr(b) - if a != a2 { - t.Errorf("addr round-trip: %v != %v", a, a2) - } - } - }) -} - -// FuzzParseAddr exercises the text-form address parser. -func FuzzParseAddr(f *testing.F) { - f.Add("0:0000.0000.0001") - f.Add("1:0001.DEAD.BEEF") - f.Add("65535:FFFF.FFFF.FFFF") - f.Add("") - f.Add(":") - f.Add("garbage") - f.Add("0:0000.0000") - f.Add("0:0000.0000.0000.0000") - - f.Fuzz(func(t *testing.T, s string) { - defer func() { - if r := recover(); r != nil { - t.Errorf("panic on input %q: %v", s, r) - } - }() - a, err := ParseAddr(s) - if err != nil { - return - } - // Round-trip: String() of a parsed addr must re-parse equal. - a2, err := ParseAddr(a.String()) - if err != nil { - t.Errorf("re-parse of %q (= %v) failed: %v", s, a, err) - return - } - if a != a2 { - t.Errorf("round-trip mismatch: %v != %v (input %q)", a, a2, s) - } - }) -} diff --git a/pkg/protocol/zz_protocol_test.go b/pkg/protocol/zz_protocol_test.go deleted file mode 100644 index 13fffd1e..00000000 --- a/pkg/protocol/zz_protocol_test.go +++ /dev/null @@ -1,442 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package protocol - -import ( - "bytes" - "encoding/binary" - "errors" - "strings" - "testing" -) - -// --------------------------------------------------------------------------- -// Addr -// --------------------------------------------------------------------------- - -func TestAddrIsZero(t *testing.T) { - t.Parallel() - if !ZeroAddr().IsZero() { - t.Fatal("ZeroAddr().IsZero() should be true") - } - if AddrRegistry.IsZero() { - t.Fatal("AddrRegistry should not be zero") - } - if (Addr{Network: 1}).IsZero() { - t.Fatal("non-zero network should not be zero") - } - if (Addr{Node: 1}).IsZero() { - t.Fatal("non-zero node should not be zero") - } -} - -func TestAddrIsBroadcast(t *testing.T) { - t.Parallel() - b := BroadcastAddr(5) - if !b.IsBroadcast() { - t.Fatalf("BroadcastAddr(5).IsBroadcast() = false") - } - if b.Network != 5 { - t.Fatalf("BroadcastAddr network = %d, want 5", b.Network) - } - if (Addr{Node: 0xFFFFFFFE}).IsBroadcast() { - t.Fatal("non-broadcast node should not be broadcast") - } -} - -func TestAddrMarshalUnmarshalRoundTrip(t *testing.T) { - t.Parallel() - cases := []Addr{ - ZeroAddr(), - AddrRegistry, - {Network: 0xABCD, Node: 0x12345678}, - {Network: 0xFFFF, Node: 0xFFFFFFFF}, - } - for _, in := range cases { - buf := in.Marshal() - if len(buf) != AddrSize { - t.Fatalf("Marshal len = %d, want %d", len(buf), AddrSize) - } - out := UnmarshalAddr(buf) - if out != in { - t.Fatalf("round-trip: got %+v, want %+v", out, in) - } - } -} - -func TestAddrMarshalToOffset(t *testing.T) { - t.Parallel() - a := Addr{Network: 0xCAFE, Node: 0xDEADBEEF} - buf := make([]byte, 20) - a.MarshalTo(buf, 8) - out := UnmarshalAddr(buf[8:14]) - if out != a { - t.Fatalf("MarshalTo offset round-trip: got %+v, want %+v", out, a) - } - // Bytes outside the 6-byte window must remain zero - for i := 0; i < 8; i++ { - if buf[i] != 0 { - t.Errorf("byte %d not zero before offset: 0x%02x", i, buf[i]) - } - } - for i := 14; i < 20; i++ { - if buf[i] != 0 { - t.Errorf("byte %d not zero after offset: 0x%02x", i, buf[i]) - } - } -} - -func TestUnmarshalAddrShortBuffer(t *testing.T) { - t.Parallel() - // PILOT-133: UnmarshalAddr must not panic on short buffers. - // It should return a zero address instead of indexing out of bounds. - - // 0 bytes - func() { - defer func() { - if r := recover(); r != nil { - t.Errorf("UnmarshalAddr panicked on 0-byte buffer: %v", r) - } - }() - a := UnmarshalAddr([]byte{}) - if !a.IsZero() { - t.Errorf("expected zero addr for empty buffer, got %+v", a) - } - }() - - // 3 bytes - func() { - defer func() { - if r := recover(); r != nil { - t.Errorf("UnmarshalAddr panicked on 3-byte buffer: %v", r) - } - }() - a := UnmarshalAddr([]byte{0x00, 0x01, 0x02}) - if !a.IsZero() { - t.Errorf("expected zero addr for short buffer, got %+v", a) - } - }() - - // 5 bytes (one short) - func() { - defer func() { - if r := recover(); r != nil { - t.Errorf("UnmarshalAddr panicked on 5-byte buffer: %v", r) - } - }() - a := UnmarshalAddr([]byte{0x00, 0x01, 0xDE, 0xAD, 0xBE}) - if !a.IsZero() { - t.Errorf("expected zero addr for 5-byte buffer, got %+v", a) - } - }() - - // 6 bytes (valid, should work normally) - a := UnmarshalAddr([]byte{0x00, 0x01, 0xDE, 0xAD, 0xBE, 0xEF}) - want := Addr{Network: 0x0001, Node: 0xDEADBEEF} - if a != want { - t.Errorf("valid 6-byte buffer: got %+v, want %+v", a, want) - } -} - -func TestAddrStringFormat(t *testing.T) { - t.Parallel() - a := Addr{Network: 0x00A3, Node: 0xF2910004} - got := a.String() - want := "163:00A3.F291.0004" - if got != want { - t.Fatalf("String() = %q, want %q", got, want) - } -} - -func TestParseAddrValid(t *testing.T) { - t.Parallel() - in := "163:00A3.F291.0004" - a, err := ParseAddr(in) - if err != nil { - t.Fatalf("ParseAddr: %v", err) - } - if a.Network != 0x00A3 || a.Node != 0xF2910004 { - t.Fatalf("parsed addr wrong: %+v", a) - } - // Round-trip via String must equal input - if a.String() != in { - t.Fatalf("round-trip: %q != %q", a.String(), in) - } -} - -func TestParseAddrErrors(t *testing.T) { - t.Parallel() - cases := []struct { - in string - wantSub string - }{ - {"no-colon", "expected N:XXXX"}, - {"abc:0000.0000.0000", "invalid network ID"}, - {"1:0000.0000", "expected 3 dot-separated"}, - {"1:000.0000.0000", "expected 4 digits"}, - {"1:GGGG.0000.0000", "invalid hex group"}, - {"1:0001.GGGG.0000", "invalid hex group"}, // network matches so we reach the high-group check - {"1:0001.0000.GGGG", "invalid hex group"}, // network matches so we reach the low-group check - {"1:0002.0000.0000", "network mismatch"}, - } - for _, tc := range cases { - t.Run(tc.in, func(t *testing.T) { - t.Parallel() - _, err := ParseAddr(tc.in) - if err == nil { - t.Fatalf("expected error for %q", tc.in) - } - if !strings.Contains(err.Error(), tc.wantSub) { - t.Fatalf("error %q missing substring %q", err.Error(), tc.wantSub) - } - }) - } -} - -// --------------------------------------------------------------------------- -// SocketAddr -// --------------------------------------------------------------------------- - -func TestSocketAddrStringAndParse(t *testing.T) { - t.Parallel() - in := SocketAddr{Addr: Addr{Network: 1, Node: 0x00010001}, Port: 8080} - str := in.String() - if str != "1:0001.0001.0001:8080" { - t.Fatalf("String() = %q", str) - } - out, err := ParseSocketAddr(str) - if err != nil { - t.Fatalf("ParseSocketAddr: %v", err) - } - if out != in { - t.Fatalf("round-trip: got %+v, want %+v", out, in) - } -} - -func TestParseSocketAddrErrors(t *testing.T) { - t.Parallel() - cases := []struct { - in string - wantSub string - }{ - {"noport", "no port"}, - {"bad-addr:80", "invalid address"}, - {"1:0001.0001.0001:notanumber", "invalid port"}, - } - for _, tc := range cases { - t.Run(tc.in, func(t *testing.T) { - _, err := ParseSocketAddr(tc.in) - if err == nil { - t.Fatalf("expected error for %q", tc.in) - } - if !strings.Contains(err.Error(), tc.wantSub) { - t.Fatalf("error %q missing substring %q", err.Error(), tc.wantSub) - } - }) - } -} - -// --------------------------------------------------------------------------- -// Packet flags + Marshal/Unmarshal -// --------------------------------------------------------------------------- - -func TestPacketFlagOps(t *testing.T) { - t.Parallel() - p := &Packet{} - if p.HasFlag(FlagSYN) { - t.Fatal("fresh packet should have no flags") - } - p.SetFlag(FlagSYN) - p.SetFlag(FlagACK) - if !p.HasFlag(FlagSYN) || !p.HasFlag(FlagACK) { - t.Fatalf("set flags not detected: flags=0x%x", p.Flags) - } - if p.HasFlag(FlagFIN) { - t.Fatal("FIN should not be set") - } - p.ClearFlag(FlagSYN) - if p.HasFlag(FlagSYN) { - t.Fatal("SYN should be cleared") - } - if !p.HasFlag(FlagACK) { - t.Fatal("ACK should still be set") - } -} - -func TestPacketHeaderSize(t *testing.T) { - t.Parallel() - if PacketHeaderSize() != 34 { - t.Fatalf("PacketHeaderSize() = %d, want 34", PacketHeaderSize()) - } -} - -func TestPacketMarshalUnmarshalRoundTrip(t *testing.T) { - t.Parallel() - in := &Packet{ - Version: Version, - Flags: FlagACK | FlagSYN, - Protocol: ProtoStream, - Src: Addr{Network: 1, Node: 0x12345678}, - Dst: Addr{Network: 2, Node: 0xABCDEF01}, - SrcPort: 4040, - DstPort: PortHTTP, - Seq: 1234, - Ack: 5678, - Window: 64, - Payload: []byte("hello world"), - } - buf, err := in.Marshal() - if err != nil { - t.Fatalf("Marshal: %v", err) - } - if len(buf) != 34+len(in.Payload) { - t.Fatalf("buf len = %d, want %d", len(buf), 34+len(in.Payload)) - } - out, err := Unmarshal(buf) - if err != nil { - t.Fatalf("Unmarshal: %v", err) - } - if out.Version != in.Version || out.Flags != in.Flags || out.Protocol != in.Protocol { - t.Fatalf("header mismatch: got %+v", out) - } - if out.Src != in.Src || out.Dst != in.Dst { - t.Fatalf("addr mismatch: got src=%v dst=%v", out.Src, out.Dst) - } - if out.SrcPort != in.SrcPort || out.DstPort != in.DstPort { - t.Fatalf("port mismatch: got src=%d dst=%d", out.SrcPort, out.DstPort) - } - if out.Seq != in.Seq || out.Ack != in.Ack || out.Window != in.Window { - t.Fatalf("seq/ack/window mismatch: got seq=%d ack=%d window=%d", out.Seq, out.Ack, out.Window) - } - if !bytes.Equal(out.Payload, in.Payload) { - t.Fatalf("payload mismatch: got %q", out.Payload) - } -} - -func TestPacketMarshalEmptyPayload(t *testing.T) { - t.Parallel() - in := &Packet{Version: Version, Protocol: ProtoControl} - buf, err := in.Marshal() - if err != nil { - t.Fatalf("Marshal: %v", err) - } - if len(buf) != 34 { - t.Fatalf("len = %d, want 34", len(buf)) - } - out, err := Unmarshal(buf) - if err != nil { - t.Fatalf("Unmarshal: %v", err) - } - if len(out.Payload) != 0 { - t.Fatalf("expected empty payload, got %d bytes", len(out.Payload)) - } -} - -func TestPacketMarshalPayloadTooLarge(t *testing.T) { - t.Parallel() - in := &Packet{Version: Version, Payload: make([]byte, 0x10000)} // 65536, exceeds 0xFFFF - _, err := in.Marshal() - if err == nil { - t.Fatal("expected payload-too-large error") - } - if !strings.Contains(err.Error(), "payload too large") { - t.Fatalf("error %q missing substring", err) - } -} - -func TestUnmarshalTooShort(t *testing.T) { - t.Parallel() - _, err := Unmarshal(make([]byte, 33)) - if err == nil { - t.Fatal("expected too-short error") - } - if !strings.Contains(err.Error(), "too short") { - t.Fatalf("error %q missing substring", err) - } -} - -func TestUnmarshalTruncatedPayload(t *testing.T) { - t.Parallel() - // Build a header claiming 100-byte payload but send only the header. - buf := make([]byte, 34) - buf[0] = Version << 4 - binary.BigEndian.PutUint16(buf[2:4], 100) - _, err := Unmarshal(buf) - if err == nil { - t.Fatal("expected truncated error") - } - if !strings.Contains(err.Error(), "truncated") { - t.Fatalf("error %q missing 'truncated'", err) - } -} - -func TestUnmarshalChecksumMismatch(t *testing.T) { - t.Parallel() - in := &Packet{Version: Version, Protocol: ProtoStream, Payload: []byte("abc")} - buf, _ := in.Marshal() - // Corrupt one byte in the payload - buf[34] ^= 0xFF - _, err := Unmarshal(buf) - if !errors.Is(err, ErrChecksumMismatch) { - t.Fatalf("expected ErrChecksumMismatch, got %v", err) - } -} - -func TestUnmarshalUnsupportedVersion(t *testing.T) { - t.Parallel() - in := &Packet{Version: Version, Payload: []byte("x")} - buf, _ := in.Marshal() - // Flip the version nibble to a different value, then re-checksum so we - // hit the version check rather than the checksum check. - buf[0] = (0xA << 4) | (buf[0] & 0x0F) // version = 0xA - binary.BigEndian.PutUint32(buf[30:34], 0) - cs := Checksum(buf) - binary.BigEndian.PutUint32(buf[30:34], cs) - _, err := Unmarshal(buf) - if err == nil { - t.Fatal("expected unsupported version error") - } - if !strings.Contains(err.Error(), "unsupported protocol version") { - t.Fatalf("error %q missing substring", err) - } -} - -func TestUnmarshalRestoresChecksumBytes(t *testing.T) { - t.Parallel() - // Verify Unmarshal does not permanently mutate the caller's buffer - // (it temporarily zeroes the checksum field for computation, then restores). - in := &Packet{Version: Version, Protocol: ProtoStream, Payload: []byte("xyz")} - buf, _ := in.Marshal() - original := append([]byte(nil), buf...) - if _, err := Unmarshal(buf); err != nil { - t.Fatalf("Unmarshal: %v", err) - } - if !bytes.Equal(buf, original) { - t.Fatalf("Unmarshal mutated caller buffer: original checksum bytes %x, current %x", - original[30:34], buf[30:34]) - } -} - -// --------------------------------------------------------------------------- -// Checksum -// --------------------------------------------------------------------------- - -func TestChecksumDeterministic(t *testing.T) { - t.Parallel() - data := []byte("the quick brown fox") - c1 := Checksum(data) - c2 := Checksum(data) - if c1 != c2 { - t.Fatalf("Checksum non-deterministic: %d != %d", c1, c2) - } -} - -func TestChecksumDiffersOnSingleBitFlip(t *testing.T) { - t.Parallel() - a := []byte("payload") - b := append([]byte(nil), a...) - b[0] ^= 0x01 - if Checksum(a) == Checksum(b) { - t.Fatal("checksum did not detect 1-bit flip") - } -} diff --git a/pkg/registry/client/binary_client.go b/pkg/registry/client/binary_client.go deleted file mode 100644 index 98534359..00000000 --- a/pkg/registry/client/binary_client.go +++ /dev/null @@ -1,278 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package client - -import ( - "context" - "encoding/json" - "fmt" - "log/slog" - "net" - "sync" - "time" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -// BinaryClient talks to a registry server using the binary wire protocol. -// It provides native binary encoding for hot-path operations (heartbeat, lookup, -// resolve) and JSON-over-binary passthrough for all other operations. -type BinaryClient struct { - conn net.Conn - mu sync.Mutex - addr string - closed bool -} - -// DialBinary connects to a registry server and negotiates the binary wire protocol. -// The server detects the magic bytes and switches to binary mode for this connection. -func DialBinary(addr string) (*BinaryClient, error) { - conn, err := net.DialTimeout("tcp", addr, 5*time.Second) - if err != nil { - return nil, fmt.Errorf("dial registry: %w", err) - } - - // Send magic + version to negotiate binary protocol - var handshake [5]byte - copy(handshake[:4], wire.Magic[:]) - handshake[4] = wire.Version - if _, err := conn.Write(handshake[:]); err != nil { - conn.Close() - return nil, fmt.Errorf("binary handshake: %w", err) - } - - return &BinaryClient{conn: conn, addr: addr}, nil -} - -// Close shuts down the binary client connection. -func (c *BinaryClient) Close() error { - c.mu.Lock() - c.closed = true - conn := c.conn - c.mu.Unlock() - if conn != nil { - return conn.Close() - } - return nil -} - -// Addr returns the registry address this client is connected to. -func (c *BinaryClient) Addr() string { - return c.addr -} - -// reconnect re-establishes the binary connection. Must be called with c.mu held. -func (c *BinaryClient) reconnect() error { - if c.closed { - return fmt.Errorf("client closed") - } - if c.conn != nil { - c.conn.Close() - } - - backoff := 500 * time.Millisecond - maxBackoff := 10 * time.Second - var lastErr error - - for attempts := 0; attempts < 5; attempts++ { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", c.addr) - cancel() - if err != nil { - lastErr = err - slog.Warn("binary client reconnect failed", "attempt", attempts+1, "err", err) - time.Sleep(backoff) - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } - continue - } - - // Re-negotiate binary protocol - var handshake [5]byte - copy(handshake[:4], wire.Magic[:]) - handshake[4] = wire.Version - if _, err := conn.Write(handshake[:]); err != nil { - conn.Close() - lastErr = err - continue - } - - c.conn = conn - slog.Info("binary client reconnected", "addr", c.addr) - return nil - } - return fmt.Errorf("reconnect failed after 5 attempts: %w", lastErr) -} - -// Heartbeat sends a binary heartbeat and returns the server time and key expiry warning. -func (c *BinaryClient) Heartbeat(nodeID uint32, sig []byte) (unixTime int64, keyExpiryWarning bool, err error) { - c.mu.Lock() - defer c.mu.Unlock() - - unixTime, keyExpiryWarning, err = c.heartbeatLocked(nodeID, sig) - if err != nil && !c.closed { - // Connection-level failure — reconnect and retry once - if reconnErr := c.reconnect(); reconnErr != nil { - return 0, false, fmt.Errorf("heartbeat failed and reconnect failed: %w", err) - } - unixTime, keyExpiryWarning, err = c.heartbeatLocked(nodeID, sig) - } - return -} - -func (c *BinaryClient) heartbeatLocked(nodeID uint32, sig []byte) (int64, bool, error) { - if err := wire.WriteFrame(c.conn, wire.MsgHeartbeat, wire.EncodeHeartbeatReq(nodeID, sig)); err != nil { - return 0, false, fmt.Errorf("send heartbeat: %w", err) - } - - c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) - msgType, payload, err := wire.ReadFrame(c.conn) - c.conn.SetReadDeadline(time.Time{}) - if err != nil { - return 0, false, fmt.Errorf("recv heartbeat: %w", err) - } - - if msgType == wire.MsgError { - return 0, false, fmt.Errorf("registry: %s", wire.DecodeError(payload)) - } - if msgType != wire.MsgHeartbeatOK { - return 0, false, fmt.Errorf("unexpected response type 0x%02x", msgType) - } - - return wire.DecodeHeartbeatResp(payload) -} - -// Lookup sends a binary lookup request and returns the decoded result. -func (c *BinaryClient) Lookup(nodeID uint32) (*wire.LookupResult, error) { - c.mu.Lock() - defer c.mu.Unlock() - - result, err := c.lookupLocked(nodeID) - if err != nil && !c.closed { - if reconnErr := c.reconnect(); reconnErr != nil { - return nil, fmt.Errorf("lookup failed and reconnect failed: %w", err) - } - result, err = c.lookupLocked(nodeID) - } - return result, err -} - -func (c *BinaryClient) lookupLocked(nodeID uint32) (*wire.LookupResult, error) { - if err := wire.WriteFrame(c.conn, wire.MsgLookup, wire.EncodeLookupReq(nodeID)); err != nil { - return nil, fmt.Errorf("send lookup: %w", err) - } - - c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) - msgType, payload, err := wire.ReadFrame(c.conn) - c.conn.SetReadDeadline(time.Time{}) - if err != nil { - return nil, fmt.Errorf("recv lookup: %w", err) - } - - if msgType == wire.MsgError { - return nil, fmt.Errorf("registry: %s", wire.DecodeError(payload)) - } - if msgType != wire.MsgLookupOK { - return nil, fmt.Errorf("unexpected response type 0x%02x", msgType) - } - - result, err := wire.DecodeLookupResp(payload) - if err != nil { - return nil, fmt.Errorf("decode lookup response: %w", err) - } - return &result, nil -} - -// Resolve sends a binary resolve request and returns the decoded result. -func (c *BinaryClient) Resolve(nodeID, requesterID uint32, sig []byte) (*wire.ResolveResult, error) { - c.mu.Lock() - defer c.mu.Unlock() - - result, err := c.resolveLocked(nodeID, requesterID, sig) - if err != nil && !c.closed { - if reconnErr := c.reconnect(); reconnErr != nil { - return nil, fmt.Errorf("resolve failed and reconnect failed: %w", err) - } - result, err = c.resolveLocked(nodeID, requesterID, sig) - } - return result, err -} - -func (c *BinaryClient) resolveLocked(nodeID, requesterID uint32, sig []byte) (*wire.ResolveResult, error) { - if err := wire.WriteFrame(c.conn, wire.MsgResolve, wire.EncodeResolveReq(nodeID, requesterID, sig)); err != nil { - return nil, fmt.Errorf("send resolve: %w", err) - } - - c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) - msgType, payload, err := wire.ReadFrame(c.conn) - c.conn.SetReadDeadline(time.Time{}) - if err != nil { - return nil, fmt.Errorf("recv resolve: %w", err) - } - - if msgType == wire.MsgError { - return nil, fmt.Errorf("registry: %s", wire.DecodeError(payload)) - } - if msgType != wire.MsgResolveOK { - return nil, fmt.Errorf("unexpected response type 0x%02x", msgType) - } - - result, err := wire.DecodeResolveResp(payload) - if err != nil { - return nil, fmt.Errorf("decode resolve response: %w", err) - } - return &result, nil -} - -// SendJSON sends a JSON message over the binary protocol using JSON passthrough. -// This allows any registry operation to be used without a native binary encoding. -func (c *BinaryClient) SendJSON(msg map[string]interface{}) (map[string]interface{}, error) { - c.mu.Lock() - defer c.mu.Unlock() - - resp, err := c.sendJSONLocked(msg) - if err != nil && resp == nil && !c.closed { - if reconnErr := c.reconnect(); reconnErr != nil { - return nil, fmt.Errorf("send failed and reconnect failed: %w", err) - } - resp, err = c.sendJSONLocked(msg) - } - return resp, err -} - -func (c *BinaryClient) sendJSONLocked(msg map[string]interface{}) (map[string]interface{}, error) { - body, err := json.Marshal(msg) - if err != nil { - return nil, fmt.Errorf("json encode: %w", err) - } - - if err := wire.WriteFrame(c.conn, wire.MsgJSON, body); err != nil { - return nil, fmt.Errorf("send: %w", err) - } - - c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) - msgType, payload, readErr := wire.ReadFrame(c.conn) - c.conn.SetReadDeadline(time.Time{}) - if readErr != nil { - return nil, fmt.Errorf("recv: %w", readErr) - } - - if msgType == wire.MsgError { - errMsg := wire.DecodeError(payload) - return map[string]interface{}{"type": "error", "error": errMsg}, fmt.Errorf("registry: %s", errMsg) - } - if msgType != wire.MsgJSON { - return nil, fmt.Errorf("unexpected response type 0x%02x for JSON passthrough", msgType) - } - - var resp map[string]interface{} - if err := json.Unmarshal(payload, &resp); err != nil { - return nil, fmt.Errorf("json decode response: %w", err) - } - if errMsg, ok := resp["error"].(string); ok { - return resp, fmt.Errorf("registry: %s", errMsg) - } - return resp, nil -} diff --git a/pkg/registry/client/client.go b/pkg/registry/client/client.go deleted file mode 100644 index 1c5db362..00000000 --- a/pkg/registry/client/client.go +++ /dev/null @@ -1,1393 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package client - -import ( - "context" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "log/slog" - "net" - "sync" - "time" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -// ErrNoRegistry is returned from every exported *Client method when the -// receiver is a typed nil pointer. Callers (loadPolicyRunners, -// ManagedEngine.fetchMembers, Daemon.Info → nodeNetworks, etc.) sometimes -// invoke registry methods before the client is configured; returning this -// sentinel instead of panicking lets them treat "no registry" as a -// recoverable condition. -var ErrNoRegistry = errors.New("registry client not configured") - -// Client talks to a registry server over TCP (optionally TLS). -// It automatically reconnects if the connection drops. -// -// By default a Client owns a single TCP connection (Dial / DialTLS / -// DialTLSPinned). Each Send takes c.mu and serialises the entire -// request/response round-trip on that one conn. Under heavy concurrent -// load (the §4.8 lock-graph stress harness — 250 heartbeat goroutines -// per daemon hammering Health / Info / SetTags / ResolveHostname plus -// the per-resolve prewarm goroutines and persistHostnameCache writers, -// all funnelling through regConn.Send) that single mutex becomes the -// bottleneck: in-flight calls cannot honour shutdown signals because -// they're queued behind the mutex. -// -// DialPool / DialTLSPool create a Client backed by a small pool of -// connections (the primary c.conn plus N-1 secondary conns). Each -// concurrent Send picks a free pooled conn (blocking only if every -// conn is in use), eliminating the head-of-line wait. The primary -// c.conn / c.mu / c.closed fields are retained for backward compatibility -// with tests that touch them directly. -type Client struct { - // Primary connection. Always present; tests in this package read - // c.conn / c.mu / c.closed directly so the field set must stay stable. - conn net.Conn - mu sync.Mutex - addr string // registry address for reconnection - closed bool - tlsConfig *tls.Config - signer func(challenge string) string // H3 fix: optional message signer - - // Optional pool of secondary connections used to parallelise Send. - // nil / empty when DialPool was not used. - pool poolState -} - -// poolState holds the secondary-conn pool. The primary slot (c.conn / c.mu) -// is also represented here as the first entry, so acquireConn / releaseConn -// can pick uniformly across all conns. -type poolState struct { - // entries is the full set of conns including the primary at index 0. - // Each entry has its own mu — taking entry.mu lets one Send proceed - // without blocking other Sends on different entries. Closed when the - // Client is closed. - entries []*pooledConn - // free is a buffered channel of pointers to entries currently free. - // Capacity equals len(entries). Send: <-free; defer free<-entry. - // nil means "no pool" (legacy single-conn path via c.mu). - free chan *pooledConn - // done is closed by Close() to wake goroutines blocked on <-free and to - // signal the deferred pool-return in sendPool to drop its entry instead - // of sending on free. Using a separate done channel avoids the race - // between close(free) and concurrent sends on free. - done chan struct{} -} - -// pooledConn wraps one TCP connection plus its own mutex. The mutex -// guards both the conn pointer and any reconnect that happens through -// it; sendOnEntry takes it for the full write/read round-trip. -type pooledConn struct { - mu sync.Mutex - conn net.Conn -} - -// SetSigner sets a signing function for authenticated registry operations (H3 fix). -// The signer receives a challenge string and returns a base64-encoded Ed25519 signature. -// -// Issue #93: when the regConn is pooled (DialPool), multiple Send goroutines -// may call sign() concurrently while a parallel RotateKey path calls -// SetSigner. We guard the field with c.mu to keep that race-free; reads via -// sign() take the same lock so the loaded function pointer is consistent. -func (c *Client) SetSigner(fn func(challenge string) string) { - if c == nil { - return - } - c.mu.Lock() - c.signer = fn - c.mu.Unlock() -} - -// sign returns a signature for the challenge. It returns an error when the -// signer is unavailable or returns an empty signature. A nil receiver returns -// ErrNoRegistry so callers can rely on errors.Is. -func (c *Client) sign(challenge string) (string, error) { - if c == nil { - return "", ErrNoRegistry - } - c.mu.Lock() - fn := c.signer - c.mu.Unlock() - if fn == nil { - return "", fmt.Errorf("registry client: no signer configured (call SetSigner first)") - } - sig := fn(challenge) - if sig == "" { - return "", fmt.Errorf("registry client: signer returned empty signature for %q", challenge) - } - return sig, nil -} - -func Dial(addr string) (*Client, error) { - conn, err := net.Dial("tcp", addr) - if err != nil { - return nil, fmt.Errorf("dial registry: %w", err) - } - return &Client{conn: conn, addr: addr}, nil -} - -// DialPool connects to a registry server over plain TCP and pre-warms a -// pool of `size` connections (size >= 1). When size == 1 this is identical -// to Dial. When size > 1, additional secondary conns are dialed; concurrent -// Send calls then run in parallel up to `size` at a time, instead of all -// queueing on a single mutex. -// -// DialPool exists to fix #93 (regConn fairness under sustained load): the -// daemon's IPC handlers spawn goroutines that all call regConn.Send and -// previously serialised on c.mu. With DialPool the daemon can keep the -// same code path while letting up to `size` registry round-trips run -// concurrently. -// -// On any pool conn dial failure DialPool closes the conns it had already -// opened and returns an error. -func DialPool(addr string, size int) (*Client, error) { - if size <= 0 { - size = 1 - } - primary, err := net.Dial("tcp", addr) - if err != nil { - return nil, fmt.Errorf("dial registry: %w", err) - } - c := &Client{conn: primary, addr: addr} - if err := c.initPool(size, nil); err != nil { - primary.Close() - return nil, err - } - return c, nil -} - -// DialTLS connects to a registry server over TLS. -// A non-nil tlsConfig is required. For certificate pinning, use DialTLSPinned. -func DialTLS(addr string, tlsConfig *tls.Config) (*Client, error) { - if tlsConfig == nil { - return nil, fmt.Errorf("TLS config required; use DialTLSPinned for certificate pinning") - } - conn, err := tls.Dial("tcp", addr, tlsConfig) - if err != nil { - return nil, fmt.Errorf("dial registry TLS: %w", err) - } - return &Client{conn: conn, addr: addr, tlsConfig: tlsConfig}, nil -} - -// DialTLSPool is the TLS variant of DialPool. -func DialTLSPool(addr string, tlsConfig *tls.Config, size int) (*Client, error) { - if tlsConfig == nil { - return nil, fmt.Errorf("TLS config required; use DialTLSPinnedPool for certificate pinning") - } - if size <= 0 { - size = 1 - } - primary, err := tls.Dial("tcp", addr, tlsConfig) - if err != nil { - return nil, fmt.Errorf("dial registry TLS: %w", err) - } - c := &Client{conn: primary, addr: addr, tlsConfig: tlsConfig} - if err := c.initPool(size, tlsConfig); err != nil { - primary.Close() - return nil, err - } - return c, nil -} - -// initPool dials size-1 additional connections and registers them in c.pool. -// It assumes c.conn (primary) is already set. tlsCfg, when non-nil, is used -// for TLS dialing; otherwise plain TCP. -func (c *Client) initPool(size int, tlsCfg *tls.Config) error { - if size <= 1 { - // No secondary conns — single-conn legacy path; pool stays empty. - return nil - } - entries := make([]*pooledConn, 0, size) - entries = append(entries, &pooledConn{conn: c.conn}) - for i := 1; i < size; i++ { - var conn net.Conn - var err error - if tlsCfg != nil { - conn, err = tls.Dial("tcp", c.addr, tlsCfg) - } else { - conn, err = net.Dial("tcp", c.addr) - } - if err != nil { - // Close any conns we already opened (excluding primary — - // caller closes that on failure). - for _, e := range entries[1:] { - e.conn.Close() - } - return fmt.Errorf("dial pool conn %d: %w", i, err) - } - entries = append(entries, &pooledConn{conn: conn}) - } - free := make(chan *pooledConn, len(entries)) - for _, e := range entries { - free <- e - } - c.pool.entries = entries - c.pool.free = free - c.pool.done = make(chan struct{}) - return nil -} - -// DialTLSPinned connects to a registry server over TLS with certificate pinning. -// The fingerprint is a hex-encoded SHA-256 hash of the server's DER-encoded certificate. -func DialTLSPinned(addr, fingerprint string) (*Client, error) { - tlsConfig := &tls.Config{ - // InsecureSkipVerify disables the default CA chain check so we can - // use VerifyPeerCertificate for certificate pinning (SHA-256 fingerprint). - // This is the standard Go pattern — the custom callback below provides - // strictly stronger verification than CA-based trust. - InsecureSkipVerify: true, //nolint:gosec // cert pinning via VerifyPeerCertificate - VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - if len(rawCerts) == 0 { - return fmt.Errorf("no certificate presented") - } - hash := sha256.Sum256(rawCerts[0]) - got := hex.EncodeToString(hash[:]) - if got != fingerprint { - return fmt.Errorf("certificate fingerprint mismatch: got %s, want %s", got, fingerprint) - } - return nil - }, - } - conn, err := tls.Dial("tcp", addr, tlsConfig) - if err != nil { - return nil, fmt.Errorf("dial registry TLS pinned: %w", err) - } - return &Client{conn: conn, addr: addr, tlsConfig: tlsConfig}, nil -} - -func (c *Client) Close() error { - if c == nil { - return nil - } - c.mu.Lock() - c.closed = true - conn := c.conn - pool := c.pool.entries - c.mu.Unlock() - // Close the conn after releasing the lock; conn is captured by value - // so reconnect() can't see it after we set c.closed=true (M7 fix) - var firstErr error - if conn != nil { - if err := conn.Close(); err != nil { - firstErr = err - } - } - // Close every secondary pooled conn. The primary is already closed - // above (entries[0] holds the same fd as c.conn). Skip index 0 so - // we don't double-close. - for i := 1; i < len(pool); i++ { - e := pool[i] - // Take e.mu to coordinate with any in-flight sendOnEntry that - // holds it; once we release the mutex it'll see a closed conn - // on its next Read/Write and return an error to the caller. - e.mu.Lock() - if e.conn != nil { - if err := e.conn.Close(); err != nil && firstErr == nil { - firstErr = err - } - } - e.mu.Unlock() - } - // Close pool.done to wake any goroutine blocked on <-c.pool.free and to - // signal the deferred pool-return in sendPool to drop its entry. We never - // close pool.free itself because that would race with concurrent sends on - // it from the sendPool defer. - if c.pool.done != nil { - close(c.pool.done) - } - return firstErr -} - -// reconnect re-establishes the TCP connection to the registry. -// Must be called with c.mu held. -func (c *Client) reconnect(ctx context.Context) error { - if c.closed { - return fmt.Errorf("client closed") - } - if c.conn != nil { - c.conn.Close() - } - - var conn net.Conn - var err error - backoff := 500 * time.Millisecond - maxBackoff := 10 * time.Second - - for attempts := 0; attempts < 5; attempts++ { - if c.tlsConfig != nil { - dialer := &tls.Dialer{Config: c.tlsConfig, NetDialer: &net.Dialer{Timeout: 5 * time.Second}} - conn, err = dialer.DialContext(ctx, "tcp", c.addr) - } else { - conn, err = net.DialTimeout("tcp", c.addr, 5*time.Second) - } - if err == nil { - c.conn = conn - slog.Info("registry reconnected", "addr", c.addr) - return nil - } - slog.Warn("registry reconnect failed", "attempt", attempts+1, "err", err) - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(backoff): - } - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } - } - return fmt.Errorf("reconnect failed after 5 attempts: %w", err) -} - -// Send sends a registry message without a deadline. For shutdown-safe use -// that respects context cancellation, prefer SendContext. -func (c *Client) Send(msg map[string]interface{}) (map[string]interface{}, error) { - return c.SendContext(context.Background(), msg) -} - -// SendContext sends a registry message with context propagation through -// reconnect retries. Callers should pass a context with deadline or -// cancellation (e.g. daemon shutdown context) so that reconnect backoff -// does not block graceful stop. -func (c *Client) SendContext(ctx context.Context, msg map[string]interface{}) (map[string]interface{}, error) { - // Nil receiver — return a sentinel rather than panicking. Every - // exported wrapper method (Register, Lookup, Resolve, …) funnels - // through Send, so this single guard turns "calling a registry - // method on a nil client" into a recoverable error for every - // caller (loadPolicyRunners, ManagedEngine.fetchMembers, - // Daemon.Info → nodeNetworks, etc.). - if c == nil { - return nil, ErrNoRegistry - } - // Pool-enabled path (DialPool / DialTLSPool): pick a free conn and - // run the round-trip on it without touching c.mu. Multiple Send - // callers can run concurrently on different pooled conns. - if c.pool.free != nil { - return c.sendPool(ctx, msg) - } - - c.mu.Lock() - defer c.mu.Unlock() - - resp, err := c.sendLocked(msg) - if err != nil && resp == nil && !c.closed { - // Connection-level failure (no response received) — reconnect and retry once. - // Server error responses (resp != nil) do NOT trigger reconnection. - if reconnErr := c.reconnect(ctx); reconnErr != nil { - return nil, fmt.Errorf("send failed and reconnect failed: %w", err) - } - resp, err = c.sendLocked(msg) - } - return resp, err -} - -// sendPool runs Send on a free pooled connection. It blocks only when -// every pooled conn is busy (capacity exhausted) — one concurrent Send -// per pool entry can be in flight at a time. -func (c *Client) sendPool(ctx context.Context, msg map[string]interface{}) (map[string]interface{}, error) { - // Cheap closed check — avoids a wedged caller waiting on a free - // channel that nobody will ever return to once Close has run. - c.mu.Lock() - closed := c.closed - c.mu.Unlock() - if closed { - return nil, fmt.Errorf("client closed") - } - - var entry *pooledConn - select { - case entry = <-c.pool.free: - case <-c.pool.done: - return nil, fmt.Errorf("client closed") - } - defer func() { - select { - case c.pool.free <- entry: - case <-c.pool.done: - // pool is torn down; drop the entry - } - }() - - entry.mu.Lock() - defer entry.mu.Unlock() - - resp, err := c.sendOnEntry(entry, msg) - if err != nil && resp == nil && !c.isClosed() { - // Connection-level failure on this entry — reconnect THIS entry - // only (other pool entries are unaffected) and retry once. - if reconnErr := c.reconnectEntry(ctx, entry); reconnErr != nil { - return nil, fmt.Errorf("send failed and reconnect failed: %w", err) - } - resp, err = c.sendOnEntry(entry, msg) - } - return resp, err -} - -// sendOnEntry writes the request and reads the response on entry.conn. -// Caller must hold entry.mu. -func (c *Client) sendOnEntry(entry *pooledConn, msg map[string]interface{}) (map[string]interface{}, error) { - if err := wire.WriteMessage(entry.conn, msg); err != nil { - return nil, fmt.Errorf("send: %w", err) - } - entry.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) - resp, err := wire.ReadMessage(entry.conn) - entry.conn.SetReadDeadline(time.Time{}) - if err != nil { - return nil, fmt.Errorf("recv: %w", err) - } - if errMsg, ok := resp["error"].(string); ok { - return resp, fmt.Errorf("registry: %s", errMsg) - } - return resp, nil -} - -// reconnectEntry redials a single pool entry. Caller must hold entry.mu. -// This is the per-entry analogue of Client.reconnect. -func (c *Client) reconnectEntry(ctx context.Context, entry *pooledConn) error { - if c.isClosed() { - return fmt.Errorf("client closed") - } - if entry.conn != nil { - entry.conn.Close() - } - - var conn net.Conn - var err error - backoff := 500 * time.Millisecond - maxBackoff := 10 * time.Second - for attempts := 0; attempts < 5; attempts++ { - if c.tlsConfig != nil { - dialer := &tls.Dialer{Config: c.tlsConfig, NetDialer: &net.Dialer{Timeout: 5 * time.Second}} - conn, err = dialer.DialContext(ctx, "tcp", c.addr) - } else { - conn, err = net.DialTimeout("tcp", c.addr, 5*time.Second) - } - if err == nil { - entry.conn = conn - // Keep c.conn (primary) in sync if this is the primary entry. - // Tests in this package read c.conn directly, so we must not - // leave it pointing at a closed fd. - if entry == c.pool.entries[0] { - c.mu.Lock() - c.conn = conn - c.mu.Unlock() - } - slog.Info("registry pool conn reconnected", "addr", c.addr) - return nil - } - slog.Warn("registry pool conn reconnect failed", "attempt", attempts+1, "err", err) - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(backoff): - } - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } - } - return fmt.Errorf("reconnect failed after 5 attempts: %w", err) -} - -// isClosed returns whether Close has been called. Cheap, lock-protected. -func (c *Client) isClosed() bool { - c.mu.Lock() - defer c.mu.Unlock() - return c.closed -} - -// sendLocked sends a message and reads the response. Must be called with c.mu held. -func (c *Client) sendLocked(msg map[string]interface{}) (map[string]interface{}, error) { - if err := wire.WriteMessage(c.conn, msg); err != nil { - return nil, fmt.Errorf("send: %w", err) - } - c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) - resp, err := wire.ReadMessage(c.conn) - c.conn.SetReadDeadline(time.Time{}) - if err != nil { - return nil, fmt.Errorf("recv: %w", err) - } - if errMsg, ok := resp["error"].(string); ok { - return resp, fmt.Errorf("registry: %s", errMsg) - } - return resp, nil -} - -func (c *Client) Register(listenAddr string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "register", - "listen_addr": listenAddr, - }) -} - -// RegisterWithOwner registers a new node with an owner identifier (email/name) -// for key rotation recovery. -func (c *Client) RegisterWithOwner(listenAddr, owner string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "register", - "listen_addr": listenAddr, - "owner": owner, - }) -} - -// RegisterWithKey re-registers using an existing Ed25519 public key. -// The registry returns the same node_id if the key is known. -// lanAddrs are the node's LAN addresses for same-network peer detection. -func (c *Client) RegisterWithKey(listenAddr, publicKeyB64, owner string, lanAddrs []string, opts ...string) (map[string]interface{}, error) { - return c.RegisterWithKeyOpts(RegisterOpts{ - ListenAddr: listenAddr, - PublicKey: publicKeyB64, - Owner: owner, - LANAddrs: lanAddrs, - Version: firstNonEmpty(opts...), - }) -} - -// RegisterOpts is the full set of registration options. Lets us add -// fields (like RelayOnly for task 32) without breaking the variadic -// signature of RegisterWithKey. -type RegisterOpts struct { - ListenAddr string - PublicKey string // base64 Ed25519 - Owner string - LANAddrs []string - Version string - RelayOnly bool // task 32: hide real_addr from peers -} - -// RegisterWithKeyOpts is the structured-form register call. Existing -// callers keep using RegisterWithKey; new flags go here. -func (c *Client) RegisterWithKeyOpts(o RegisterOpts) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "register", - "listen_addr": o.ListenAddr, - "public_key": o.PublicKey, - } - if o.Owner != "" { - msg["owner"] = o.Owner - } - if len(o.LANAddrs) > 0 { - msg["lan_addrs"] = o.LANAddrs - } - if o.Version != "" { - msg["version"] = o.Version - } - if o.RelayOnly { - msg["relay_only"] = true - } - return c.Send(msg) -} - -func firstNonEmpty(s ...string) string { - for _, v := range s { - if v != "" { - return v - } - } - return "" -} - -// RotateKey requests a key rotation for a node. -// Requires a signature proving ownership of the current key and the new public key. -func (c *Client) RotateKey(nodeID uint32, signatureB64, newPubKeyB64 string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "rotate_key", - "node_id": nodeID, - } - if signatureB64 != "" { - msg["signature"] = signatureB64 - } - if newPubKeyB64 != "" { - msg["new_public_key"] = newPubKeyB64 - } - return c.Send(msg) -} - -func (c *Client) Lookup(nodeID uint32) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "lookup", - "node_id": nodeID, - }) -} - -func (c *Client) Resolve(nodeID, requesterID uint32) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "resolve", - "node_id": nodeID, - "requester_id": requesterID, - } - sig, err := c.sign(fmt.Sprintf("resolve:%d:%d", requesterID, nodeID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -func (c *Client) ReportTrust(nodeID, peerID uint32) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "report_trust", - "node_id": nodeID, - "peer_id": peerID, - } - sig, err := c.sign(fmt.Sprintf("report_trust:%d:%d", nodeID, peerID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -func (c *Client) RevokeTrust(nodeID, peerID uint32) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "revoke_trust", - "node_id": nodeID, - "peer_id": peerID, - } - sig, err := c.sign(fmt.Sprintf("revoke_trust:%d:%d", nodeID, peerID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -func (c *Client) SetVisibility(nodeID uint32, public bool) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "set_visibility", - "node_id": nodeID, - "public": public, - } - sig, err := c.sign(fmt.Sprintf("set_visibility:%d", nodeID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -func (c *Client) CreateNetwork(nodeID uint32, name, joinRule, token, adminToken string, enterprise bool, networkAdminToken ...string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "create_network", - "node_id": nodeID, - "name": name, - "join_rule": joinRule, - "token": token, - } - if adminToken != "" { - msg["admin_token"] = adminToken - } - if enterprise { - msg["enterprise"] = true - } - if len(networkAdminToken) > 0 && networkAdminToken[0] != "" { - msg["network_admin_token"] = networkAdminToken[0] - } - return c.Send(msg) -} - -// CreateManagedNetwork creates a network with managed rules. -func (c *Client) CreateManagedNetwork(nodeID uint32, name, joinRule, token, adminToken string, enterprise bool, rules string, networkAdminToken ...string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "create_network", - "node_id": nodeID, - "name": name, - "join_rule": joinRule, - "token": token, - "rules": rules, - } - if adminToken != "" { - msg["admin_token"] = adminToken - } - if enterprise { - msg["enterprise"] = true - } - if len(networkAdminToken) > 0 && networkAdminToken[0] != "" { - msg["network_admin_token"] = networkAdminToken[0] - } - return c.Send(msg) -} - -func (c *Client) JoinNetwork(nodeID uint32, networkID uint16, token string, inviterID uint32, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "join_network", - "node_id": nodeID, - "network_id": networkID, - "token": token, - "inviter_id": inviterID, - } - sig, err := c.sign(fmt.Sprintf("join_network:%d:%d", nodeID, networkID)) - if err == nil { - msg["signature"] = sig - } else if adminToken != "" { - msg["admin_token"] = adminToken - } - return c.Send(msg) -} - -func (c *Client) LeaveNetwork(nodeID uint32, networkID uint16, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "leave_network", - "node_id": nodeID, - "network_id": networkID, - } - sig, err := c.sign(fmt.Sprintf("leave_network:%d:%d", nodeID, networkID)) - if err == nil { - msg["signature"] = sig - } else if adminToken != "" { - msg["admin_token"] = adminToken - } - return c.Send(msg) -} - -func (c *Client) DeleteNetwork(networkID uint16, adminToken string, nodeID ...uint32) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "delete_network", - "network_id": networkID, - } - if adminToken != "" { - msg["admin_token"] = adminToken - } - if len(nodeID) > 0 && nodeID[0] != 0 { - msg["node_id"] = nodeID[0] - } - return c.Send(msg) -} - -func (c *Client) RenameNetwork(networkID uint16, name, adminToken string, nodeID ...uint32) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "rename_network", - "network_id": networkID, - "name": name, - } - if adminToken != "" { - msg["admin_token"] = adminToken - } - if len(nodeID) > 0 && nodeID[0] != 0 { - msg["node_id"] = nodeID[0] - } - return c.Send(msg) -} - -func (c *Client) SetNetworkEnterprise(networkID uint16, enterprise bool, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_network_enterprise", - "network_id": networkID, - "enterprise": enterprise, - "admin_token": adminToken, - }) -} - -// ListNetworks returns the registry's network catalog. Member counts -// (the `members` field on each entry) are admin-only — pass a non-empty -// adminToken to receive them; otherwise the field is omitted. -func (c *Client) ListNetworks(adminToken ...string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "list_networks", - } - if len(adminToken) > 0 && adminToken[0] != "" { - msg["admin_token"] = adminToken[0] - } - return c.Send(msg) -} - -func (c *Client) ListNodes(networkID uint16, adminToken ...string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "list_nodes", - "network_id": networkID, - } - if len(adminToken) > 0 && adminToken[0] != "" { - msg["admin_token"] = adminToken[0] - } - return c.Send(msg) -} - -func (c *Client) Deregister(nodeID uint32) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "deregister", - "node_id": nodeID, - } - sig, err := c.sign(fmt.Sprintf("deregister:%d", nodeID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -func (c *Client) Heartbeat(nodeID uint32) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "heartbeat", - "node_id": nodeID, - } - sig, err := c.sign(fmt.Sprintf("heartbeat:%d", nodeID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -func (c *Client) Punch(requesterID, nodeA, nodeB uint32) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "punch", - "requester_id": requesterID, - "node_a": nodeA, - "node_b": nodeB, - } - sig, err := c.sign(fmt.Sprintf("punch:%d:%d", nodeA, nodeB)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -// RequestHandshake relays a handshake request through the registry to a target node. -// This works even for private nodes — no IP exposure needed. -// M12 fix: includes a signature to prove sender identity. -func (c *Client) RequestHandshake(fromNodeID, toNodeID uint32, justification, signatureB64 string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "request_handshake", - "from_node_id": fromNodeID, - "to_node_id": toNodeID, - "justification": justification, - } - if signatureB64 != "" { - msg["signature"] = signatureB64 - } - return c.Send(msg) -} - -// PollHandshakes retrieves and clears pending handshake requests for a node. -// H3 fix: includes a signature to prove node identity. -func (c *Client) PollHandshakes(nodeID uint32) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "poll_handshakes", - "node_id": nodeID, - } - sig, err := c.sign(fmt.Sprintf("poll_handshakes:%d", nodeID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -// RespondHandshake approves or rejects a relayed handshake request. -// If accepted, the registry creates a mutual trust pair. -// M12 fix: includes a signature to prove responder identity. -func (c *Client) RespondHandshake(nodeID, peerID uint32, accept bool, signatureB64 string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "respond_handshake", - "node_id": nodeID, - "peer_id": peerID, - "accept": accept, - } - if signatureB64 != "" { - msg["signature"] = signatureB64 - } - return c.Send(msg) -} - -// SetHostname sets or clears the hostname for a node. -// An empty hostname clears the current hostname. -func (c *Client) SetHostname(nodeID uint32, hostname string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "set_hostname", - "node_id": nodeID, - "hostname": hostname, - } - sig, err := c.sign(fmt.Sprintf("set_hostname:%d", nodeID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -// SetTags sets the capability tags for a node. -func (c *Client) SetTags(nodeID uint32, tags []string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "set_tags", - "node_id": nodeID, - "tags": tags, - } - sig, err := c.sign(fmt.Sprintf("set_tags:%d", nodeID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -// ResolveHostname resolves a hostname to node info (node_id, address, public flag). -func (c *Client) ResolveHostname(hostname string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "resolve_hostname", - "hostname": hostname, - }) -} - -// ResolveHostnameAs resolves a hostname with a requester_id for privacy checks. -// Private nodes require the requester to have a trust pair or shared network. -func (c *Client) ResolveHostnameAs(requesterID uint32, hostname string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "resolve_hostname", - "hostname": hostname, - "requester_id": requesterID, - }) -} - -// CheckTrust checks if a trust pair or shared network exists between two nodes. -func (c *Client) CheckTrust(nodeA, nodeB uint32) (bool, error) { - if c == nil { - return false, ErrNoRegistry - } - resp, err := c.Send(map[string]interface{}{ - "type": "check_trust", - "node_id": nodeA, - "peer_id": nodeB, - }) - if err != nil { - return false, err - } - trusted, _ := resp["trusted"].(bool) - return trusted, nil -} - -// InviteToNetwork stores a pending invite for a target node to join an invite-only network. -func (c *Client) InviteToNetwork(networkID uint16, inviterID, targetNodeID uint32, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "invite_to_network", - "network_id": networkID, - "inviter_id": inviterID, - "target_node_id": targetNodeID, - } - sig, err := c.sign(fmt.Sprintf("invite:%d:%d:%d", inviterID, networkID, targetNodeID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - if adminToken != "" { - msg["admin_token"] = adminToken - } - return c.Send(msg) -} - -// PollInvites returns and clears pending network invites for a node. Signed. -func (c *Client) PollInvites(nodeID uint32) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "poll_invites", - "node_id": nodeID, - } - sig, err := c.sign(fmt.Sprintf("poll_invites:%d", nodeID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -// RespondInvite accepts or rejects a pending network invite. Signed. -func (c *Client) RespondInvite(nodeID uint32, networkID uint16, accept bool) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "respond_invite", - "node_id": nodeID, - "network_id": networkID, - "accept": accept, - } - sig, err := c.sign(fmt.Sprintf("respond_invite:%d:%d", nodeID, networkID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -// PromoteMember promotes a network member to admin. Only the owner can promote. -func (c *Client) PromoteMember(networkID uint16, nodeID, targetNodeID uint32, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "promote_member", - "network_id": networkID, - "node_id": nodeID, - "target_node_id": targetNodeID, - } - if adminToken != "" { - msg["admin_token"] = adminToken - } - return c.Send(msg) -} - -// DemoteMember demotes an admin to member. Only the owner can demote. -func (c *Client) DemoteMember(networkID uint16, nodeID, targetNodeID uint32, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "demote_member", - "network_id": networkID, - "node_id": nodeID, - "target_node_id": targetNodeID, - } - if adminToken != "" { - msg["admin_token"] = adminToken - } - return c.Send(msg) -} - -// KickMember removes a member from a network. Requires owner or admin role. -func (c *Client) KickMember(networkID uint16, nodeID, targetNodeID uint32, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "kick_member", - "network_id": networkID, - "node_id": nodeID, - "target_node_id": targetNodeID, - } - if adminToken != "" { - msg["admin_token"] = adminToken - } - return c.Send(msg) -} - -// TransferOwnership transfers network ownership to another member. Only the current owner can transfer. -func (c *Client) TransferOwnership(networkID uint16, ownerNodeID, newOwnerID uint32, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "transfer_ownership", - "network_id": networkID, - "node_id": ownerNodeID, - "new_owner_id": newOwnerID, - } - if adminToken != "" { - msg["admin_token"] = adminToken - } - return c.Send(msg) -} - -// GetMemberRole returns the RBAC role of a node in a network. -func (c *Client) GetMemberRole(networkID uint16, targetNodeID uint32) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_member_role", - "network_id": networkID, - "target_node_id": targetNodeID, - }) -} - -// SetNetworkPolicy sets or updates a network's policy. Requires owner/admin role or admin token. -func (c *Client) SetNetworkPolicy(networkID uint16, policy map[string]interface{}, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{} - for k, v := range policy { - msg[k] = v - } - msg["type"] = "set_network_policy" - msg["network_id"] = networkID - if adminToken != "" { - msg["admin_token"] = adminToken - } - return c.Send(msg) -} - -// GetNetworkPolicy returns the policy for a given network. -func (c *Client) GetNetworkPolicy(networkID uint16) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_network_policy", - "network_id": networkID, - }) -} - -// SetExprPolicy sets the programmable policy for a network. -// Requires owner/admin role or admin token. -func (c *Client) SetExprPolicy(networkID uint16, policyJSON json.RawMessage, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "set_expr_policy", - "network_id": networkID, - "expr_policy": string(policyJSON), - } - if adminToken != "" { - msg["admin_token"] = adminToken - } - return c.Send(msg) -} - -// GetExprPolicy returns the programmable policy for a network. -func (c *Client) GetExprPolicy(networkID uint16) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_expr_policy", - "network_id": networkID, - }) -} - -// SetKeyExpiry sets the key expiry time for a node. Requires signature. -func (c *Client) SetKeyExpiry(nodeID uint32, expiresAt time.Time) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "set_key_expiry", - "node_id": nodeID, - "expires_at": expiresAt.Format(time.RFC3339), - } - sig, err := c.sign(fmt.Sprintf("set_key_expiry:%d", nodeID)) - if err != nil { - return nil, err - } - msg["signature"] = sig - return c.Send(msg) -} - -// GetKeyInfo returns key lifecycle metadata for a node. -func (c *Client) GetKeyInfo(nodeID uint32) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_key_info", - "node_id": nodeID, - }) -} - -// --- Admin methods (bypass node signature, use admin_token instead) --- - -// SetHostnameAdmin sets a node's hostname using admin token auth. -func (c *Client) SetHostnameAdmin(nodeID uint32, hostname, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_hostname", - "node_id": nodeID, - "hostname": hostname, - "admin_token": adminToken, - }) -} - -// SetVisibilityAdmin sets a node's visibility using admin token auth. -func (c *Client) SetVisibilityAdmin(nodeID uint32, public bool, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_visibility", - "node_id": nodeID, - "public": public, - "admin_token": adminToken, - }) -} - -// SetTagsAdmin sets a node's tags using admin token auth. -func (c *Client) SetTagsAdmin(nodeID uint32, tags []string, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_tags", - "node_id": nodeID, - "tags": tags, - "admin_token": adminToken, - }) -} - -// SetMemberTags sets admin-assigned tags for a member within a network. -func (c *Client) SetMemberTags(netID uint16, targetNodeID uint32, tags []string, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_member_tags", - "network_id": netID, - "target_node_id": targetNodeID, - "tags": tags, - "admin_token": adminToken, - }) -} - -// GetMemberTags returns admin-assigned member tags for a node (or all members if targetNodeID=0). -func (c *Client) GetMemberTags(netID uint16, targetNodeID uint32) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_member_tags", - "network_id": netID, - "target_node_id": targetNodeID, - }) -} - -// SetKeyExpiryAdmin sets a node's key expiry using admin token auth. -func (c *Client) SetKeyExpiryAdmin(nodeID uint32, expiresAt time.Time, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_key_expiry", - "node_id": nodeID, - "expires_at": expiresAt.Format(time.RFC3339), - "admin_token": adminToken, - }) -} - -// ClearKeyExpiryAdmin removes the key expiry from a node using admin token auth. -func (c *Client) ClearKeyExpiryAdmin(nodeID uint32, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_key_expiry", - "node_id": nodeID, - "expires_at": "never", - "admin_token": adminToken, - }) -} - -// DeregisterAdmin removes a node using admin token auth. -func (c *Client) DeregisterAdmin(nodeID uint32, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "deregister", - "node_id": nodeID, - "admin_token": adminToken, - }) -} - -// GetAuditLog returns recent audit entries from the registry. -func (c *Client) GetAuditLog(networkID uint16, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "get_audit_log", - "admin_token": adminToken, - } - if networkID != 0 { - msg["network_id"] = networkID - } - return c.Send(msg) -} - -// SetWebhook configures the registry webhook URL. Pass empty string to disable. -func (c *Client) SetWebhook(url, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_webhook", - "url": url, - "admin_token": adminToken, - }) -} - -// GetWebhook returns the current webhook configuration. -func (c *Client) GetWebhook(adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_webhook", - "admin_token": adminToken, - }) -} - -// GetWebhookDLQ returns the dead letter queue (failed webhook events). -func (c *Client) GetWebhookDLQ(adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_webhook_dlq", - "admin_token": adminToken, - }) -} - -// SetIdentityWebhook configures the identity verification webhook URL. -func (c *Client) SetIdentityWebhook(url, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_identity_webhook", - "url": url, - "admin_token": adminToken, - }) -} - -// SetExternalID sets the external identity on a node. Requires admin token. -func (c *Client) SetExternalID(nodeID uint32, externalID, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_external_id", - "node_id": nodeID, - "external_id": externalID, - "admin_token": adminToken, - }) -} - -// GetIdentity returns the external identity of a node. Requires admin token. -func (c *Client) GetIdentity(nodeID uint32, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_identity", - "node_id": nodeID, - "admin_token": adminToken, - }) -} - -// ProvisionNetwork applies a network blueprint. Requires admin token. -func (c *Client) ProvisionNetwork(blueprint map[string]interface{}, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "provision_network", - "blueprint": blueprint, - "admin_token": adminToken, - }) -} - -// SetAuditExport configures the audit export adapter. Requires admin token. -func (c *Client) SetAuditExport(format, endpoint, token, index, source, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "set_audit_export", - "format": format, - "endpoint": endpoint, - "token": token, - "index": index, - "source": source, - "admin_token": adminToken, - }) -} - -// GetAuditExport returns the current audit export configuration. Requires admin token. -func (c *Client) GetAuditExport(adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_audit_export", - "admin_token": adminToken, - }) -} - -// SetIDPConfig configures the identity provider. Requires admin token. -func (c *Client) SetIDPConfig(idpType, url, issuer, clientID, tenantID, domain, adminToken string) (map[string]interface{}, error) { - msg := map[string]interface{}{ - "type": "set_idp_config", - "idp_type": idpType, - "url": url, - "admin_token": adminToken, - } - if issuer != "" { - msg["issuer"] = issuer - } - if clientID != "" { - msg["client_id"] = clientID - } - if tenantID != "" { - msg["tenant_id"] = tenantID - } - if domain != "" { - msg["domain"] = domain - } - return c.Send(msg) -} - -// GetIDPConfig returns the current identity provider configuration. Requires admin token. -func (c *Client) GetIDPConfig(adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_idp_config", - "admin_token": adminToken, - }) -} - -// GetProvisionStatus returns per-network provisioning status. Requires admin token. -func (c *Client) GetProvisionStatus(adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "get_provision_status", - "admin_token": adminToken, - }) -} - -// DirectorySync pushes a directory listing to update RBAC roles and membership. -func (c *Client) DirectorySync(networkID uint16, entries []map[string]interface{}, removeUnlisted bool, adminToken string) (map[string]interface{}, error) { - entryList := make([]interface{}, len(entries)) - for i, e := range entries { - entryList[i] = e - } - return c.Send(map[string]interface{}{ - "type": "directory_sync", - "network_id": networkID, - "entries": entryList, - "remove_unlisted": removeUnlisted, - "admin_token": adminToken, - }) -} - -// DirectoryStatus returns directory sync status for a network. -func (c *Client) DirectoryStatus(networkID uint16, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "directory_status", - "network_id": networkID, - "admin_token": adminToken, - }) -} - -// ValidateToken validates a JWT token against the configured IDP. Requires admin token. -func (c *Client) ValidateToken(token, adminToken string) (map[string]interface{}, error) { - return c.Send(map[string]interface{}{ - "type": "validate_token", - "token": token, - "admin_token": adminToken, - }) -} diff --git a/pkg/registry/client/zz_binary_client_test.go b/pkg/registry/client/zz_binary_client_test.go deleted file mode 100644 index 5366af19..00000000 --- a/pkg/registry/client/zz_binary_client_test.go +++ /dev/null @@ -1,550 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package client - -import ( - "encoding/binary" - "encoding/json" - "io" - "net" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -// Iter-116 coverage for registry/binary_client.go — 9 zero-coverage functions: -// DialBinary, Close, Addr, reconnect, Heartbeat/heartbeatLocked, Lookup/lookupLocked, -// Resolve/resolveLocked, SendJSON/sendJSONLocked. Strategy: stand up a real TCP -// listener that reads the 5-byte handshake (magic + version), then runs a -// per-test frame handler against the wire protocol via wire.ReadFrame/wire.WriteFrame. - -// --- fakeBinaryServer: minimal TCP server speaking the binary wire protocol --- - -type fakeBinaryServer struct { - ln net.Listener - handler func(msgType byte, payload []byte) (respType byte, respPayload []byte) - mu sync.Mutex - handshakes atomic.Uint32 - frames atomic.Uint32 - done chan struct{} -} - -func newFakeBinaryServer(t *testing.T, handler func(msgType byte, payload []byte) (byte, []byte)) *fakeBinaryServer { - t.Helper() - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - s := &fakeBinaryServer{ln: ln, handler: handler, done: make(chan struct{})} - go s.accept() - t.Cleanup(s.Close) - return s -} - -func (s *fakeBinaryServer) addr() string { return s.ln.Addr().String() } - -func (s *fakeBinaryServer) Close() { - s.mu.Lock() - defer s.mu.Unlock() - select { - case <-s.done: - return - default: - } - close(s.done) - s.ln.Close() -} - -func (s *fakeBinaryServer) accept() { - for { - conn, err := s.ln.Accept() - if err != nil { - return - } - go s.handle(conn) - } -} - -func (s *fakeBinaryServer) handle(conn net.Conn) { - defer conn.Close() - // Read 5-byte handshake. - var hdr [5]byte - if _, err := io.ReadFull(conn, hdr[:]); err != nil { - return - } - s.handshakes.Add(1) - // Verify magic — but don't enforce version. - for i, b := range wire.Magic { - if hdr[i] != b { - return - } - } - // Per-frame loop. - for { - msgType, payload, err := wire.ReadFrame(conn) - if err != nil { - return - } - s.frames.Add(1) - if s.handler == nil { - return - } - respType, respPayload := s.handler(msgType, payload) - if respType == 0 && respPayload == nil { - // Sentinel for "close without responding" — test uses this to force recv error. - return - } - if err := wire.WriteFrame(conn, respType, respPayload); err != nil { - return - } - } -} - -// --- DialBinary: success, dial error, handshake write error --- - -func TestDialBinarySuccess(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, nil) - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - if c.Addr() != srv.addr() { - t.Fatalf("Addr = %q, want %q", c.Addr(), srv.addr()) - } - // Wait for the server to see the handshake. - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - if srv.handshakes.Load() == 1 { - return - } - time.Sleep(10 * time.Millisecond) - } - t.Fatal("server did not receive handshake within 2s") -} - -func TestDialBinaryDialErrorWrapsMessage(t *testing.T) { - t.Parallel() - // Port 1 on 127.0.0.1 is almost certainly not listening (privileged, reserved). - // The dial will fail with ECONNREFUSED within the 5s timeout. - _, err := DialBinary("127.0.0.1:1") - if err == nil { - t.Fatal("DialBinary expected error on unreachable addr") - } - // The wrap format is `dial registry: `. We don't pin the exact text. - if len(err.Error()) == 0 { - t.Fatal("error message is empty") - } -} - -// --- Close: nil-conn path + idempotency --- - -func TestBinaryClientCloseIsSafeWithNilConn(t *testing.T) { - t.Parallel() - c := &BinaryClient{conn: nil} - if err := c.Close(); err != nil { - t.Fatalf("Close on nil conn = %v, want nil (no panic, no err)", err) - } - // Second Close is also safe — closed flag set, conn already nil. - if err := c.Close(); err != nil { - t.Fatalf("second Close = %v, want nil", err) - } -} - -// --- Addr: returns the configured addr without connection --- - -func TestBinaryClientAddrReflectsCtorValue(t *testing.T) { - t.Parallel() - c := &BinaryClient{addr: "host.example:9000"} - if got := c.Addr(); got != "host.example:9000" { - t.Fatalf("Addr = %q, want host.example:9000", got) - } -} - -// --- Heartbeat: happy path returns unixTime + warning flag --- - -func TestHeartbeatHappyPathReturnsTimeAndWarning(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - if msgType != wire.MsgHeartbeat { - return wire.MsgError, wire.EncodeError("unexpected msg") - } - req, err := wire.DecodeHeartbeatReq(payload) - if err != nil { - return wire.MsgError, wire.EncodeError(err.Error()) - } - if req.NodeID != 12345 { - return wire.MsgError, wire.EncodeError("wrong node id") - } - return wire.MsgHeartbeatOK, wire.EncodeHeartbeatResp(1_700_000_000, true) - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - sig := make([]byte, 64) - unixTime, warn, err := c.Heartbeat(12345, sig) - if err != nil { - t.Fatalf("Heartbeat: %v", err) - } - if unixTime != 1_700_000_000 { - t.Fatalf("unixTime = %d, want 1_700_000_000", unixTime) - } - if !warn { - t.Fatal("keyExpiryWarning = false, want true") - } -} - -// --- Heartbeat: server returns wire.MsgError → client surfaces "registry: " --- - -func TestHeartbeatServerErrorResponseReturnsWrappedError(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - return wire.MsgError, wire.EncodeError("node not registered") - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - _, _, err = c.Heartbeat(9999, make([]byte, 64)) - if err == nil { - t.Fatal("Heartbeat should return error when server sends wire.MsgError") - } - if got := err.Error(); got != "registry: node not registered" { - t.Fatalf("err = %q, want %q", got, "registry: node not registered") - } -} - -// --- Heartbeat: unexpected response type → error --- - -func TestHeartbeatUnexpectedResponseTypeReturnsError(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - // Respond with a LookupOK type instead of HeartbeatOK. - return wire.MsgLookupOK, []byte{0, 0, 0, 0} - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - _, _, err = c.Heartbeat(1, make([]byte, 64)) - if err == nil { - t.Fatal("expected error on unexpected response type") - } -} - -// --- Lookup: happy path decodes wire.LookupResult --- - -func TestLookupHappyPathDecodesResult(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - if msgType != wire.MsgLookup { - return wire.MsgError, wire.EncodeError("bad type") - } - return wire.MsgLookupOK, wire.EncodeLookupResp( - 42, // nodeID - true, false, // public, taskExec - []uint16{1, 2}, // networks - []byte{0xAB}, // pubkey - "host.example", // hostname - []string{"t1"}, // tags - "1.2.3.4:444", // realAddr - "ext-123", // externalID - ) - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - res, err := c.Lookup(42) - if err != nil { - t.Fatalf("Lookup: %v", err) - } - if res.NodeID != 42 { - t.Fatalf("NodeID = %d", res.NodeID) - } - if !res.Public || res.TaskExec { - t.Fatalf("flags: public=%v taskExec=%v", res.Public, res.TaskExec) - } - if len(res.Networks) != 2 || res.Networks[0] != 1 || res.Networks[1] != 2 { - t.Fatalf("Networks = %v", res.Networks) - } - if res.Hostname != "host.example" { - t.Fatalf("Hostname = %q", res.Hostname) - } - if len(res.Tags) != 1 || res.Tags[0] != "t1" { - t.Fatalf("Tags = %v", res.Tags) - } - if res.RealAddr != "1.2.3.4:444" { - t.Fatalf("RealAddr = %q", res.RealAddr) - } - if res.ExternalID != "ext-123" { - t.Fatalf("ExternalID = %q", res.ExternalID) - } -} - -// --- Lookup: unexpected response type --- - -func TestLookupUnexpectedResponseTypeReturnsError(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - return wire.MsgHeartbeatOK, wire.EncodeHeartbeatResp(0, false) - }) - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - if _, err := c.Lookup(99); err == nil { - t.Fatal("expected error on wrong response type") - } -} - -// --- Resolve: happy path decodes wire.ResolveResult --- - -func TestResolveHappyPathDecodesResult(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - if msgType != wire.MsgResolve { - return wire.MsgError, wire.EncodeError("bad type") - } - return wire.MsgResolveOK, wire.EncodeResolveResp( - 77, "10.0.0.1:5000", - []string{"192.168.1.1:5000", "192.168.1.2:5000"}, - 42, - ) - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - res, err := c.Resolve(77, 1, make([]byte, 64)) - if err != nil { - t.Fatalf("Resolve: %v", err) - } - if res.NodeID != 77 { - t.Fatalf("NodeID = %d", res.NodeID) - } - if res.RealAddr != "10.0.0.1:5000" { - t.Fatalf("RealAddr = %q", res.RealAddr) - } - if len(res.LANAddrs) != 2 { - t.Fatalf("LANAddrs = %v", res.LANAddrs) - } - if res.KeyAgeDays != 42 { - t.Fatalf("KeyAgeDays = %d", res.KeyAgeDays) - } -} - -// --- Resolve: -1 key_age_days (MaxUint32 in wire) --- - -func TestResolveMaxUint32KeyAgeMapsToNegativeOne(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - return wire.MsgResolveOK, wire.EncodeResolveResp(1, "a:1", nil, -1) - }) - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - res, err := c.Resolve(1, 1, make([]byte, 64)) - if err != nil { - t.Fatalf("Resolve: %v", err) - } - if res.KeyAgeDays != -1 { - t.Fatalf("KeyAgeDays = %d, want -1 (MaxUint32 sentinel)", res.KeyAgeDays) - } -} - -// --- SendJSON: roundtrip of a generic map --- - -func TestSendJSONRoundtripsGenericMap(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - if msgType != wire.MsgJSON { - return wire.MsgError, wire.EncodeError("bad type") - } - var req map[string]interface{} - if err := json.Unmarshal(payload, &req); err != nil { - return wire.MsgError, wire.EncodeError(err.Error()) - } - resp := map[string]interface{}{ - "type": "ok", - "echo": req["x"], - } - body, _ := json.Marshal(resp) - return wire.MsgJSON, body - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - resp, err := c.SendJSON(map[string]interface{}{"x": 7.0}) - if err != nil { - t.Fatalf("SendJSON: %v", err) - } - if resp["type"] != "ok" { - t.Fatalf("resp.type = %v, want ok", resp["type"]) - } - if got, _ := resp["echo"].(float64); got != 7 { - t.Fatalf("resp.echo = %v, want 7", resp["echo"]) - } -} - -// --- SendJSON: server returns wire.MsgError (server-side protocol error) --- - -func TestSendJSONWireMsgErrorReturnsMapWithError(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - return wire.MsgError, wire.EncodeError("rate limited") - }) - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - resp, err := c.SendJSON(map[string]interface{}{"op": "whatever"}) - if err == nil { - t.Fatal("expected error on wire.MsgError") - } - if resp == nil { - t.Fatal("resp must NOT be nil on wire.MsgError — caller relies on non-nil to skip reconnect") - } - if resp["type"] != "error" || resp["error"] != "rate limited" { - t.Fatalf("resp = %v, want type=error error=rate limited", resp) - } -} - -// --- SendJSON: application-level error field in normal JSON response --- - -func TestSendJSONReturnsErrorWhenResponseHasErrorField(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - resp := map[string]interface{}{"type": "bad", "error": "invalid op"} - body, _ := json.Marshal(resp) - return wire.MsgJSON, body - }) - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - resp, err := c.SendJSON(map[string]interface{}{"op": "x"}) - if err == nil { - t.Fatal("expected error when response has error field") - } - if got := err.Error(); got != "registry: invalid op" { - t.Fatalf("err = %q, want %q", got, "registry: invalid op") - } - if resp["type"] != "bad" { - t.Fatalf("resp.type = %v, want bad", resp["type"]) - } -} - -// --- SendJSON: unexpected response type --- - -func TestSendJSONUnexpectedResponseTypeReturnsError(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - return wire.MsgLookupOK, []byte{0, 0, 0, 0} - }) - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - _, err = c.SendJSON(map[string]interface{}{"op": "x"}) - if err == nil { - t.Fatal("expected error on wrong response type") - } -} - -// --- SendJSON: server returns malformed JSON in wire.MsgJSON → decode err --- - -func TestSendJSONMalformedResponseReturnsDecodeError(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - return wire.MsgJSON, []byte("not valid json }{") - }) - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - _, err = c.SendJSON(map[string]interface{}{"op": "x"}) - if err == nil { - t.Fatal("expected decode error on malformed JSON response") - } -} - -// --- encode/decode round-trips: sanity of our test helpers as well as SUT symmetry --- - -func TestEncodeDecodeHeartbeatReqRoundTrip(t *testing.T) { - t.Parallel() - sig := make([]byte, 64) - for i := range sig { - sig[i] = byte(i) - } - buf := wire.EncodeHeartbeatReq(0xDEADBEEF, sig) - req, err := wire.DecodeHeartbeatReq(buf) - if err != nil { - t.Fatalf("decode: %v", err) - } - if req.NodeID != 0xDEADBEEF { - t.Fatalf("NodeID = %x", req.NodeID) - } - for i := 0; i < 64; i++ { - if req.Signature[i] != byte(i) { - t.Fatalf("sig[%d] = %x, want %x", i, req.Signature[i], i) - } - } -} - -func TestDecodeWireErrorShortPayloadReturnsSentinel(t *testing.T) { - t.Parallel() - if got := wire.DecodeError([]byte{0x00}); got != "unknown error" { - t.Fatalf("wire.DecodeError(short) = %q, want unknown error", got) - } -} - -func TestDecodeWireErrorTruncatesToActualLen(t *testing.T) { - t.Parallel() - // Claim length=100 but only 5 real bytes follow — decoder clamps to available. - buf := make([]byte, 7) - binary.BigEndian.PutUint16(buf[:2], 100) - copy(buf[2:], []byte("hello")) - got := wire.DecodeError(buf) - if got != "hello" { - t.Fatalf("wire.DecodeError(truncated) = %q, want hello", got) - } -} diff --git a/pkg/registry/client/zz_client_branch_test.go b/pkg/registry/client/zz_client_branch_test.go deleted file mode 100644 index e0f69d39..00000000 --- a/pkg/registry/client/zz_client_branch_test.go +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package client - -import ( - "context" - "crypto/tls" - "encoding/json" - "net" - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -// Branch-fill tests: every wrapper that takes an optional adminToken, -// signature, or variadic flag has an untested branch when the optional -// arg is blank. This file ticks the remaining `if x != ""` / `if len(...) > 0` -// branches and the binary_client reconnect/lookup/resolve error edges. - -// --- Client member-mgmt wrappers: with-adminToken branches -------------- -// -// Existing tests cover the blank-token path; here we cover the non-blank -// branch so the `if adminToken != ""` is exercised both ways. - -func TestPromoteDemoteKickTransferIncludeAdminToken(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - cases := []struct { - name string - call func() (map[string]interface{}, error) - targetKey string - }{ - {"promote", func() (map[string]interface{}, error) { return c.PromoteMember(1, 2, 3, "ADM") }, "target_node_id"}, - {"demote", func() (map[string]interface{}, error) { return c.DemoteMember(1, 2, 3, "ADM") }, "target_node_id"}, - {"kick", func() (map[string]interface{}, error) { return c.KickMember(1, 2, 3, "ADM") }, "target_node_id"}, - {"transfer", func() (map[string]interface{}, error) { return c.TransferOwnership(1, 2, 3, "ADM") }, "new_owner_id"}, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - resp, err := tc.call() - if err != nil { - t.Fatalf("%s: %v", tc.name, err) - } - echo := assertEcho(t, resp) - if got, _ := echo["admin_token"].(string); got != "ADM" { - t.Fatalf("%s: admin_token: %q", tc.name, got) - } - if got, _ := echo[tc.targetKey].(float64); uint32(got) != 3 { - t.Fatalf("%s: %s: %v", tc.name, tc.targetKey, got) - } - }) - } -} - -// --- ReportTrust / RevokeTrust / SetVisibility: WITH-signer branch ----- -// -// Existing TestReportTrustAndRevokeTrustFormat and TestSetVisibilityPublicFlagSerialized -// drive the no-signer (sig empty) path. Cover the signer-attached branch -// so `if sig := ...; sig != ""` is hit both ways. - -func TestReportRevokeVisibilityIncludeSignatureWhenSignerSet(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - c.SetSigner(func(ch string) string { return "SIG:" + ch }) - cases := []struct { - name string - call func() (map[string]interface{}, error) - challenge string - }{ - {"report_trust", func() (map[string]interface{}, error) { return c.ReportTrust(1, 2) }, "report_trust:1:2"}, - {"revoke_trust", func() (map[string]interface{}, error) { return c.RevokeTrust(1, 2) }, "revoke_trust:1:2"}, - {"set_visibility", func() (map[string]interface{}, error) { return c.SetVisibility(9, false) }, "set_visibility:9"}, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - resp, err := tc.call() - if err != nil { - t.Fatalf("%s: %v", tc.name, err) - } - echo := assertEcho(t, resp) - if got, _ := echo["signature"].(string); got != "SIG:"+tc.challenge { - t.Fatalf("%s: signature: want SIG:%s, got %q", tc.name, tc.challenge, got) - } - }) - } -} - -// --- CreateManagedNetwork full-options branch ----------------------------- - -func TestCreateManagedNetworkFullOptions(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, err := c.CreateManagedNetwork(2, "n", "invite", "tok", "ADM", true, `{"a":1}`, "NAT") - if err != nil { - t.Fatalf("create: %v", err) - } - echo := assertEcho(t, resp) - if got, _ := echo["admin_token"].(string); got != "ADM" { - t.Fatalf("admin_token: %q", got) - } - if got, _ := echo["enterprise"].(bool); !got { - t.Fatalf("enterprise: %v", got) - } - if got, _ := echo["network_admin_token"].(string); got != "NAT" { - t.Fatalf("network_admin_token: %q", got) - } -} - -// --- ListNetworks with adminToken ---------------------------------------- - -func TestListNetworksWithAdminTokenIncludesField(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, err := c.ListNetworks("SUPER") - if err != nil { - t.Fatalf("list: %v", err) - } - echo := assertEcho(t, resp) - if got, _ := echo["admin_token"].(string); got != "SUPER" { - t.Fatalf("admin_token: %q", got) - } -} - -// --- DialTLS error wrapping happy/sad already covered; sad path only used -// "dial registry TLS" prefix once. Cover the happy-path connect branch with -// a real TLS listener that closes immediately. ----------------------------- - -// --- binary_client: reconnect after failure ------------------------------ - -// TestBinaryHeartbeatReconnectsAfterBrokenConn covers the -// `err != nil && !c.closed → reconnect → retry` branch in Heartbeat. -func TestBinaryHeartbeatReconnectsAfterBrokenConn(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - if msgType != wire.MsgHeartbeat { - return wire.MsgError, wire.EncodeError("bad") - } - return wire.MsgHeartbeatOK, wire.EncodeHeartbeatResp(123, false) - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - // Forcibly close the underlying conn so the next Heartbeat triggers - // reconnect+retry. - c.mu.Lock() - _ = c.conn.Close() - c.mu.Unlock() - - unixTime, _, err := c.Heartbeat(1, make([]byte, 64)) - if err != nil { - t.Fatalf("Heartbeat after broken conn: %v", err) - } - if unixTime != 123 { - t.Fatalf("unixTime = %d, want 123", unixTime) - } -} - -// TestBinaryLookupReconnectsAfterBrokenConn covers the reconnect branch -// inside Lookup. -func TestBinaryLookupReconnectsAfterBrokenConn(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - if msgType != wire.MsgLookup { - return wire.MsgError, wire.EncodeError("bad") - } - return wire.MsgLookupOK, wire.EncodeLookupResp( - 7, true, false, nil, nil, "h", nil, "1.1.1.1:1", "", - ) - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - c.mu.Lock() - _ = c.conn.Close() - c.mu.Unlock() - - res, err := c.Lookup(7) - if err != nil { - t.Fatalf("Lookup: %v", err) - } - if res.NodeID != 7 { - t.Fatalf("NodeID = %d", res.NodeID) - } -} - -// TestBinaryResolveReconnectsAfterBrokenConn covers the reconnect branch -// inside Resolve. -func TestBinaryResolveReconnectsAfterBrokenConn(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - if msgType != wire.MsgResolve { - return wire.MsgError, wire.EncodeError("bad") - } - return wire.MsgResolveOK, wire.EncodeResolveResp(8, "2.2.2.2:2", nil, 0) - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - c.mu.Lock() - _ = c.conn.Close() - c.mu.Unlock() - - res, err := c.Resolve(8, 1, make([]byte, 64)) - if err != nil { - t.Fatalf("Resolve: %v", err) - } - if res.NodeID != 8 { - t.Fatalf("NodeID = %d", res.NodeID) - } -} - -// TestBinarySendJSONReconnectsAfterBrokenConn covers the reconnect branch -// inside SendJSON. -func TestBinarySendJSONReconnectsAfterBrokenConn(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { - if msgType != wire.MsgJSON { - return wire.MsgError, wire.EncodeError("bad") - } - body, _ := json.Marshal(map[string]interface{}{"type": "ok"}) - return wire.MsgJSON, body - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - defer c.Close() - - c.mu.Lock() - _ = c.conn.Close() - c.mu.Unlock() - - resp, err := c.SendJSON(map[string]interface{}{"op": "x"}) - if err != nil { - t.Fatalf("SendJSON: %v", err) - } - if resp["type"] != "ok" { - t.Fatalf("type: %v", resp["type"]) - } -} - -// TestBinaryReconnectAllAttemptsFail covers the failure path: all 5 -// reconnect attempts fail and the client surfaces "reconnect failed". -// -// We pass a small backoff window indirectly by pointing at a closed port. -// 5 attempts * ~0.5s backoff each is up to ~7.5s of sleeping inside -// reconnect — too slow for -short. Instead, exercise the immediate path: -// close the client first so reconnect returns "client closed" without -// sleeping. This still covers the c.closed branch. -func TestBinaryReconnectShortCircuitsWhenClosed(t *testing.T) { - t.Parallel() - srv := newFakeBinaryServer(t, func(byte, []byte) (byte, []byte) { - return wire.MsgHeartbeatOK, wire.EncodeHeartbeatResp(1, false) - }) - - c, err := DialBinary(srv.addr()) - if err != nil { - t.Fatalf("DialBinary: %v", err) - } - - // Tear down the listener to make a future dial fail, then close the - // client to force the reconnect early-return branch. - srv.Close() - // Drop the local conn and mark closed before triggering reconnect. - c.mu.Lock() - _ = c.conn.Close() - c.closed = true - err = c.reconnect() - c.mu.Unlock() - if err == nil { - t.Fatalf("reconnect after Close should fail") - } - if !strings.Contains(err.Error(), "client closed") { - t.Fatalf("expected 'client closed' error, got: %v", err) - } -} - -// TestBinaryDialBinaryHandshakeWriteFailure exercises the -// "conn.Write(handshake) fails" branch in DialBinary. We can't intercept -// the write directly, but we can race a close: connect to a listener that -// accepts then immediately closes the conn before the handshake write -// completes. On macOS this typically surfaces as a write error on a -// half-closed socket. -func TestBinaryDialBinaryHandshakeWriteFailure(t *testing.T) { - t.Parallel() - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - addr := ln.Addr().String() - - done := make(chan struct{}) - go func() { - defer close(done) - conn, err := ln.Accept() - if err != nil { - return - } - // Drop the conn immediately. Then close the listener so subsequent - // dials return ECONNREFUSED if the test re-runs. - conn.Close() - }() - // DialBinary may succeed (handshake writes 5 bytes into the kernel buffer - // before EOF is observed) or fail. Either is fine — we just need the path - // to execute and not panic. The accept goroutine guarantees a real connect. - c, err := DialBinary(addr) - if err == nil && c != nil { - c.Close() - } - <-done - ln.Close() - - // As a deterministic companion: dial a closed port. This exercises the - // "net.Dial fails" branch. - closed, _ := net.Listen("tcp", "127.0.0.1:0") - closedAddr := closed.Addr().String() - closed.Close() - _, err = DialBinary(closedAddr) - if err == nil { - t.Fatalf("DialBinary to closed port should fail") - } -} - -// --- Backoff cap check via reconnect against unreachable addr ------------ - -// TestClientReconnectBackoffCapsAtMax pushes reconnect into >5s of backoff -// growth and verifies it returns the eventual "reconnect failed" wrap. -// We use a Client whose addr points to a closed port; reconnect dials it -// 5 times then gives up. Using a fresh-grabbed-then-released kernel port -// keeps each failed dial fast (ECONNREFUSED on loopback is sub-ms). -func TestClientReconnectExhaustsAttempts(t *testing.T) { - if testing.Short() { - // 5 attempts * 0.5s = ~7.5s of sleep — too slow for -short with -race. - t.Skip("skipping long reconnect-exhaustion test under -short") - } - t.Parallel() - ln, _ := net.Listen("tcp", "127.0.0.1:0") - addr := ln.Addr().String() - ln.Close() - - c := &Client{addr: addr} - c.mu.Lock() - err := c.reconnect(context.Background()) - c.mu.Unlock() - if err == nil { - t.Fatalf("expected reconnect failure") - } - if !strings.Contains(err.Error(), "reconnect failed") { - t.Fatalf("expected 'reconnect failed' wrap, got: %v", err) - } -} - -// --- Close after pool conn already closed: idempotency / second-close ---- - -func TestClosePoolEntryAlreadyClosedReturnsFirstErrOrNil(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 3) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - // Pre-close one secondary conn so Close()'s per-entry loop hits the - // "already-closed" path on it. - c.pool.entries[1].mu.Lock() - _ = c.pool.entries[1].conn.Close() - c.pool.entries[1].mu.Unlock() - - // Double-close should be safe. - if err := c.Close(); err != nil { - // Close may surface the first error (double-close on a TCPConn). - // That's acceptable — what matters is no panic. - t.Logf("Close returned (acceptable): %v", err) - } -} - -// --- sendOnEntry server error path (response with "error" key) ----------- - -func TestSendOnEntryReturnsServerErrorResponse(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { - return map[string]interface{}{"error": "rate-limited"} - }) - defer srv.close() - - c, err := DialPool(srv.addr(), 2) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - defer c.Close() - - resp, err := c.Send(map[string]interface{}{"type": "x"}) - if err == nil { - t.Fatalf("expected server error") - } - if resp == nil { - t.Fatalf("resp must be non-nil for server-error path") - } - if !strings.Contains(err.Error(), "rate-limited") { - t.Fatalf("error should contain server message, got: %v", err) - } -} - -// --- DialTLS happy path --------------------------------------------------- - -func TestDialTLSHappyPathConnects(t *testing.T) { - t.Parallel() - srv := newFakeTLSServer(t, echoHandler()) - cfg := &tls.Config{ - MinVersion: tls.VersionTLS12, - InsecureSkipVerify: true, //nolint:gosec // test-only - } - c, err := DialTLS(srv.addr(), cfg) - if err != nil { - t.Fatalf("DialTLS: %v", err) - } - defer c.Close() - resp, err := c.Send(map[string]interface{}{"type": "x"}) - if err != nil { - t.Fatalf("send: %v", err) - } - if got, _ := resp["type"].(string); got != "ok" { - t.Fatalf("type: %q", got) - } - // Ensure tlsConfig is retained so reconnect would use TLS too. - if c.tlsConfig == nil { - t.Fatalf("tlsConfig should be retained on Client after DialTLS") - } -} diff --git a/pkg/registry/client/zz_client_join_signature_test.go b/pkg/registry/client/zz_client_join_signature_test.go deleted file mode 100644 index 677b0b08..00000000 --- a/pkg/registry/client/zz_client_join_signature_test.go +++ /dev/null @@ -1,707 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package client - -import ( - "encoding/json" - "errors" - "strings" - "testing" - "time" -) - -// echoOnlyClient dials a fakeJSONServer with echoHandler and returns a -// connected Client plus the server so test bodies can assert wire payloads. -func echoOnlyClient(t *testing.T) (*Client, *fakeJSONServer) { - t.Helper() - srv := newFakeJSONServer(t, echoHandler()) - c, err := Dial(srv.addr()) - if err != nil { - srv.close() - t.Fatalf("dial: %v", err) - } - t.Cleanup(func() { c.Close(); srv.close() }) - return c, srv -} - -// assertEcho fetches the echoed request payload that the fake server round-tripped. -func assertEcho(t *testing.T, resp map[string]interface{}) map[string]interface{} { - t.Helper() - echo, ok := resp["echo"].(map[string]interface{}) - if !ok { - t.Fatalf("response missing echo key: %+v", resp) - } - return echo -} - -// --- JoinNetwork / LeaveNetwork : signature wins over admin_token -------- - -func TestJoinNetworkSignaturePreferredOverAdminToken(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - c.SetSigner(func(ch string) string { return "SIG:" + ch }) - - resp, err := c.JoinNetwork(11, 3, "tok", 4, "ADMIN_SHOULD_BE_IGNORED") - if err != nil { - t.Fatalf("join: %v", err) - } - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "join_network" { - t.Fatalf("type: %q", got) - } - if got, _ := echo["signature"].(string); got != "SIG:join_network:11:3" { - t.Fatalf("signature: %q", got) - } - if _, ok := echo["admin_token"]; ok { - t.Fatalf("admin_token should be omitted when signature present") - } -} - -func TestJoinNetworkFallsBackToAdminTokenWithoutSigner(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, err := c.JoinNetwork(11, 3, "tok", 4, "ADM") - if err != nil { - t.Fatalf("join: %v", err) - } - echo := assertEcho(t, resp) - if _, ok := echo["signature"]; ok { - t.Fatalf("signature should be absent with no signer") - } - if got, _ := echo["admin_token"].(string); got != "ADM" { - t.Fatalf("admin_token: %q", got) - } - if got, _ := echo["inviter_id"].(float64); uint32(got) != 4 { - t.Fatalf("inviter_id: %v", got) - } -} - -func TestLeaveNetworkSignatureOrAdminToken(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - // Signer wins. - c.SetSigner(func(ch string) string { return "SIG:" + ch }) - resp, _ := c.LeaveNetwork(5, 2, "ADMIN") - echo := assertEcho(t, resp) - if got, _ := echo["signature"].(string); got != "SIG:leave_network:5:2" { - t.Fatalf("signature: %q", got) - } - if _, ok := echo["admin_token"]; ok { - t.Fatalf("admin_token should be omitted when sig present") - } - // Drop signer → admin_token fallback. - c.SetSigner(nil) - resp, _ = c.LeaveNetwork(5, 2, "ADMIN") - echo = assertEcho(t, resp) - if got, _ := echo["admin_token"].(string); got != "ADMIN" { - t.Fatalf("admin_token fallback: %q", got) - } -} - -// --- DeleteNetwork / RenameNetwork : variadic node_id -------------------- - -func TestDeleteNetworkVariadicNodeIDAndAdminToken(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - // No nodeID, no adminToken. - resp, _ := c.DeleteNetwork(3, "") - echo := assertEcho(t, resp) - if _, ok := echo["node_id"]; ok { - t.Fatalf("node_id should be omitted when not passed") - } - if _, ok := echo["admin_token"]; ok { - t.Fatalf("admin_token should be omitted when blank") - } - // With both. - resp, _ = c.DeleteNetwork(3, "ADM", 77) - echo = assertEcho(t, resp) - if got, _ := echo["node_id"].(float64); uint32(got) != 77 { - t.Fatalf("node_id: %v", got) - } - if got, _ := echo["admin_token"].(string); got != "ADM" { - t.Fatalf("admin_token: %q", got) - } - // Explicit 0 node_id → still omitted. - resp, _ = c.DeleteNetwork(3, "ADM", 0) - echo = assertEcho(t, resp) - if _, ok := echo["node_id"]; ok { - t.Fatalf("node_id=0 should be omitted (matches client logic)") - } -} - -func TestRenameNetworkPassesName(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.RenameNetwork(1, "shiny", "ADM", 9) - echo := assertEcho(t, resp) - if got, _ := echo["name"].(string); got != "shiny" { - t.Fatalf("name: %q", got) - } - if got, _ := echo["node_id"].(float64); uint32(got) != 9 { - t.Fatalf("node_id: %v", got) - } -} - -// --- ListNetworks / ListNodes / SetNetworkEnterprise -------------------- - -func TestListNetworksBareType(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.ListNetworks() - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "list_networks" { - t.Fatalf("type: %q", got) - } -} - -func TestListNodesAdminTokenOptional(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - // Without admin token. - resp, _ := c.ListNodes(42) - echo := assertEcho(t, resp) - if _, ok := echo["admin_token"]; ok { - t.Fatalf("admin_token should be omitted when not supplied") - } - // With admin token. - resp, _ = c.ListNodes(42, "ADM") - echo = assertEcho(t, resp) - if got, _ := echo["admin_token"].(string); got != "ADM" { - t.Fatalf("admin_token: %q", got) - } -} - -func TestSetNetworkEnterpriseSerializesBool(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.SetNetworkEnterprise(7, true, "ADM") - echo := assertEcho(t, resp) - if got, _ := echo["enterprise"].(bool); !got { - t.Fatalf("enterprise should be true") - } - if got, _ := echo["admin_token"].(string); got != "ADM" { - t.Fatalf("admin_token: %q", got) - } -} - -// --- Signed thin wrappers: Deregister / Heartbeat / Punch ---------------- - -func TestSignedWrappersIncludeCorrectChallenge(t *testing.T) { - t.Parallel() - cases := []struct { - name string - call func(c *Client) (map[string]interface{}, error) - expect string - }{ - {"deregister", func(c *Client) (map[string]interface{}, error) { return c.Deregister(42) }, "deregister:42"}, - {"heartbeat", func(c *Client) (map[string]interface{}, error) { return c.Heartbeat(42) }, "heartbeat:42"}, - {"punch", func(c *Client) (map[string]interface{}, error) { return c.Punch(1, 42, 43) }, "punch:42:43"}, - {"poll_handshakes", func(c *Client) (map[string]interface{}, error) { return c.PollHandshakes(42) }, "poll_handshakes:42"}, - {"set_hostname", func(c *Client) (map[string]interface{}, error) { return c.SetHostname(42, "h") }, "set_hostname:42"}, - {"set_tags", func(c *Client) (map[string]interface{}, error) { return c.SetTags(42, []string{"a"}) }, "set_tags:42"}, - {"poll_invites", func(c *Client) (map[string]interface{}, error) { return c.PollInvites(42) }, "poll_invites:42"}, - {"set_key_expiry", func(c *Client) (map[string]interface{}, error) { return c.SetKeyExpiry(42, time.Unix(0, 0).UTC()) }, "set_key_expiry:42"}, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - c, _ := echoOnlyClient(t) - c.SetSigner(func(ch string) string { return "SIG:" + ch }) - resp, err := tc.call(c) - if err != nil { - t.Fatalf("call: %v", err) - } - echo := assertEcho(t, resp) - if got, _ := echo["signature"].(string); got != "SIG:"+tc.expect { - t.Fatalf("signature: want SIG:%s, got %q", tc.expect, got) - } - }) - } -} - -func TestSetKeyExpiryFormatsRFC3339(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - c.SetSigner(func(ch string) string { return "SIG:" + ch }) - moment := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC) - resp, _ := c.SetKeyExpiry(9, moment) - echo := assertEcho(t, resp) - if got, _ := echo["expires_at"].(string); got != "2030-01-02T03:04:05Z" { - t.Fatalf("expires_at: %q", got) - } -} - -// --- RequestHandshake / RespondHandshake (caller-supplied signature) ----- - -func TestRequestAndRespondHandshakePassThroughSignature(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.RequestHandshake(1, 2, "please", "SIG_REQ") - echo := assertEcho(t, resp) - if got, _ := echo["signature"].(string); got != "SIG_REQ" { - t.Fatalf("req signature: %q", got) - } - if got, _ := echo["justification"].(string); got != "please" { - t.Fatalf("justification: %q", got) - } - - resp, _ = c.RespondHandshake(3, 4, true, "SIG_RESP") - echo = assertEcho(t, resp) - if got, _ := echo["accept"].(bool); !got { - t.Fatalf("accept: %v", got) - } - if got, _ := echo["signature"].(string); got != "SIG_RESP" { - t.Fatalf("resp signature: %q", got) - } - - // Blank signature omitted. - resp, _ = c.RespondHandshake(3, 4, false, "") - echo = assertEcho(t, resp) - if _, ok := echo["signature"]; ok { - t.Fatalf("signature should be omitted when blank") - } -} - -// --- ResolveHostname / ResolveHostnameAs / CheckTrust -------------------- - -func TestResolveHostnameBothForms(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.ResolveHostname("alpha") - echo := assertEcho(t, resp) - if got, _ := echo["hostname"].(string); got != "alpha" { - t.Fatalf("hostname: %q", got) - } - if _, ok := echo["requester_id"]; ok { - t.Fatalf("requester_id should be absent") - } - - resp, _ = c.ResolveHostnameAs(99, "beta") - echo = assertEcho(t, resp) - if got, _ := echo["requester_id"].(float64); uint32(got) != 99 { - t.Fatalf("requester_id: %v", got) - } -} - -func TestCheckTrustReturnsTypedBool(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { - return map[string]interface{}{"type": "ok", "trusted": true} - }) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - - trusted, err := c.CheckTrust(1, 2) - if err != nil { - t.Fatalf("check trust: %v", err) - } - if !trusted { - t.Fatalf("expected trusted=true") - } -} - -func TestCheckTrustPropagatesError(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { - return map[string]interface{}{"error": "forbidden"} - }) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - - trusted, err := c.CheckTrust(1, 2) - if err == nil || !strings.Contains(err.Error(), "forbidden") { - t.Fatalf("expected forbidden error, got %v", err) - } - if trusted { - t.Fatalf("expected trusted=false on error") - } -} - -// --- Invite family -------------------------------------------------------- - -func TestInviteToNetworkSigAndAdminBothAllowed(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - c.SetSigner(func(ch string) string { return "SIG:" + ch }) - // Both signature AND admin_token are included (logic uses two independent ifs). - resp, _ := c.InviteToNetwork(3, 1, 2, "ADM") - echo := assertEcho(t, resp) - if got, _ := echo["signature"].(string); got != "SIG:invite:1:3:2" { - t.Fatalf("signature: %q", got) - } - if got, _ := echo["admin_token"].(string); got != "ADM" { - t.Fatalf("admin_token: %q", got) - } -} - -func TestRespondInvitePassesAccept(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - c.SetSigner(func(ch string) string { return "SIG:" + ch }) - resp, _ := c.RespondInvite(5, 9, false) - echo := assertEcho(t, resp) - if got, _ := echo["accept"].(bool); got { - t.Fatalf("accept: %v", got) - } - if got, _ := echo["signature"].(string); got != "SIG:respond_invite:5:9" { - t.Fatalf("signature: %q", got) - } -} - -// --- Member role operations --------------------------------------------- - -func TestMemberRoleOpsOmitBlankAdmin(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - ops := map[string]func() (map[string]interface{}, error){ - "promote_member": func() (map[string]interface{}, error) { return c.PromoteMember(1, 2, 3, "") }, - "demote_member": func() (map[string]interface{}, error) { return c.DemoteMember(1, 2, 3, "") }, - "kick_member": func() (map[string]interface{}, error) { return c.KickMember(1, 2, 3, "") }, - "transfer_ownership": func() (map[string]interface{}, error) { return c.TransferOwnership(1, 2, 3, "") }, - } - for name, op := range ops { - resp, err := op() - if err != nil { - t.Fatalf("%s: %v", name, err) - } - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != name { - t.Fatalf("%s: type=%q", name, got) - } - if _, ok := echo["admin_token"]; ok { - t.Fatalf("%s: admin_token should be omitted when blank", name) - } - } -} - -func TestGetMemberRoleSimple(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.GetMemberRole(3, 7) - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "get_member_role" { - t.Fatalf("type: %q", got) - } - if got, _ := echo["target_node_id"].(float64); uint32(got) != 7 { - t.Fatalf("target_node_id: %v", got) - } -} - -// --- Policy / ExprPolicy -------------------------------------------------- - -func TestSetNetworkPolicyMergesPolicyMap(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - policy := map[string]interface{}{ - "allow_public": true, - "max_members": float64(10), - "network_id": float64(999), // must NOT override the real network_id - "type": "evil_type", // must NOT override the real type - "admin_token": "EVIL", // must NOT override the real admin_token - } - resp, _ := c.SetNetworkPolicy(3, policy, "ADM") - echo := assertEcho(t, resp) - if got, _ := echo["allow_public"].(bool); !got { - t.Fatalf("allow_public: %v", got) - } - if got, _ := echo["max_members"].(float64); got != 10 { - t.Fatalf("max_members: %v", got) - } - // Explicit networkID parameter must win over a policy key of the same name. - if got, _ := echo["network_id"].(float64); got != 3 { - t.Fatalf("network_id: want 3 (explicit param wins), got %v", got) - } - if got, _ := echo["type"].(string); got != "set_network_policy" { - t.Fatalf("type: want set_network_policy (protected), got %q", got) - } - if got, _ := echo["admin_token"].(string); got != "ADM" { - t.Fatalf("admin_token: want ADM (explicit param wins), got %q", got) - } -} - -func TestGetNetworkPolicyType(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.GetNetworkPolicy(7) - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "get_network_policy" { - t.Fatalf("type: %q", got) - } -} - -func TestSetExprPolicyStringifiesJSON(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - raw := json.RawMessage(`{"rule":"true"}`) - resp, _ := c.SetExprPolicy(9, raw, "ADM") - echo := assertEcho(t, resp) - if got, _ := echo["expr_policy"].(string); got != `{"rule":"true"}` { - t.Fatalf("expr_policy: %q", got) - } -} - -func TestGetExprPolicyType(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.GetExprPolicy(9) - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "get_expr_policy" { - t.Fatalf("type: %q", got) - } -} - -// --- Admin wrappers (trivial payload formatters) ------------------------- - -func TestAdminWrappersIncludeAdminToken(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - cases := []struct { - name string - call func() (map[string]interface{}, error) - typ string - }{ - {"set_hostname_admin", func() (map[string]interface{}, error) { return c.SetHostnameAdmin(1, "h", "T") }, "set_hostname"}, - {"set_visibility_admin", func() (map[string]interface{}, error) { return c.SetVisibilityAdmin(1, true, "T") }, "set_visibility"}, - {"set_tags_admin", func() (map[string]interface{}, error) { return c.SetTagsAdmin(1, []string{"x"}, "T") }, "set_tags"}, - {"set_key_expiry_admin", func() (map[string]interface{}, error) { return c.SetKeyExpiryAdmin(1, time.Unix(0, 0).UTC(), "T") }, "set_key_expiry"}, - {"clear_key_expiry_admin", func() (map[string]interface{}, error) { return c.ClearKeyExpiryAdmin(1, "T") }, "set_key_expiry"}, - {"deregister_admin", func() (map[string]interface{}, error) { return c.DeregisterAdmin(1, "T") }, "deregister"}, - } - for _, tc := range cases { - resp, err := tc.call() - if err != nil { - t.Fatalf("%s: %v", tc.name, err) - } - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != tc.typ { - t.Fatalf("%s: type=%q want %q", tc.name, got, tc.typ) - } - if got, _ := echo["admin_token"].(string); got != "T" { - t.Fatalf("%s: admin_token=%q", tc.name, got) - } - } -} - -func TestClearKeyExpiryAdminSendsNeverLiteral(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.ClearKeyExpiryAdmin(1, "T") - echo := assertEcho(t, resp) - if got, _ := echo["expires_at"].(string); got != "never" { - t.Fatalf("expires_at: want 'never', got %q", got) - } -} - -func TestSetMemberTagsAndGetMemberTags(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.SetMemberTags(2, 3, []string{"gpu", "fast"}, "T") - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "set_member_tags" { - t.Fatalf("type: %q", got) - } - tags, _ := echo["tags"].([]interface{}) - if len(tags) != 2 || tags[0] != "gpu" || tags[1] != "fast" { - t.Fatalf("tags: %v", tags) - } - - resp, _ = c.GetMemberTags(2, 3) - echo = assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "get_member_tags" { - t.Fatalf("type: %q", got) - } -} - -// --- Audit log / Audit export / Webhooks / Identity / IDP / Provision ---- - -func TestGetAuditLogOmitsZeroNetworkID(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.GetAuditLog(0, "T") - echo := assertEcho(t, resp) - if _, ok := echo["network_id"]; ok { - t.Fatalf("network_id should be omitted when 0") - } - resp, _ = c.GetAuditLog(3, "T") - echo = assertEcho(t, resp) - if got, _ := echo["network_id"].(float64); uint16(got) != 3 { - t.Fatalf("network_id: %v", got) - } -} - -func TestWebhookWrappers(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - cases := []struct { - typ string - call func() (map[string]interface{}, error) - }{ - {"set_webhook", func() (map[string]interface{}, error) { return c.SetWebhook("http://x", "T") }}, - {"get_webhook", func() (map[string]interface{}, error) { return c.GetWebhook("T") }}, - {"get_webhook_dlq", func() (map[string]interface{}, error) { return c.GetWebhookDLQ("T") }}, - {"set_identity_webhook", func() (map[string]interface{}, error) { return c.SetIdentityWebhook("http://id", "T") }}, - } - for _, tc := range cases { - resp, err := tc.call() - if err != nil { - t.Fatalf("%s: %v", tc.typ, err) - } - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != tc.typ { - t.Fatalf("%s: type=%q", tc.typ, got) - } - } -} - -func TestIdentityExternalIDWrappers(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.SetExternalID(5, "ext-7", "T") - echo := assertEcho(t, resp) - if got, _ := echo["external_id"].(string); got != "ext-7" { - t.Fatalf("external_id: %q", got) - } - resp, _ = c.GetIdentity(5, "T") - echo = assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "get_identity" { - t.Fatalf("type: %q", got) - } -} - -func TestSetIDPConfigOptionalFields(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - // Only required fields. - resp, _ := c.SetIDPConfig("oidc", "https://idp", "", "", "", "", "T") - echo := assertEcho(t, resp) - for _, key := range []string{"issuer", "client_id", "tenant_id", "domain"} { - if _, ok := echo[key]; ok { - t.Fatalf("%s should be omitted when blank", key) - } - } - if got, _ := echo["idp_type"].(string); got != "oidc" { - t.Fatalf("idp_type: %q", got) - } - // All fields. - resp, _ = c.SetIDPConfig("oidc", "https://idp", "ISS", "CID", "TID", "example.com", "T") - echo = assertEcho(t, resp) - for _, key := range []string{"issuer", "client_id", "tenant_id", "domain"} { - if _, ok := echo[key]; !ok { - t.Fatalf("%s should be present when supplied", key) - } - } -} - -func TestGetIDPConfigAndGetProvisionStatus(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.GetIDPConfig("T") - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "get_idp_config" { - t.Fatalf("type: %q", got) - } - resp, _ = c.GetProvisionStatus("T") - echo = assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "get_provision_status" { - t.Fatalf("type: %q", got) - } -} - -func TestProvisionNetworkPassesBlueprint(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - bp := map[string]interface{}{"name": "bp", "networks": []interface{}{}} - resp, _ := c.ProvisionNetwork(bp, "T") - echo := assertEcho(t, resp) - blueprint, _ := echo["blueprint"].(map[string]interface{}) - if got, _ := blueprint["name"].(string); got != "bp" { - t.Fatalf("blueprint.name: %q", got) - } -} - -func TestSetAuditExportAllFields(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.SetAuditExport("splunk_hec", "https://hec", "TOK", "idx", "src", "T") - echo := assertEcho(t, resp) - if got, _ := echo["format"].(string); got != "splunk_hec" { - t.Fatalf("format: %q", got) - } - if got, _ := echo["endpoint"].(string); got != "https://hec" { - t.Fatalf("endpoint: %q", got) - } - if got, _ := echo["index"].(string); got != "idx" { - t.Fatalf("index: %q", got) - } - if got, _ := echo["source"].(string); got != "src" { - t.Fatalf("source: %q", got) - } -} - -func TestGetAuditExport(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.GetAuditExport("T") - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "get_audit_export" { - t.Fatalf("type: %q", got) - } -} - -// --- Directory sync / ValidateToken / GetKeyInfo ------------------------- - -func TestDirectorySyncConvertsEntriesAndPassesFlag(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - entries := []map[string]interface{}{ - {"id": "u1", "role": "admin"}, - {"id": "u2", "role": "member"}, - } - resp, _ := c.DirectorySync(1, entries, true, "T") - echo := assertEcho(t, resp) - list, _ := echo["entries"].([]interface{}) - if len(list) != 2 { - t.Fatalf("entries: %v", list) - } - first, _ := list[0].(map[string]interface{}) - if got, _ := first["id"].(string); got != "u1" { - t.Fatalf("first.id: %q", got) - } - if got, _ := echo["remove_unlisted"].(bool); !got { - t.Fatalf("remove_unlisted: %v", got) - } -} - -func TestDirectoryStatusSimple(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.DirectoryStatus(5, "T") - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "directory_status" { - t.Fatalf("type: %q", got) - } -} - -func TestValidateTokenPassesPayload(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.ValidateToken("jwt.header.sig", "T") - echo := assertEcho(t, resp) - if got, _ := echo["token"].(string); got != "jwt.header.sig" { - t.Fatalf("token: %q", got) - } -} - -func TestGetKeyInfoSimple(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, _ := c.GetKeyInfo(7) - echo := assertEcho(t, resp) - if got, _ := echo["type"].(string); got != "get_key_info" { - t.Fatalf("type: %q", got) - } -} - -// Ensure errors package remains used if inline error checks are trimmed. -var _ = errors.New diff --git a/pkg/registry/client/zz_client_nil_receiver_test.go b/pkg/registry/client/zz_client_nil_receiver_test.go deleted file mode 100644 index 1c0995e7..00000000 --- a/pkg/registry/client/zz_client_nil_receiver_test.go +++ /dev/null @@ -1,253 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package client - -import ( - "encoding/json" - "errors" - "testing" - "time" -) - -// TestNilClient_AllMethodsReturnError asserts that every exported *Client -// method is safe to call on a typed-nil receiver and returns ErrNoRegistry -// (or, for SetSigner/Close, a no-op without panic). Several callers in the -// daemon (loadPolicyRunners, ManagedEngine.fetchMembers, Daemon.Info → -// nodeNetworks) invoke registry methods without nil-checking the client, -// so the only acceptable behavior is "no panic; recoverable error." -// -// The test invokes each method, recovers any panic, and asserts the -// expected error. A panic counts as a regression and fails the test. -func TestNilClient_AllMethodsReturnError(t *testing.T) { - t.Parallel() - - var c *Client - - // callErr runs fn and asserts (a) no panic and (b) the returned error - // is ErrNoRegistry (using errors.Is). name identifies the method. - callErr := func(name string, fn func() error) { - t.Helper() - defer func() { - if r := recover(); r != nil { - t.Errorf("%s panicked on nil receiver: %v", name, r) - } - }() - err := fn() - if !errors.Is(err, ErrNoRegistry) { - t.Errorf("%s: err = %v, want ErrNoRegistry", name, err) - } - } - - // callMap discards the map return and asserts the error contract. - callMap := func(name string, fn func() (map[string]interface{}, error)) { - t.Helper() - callErr(name, func() error { - _, err := fn() - return err - }) - } - - // --- void / no-error methods (must not panic; nothing else to assert) --- - func() { - defer func() { - if r := recover(); r != nil { - t.Errorf("SetSigner panicked on nil receiver: %v", r) - } - }() - c.SetSigner(func(string) string { return "" }) - }() - - func() { - defer func() { - if r := recover(); r != nil { - t.Errorf("Close panicked on nil receiver: %v", r) - } - }() - if err := c.Close(); err != nil { - t.Errorf("Close on nil receiver: err = %v, want nil", err) - } - }() - - // --- methods that return (map, error) — go through Send --- - callMap("Send", func() (map[string]interface{}, error) { - return c.Send(map[string]interface{}{"type": "ping"}) - }) - callMap("Register", func() (map[string]interface{}, error) { return c.Register("127.0.0.1:0") }) - callMap("RegisterWithOwner", func() (map[string]interface{}, error) { - return c.RegisterWithOwner("127.0.0.1:0", "owner") - }) - callMap("RegisterWithKey", func() (map[string]interface{}, error) { - return c.RegisterWithKey("127.0.0.1:0", "key", "owner", nil) - }) - callMap("RegisterWithKeyOpts", func() (map[string]interface{}, error) { - return c.RegisterWithKeyOpts(RegisterOpts{ListenAddr: "127.0.0.1:0", PublicKey: "k"}) - }) - callMap("RotateKey", func() (map[string]interface{}, error) { - return c.RotateKey(1, "sig", "newkey") - }) - callMap("Lookup", func() (map[string]interface{}, error) { return c.Lookup(1) }) - callMap("Resolve", func() (map[string]interface{}, error) { return c.Resolve(1, 2) }) - callMap("ReportTrust", func() (map[string]interface{}, error) { return c.ReportTrust(1, 2) }) - callMap("RevokeTrust", func() (map[string]interface{}, error) { return c.RevokeTrust(1, 2) }) - callMap("SetVisibility", func() (map[string]interface{}, error) { return c.SetVisibility(1, true) }) - callMap("CreateNetwork", func() (map[string]interface{}, error) { - return c.CreateNetwork(1, "name", "open", "tok", "admin", false) - }) - callMap("CreateManagedNetwork", func() (map[string]interface{}, error) { - return c.CreateManagedNetwork(1, "name", "open", "tok", "admin", false, "{}") - }) - callMap("JoinNetwork", func() (map[string]interface{}, error) { - return c.JoinNetwork(1, 2, "tok", 3, "admin") - }) - callMap("LeaveNetwork", func() (map[string]interface{}, error) { - return c.LeaveNetwork(1, 2, "admin") - }) - callMap("DeleteNetwork", func() (map[string]interface{}, error) { return c.DeleteNetwork(1, "admin") }) - callMap("RenameNetwork", func() (map[string]interface{}, error) { - return c.RenameNetwork(1, "new", "admin") - }) - callMap("SetNetworkEnterprise", func() (map[string]interface{}, error) { - return c.SetNetworkEnterprise(1, true, "admin") - }) - callMap("ListNetworks", func() (map[string]interface{}, error) { return c.ListNetworks() }) - callMap("ListNodes", func() (map[string]interface{}, error) { return c.ListNodes(1) }) - callMap("Deregister", func() (map[string]interface{}, error) { return c.Deregister(1) }) - callMap("Heartbeat", func() (map[string]interface{}, error) { return c.Heartbeat(1) }) - callMap("Punch", func() (map[string]interface{}, error) { return c.Punch(1, 2, 3) }) - callMap("RequestHandshake", func() (map[string]interface{}, error) { - return c.RequestHandshake(1, 2, "why", "sig") - }) - callMap("PollHandshakes", func() (map[string]interface{}, error) { return c.PollHandshakes(1) }) - callMap("RespondHandshake", func() (map[string]interface{}, error) { - return c.RespondHandshake(1, 2, true, "sig") - }) - callMap("SetHostname", func() (map[string]interface{}, error) { return c.SetHostname(1, "h") }) - callMap("SetTags", func() (map[string]interface{}, error) { return c.SetTags(1, []string{"t"}) }) - callMap("ResolveHostname", func() (map[string]interface{}, error) { return c.ResolveHostname("h") }) - callMap("ResolveHostnameAs", func() (map[string]interface{}, error) { - return c.ResolveHostnameAs(1, "h") - }) - callMap("InviteToNetwork", func() (map[string]interface{}, error) { - return c.InviteToNetwork(1, 2, 3, "admin") - }) - callMap("PollInvites", func() (map[string]interface{}, error) { return c.PollInvites(1) }) - callMap("RespondInvite", func() (map[string]interface{}, error) { - return c.RespondInvite(1, 2, true) - }) - callMap("PromoteMember", func() (map[string]interface{}, error) { - return c.PromoteMember(1, 2, 3, "admin") - }) - callMap("DemoteMember", func() (map[string]interface{}, error) { - return c.DemoteMember(1, 2, 3, "admin") - }) - callMap("KickMember", func() (map[string]interface{}, error) { - return c.KickMember(1, 2, 3, "admin") - }) - callMap("TransferOwnership", func() (map[string]interface{}, error) { - return c.TransferOwnership(1, 2, 3, "admin") - }) - callMap("GetMemberRole", func() (map[string]interface{}, error) { - return c.GetMemberRole(1, 2) - }) - callMap("SetNetworkPolicy", func() (map[string]interface{}, error) { - return c.SetNetworkPolicy(1, map[string]interface{}{}, "admin") - }) - callMap("GetNetworkPolicy", func() (map[string]interface{}, error) { - return c.GetNetworkPolicy(1) - }) - callMap("SetExprPolicy", func() (map[string]interface{}, error) { - return c.SetExprPolicy(1, json.RawMessage(`{}`), "admin") - }) - callMap("GetExprPolicy", func() (map[string]interface{}, error) { return c.GetExprPolicy(1) }) - callMap("SetKeyExpiry", func() (map[string]interface{}, error) { - return c.SetKeyExpiry(1, time.Now()) - }) - callMap("GetKeyInfo", func() (map[string]interface{}, error) { return c.GetKeyInfo(1) }) - callMap("SetHostnameAdmin", func() (map[string]interface{}, error) { - return c.SetHostnameAdmin(1, "h", "admin") - }) - callMap("SetVisibilityAdmin", func() (map[string]interface{}, error) { - return c.SetVisibilityAdmin(1, true, "admin") - }) - callMap("SetTagsAdmin", func() (map[string]interface{}, error) { - return c.SetTagsAdmin(1, []string{"t"}, "admin") - }) - callMap("SetMemberTags", func() (map[string]interface{}, error) { - return c.SetMemberTags(1, 2, []string{"t"}, "admin") - }) - callMap("GetMemberTags", func() (map[string]interface{}, error) { - return c.GetMemberTags(1, 2) - }) - callMap("SetKeyExpiryAdmin", func() (map[string]interface{}, error) { - return c.SetKeyExpiryAdmin(1, time.Now(), "admin") - }) - callMap("ClearKeyExpiryAdmin", func() (map[string]interface{}, error) { - return c.ClearKeyExpiryAdmin(1, "admin") - }) - callMap("DeregisterAdmin", func() (map[string]interface{}, error) { - return c.DeregisterAdmin(1, "admin") - }) - callMap("GetAuditLog", func() (map[string]interface{}, error) { - return c.GetAuditLog(1, "admin") - }) - callMap("SetWebhook", func() (map[string]interface{}, error) { - return c.SetWebhook("http://x", "admin") - }) - callMap("GetWebhook", func() (map[string]interface{}, error) { return c.GetWebhook("admin") }) - callMap("GetWebhookDLQ", func() (map[string]interface{}, error) { - return c.GetWebhookDLQ("admin") - }) - callMap("SetIdentityWebhook", func() (map[string]interface{}, error) { - return c.SetIdentityWebhook("http://x", "admin") - }) - callMap("SetExternalID", func() (map[string]interface{}, error) { - return c.SetExternalID(1, "ext", "admin") - }) - callMap("GetIdentity", func() (map[string]interface{}, error) { - return c.GetIdentity(1, "admin") - }) - callMap("ProvisionNetwork", func() (map[string]interface{}, error) { - return c.ProvisionNetwork(map[string]interface{}{}, "admin") - }) - callMap("SetAuditExport", func() (map[string]interface{}, error) { - return c.SetAuditExport("splunk", "https://x", "t", "i", "s", "admin") - }) - callMap("GetAuditExport", func() (map[string]interface{}, error) { - return c.GetAuditExport("admin") - }) - callMap("SetIDPConfig", func() (map[string]interface{}, error) { - return c.SetIDPConfig("oidc", "https://x", "iss", "cid", "tid", "dom", "admin") - }) - callMap("GetIDPConfig", func() (map[string]interface{}, error) { - return c.GetIDPConfig("admin") - }) - callMap("GetProvisionStatus", func() (map[string]interface{}, error) { - return c.GetProvisionStatus("admin") - }) - callMap("DirectorySync", func() (map[string]interface{}, error) { - return c.DirectorySync(1, nil, false, "admin") - }) - callMap("DirectoryStatus", func() (map[string]interface{}, error) { - return c.DirectoryStatus(1, "admin") - }) - callMap("ValidateToken", func() (map[string]interface{}, error) { - return c.ValidateToken("tok", "admin") - }) - - // --- CheckTrust: (bool, error) — pinned separately because the - // return type differs and a non-false bool would be misleading. - func() { - defer func() { - if r := recover(); r != nil { - t.Errorf("CheckTrust panicked on nil receiver: %v", r) - } - }() - ok, err := c.CheckTrust(1, 2) - if ok { - t.Errorf("CheckTrust: ok = true, want false") - } - if !errors.Is(err, ErrNoRegistry) { - t.Errorf("CheckTrust: err = %v, want ErrNoRegistry", err) - } - }() -} diff --git a/pkg/registry/client/zz_client_pool_test.go b/pkg/registry/client/zz_client_pool_test.go deleted file mode 100644 index 04c314ed..00000000 --- a/pkg/registry/client/zz_client_pool_test.go +++ /dev/null @@ -1,861 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package client - -import ( - "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/hex" - "encoding/json" - "encoding/pem" - "math/big" - "net" - "strings" - "sync" - "sync/atomic" - "testing" - "time" -) - -// Coverage push for pkg/registry/client targeting the previously 0% / low% -// surfaces in client.go: -// -// - DialPool / DialTLSPool / initPool -// - sendPool, sendOnEntry, reconnectEntry, isClosed -// - Close with pooled secondary conns -// - DialTLSPinned full verify path (fingerprint match + mismatch) -// - Send: reconnect-failure error wrap -// -// All fake servers are 127.0.0.1:0 TCP listeners that speak the -// length-prefixed JSON wire protocol used by Client.Send. - -// --- helpers ---------------------------------------------------------------- - -// genSelfSignedCert returns a fresh single-host self-signed cert+key plus the -// raw DER bytes (for pin fingerprint computation). Used by the TLS dial tests. -func genSelfSignedCert(t *testing.T) (tlsCert tls.Certificate, derBytes []byte) { - t.Helper() - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("genkey: %v", err) - } - tmpl := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "pilot-test"}, - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - KeyUsage: x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, - DNSNames: []string{"localhost"}, - } - der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv) - if err != nil { - t.Fatalf("create cert: %v", err) - } - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) - keyDER, err := x509.MarshalECPrivateKey(priv) - if err != nil { - t.Fatalf("marshal key: %v", err) - } - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) - tlsCert, err = tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - t.Fatalf("X509KeyPair: %v", err) - } - return tlsCert, der -} - -// newFakeTLSServer wraps the existing fakeJSONServer with a TLS listener -// (so DialTLS / DialTLSPool / DialTLSPinned can connect). -type fakeTLSServer struct { - ln net.Listener - cert tls.Certificate - der []byte - handler func(req map[string]interface{}) map[string]interface{} - connections atomic.Uint32 - wg sync.WaitGroup - closeOnce sync.Once -} - -func newFakeTLSServer(t *testing.T, handler func(req map[string]interface{}) map[string]interface{}) *fakeTLSServer { - t.Helper() - cert, der := genSelfSignedCert(t) - cfg := &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - } - ln, err := tls.Listen("tcp", "127.0.0.1:0", cfg) - if err != nil { - t.Fatalf("tls listen: %v", err) - } - s := &fakeTLSServer{ln: ln, cert: cert, der: der, handler: handler} - s.wg.Add(1) - go s.accept() - t.Cleanup(s.close) - return s -} - -func (s *fakeTLSServer) addr() string { return s.ln.Addr().String() } - -func (s *fakeTLSServer) close() { - s.closeOnce.Do(func() { - s.ln.Close() - s.wg.Wait() - }) -} - -func (s *fakeTLSServer) accept() { - defer s.wg.Done() - for { - conn, err := s.ln.Accept() - if err != nil { - return - } - s.connections.Add(1) - s.wg.Add(1) - go func() { - defer s.wg.Done() - handleJSONOverConn(conn, s.handler) - }() - } -} - -// handleJSONOverConn runs the standard 4-byte length-prefix JSON loop on a conn. -func handleJSONOverConn(conn net.Conn, handler func(req map[string]interface{}) map[string]interface{}) { - defer conn.Close() - for { - var lenBuf [4]byte - if _, err := readFullN(conn, lenBuf[:]); err != nil { - return - } - n := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3]) - if n > 1<<20 { - return - } - body := make([]byte, n) - if _, err := readFullN(conn, body); err != nil { - return - } - req := map[string]interface{}{} - if err := jsonUnmarshalLite(body, &req); err != nil { - return - } - resp := handler(req) - if resp == nil { - return - } - out, _ := jsonMarshalLite(resp) - var outLen [4]byte - outLen[0] = byte(len(out) >> 24) - outLen[1] = byte(len(out) >> 16) - outLen[2] = byte(len(out) >> 8) - outLen[3] = byte(len(out)) - conn.Write(outLen[:]) - conn.Write(out) - } -} - -// Thin wrappers around encoding/json so the per-conn read loop helper stays -// readable. Same framing as the canonical fakeJSONServer.handle(). -func jsonUnmarshalLite(b []byte, v interface{}) error { return json.Unmarshal(b, v) } -func jsonMarshalLite(v interface{}) ([]byte, error) { return json.Marshal(v) } - -func readFullN(r net.Conn, buf []byte) (int, error) { - total := 0 - for total < len(buf) { - n, err := r.Read(buf[total:]) - if n > 0 { - total += n - } - if err != nil { - return total, err - } - } - return total, nil -} - -// --- DialPool / sendPool basic happy path ---------------------------------- - -func TestDialPoolSizeOneIsSingleConn(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 1) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - defer c.Close() - // size == 1 → no secondary pool entries, free chan stays nil. - if c.pool.free != nil { - t.Fatalf("pool.free should be nil for size=1, got %v", c.pool.free) - } - // Send still works via legacy single-conn path. - resp, err := c.Send(map[string]interface{}{"type": "hello"}) - if err != nil { - t.Fatalf("send: %v", err) - } - if got, _ := resp["type"].(string); got != "ok" { - t.Fatalf("type: %q", got) - } -} - -func TestDialPoolMultiConnExercisesSendPool(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 4) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - defer c.Close() - - if c.pool.free == nil { - t.Fatalf("pool.free should be initialised for size>1") - } - if len(c.pool.entries) != 4 { - t.Fatalf("pool entries: want 4, got %d", len(c.pool.entries)) - } - // Each pool entry corresponds to a real conn on the server. - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - if srv.connections.Load() == 4 { - break - } - time.Sleep(10 * time.Millisecond) - } - if srv.connections.Load() != 4 { - t.Fatalf("server connections: want 4, got %d", srv.connections.Load()) - } - - // A serial Send drives sendPool / sendOnEntry on entry[0] (or whichever - // is free), exercising the lock+round-trip path. - resp, err := c.Send(map[string]interface{}{"type": "ping"}) - if err != nil { - t.Fatalf("send: %v", err) - } - if got, _ := resp["type"].(string); got != "ok" { - t.Fatalf("type: %q", got) - } -} - -// TestDialPoolZeroOrNegativeSizeNormalisesToOne covers the size<=0 branch. -func TestDialPoolZeroOrNegativeSizeNormalisesToOne(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 0) - if err != nil { - t.Fatalf("DialPool(0): %v", err) - } - defer c.Close() - if c.pool.free != nil { - t.Fatalf("size<=0 should normalise to 1 (no pool)") - } - - c2, err := DialPool(srv.addr(), -3) - if err != nil { - t.Fatalf("DialPool(-3): %v", err) - } - defer c2.Close() - if c2.pool.free != nil { - t.Fatalf("negative size should normalise to 1 (no pool)") - } -} - -func TestDialPoolErrorOnUnreachable(t *testing.T) { - t.Parallel() - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - addr := ln.Addr().String() - ln.Close() - if _, err := DialPool(addr, 3); err == nil { - t.Fatalf("expected DialPool error on unreachable addr") - } -} - -// TestDialPoolPartialSecondaryFailureClosesPrimary covers the -// "primary dialed, secondary dial failed → close primary, return err" branch -// inside initPool. We accept the primary conn then close the listener to make -// secondary dials fail. -func TestDialPoolPartialSecondaryFailureClosesPrimary(t *testing.T) { - t.Parallel() - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - addr := ln.Addr().String() - - accepted := make(chan net.Conn, 1) - go func() { - conn, err := ln.Accept() - if err != nil { - return - } - accepted <- conn - // Close listener immediately so the next net.Dial inside initPool - // fails with ECONNREFUSED. - ln.Close() - }() - - _, err = DialPool(addr, 4) - if err == nil { - t.Fatalf("expected DialPool to fail when secondary dial errors") - } - if !strings.Contains(err.Error(), "dial pool conn") { - t.Fatalf("error should mention dial pool conn, got: %v", err) - } - - // Clean up the primary conn that the server accepted. - select { - case c := <-accepted: - c.Close() - case <-time.After(time.Second): - } -} - -// --- sendPool: closed-client guards --------------------------------------- - -func TestSendPoolAfterCloseFailsFast(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 3) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - c.Close() - - _, err = c.Send(map[string]interface{}{"type": "x"}) - if err == nil { - t.Fatalf("expected error after Close") - } - if !strings.Contains(err.Error(), "closed") { - t.Fatalf("expected 'closed' in error, got: %v", err) - } -} - -// TestSendPoolUnblocksOnCloseWhileBlocked ensures that a goroutine blocked -// in <-c.pool.free is woken up by Close (via the done channel select). -// We exhaust the pool, then Close, then assert that a pending Send returns -// with a "closed" error rather than hanging. -func TestSendPoolUnblocksOnCloseWhileBlocked(t *testing.T) { - t.Parallel() - // Handler that blocks until we tell it to release — so the in-flight - // Send holds its pool entry indefinitely. Pool size 1 → second Send - // is blocked on <-c.pool.free. - release := make(chan struct{}) - srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { - <-release - return map[string]interface{}{"type": "ok"} - }) - defer srv.close() - - c, err := DialPool(srv.addr(), 2) // primary + 1 secondary - if err != nil { - t.Fatalf("DialPool: %v", err) - } - - // Saturate both pool entries with in-flight Sends. - for i := 0; i < 2; i++ { - go func() { - _, _ = c.Send(map[string]interface{}{"type": "block"}) - }() - } - // Give them a chance to grab pool entries. - time.Sleep(50 * time.Millisecond) - - // Third Send blocks in <-c.pool.free. - thirdDone := make(chan error, 1) - go func() { - _, err := c.Send(map[string]interface{}{"type": "third"}) - thirdDone <- err - }() - time.Sleep(50 * time.Millisecond) - - // Close should unblock the waiter via the <-c.pool.done branch. - c.Close() - close(release) // let the in-flight Sends drain - - select { - case err := <-thirdDone: - if err == nil { - t.Fatalf("third Send should have errored after Close") - } - if !strings.Contains(err.Error(), "closed") { - t.Fatalf("expected 'closed' error, got: %v", err) - } - case <-time.After(2 * time.Second): - t.Fatalf("third Send did not return after Close — pool.done branch not wired") - } -} - -// --- sendPool: per-entry reconnect when the conn dies --------------------- - -func TestSendPoolReconnectsBrokenEntry(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 2) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - defer c.Close() - - // Kill the primary entry's conn so the next Send hitting it triggers - // reconnectEntry. We can't predict which entry the channel picks, so - // kill both — sendOnEntry on whichever one we get will fail then reconnect. - for _, e := range c.pool.entries { - e.mu.Lock() - _ = e.conn.Close() - e.mu.Unlock() - } - - resp, err := c.Send(map[string]interface{}{"type": "ping"}) - if err != nil { - t.Fatalf("send after killing entry: %v", err) - } - if got, _ := resp["type"].(string); got != "ok" { - t.Fatalf("type: %q", got) - } - // The reconnect path must have produced a new TCP conn. - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - if srv.connections.Load() >= 3 { - break - } - time.Sleep(10 * time.Millisecond) - } - if srv.connections.Load() < 3 { - t.Fatalf("expected reconnect to open a new conn, server saw %d", srv.connections.Load()) - } -} - -// TestReconnectEntrySyncsPrimary verifies that when the primary entry -// (entries[0]) is reconnected, c.conn is updated in lockstep so callers -// reading c.conn directly don't see a stale fd. This is the "if entry == -// c.pool.entries[0]" branch in reconnectEntry. -func TestReconnectEntrySyncsPrimary(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 2) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - defer c.Close() - - // Acquire entries[0] off the free channel directly to guarantee we're - // reconnecting the primary slot. - primary := <-c.pool.free - // Sanity: should be entries[0] OR entries[1]; force primary if not. - if primary != c.pool.entries[0] { - // put it back, grab the actual primary. - c.pool.free <- primary - // drain until we get entries[0]. - for i := 0; i < 4; i++ { - candidate := <-c.pool.free - if candidate == c.pool.entries[0] { - primary = candidate - break - } - c.pool.free <- candidate - } - } - - oldConn := c.conn - primary.mu.Lock() - _ = primary.conn.Close() - if err := c.reconnectEntry(context.Background(), primary); err != nil { - primary.mu.Unlock() - t.Fatalf("reconnectEntry: %v", err) - } - newConn := primary.conn - primary.mu.Unlock() - - if newConn == oldConn { - t.Fatalf("primary conn was not replaced by reconnectEntry") - } - // c.conn should be in sync with the new primary conn. - c.mu.Lock() - if c.conn != newConn { - c.mu.Unlock() - t.Fatalf("c.conn not synced after primary reconnect") - } - c.mu.Unlock() - - // Put the entry back so Close doesn't deadlock. - c.pool.free <- primary -} - -// TestReconnectEntryFailsWhenClosed exercises the early-return-on-closed -// branch of reconnectEntry. -func TestReconnectEntryFailsWhenClosed(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 2) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - c.Close() - - if err := c.reconnectEntry(context.Background(), c.pool.entries[0]); err == nil { - t.Fatalf("reconnectEntry on closed client should fail") - } -} - -// --- isClosed ------------------------------------------------------------- - -func TestIsClosedReflectsCloseState(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 2) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - if c.isClosed() { - t.Fatalf("fresh client should not report closed") - } - c.Close() - if !c.isClosed() { - t.Fatalf("client should report closed after Close()") - } -} - -// --- DialTLSPool --------------------------------------------------------- - -func TestDialTLSPoolNilConfigReturnsError(t *testing.T) { - t.Parallel() - if _, err := DialTLSPool("127.0.0.1:1", nil, 2); err == nil { - t.Fatalf("expected nil config error") - } -} - -func TestDialTLSPoolSucceedsAndDialsSize(t *testing.T) { - t.Parallel() - srv := newFakeTLSServer(t, echoHandler()) - - clientCfg := &tls.Config{ - MinVersion: tls.VersionTLS12, - InsecureSkipVerify: true, //nolint:gosec // test-only - } - c, err := DialTLSPool(srv.addr(), clientCfg, 3) - if err != nil { - t.Fatalf("DialTLSPool: %v", err) - } - defer c.Close() - if len(c.pool.entries) != 3 { - t.Fatalf("pool entries: %d, want 3", len(c.pool.entries)) - } - resp, err := c.Send(map[string]interface{}{"type": "hello"}) - if err != nil { - t.Fatalf("send over TLS pool: %v", err) - } - if got, _ := resp["type"].(string); got != "ok" { - t.Fatalf("type: %q", got) - } -} - -func TestDialTLSPoolSizeOneIsSingleConn(t *testing.T) { - t.Parallel() - srv := newFakeTLSServer(t, echoHandler()) - clientCfg := &tls.Config{ - MinVersion: tls.VersionTLS12, - InsecureSkipVerify: true, //nolint:gosec // test-only - } - c, err := DialTLSPool(srv.addr(), clientCfg, 1) - if err != nil { - t.Fatalf("DialTLSPool size=1: %v", err) - } - defer c.Close() - if c.pool.free != nil { - t.Fatalf("size=1 should not initialise pool channel") - } -} - -func TestDialTLSPoolDialErrorWrapsMessage(t *testing.T) { - t.Parallel() - ln, _ := net.Listen("tcp", "127.0.0.1:0") - addr := ln.Addr().String() - ln.Close() - cfg := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12} //nolint:gosec // test-only - if _, err := DialTLSPool(addr, cfg, 2); err == nil { - t.Fatalf("expected DialTLSPool error on unreachable addr") - } -} - -// --- DialTLSPinned: full verify path ------------------------------------- - -func TestDialTLSPinnedAcceptsMatchingFingerprint(t *testing.T) { - t.Parallel() - srv := newFakeTLSServer(t, echoHandler()) - - sum := sha256.Sum256(srv.der) - fp := hex.EncodeToString(sum[:]) - - c, err := DialTLSPinned(srv.addr(), fp) - if err != nil { - t.Fatalf("DialTLSPinned (matching fp): %v", err) - } - defer c.Close() - resp, err := c.Send(map[string]interface{}{"type": "hi"}) - if err != nil { - t.Fatalf("send over pinned conn: %v", err) - } - if got, _ := resp["type"].(string); got != "ok" { - t.Fatalf("type: %q", got) - } -} - -func TestDialTLSPinnedRejectsMismatchedFingerprint(t *testing.T) { - t.Parallel() - srv := newFakeTLSServer(t, echoHandler()) - - _, err := DialTLSPinned(srv.addr(), "00112233445566778899aabbccddeeff") - if err == nil { - t.Fatalf("expected fingerprint mismatch error") - } - if !strings.Contains(err.Error(), "fingerprint mismatch") && - !strings.Contains(err.Error(), "dial registry TLS pinned") { - t.Fatalf("error should mention fingerprint mismatch or pinned dial: %v", err) - } -} - -// --- Concurrent Send under -race confirms the regConn mutex is real ------ - -func TestSendConcurrentRaceFreeOnSingleConn(t *testing.T) { - t.Parallel() - var counter atomic.Uint64 - srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { - counter.Add(1) - return map[string]interface{}{"type": "ok"} - }) - defer srv.close() - - c, err := Dial(srv.addr()) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer c.Close() - - const goroutines = 16 - const callsEach = 25 - var wg sync.WaitGroup - wg.Add(goroutines) - for i := 0; i < goroutines; i++ { - go func() { - defer wg.Done() - for j := 0; j < callsEach; j++ { - if _, err := c.Send(map[string]interface{}{"type": "x"}); err != nil { - t.Errorf("send: %v", err) - return - } - } - }() - } - wg.Wait() - if got := counter.Load(); got != uint64(goroutines*callsEach) { - t.Fatalf("server saw %d requests, want %d", got, goroutines*callsEach) - } -} - -func TestSendConcurrentRaceFreeOnPool(t *testing.T) { - t.Parallel() - var counter atomic.Uint64 - srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { - counter.Add(1) - return map[string]interface{}{"type": "ok"} - }) - defer srv.close() - - c, err := DialPool(srv.addr(), 4) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - defer c.Close() - - const goroutines = 16 - const callsEach = 25 - var wg sync.WaitGroup - wg.Add(goroutines) - for i := 0; i < goroutines; i++ { - go func() { - defer wg.Done() - for j := 0; j < callsEach; j++ { - if _, err := c.Send(map[string]interface{}{"type": "x"}); err != nil { - t.Errorf("send: %v", err) - return - } - } - }() - } - wg.Wait() - if got := counter.Load(); got != uint64(goroutines*callsEach) { - t.Fatalf("server saw %d requests, want %d", got, goroutines*callsEach) - } -} - -// --- Send (single-conn) reconnect-failure branch ------------------------ - -// TestSendReconnectFailureSurfacesWrappedError covers the legacy-path -// branch: send fails, reconnect also fails → Client returns a "send failed -// and reconnect failed" wrap. We close the server first, then make Send -// hit a dead conn — both attempts (initial + reconnect) fail. -func TestSendReconnectFailureSurfacesWrappedError(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - c, err := Dial(srv.addr()) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer c.Close() - - // Kill the server first so the reconnect dial inside Send also fails. - srv.close() - // Also close the local end so the very first WriteMessage errors quickly. - c.mu.Lock() - _ = c.conn.Close() - c.mu.Unlock() - - _, err = c.Send(map[string]interface{}{"type": "x"}) - if err == nil { - t.Fatalf("expected error when both send and reconnect fail") - } - // Could be either "send failed and reconnect failed" or a raw send/recv - // error — accept any failure. - if err.Error() == "" { - t.Fatalf("error message must not be empty") - } -} - -// TestPoolSendReconnectFailureSurfacesWrappedError is the pool-path analogue. -func TestPoolSendReconnectFailureSurfacesWrappedError(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - c, err := DialPool(srv.addr(), 2) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - defer c.Close() - - // Kill the server, then kill both pool entries' conns so the round-trip - // fails AND reconnectEntry's dial fails. - srv.close() - for _, e := range c.pool.entries { - e.mu.Lock() - _ = e.conn.Close() - e.mu.Unlock() - } - - _, err = c.Send(map[string]interface{}{"type": "x"}) - if err == nil { - t.Fatalf("expected pool-path error when both send and reconnect fail") - } -} - -// --- Close: pool with secondary entries ----------------------------------- - -func TestClosePoolReleasesAllSecondaryConns(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 3) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - - // All entries should currently have non-nil conns. - conns := make([]net.Conn, len(c.pool.entries)) - for i, e := range c.pool.entries { - conns[i] = e.conn - } - - if err := c.Close(); err != nil { - t.Fatalf("Close: %v", err) - } - - // Every conn (primary + secondary) should now be unusable. - for i, conn := range conns { - if _, err := conn.Write([]byte{0}); err == nil { - t.Fatalf("conn %d should be closed after Close()", i) - } - } -} - -// --- Misc small branch fills --------------------------------------------- - -// Verify the helper Send returns a "client closed" error when isClosed -// is true AND we still try to send (covers the closed-guard inside sendPool). -func TestSendPoolReturnsClosedErrorBeforeAcquire(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := DialPool(srv.addr(), 2) - if err != nil { - t.Fatalf("DialPool: %v", err) - } - c.Close() - - _, err = c.Send(map[string]interface{}{"type": "x"}) - if err == nil || !strings.Contains(err.Error(), "closed") { - t.Fatalf("expected closed error, got: %v", err) - } -} - -// Smoke-test the RegisterWithKeyOpts RelayOnly + LANAddrs branch (the only -// "false" branches not exercised by existing tests). -func TestRegisterWithKeyOptsRelayOnlySerialized(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, err := c.RegisterWithKeyOpts(RegisterOpts{ - ListenAddr: "x:1", - PublicKey: "PUB", - LANAddrs: []string{"10.0.0.1:1"}, - RelayOnly: true, - }) - if err != nil { - t.Fatalf("register: %v", err) - } - echo := assertEcho(t, resp) - if got, _ := echo["relay_only"].(bool); !got { - t.Fatalf("relay_only: %v", got) - } - if _, ok := echo["owner"]; ok { - t.Fatalf("owner should be omitted when blank") - } -} - -// Ensure RegisterWithKey with multiple version variadic args picks the first -// non-empty (firstNonEmpty branch). -func TestRegisterWithKeyFirstNonEmptyVersion(t *testing.T) { - t.Parallel() - c, _ := echoOnlyClient(t) - resp, err := c.RegisterWithKey("x:1", "PUB", "", nil, "", "", "v2.0.0", "v3.0.0") - if err != nil { - t.Fatalf("register: %v", err) - } - echo := assertEcho(t, resp) - if got, _ := echo["version"].(string); got != "v2.0.0" { - t.Fatalf("version: want v2.0.0 (first non-empty), got %q", got) - } -} diff --git a/pkg/registry/client/zz_client_wire_test.go b/pkg/registry/client/zz_client_wire_test.go deleted file mode 100644 index b8cd93fe..00000000 --- a/pkg/registry/client/zz_client_wire_test.go +++ /dev/null @@ -1,618 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package client - -import ( - "crypto/tls" - "encoding/binary" - "encoding/json" - "io" - "net" - "strings" - "sync/atomic" - "testing" -) - -// fakeJSONServer speaks the registry JSON-over-TCP wire protocol -// (4-byte big-endian length prefix + JSON body). Each connection handshake -// is dispatched to a handler callback that can read the request and write -// a reply. -type fakeJSONServer struct { - ln net.Listener - handler func(req map[string]interface{}) map[string]interface{} - requests atomic.Uint32 - connections atomic.Uint32 - done chan struct{} -} - -func newFakeJSONServer(t *testing.T, handler func(req map[string]interface{}) map[string]interface{}) *fakeJSONServer { - t.Helper() - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - s := &fakeJSONServer{ln: ln, handler: handler, done: make(chan struct{})} - go s.accept() - return s -} - -func (s *fakeJSONServer) addr() string { return s.ln.Addr().String() } - -func (s *fakeJSONServer) close() { s.ln.Close(); close(s.done) } - -func (s *fakeJSONServer) accept() { - for { - conn, err := s.ln.Accept() - if err != nil { - return - } - s.connections.Add(1) - go s.handle(conn) - } -} - -func (s *fakeJSONServer) handle(conn net.Conn) { - defer conn.Close() - for { - var lenBuf [4]byte - if _, err := io.ReadFull(conn, lenBuf[:]); err != nil { - return - } - n := binary.BigEndian.Uint32(lenBuf[:]) - // Defensive cap: any caller that sends non-JSON framing (e.g. TLS - // ClientHello) would otherwise block this goroutine in io.ReadFull - // until the full test timeout. - if n > 1<<20 { - return - } - body := make([]byte, n) - if _, err := io.ReadFull(conn, body); err != nil { - return - } - var req map[string]interface{} - if err := json.Unmarshal(body, &req); err != nil { - return - } - s.requests.Add(1) - resp := s.handler(req) - if resp == nil { - return - } - out, _ := json.Marshal(resp) - var outLen [4]byte - binary.BigEndian.PutUint32(outLen[:], uint32(len(out))) - conn.Write(outLen[:]) - conn.Write(out) - } -} - -// Echo the request type, plus include every key that was sent, under "echo". -// Tests can assert that the wire payload carried the right keys. -func echoHandler() func(map[string]interface{}) map[string]interface{} { - return func(req map[string]interface{}) map[string]interface{} { - resp := map[string]interface{}{"type": "ok", "echo": req} - return resp - } -} - -// --- Dial / Close / Addr ---------------------------------------------------- - -func TestDialSuccessReturnsClientWithAddr(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := Dial(srv.addr()) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer c.Close() - - if c.addr != srv.addr() { - t.Fatalf("addr: want %q, got %q", srv.addr(), c.addr) - } - if c.conn == nil { - t.Fatalf("conn should be set") - } -} - -func TestDialErrorOnBadAddress(t *testing.T) { - t.Parallel() - // Grab a port from the kernel and immediately release it so Dial - // fails fast with ECONNREFUSED on loopback (no DNS/route wait). - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - addr := ln.Addr().String() - ln.Close() - - _, err = Dial(addr) - if err == nil { - t.Fatalf("expected error") - } - if !strings.Contains(err.Error(), "dial registry") { - t.Fatalf("error should mention dial registry: %v", err) - } -} - -func TestDialTLSReturnsErrorWhenConfigNil(t *testing.T) { - t.Parallel() - if _, err := DialTLS("127.0.0.1:1", nil); err == nil { - t.Fatalf("expected error on nil tlsConfig") - } -} - -// closeOnAcceptListener accepts each connection and immediately closes it, so -// a TLS dial against it fails fast with EOF during the handshake. -func closeOnAcceptListener(t *testing.T) (addr string, stop func()) { - t.Helper() - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - go func() { - for { - conn, err := ln.Accept() - if err != nil { - return - } - conn.Close() - } - }() - return ln.Addr().String(), func() { ln.Close() } -} - -func TestDialTLSFailsConnectToPlainServer(t *testing.T) { - t.Parallel() - addr, stop := closeOnAcceptListener(t) - defer stop() - _, err := DialTLS(addr, minimalTLSConfig()) - if err == nil { - t.Fatalf("expected TLS error") - } - if !strings.Contains(err.Error(), "dial registry TLS") { - t.Fatalf("error should mention TLS dial: %v", err) - } -} - -func TestDialTLSPinnedFailsConnectToPlainServer(t *testing.T) { - t.Parallel() - addr, stop := closeOnAcceptListener(t) - defer stop() - _, err := DialTLSPinned(addr, "deadbeef") - if err == nil { - t.Fatalf("expected TLS pin error") - } - if !strings.Contains(err.Error(), "dial registry TLS pinned") { - t.Fatalf("error should mention TLS pinned dial: %v", err) - } -} - -func TestCloseSafeWhenNilConn(t *testing.T) { - t.Parallel() - c := &Client{} - if err := c.Close(); err != nil { - t.Fatalf("Close on empty client should not error: %v", err) - } - if !c.closed { - t.Fatalf("client should report closed after Close()") - } -} - -func TestCloseClosesRealConn(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, err := Dial(srv.addr()) - if err != nil { - t.Fatalf("dial: %v", err) - } - if err := c.Close(); err != nil { - t.Fatalf("close: %v", err) - } - // After Close, conn.Write should error. - if _, err := c.conn.Write([]byte{0}); err == nil { - t.Fatalf("expected write error after Close") - } -} - -// --- Signer ----------------------------------------------------------------- - -func TestSetSignerReturnsSignature(t *testing.T) { - t.Parallel() - c := &Client{} - sig, err := c.sign("whatever") - if err == nil { - t.Fatalf("expected error with no signer, got sig=%q", sig) - } - c.SetSigner(func(challenge string) string { - return "sig(" + challenge + ")" - }) - sig, err = c.sign("abc") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if sig != "sig(abc)" { - t.Fatalf("expected sig(abc), got %q", sig) - } -} - -func TestResolveIncludesSignatureWhenSignerSet(t *testing.T) { - t.Parallel() - var gotChallenge string - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := Dial(srv.addr()) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer c.Close() - - c.SetSigner(func(challenge string) string { - gotChallenge = challenge - return "SIG" - }) - resp, err := c.Resolve(42, 7) - if err != nil { - t.Fatalf("resolve: %v", err) - } - if gotChallenge != "resolve:7:42" { - t.Fatalf("challenge: want resolve:7:42, got %q", gotChallenge) - } - echo, _ := resp["echo"].(map[string]interface{}) - if sig, _ := echo["signature"].(string); sig != "SIG" { - t.Fatalf("signature wire value: want SIG, got %q", sig) - } -} - -// --- Send / sendLocked ------------------------------------------------------ - -func TestSendReturnsServerErrorResponse(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { - return map[string]interface{}{"error": "boom"} - }) - defer srv.close() - - c, err := Dial(srv.addr()) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer c.Close() - - resp, err := c.Send(map[string]interface{}{"type": "ping"}) - if err == nil { - t.Fatalf("expected server error") - } - if !strings.Contains(err.Error(), "boom") { - t.Fatalf("error should contain 'boom': %v", err) - } - // resp is non-nil for server errors so the caller can inspect it. - if resp == nil { - t.Fatalf("expected non-nil response on server error") - } -} - -func TestSendHappyPath(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, err := Dial(srv.addr()) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer c.Close() - - resp, err := c.Send(map[string]interface{}{"type": "hello", "num": float64(3)}) - if err != nil { - t.Fatalf("send: %v", err) - } - if got, _ := resp["type"].(string); got != "ok" { - t.Fatalf("type: want ok, got %q", got) - } -} - -func TestSendReconnectsAfterDroppedConnection(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - - c, err := Dial(srv.addr()) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer c.Close() - - // Simulate a connection-level failure by closing the client's conn - // without marking the Client closed. The next Send should reconnect. - c.mu.Lock() - _ = c.conn.Close() - c.mu.Unlock() - - resp, err := c.Send(map[string]interface{}{"type": "hello"}) - if err != nil { - t.Fatalf("send after reconnect: %v", err) - } - if got, _ := resp["type"].(string); got != "ok" { - t.Fatalf("type: want ok, got %q", got) - } - if srv.connections.Load() < 2 { - t.Fatalf("expected second connection from reconnect, got %d", srv.connections.Load()) - } -} - -func TestSendFailsWhenClosed(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, err := Dial(srv.addr()) - if err != nil { - t.Fatalf("dial: %v", err) - } - c.Close() - - _, err = c.Send(map[string]interface{}{"type": "hello"}) - if err == nil { - t.Fatalf("expected error after Close") - } -} - -// --- Register family -------------------------------------------------------- - -func TestRegisterSendsCorrectWireMessage(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, err := Dial(srv.addr()) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer c.Close() - - resp, err := c.Register("1.2.3.4:4000") - if err != nil { - t.Fatalf("register: %v", err) - } - echo, _ := resp["echo"].(map[string]interface{}) - if got, _ := echo["type"].(string); got != "register" { - t.Fatalf("wire type: want register, got %q", got) - } - if got, _ := echo["listen_addr"].(string); got != "1.2.3.4:4000" { - t.Fatalf("listen_addr: want 1.2.3.4:4000, got %q", got) - } -} - -func TestRegisterWithOwnerIncludesOwner(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - - resp, err := c.RegisterWithOwner("x:1", "alice@example.com") - if err != nil { - t.Fatalf("register: %v", err) - } - echo, _ := resp["echo"].(map[string]interface{}) - if got, _ := echo["owner"].(string); got != "alice@example.com" { - t.Fatalf("owner: %q", got) - } -} - -func TestRegisterWithKeyOmitsBlankOwnerAndLAN(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - - resp, err := c.RegisterWithKey("x:1", "PUB==", "", nil) - if err != nil { - t.Fatalf("register: %v", err) - } - echo, _ := resp["echo"].(map[string]interface{}) - if _, ok := echo["owner"]; ok { - t.Fatalf("owner should be omitted when blank") - } - if _, ok := echo["lan_addrs"]; ok { - t.Fatalf("lan_addrs should be omitted when empty") - } - if _, ok := echo["version"]; ok { - t.Fatalf("version should be omitted when not supplied") - } - if got, _ := echo["public_key"].(string); got != "PUB==" { - t.Fatalf("public_key: %q", got) - } -} - -func TestRegisterWithKeyIncludesAllFields(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - - resp, err := c.RegisterWithKey("x:1", "PUB==", "bob", []string{"10.0.0.1:80"}, "v1.2.3") - if err != nil { - t.Fatalf("register: %v", err) - } - echo, _ := resp["echo"].(map[string]interface{}) - if got, _ := echo["owner"].(string); got != "bob" { - t.Fatalf("owner: %q", got) - } - if got, _ := echo["version"].(string); got != "v1.2.3" { - t.Fatalf("version: %q", got) - } - lan, _ := echo["lan_addrs"].([]interface{}) - if len(lan) != 1 || lan[0] != "10.0.0.1:80" { - t.Fatalf("lan_addrs: %v", lan) - } -} - -// --- Lookup / Resolve / ReportTrust / RevokeTrust / SetVisibility ---------- - -func TestLookupSendsNodeID(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - - resp, err := c.Lookup(42) - if err != nil { - t.Fatalf("lookup: %v", err) - } - echo, _ := resp["echo"].(map[string]interface{}) - if got, _ := echo["type"].(string); got != "lookup" { - t.Fatalf("type: %q", got) - } - if got := uint32(echo["node_id"].(float64)); got != 42 { - t.Fatalf("node_id: %d", got) - } -} - -func TestReportTrustAndRevokeTrustFormat(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - c.SetSigner(func(ch string) string { return "SIG:" + ch }) - - for name, fn := range map[string]func() (map[string]interface{}, error){ - "report_trust": func() (map[string]interface{}, error) { return c.ReportTrust(1, 2) }, - "revoke_trust": func() (map[string]interface{}, error) { return c.RevokeTrust(1, 2) }, - } { - resp, err := fn() - if err != nil { - t.Fatalf("%s: %v", name, err) - } - echo, _ := resp["echo"].(map[string]interface{}) - if got, _ := echo["type"].(string); got != name { - t.Fatalf("%s: type=%q", name, got) - } - if got := uint32(echo["node_id"].(float64)); got != 1 { - t.Fatalf("%s: node_id=%d", name, got) - } - if got := uint32(echo["peer_id"].(float64)); got != 2 { - t.Fatalf("%s: peer_id=%d", name, got) - } - } -} - -func TestSetVisibilityPublicFlagSerialized(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - c.SetSigner(func(ch string) string { return "SIG:" + ch }) - - resp, err := c.SetVisibility(9, true) - if err != nil { - t.Fatalf("set_visibility: %v", err) - } - echo, _ := resp["echo"].(map[string]interface{}) - if got, _ := echo["public"].(bool); got != true { - t.Fatalf("public: %v", got) - } -} - -// --- CreateNetwork / CreateManagedNetwork ---------------------------------- - -func TestCreateNetworkBasicAndFull(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - - // Basic: no adminToken, no enterprise, no networkAdminToken. - resp, err := c.CreateNetwork(1, "foo", "public", "tok", "", false) - if err != nil { - t.Fatalf("create: %v", err) - } - echo, _ := resp["echo"].(map[string]interface{}) - if _, ok := echo["admin_token"]; ok { - t.Fatalf("admin_token should be omitted when blank") - } - if _, ok := echo["enterprise"]; ok { - t.Fatalf("enterprise should be omitted when false") - } - if _, ok := echo["network_admin_token"]; ok { - t.Fatalf("network_admin_token should be omitted when not supplied") - } - - // Full: adminToken + enterprise + networkAdminToken. - resp, err = c.CreateNetwork(1, "foo", "public", "tok", "ADM", true, "NAT") - if err != nil { - t.Fatalf("create full: %v", err) - } - echo, _ = resp["echo"].(map[string]interface{}) - if got, _ := echo["admin_token"].(string); got != "ADM" { - t.Fatalf("admin_token: %q", got) - } - if got, _ := echo["enterprise"].(bool); !got { - t.Fatalf("enterprise: %v", got) - } - if got, _ := echo["network_admin_token"].(string); got != "NAT" { - t.Fatalf("network_admin_token: %q", got) - } -} - -func TestCreateManagedNetworkIncludesRules(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - - resp, err := c.CreateManagedNetwork(2, "n", "invite", "tok", "", false, `{"a":1}`) - if err != nil { - t.Fatalf("managed: %v", err) - } - echo, _ := resp["echo"].(map[string]interface{}) - if got, _ := echo["rules"].(string); got != `{"a":1}` { - t.Fatalf("rules: %q", got) - } -} - -// --- RotateKey -------------------------------------------------------------- - -func TestRotateKeyOmitsBlankSignatureAndPubKey(t *testing.T) { - t.Parallel() - srv := newFakeJSONServer(t, echoHandler()) - defer srv.close() - c, _ := Dial(srv.addr()) - defer c.Close() - - resp, err := c.RotateKey(7, "", "") - if err != nil { - t.Fatalf("rotate: %v", err) - } - echo, _ := resp["echo"].(map[string]interface{}) - if _, ok := echo["signature"]; ok { - t.Fatalf("signature should be omitted when blank") - } - if _, ok := echo["new_public_key"]; ok { - t.Fatalf("new_public_key should be omitted when blank") - } - - resp, err = c.RotateKey(7, "SIG", "NPK") - if err != nil { - t.Fatalf("rotate full: %v", err) - } - echo, _ = resp["echo"].(map[string]interface{}) - if got, _ := echo["signature"].(string); got != "SIG" { - t.Fatalf("signature: %q", got) - } - if got, _ := echo["new_public_key"].(string); got != "NPK" { - t.Fatalf("new_public_key: %q", got) - } -} - -func minimalTLSConfig() *tls.Config { - return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} //nolint:gosec // test-only -} diff --git a/pkg/registry/wire/blueprint.go b/pkg/registry/wire/blueprint.go deleted file mode 100644 index 62129f97..00000000 --- a/pkg/registry/wire/blueprint.go +++ /dev/null @@ -1,178 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package wire - -import ( - "encoding/json" - "fmt" - "os" - - "github.com/TeoSlayer/pilotprotocol/pkg/urlvalidate" -) - -// NetworkBlueprint defines a declarative configuration for provisioning -// an enterprise network. Enterprises apply blueprints via the registry -// protocol or the pilotctl CLI to create and configure networks in one shot. -type NetworkBlueprint struct { - // Network settings - Name string `json:"name"` - JoinRule string `json:"join_rule,omitempty"` // "open", "token", "invite" (default: "open") - JoinToken string `json:"join_token,omitempty"` // required if join_rule = "token" - Enterprise bool `json:"enterprise,omitempty"` // enable enterprise features - - // Policy - Policy *BlueprintPolicy `json:"policy,omitempty"` - ExprPolicy json.RawMessage `json:"expr_policy,omitempty"` - - // RBAC pre-assignments (by external_id) - Roles []BlueprintRole `json:"roles,omitempty"` - - // Identity provider configuration - IdentityProvider *BlueprintIdentityProvider `json:"identity_provider,omitempty"` - - // Observability - Webhooks *BlueprintWebhooks `json:"webhooks,omitempty"` - - // Audit export - AuditExport *BlueprintAuditExport `json:"audit_export,omitempty"` - - // Per-network admin token (optional override) - NetworkAdminToken string `json:"network_admin_token,omitempty"` -} - -// BlueprintPolicy defines the network policy section of a blueprint. -type BlueprintPolicy struct { - MaxMembers int `json:"max_members,omitempty"` - AllowedPorts []uint16 `json:"allowed_ports,omitempty"` - Description string `json:"description,omitempty"` -} - -// BlueprintRole pre-assigns RBAC roles by external identity. -type BlueprintRole struct { - ExternalID string `json:"external_id"` - Role string `json:"role"` // "owner", "admin", "member" -} - -// BlueprintIdentityProvider configures external identity verification. -type BlueprintIdentityProvider struct { - Type string `json:"type"` // "oidc", "saml", "webhook", "entra_id", "ldap" - URL string `json:"url"` // verification endpoint - Issuer string `json:"issuer,omitempty"` // OIDC issuer URL - ClientID string `json:"client_id,omitempty"` // OIDC client ID - TenantID string `json:"tenant_id,omitempty"` // Azure AD / Entra ID tenant - Domain string `json:"domain,omitempty"` // LDAP domain -} - -// BlueprintWebhooks configures webhook endpoints for the network. -type BlueprintWebhooks struct { - AuditURL string `json:"audit_url,omitempty"` // audit event webhook - IdentityURL string `json:"identity_url,omitempty"` // identity verification webhook -} - -// BlueprintAuditExport configures external audit log export. -type BlueprintAuditExport struct { - Format string `json:"format"` // "json", "splunk_hec", "syslog_cef" - Endpoint string `json:"endpoint"` // destination URL or address - Token string `json:"token,omitempty"` // auth token (e.g., Splunk HEC token) - Index string `json:"index,omitempty"` // Splunk index - Source string `json:"source,omitempty"` // source identifier -} - -// LoadBlueprint reads a network blueprint from a JSON file. -func LoadBlueprint(path string) (*NetworkBlueprint, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("read blueprint: %w", err) - } - var bp NetworkBlueprint - if err := json.Unmarshal(data, &bp); err != nil { - return nil, fmt.Errorf("parse blueprint: %w", err) - } - if bp.Name == "" { - return nil, fmt.Errorf("blueprint: name is required") - } - return &bp, nil -} - -// ValidateBlueprint checks a blueprint for configuration errors. -func ValidateBlueprint(bp *NetworkBlueprint) error { - if bp.Name == "" { - return fmt.Errorf("name is required") - } - switch bp.JoinRule { - case "", "open", "token", "invite": - default: - return fmt.Errorf("invalid join_rule %q (must be open, token, or invite)", bp.JoinRule) - } - if bp.JoinRule == "token" && bp.JoinToken == "" { - return fmt.Errorf("join_token is required when join_rule is token") - } - for _, r := range bp.Roles { - if r.ExternalID == "" { - return fmt.Errorf("role entry: external_id is required") - } - switch r.Role { - case "owner", "admin", "member": - default: - return fmt.Errorf("invalid role %q for %s", r.Role, r.ExternalID) - } - } - if bp.IdentityProvider != nil { - switch bp.IdentityProvider.Type { - case "oidc", "saml", "webhook", "entra_id", "ldap": - default: - return fmt.Errorf("invalid identity_provider type %q", bp.IdentityProvider.Type) - } - if bp.IdentityProvider.URL == "" { - return fmt.Errorf("identity_provider.url is required") - } - if err := urlvalidate.Validate(bp.IdentityProvider.URL); err != nil { - return fmt.Errorf("identity_provider.url: %w", err) - } - } - if bp.Webhooks != nil { - if bp.Webhooks.AuditURL != "" { - if err := urlvalidate.Validate(bp.Webhooks.AuditURL); err != nil { - return fmt.Errorf("webhooks.audit_url: %w", err) - } - } - if bp.Webhooks.IdentityURL != "" { - if err := urlvalidate.Validate(bp.Webhooks.IdentityURL); err != nil { - return fmt.Errorf("webhooks.identity_url: %w", err) - } - } - } - if bp.AuditExport != nil { - switch bp.AuditExport.Format { - case "json", "splunk_hec", "syslog_cef": - default: - return fmt.Errorf("invalid audit_export format %q", bp.AuditExport.Format) - } - if bp.AuditExport.Endpoint == "" { - return fmt.Errorf("audit_export.endpoint is required") - } - // syslog_cef sinks accept raw host:port targets; only the HTTP(S) - // formats need SSRF validation. - if bp.AuditExport.Format == "json" || bp.AuditExport.Format == "splunk_hec" { - if err := urlvalidate.Validate(bp.AuditExport.Endpoint); err != nil { - return fmt.Errorf("audit_export.endpoint: %w", err) - } - } - } - if len(bp.ExprPolicy) > 0 { - var check struct { - Version int `json:"version"` - Rules json.RawMessage `json:"rules"` - } - if err := json.Unmarshal(bp.ExprPolicy, &check); err != nil { - return fmt.Errorf("expr_policy: invalid JSON: %w", err) - } - if check.Version != 1 { - return fmt.Errorf("expr_policy: unsupported version %d (want 1)", check.Version) - } - if len(check.Rules) == 0 || string(check.Rules) == "null" { - return fmt.Errorf("expr_policy: at least one rule is required") - } - } - return nil -} diff --git a/pkg/registry/wire/rules.go b/pkg/registry/wire/rules.go deleted file mode 100644 index f19e566b..00000000 --- a/pkg/registry/wire/rules.go +++ /dev/null @@ -1,204 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package wire - -import ( - "encoding/json" - "fmt" - "time" -) - -// NetworkRules defines the managed network ruleset. When set on a NetworkInfo, -// the network becomes "managed" — daemon-local link lifecycle is governed by -// these rules. The registry only stores and distributes the rules; all cycle -// logic runs inside each daemon. -type NetworkRules struct { - Links int `json:"links"` // max managed peers per node - Cycle string `json:"cycle"` // Go duration: "24h", "1h" - Prune int `json:"prune"` // how many to drop per cycle - PruneBy string `json:"prune_by"` // "score", "age", "activity" - Fill int `json:"fill"` // how many to add per cycle - FillHow string `json:"fill_how"` // "random" - Grace string `json:"grace,omitempty"` // grace period for new members -} - -// ValidateRules checks that a NetworkRules is well-formed. Returns nil if valid. -func ValidateRules(r *NetworkRules) error { - if r == nil { - return nil - } - if r.Links < 1 { - return fmt.Errorf("rules: links must be >= 1 (got %d)", r.Links) - } - if r.Cycle == "" { - return fmt.Errorf("rules: cycle is required") - } - d, err := time.ParseDuration(r.Cycle) - if err != nil { - return fmt.Errorf("rules: invalid cycle duration %q: %w", r.Cycle, err) - } - if d < 1*time.Minute { - return fmt.Errorf("rules: cycle must be >= 1m (got %s)", r.Cycle) - } - if r.Prune < 0 { - return fmt.Errorf("rules: prune must be >= 0 (got %d)", r.Prune) - } - if r.Fill < 0 { - return fmt.Errorf("rules: fill must be >= 0 (got %d)", r.Fill) - } - if r.Prune > r.Links { - return fmt.Errorf("rules: prune (%d) cannot exceed links (%d)", r.Prune, r.Links) - } - if r.Fill > r.Links { - return fmt.Errorf("rules: fill (%d) cannot exceed links (%d)", r.Fill, r.Links) - } - - switch r.PruneBy { - case "score", "age", "activity": - // valid - case "": - return fmt.Errorf("rules: prune_by is required") - default: - return fmt.Errorf("rules: unknown prune_by strategy %q (want score|age|activity)", r.PruneBy) - } - - switch r.FillHow { - case "random": - // valid - case "": - return fmt.Errorf("rules: fill_how is required") - default: - return fmt.Errorf("rules: unknown fill_how strategy %q (want random)", r.FillHow) - } - - if r.Grace != "" { - g, err := time.ParseDuration(r.Grace) - if err != nil { - return fmt.Errorf("rules: invalid grace duration %q: %w", r.Grace, err) - } - if g < 0 { - return fmt.Errorf("rules: grace must be >= 0") - } - } - - return nil -} - -// ParseRules unmarshals a JSON string into NetworkRules and validates it. -func ParseRules(raw string) (*NetworkRules, error) { - var r NetworkRules - if err := json.Unmarshal([]byte(raw), &r); err != nil { - return nil, fmt.Errorf("rules: invalid JSON: %w", err) - } - if err := ValidateRules(&r); err != nil { - return nil, err - } - return &r, nil -} - -// RulesToPolicy converts a NetworkRules struct into a PolicyDocument JSON -// (json.RawMessage). This provides backward compatibility: existing managed -// networks continue to work through the policy engine. -func RulesToPolicy(r *NetworkRules) (json.RawMessage, error) { - if r == nil { - return nil, nil - } - - type action struct { - Type string `json:"type"` - Params map[string]interface{} `json:"params,omitempty"` - } - type rule struct { - Name string `json:"name"` - On string `json:"on"` - Match string `json:"match"` - Actions []action `json:"actions"` - } - type policyDoc struct { - Version int `json:"version"` - Config map[string]interface{} `json:"config,omitempty"` - Rules []rule `json:"rules"` - } - - doc := policyDoc{ - Version: 1, - Config: map[string]interface{}{ - "max_peers": r.Links, - "cycle": r.Cycle, - }, - Rules: []rule{ - { - Name: "cycle-prune-fill", - On: "cycle", - Match: "true", - Actions: []action{ - {Type: "prune", Params: map[string]interface{}{"count": r.Prune, "by": r.PruneBy}}, - {Type: "fill", Params: map[string]interface{}{"count": r.Fill, "how": r.FillHow}}, - }, - }, - }, - } - - if r.Grace != "" { - doc.Config["grace"] = r.Grace - } - - data, err := json.Marshal(doc) - if err != nil { - return nil, fmt.Errorf("rules-to-policy: %w", err) - } - return json.RawMessage(data), nil -} - -// AllowedPortsToPolicy converts a port allowlist into a PolicyDocument JSON -// (json.RawMessage). This replaces the old AllowedPorts mechanism with -// equivalent policy rules. -func AllowedPortsToPolicy(ports []uint16) (json.RawMessage, error) { - if len(ports) == 0 { - return nil, nil - } - - // Build match expression: "port in [80, 443, 1001]" - var buf []byte - buf = append(buf, "port in ["...) - for i, p := range ports { - if i > 0 { - buf = append(buf, ", "...) - } - buf = fmt.Appendf(buf, "%d", p) - } - buf = append(buf, ']') - matchExpr := string(buf) - - type action struct { - Type string `json:"type"` - } - type rule struct { - Name string `json:"name"` - On string `json:"on"` - Match string `json:"match"` - Actions []action `json:"actions"` - } - type policyDoc struct { - Version int `json:"version"` - Rules []rule `json:"rules"` - } - - doc := policyDoc{ - Version: 1, - Rules: []rule{ - {Name: "allow-ports", On: "connect", Match: matchExpr, Actions: []action{{Type: "allow"}}}, - {Name: "allow-ports-dg", On: "datagram", Match: matchExpr, Actions: []action{{Type: "allow"}}}, - {Name: "allow-ports-dial", On: "dial", Match: matchExpr, Actions: []action{{Type: "allow"}}}, - {Name: "deny-rest", On: "connect", Match: "true", Actions: []action{{Type: "deny"}}}, - {Name: "deny-rest-dg", On: "datagram", Match: "true", Actions: []action{{Type: "deny"}}}, - {Name: "deny-rest-dial", On: "dial", Match: "true", Actions: []action{{Type: "deny"}}}, - }, - } - - data, err := json.Marshal(doc) - if err != nil { - return nil, fmt.Errorf("ports-to-policy: %w", err) - } - return json.RawMessage(data), nil -} diff --git a/pkg/registry/wire/wire.go b/pkg/registry/wire/wire.go deleted file mode 100644 index 4ea64e3e..00000000 --- a/pkg/registry/wire/wire.go +++ /dev/null @@ -1,595 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -// Package wire defines the binary wire format shared between the registry -// client and server. It contains protocol constants, framing, and the -// encode/decode helpers that both sides use to talk over the same TCP -// connection. Pure types and functions — no networking, no logging, no I/O -// beyond the io.Reader/io.Writer abstractions used by the framing layer. -package wire - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "io" - "math" - "net" - "time" -) - -// maxCount caps wire-controlled list lengths to prevent a malicious -// peer from triggering large allocations (e.g. netCount=65535 → -// 130 KB make()). All frames are bounded by MaxMessageSize (64 MiB) -// but per-field allocations without caps can cause memory pressure -// before the overall frame limit is reached. -const maxCount = 1024 - -// WriteMessageDeadline bounds how long a single JSON response write can -// take. If a client is slow to drain (overloaded host, kernel buffer -// pressure) we'd otherwise hold the request goroutine + response payload -// in memory indefinitely. After this deadline expires, w.Write returns -// an error and the caller can drop the connection cleanly. -const WriteMessageDeadline = 5 * time.Second - -// MaxMessageSize is the maximum allowed wire message size (64 MiB). -// Messages exceeding this limit cause the connection to be closed. -// Note: must stay well below the binary wire magic (0x50494C54 ≈ 1.3B) -// for protocol auto-detection to work. Sized for full registry snapshot -// in subscribe_replication: ~26 MiB at 100k+ nodes, with headroom. -const MaxMessageSize = 64 * 1024 * 1024 - -// Binary wire format for high-throughput operations. -// -// Protocol negotiation: binary clients send magic 0x50494C54 ("PILT") + 1 byte -// version as the first 5 bytes of the connection. The server detects this vs a -// JSON length prefix (which is always < 64KB) and switches mode per-connection. -// -// Binary frame: [4B total_length][1B msg_type][payload] -// -// Message types: -// 0x00 = JSON passthrough (payload is JSON bytes) -// 0x01 = heartbeat request -// 0x81 = heartbeat response -// 0x02 = lookup request -// 0x82 = lookup response -// 0x03 = resolve request -// 0x83 = resolve response -// 0xFF = error response - -// Magic is the 4-byte magic sent by binary clients at connection start. -var Magic = [4]byte{0x50, 0x49, 0x4C, 0x54} // "PILT" - -// Version is the current binary protocol version. -const Version byte = 1 - -// Binary message type constants. -const ( - MsgJSON byte = 0x00 - MsgHeartbeat byte = 0x01 - MsgHeartbeatOK byte = 0x81 - MsgLookup byte = 0x02 - MsgLookupOK byte = 0x82 - MsgResolve byte = 0x03 - MsgResolveOK byte = 0x83 - MsgError byte = 0xFF -) - -// ReadFrame reads a single binary frame: [4B length][1B type][payload]. -func ReadFrame(r io.Reader) (msgType byte, payload []byte, err error) { - var hdr [5]byte - if _, err = io.ReadFull(r, hdr[:]); err != nil { - return 0, nil, err - } - length := binary.BigEndian.Uint32(hdr[:4]) - if length < 1 { - return 0, nil, fmt.Errorf("binary frame too short") - } - if length > MaxMessageSize { - return 0, nil, fmt.Errorf("binary frame too large: %d bytes (max %d)", length, MaxMessageSize) - } - msgType = hdr[4] - payloadLen := length - 1 // length includes the type byte - if payloadLen > 0 { - payload = make([]byte, payloadLen) - if _, err = io.ReadFull(r, payload); err != nil { - return 0, nil, err - } - } - return msgType, payload, nil -} - -// WriteFrame writes a single binary frame. -func WriteFrame(w io.Writer, msgType byte, payload []byte) error { - length := uint32(1 + len(payload)) // type byte + payload - var hdr [5]byte - binary.BigEndian.PutUint32(hdr[:4], length) - hdr[4] = msgType - if _, err := w.Write(hdr[:]); err != nil { - return err - } - if len(payload) > 0 { - if _, err := w.Write(payload); err != nil { - return err - } - } - return nil -} - -// --- Heartbeat --- - -// HeartbeatReq holds a decoded binary heartbeat request: [4B node_id][64B signature]. -type HeartbeatReq struct { - NodeID uint32 - Signature [64]byte -} - -// EncodeHeartbeatReq encodes a heartbeat request payload. -func EncodeHeartbeatReq(nodeID uint32, sig []byte) []byte { - buf := make([]byte, 4+64) - binary.BigEndian.PutUint32(buf[:4], nodeID) - copy(buf[4:], sig) - return buf -} - -// DecodeHeartbeatReq decodes a heartbeat request payload. -func DecodeHeartbeatReq(payload []byte) (HeartbeatReq, error) { - if len(payload) < 68 { - return HeartbeatReq{}, fmt.Errorf("heartbeat request too short: %d bytes", len(payload)) - } - var req HeartbeatReq - req.NodeID = binary.BigEndian.Uint32(payload[:4]) - copy(req.Signature[:], payload[4:68]) - return req, nil -} - -// EncodeHeartbeatResp encodes the heartbeat response: [8B unix_time][1B flags]. -// flags bit0 = key_expiry_warning. -func EncodeHeartbeatResp(unixTime int64, keyExpiryWarning bool) []byte { - buf := make([]byte, 9) - binary.BigEndian.PutUint64(buf[:8], uint64(unixTime)) - if keyExpiryWarning { - buf[8] = 1 - } - return buf -} - -// DecodeHeartbeatResp decodes a heartbeat response. -func DecodeHeartbeatResp(payload []byte) (unixTime int64, keyExpiryWarning bool, err error) { - if len(payload) < 9 { - return 0, false, fmt.Errorf("heartbeat response too short: %d bytes", len(payload)) - } - unixTime = int64(binary.BigEndian.Uint64(payload[:8])) - keyExpiryWarning = payload[8]&1 != 0 - return unixTime, keyExpiryWarning, nil -} - -// --- Lookup --- - -// EncodeLookupReq encodes a lookup request: [4B node_id]. -func EncodeLookupReq(nodeID uint32) []byte { - buf := make([]byte, 4) - binary.BigEndian.PutUint32(buf, nodeID) - return buf -} - -// DecodeLookupReq decodes a lookup request. -func DecodeLookupReq(payload []byte) (uint32, error) { - if len(payload) < 4 { - return 0, fmt.Errorf("lookup request too short: %d bytes", len(payload)) - } - return binary.BigEndian.Uint32(payload[:4]), nil -} - -// EncodeLookupResp encodes a lookup response. -// Format: [4B node_id][1B flags][4B reserved][2B net_count][net_ids...] -// -// [1B pubkey_len][pubkey...][1B hostname_len][hostname...] -// [1B tags_count][for each: 1B len, bytes...][2B addr_len][addr...] -// [1B extid_len][extid...] -// -// The 4-byte reserved field was formerly polo_score; it is written as zero -// and ignored on decode to preserve wire-format compatibility. -func EncodeLookupResp(nodeID uint32, public, taskExec bool, - networks []uint16, pubKey []byte, hostname string, tags []string, - realAddr string, externalID string) []byte { - - // Calculate size - size := 4 + 1 + 4 + 2 + len(networks)*2 // node_id + flags + reserved + nets - size += 1 + len(pubKey) // pubkey - size += 1 + len(hostname) // hostname - size += 1 // tags count - for _, t := range tags { - size += 1 + len(t) // tag len + tag - } - size += 2 + len(realAddr) // real_addr (only if public) - size += 1 + len(externalID) // external_id - - buf := make([]byte, 0, size) - - // node_id - buf = binary.BigEndian.AppendUint32(buf, nodeID) - - // flags - var flags byte - if public { - flags |= 0x01 - } - if taskExec { - flags |= 0x02 - } - buf = append(buf, flags) - - // reserved (was polo_score) — always zero - buf = binary.BigEndian.AppendUint32(buf, 0) - - // networks - buf = binary.BigEndian.AppendUint16(buf, uint16(len(networks))) - for _, n := range networks { - buf = binary.BigEndian.AppendUint16(buf, n) - } - - // pubkey - if len(pubKey) > 255 { - pubKey = pubKey[:255] - } - buf = append(buf, byte(len(pubKey))) - buf = append(buf, pubKey...) - - // hostname - if len(hostname) > 255 { - hostname = hostname[:255] - } - buf = append(buf, byte(len(hostname))) - buf = append(buf, []byte(hostname)...) - - // tags - if len(tags) > 255 { - tags = tags[:255] - } - buf = append(buf, byte(len(tags))) - for _, t := range tags { - if len(t) > 255 { - t = t[:255] - } - buf = append(buf, byte(len(t))) - buf = append(buf, []byte(t)...) - } - - // real_addr - buf = binary.BigEndian.AppendUint16(buf, uint16(len(realAddr))) - buf = append(buf, []byte(realAddr)...) - - // external_id - if len(externalID) > 255 { - externalID = externalID[:255] - } - buf = append(buf, byte(len(externalID))) - buf = append(buf, []byte(externalID)...) - - return buf -} - -// --- Resolve --- - -// EncodeResolveReq encodes a resolve request: [4B node_id][4B requester_id][64B signature]. -func EncodeResolveReq(nodeID, requesterID uint32, sig []byte) []byte { - buf := make([]byte, 4+4+64) - binary.BigEndian.PutUint32(buf[:4], nodeID) - binary.BigEndian.PutUint32(buf[4:8], requesterID) - copy(buf[8:], sig) - return buf -} - -// DecodeResolveReq decodes a resolve request. -func DecodeResolveReq(payload []byte) (nodeID, requesterID uint32, sig []byte, err error) { - if len(payload) < 72 { - return 0, 0, nil, fmt.Errorf("resolve request too short: %d bytes", len(payload)) - } - nodeID = binary.BigEndian.Uint32(payload[:4]) - requesterID = binary.BigEndian.Uint32(payload[4:8]) - sig = payload[8:72] - return nodeID, requesterID, sig, nil -} - -// EncodeResolveResp encodes a resolve response. -// Format: [4B node_id][2B addr_len][addr...][2B lan_count][for each: 2B len, addr...] -// -// [4B key_age_days] (math.MaxUint32 if unknown) -func EncodeResolveResp(nodeID uint32, realAddr string, lanAddrs []string, keyAgeDays int) []byte { - size := 4 + 2 + len(realAddr) + 2 + 4 - for _, la := range lanAddrs { - size += 2 + len(la) - } - buf := make([]byte, 0, size) - - buf = binary.BigEndian.AppendUint32(buf, nodeID) - - buf = binary.BigEndian.AppendUint16(buf, uint16(len(realAddr))) - buf = append(buf, []byte(realAddr)...) - - buf = binary.BigEndian.AppendUint16(buf, uint16(len(lanAddrs))) - for _, la := range lanAddrs { - buf = binary.BigEndian.AppendUint16(buf, uint16(len(la))) - buf = append(buf, []byte(la)...) - } - - if keyAgeDays < 0 { - buf = binary.BigEndian.AppendUint32(buf, math.MaxUint32) - } else { - buf = binary.BigEndian.AppendUint32(buf, uint32(keyAgeDays)) - } - - return buf -} - -// --- Error --- - -// EncodeError encodes an error message frame payload. -func EncodeError(msg string) []byte { - if len(msg) > 65000 { - msg = msg[:65000] - } - buf := make([]byte, 2+len(msg)) - binary.BigEndian.PutUint16(buf[:2], uint16(len(msg))) - copy(buf[2:], msg) - return buf -} - -// DecodeError decodes an error message frame payload. -func DecodeError(payload []byte) string { - if len(payload) < 2 { - return "unknown error" - } - length := binary.BigEndian.Uint16(payload[:2]) - if int(length) > len(payload)-2 { - length = uint16(len(payload) - 2) - } - return string(payload[2 : 2+length]) -} - -// --- Lookup response decoder (client-side) --- - -// LookupResult holds the decoded fields from a binary lookup response. -type LookupResult struct { - NodeID uint32 - Public bool - TaskExec bool - Networks []uint16 - PubKey []byte - Hostname string - Tags []string - RealAddr string - ExternalID string -} - -// DecodeLookupResp decodes a binary lookup response. -func DecodeLookupResp(payload []byte) (LookupResult, error) { - var r LookupResult - if len(payload) < 11 { - return r, fmt.Errorf("lookup response too short: %d bytes", len(payload)) - } - - off := 0 - r.NodeID = binary.BigEndian.Uint32(payload[off : off+4]) - off += 4 - flags := payload[off] - off++ - r.Public = flags&0x01 != 0 - r.TaskExec = flags&0x02 != 0 - off += 4 // skip reserved field (was polo_score) - - if off+2 > len(payload) { - return r, fmt.Errorf("truncated network count") - } - netCount := int(binary.BigEndian.Uint16(payload[off : off+2])) - off += 2 - if netCount > maxCount { - return r, fmt.Errorf("network count %d exceeds cap %d", netCount, maxCount) - } - r.Networks = make([]uint16, netCount) - for i := 0; i < netCount; i++ { - if off+2 > len(payload) { - return r, fmt.Errorf("truncated networks at index %d", i) - } - r.Networks[i] = binary.BigEndian.Uint16(payload[off : off+2]) - off += 2 - } - - if off >= len(payload) { - return r, fmt.Errorf("truncated pubkey length") - } - pkLen := int(payload[off]) - off++ - if off+pkLen > len(payload) { - return r, fmt.Errorf("truncated pubkey data") - } - if pkLen > 0 { - r.PubKey = make([]byte, pkLen) - copy(r.PubKey, payload[off:off+pkLen]) - } - off += pkLen - - if off >= len(payload) { - return r, fmt.Errorf("truncated hostname length") - } - hnLen := int(payload[off]) - off++ - if off+hnLen > len(payload) { - return r, fmt.Errorf("truncated hostname data") - } - r.Hostname = string(payload[off : off+hnLen]) - off += hnLen - - if off >= len(payload) { - return r, fmt.Errorf("truncated tags count") - } - tagCount := int(payload[off]) - off++ - if tagCount > maxCount { - return r, fmt.Errorf("tag count %d exceeds cap %d", tagCount, maxCount) - } - r.Tags = make([]string, tagCount) - for i := 0; i < tagCount; i++ { - if off >= len(payload) { - return r, fmt.Errorf("truncated tag length at index %d", i) - } - tLen := int(payload[off]) - off++ - if off+tLen > len(payload) { - return r, fmt.Errorf("truncated tag data at index %d", i) - } - r.Tags[i] = string(payload[off : off+tLen]) - off += tLen - } - - if off+2 > len(payload) { - return r, fmt.Errorf("truncated real_addr length") - } - addrLen := int(binary.BigEndian.Uint16(payload[off : off+2])) - off += 2 - if off+addrLen > len(payload) { - return r, fmt.Errorf("truncated real_addr data") - } - r.RealAddr = string(payload[off : off+addrLen]) - off += addrLen - - if off >= len(payload) { - return r, fmt.Errorf("truncated external_id length") - } - eidLen := int(payload[off]) - off++ - if off+eidLen > len(payload) { - return r, fmt.Errorf("truncated external_id data") - } - r.ExternalID = string(payload[off : off+eidLen]) - - return r, nil -} - -// --- Resolve response decoder (client-side) --- - -// ResolveResult holds the decoded fields from a binary resolve response. -type ResolveResult struct { - NodeID uint32 - RealAddr string - LANAddrs []string - KeyAgeDays int // -1 if unknown -} - -// DecodeResolveResp decodes a binary resolve response. -func DecodeResolveResp(payload []byte) (ResolveResult, error) { - var r ResolveResult - if len(payload) < 12 { - return r, fmt.Errorf("resolve response too short: %d bytes", len(payload)) - } - - off := 0 - r.NodeID = binary.BigEndian.Uint32(payload[off : off+4]) - off += 4 - - if off+2 > len(payload) { - return r, fmt.Errorf("truncated addr length") - } - addrLen := int(binary.BigEndian.Uint16(payload[off : off+2])) - off += 2 - if off+addrLen > len(payload) { - return r, fmt.Errorf("truncated addr data") - } - r.RealAddr = string(payload[off : off+addrLen]) - off += addrLen - - if off+2 > len(payload) { - return r, fmt.Errorf("truncated lan_addrs count") - } - lanCount := int(binary.BigEndian.Uint16(payload[off : off+2])) - off += 2 - if lanCount > maxCount { - return r, fmt.Errorf("lan_addrs count %d exceeds cap %d", lanCount, maxCount) - } - r.LANAddrs = make([]string, lanCount) - for i := 0; i < lanCount; i++ { - if off+2 > len(payload) { - return r, fmt.Errorf("truncated lan addr length at index %d", i) - } - laLen := int(binary.BigEndian.Uint16(payload[off : off+2])) - off += 2 - if off+laLen > len(payload) { - return r, fmt.Errorf("truncated lan addr data at index %d", i) - } - r.LANAddrs[i] = string(payload[off : off+laLen]) - off += laLen - } - - if off+4 > len(payload) { - return r, fmt.Errorf("truncated key_age_days") - } - raw := binary.BigEndian.Uint32(payload[off : off+4]) - if raw == math.MaxUint32 { - r.KeyAgeDays = -1 - } else { - r.KeyAgeDays = int(raw) - } - - return r, nil -} - -// --- JSON message framing --- -// -// The non-binary JSON protocol uses a 4-byte big-endian length prefix -// followed by a JSON body. ReadMessage/WriteMessage are the helpers both -// the client and the server use over the same TCP connection. - -// ReadMessage reads a length-prefixed JSON message from r and decodes -// it into a map. -func ReadMessage(r io.Reader) (map[string]interface{}, error) { - var lenBuf [4]byte - if _, err := io.ReadFull(r, lenBuf[:]); err != nil { - return nil, err - } - length := binary.BigEndian.Uint32(lenBuf[:]) - if length > MaxMessageSize { - return nil, fmt.Errorf("message too large: %d bytes (max %d)", length, MaxMessageSize) - } - - body := make([]byte, length) - if _, err := io.ReadFull(r, body); err != nil { - return nil, err - } - - var msg map[string]interface{} - if err := json.Unmarshal(body, &msg); err != nil { - return nil, fmt.Errorf("json decode: %w", err) - } - return msg, nil -} - -// WriteMessage writes a length-prefixed JSON-encoded message to w. -// If w is a net.Conn, a write deadline is applied. -func WriteMessage(w io.Writer, msg map[string]interface{}) error { - body, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("json encode: %w", err) - } - return WriteRawMessage(w, body) -} - -// WriteRawMessage writes a length-prefixed raw JSON body to w. -// Callers that have already produced the JSON bytes (e.g., a list-nodes -// cache hit) can skip the json.Marshal step. -func WriteRawMessage(w io.Writer, body []byte) error { - var lenBuf [4]byte - binary.BigEndian.PutUint32(lenBuf[:], uint32(len(body))) - - if c, ok := w.(net.Conn); ok { - _ = c.SetWriteDeadline(time.Now().Add(WriteMessageDeadline)) - defer c.SetWriteDeadline(time.Time{}) - } - - if _, err := w.Write(lenBuf[:]); err != nil { - return err - } - if _, err := w.Write(body); err != nil { - return err - } - return nil -} diff --git a/pkg/registry/wire/zz_blueprint_test.go b/pkg/registry/wire/zz_blueprint_test.go deleted file mode 100644 index 5245905b..00000000 --- a/pkg/registry/wire/zz_blueprint_test.go +++ /dev/null @@ -1,254 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package wire_test - -import ( - "encoding/json" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -func TestLoadBlueprint_HappyPath(t *testing.T) { - t.Parallel() - bp := &wire.NetworkBlueprint{ - Name: "test-net", - JoinRule: "open", - } - data, err := json.Marshal(bp) - if err != nil { - t.Fatalf("marshal: %v", err) - } - dir := t.TempDir() - path := filepath.Join(dir, "bp.json") - if err := os.WriteFile(path, data, 0600); err != nil { - t.Fatalf("write: %v", err) - } - got, err := wire.LoadBlueprint(path) - if err != nil { - t.Fatalf("LoadBlueprint: %v", err) - } - if got.Name != bp.Name { - t.Errorf("Name: got %q, want %q", got.Name, bp.Name) - } - if got.JoinRule != bp.JoinRule { - t.Errorf("JoinRule: got %q, want %q", got.JoinRule, bp.JoinRule) - } -} - -func TestLoadBlueprint_FileNotFound(t *testing.T) { - t.Parallel() - _, err := wire.LoadBlueprint(filepath.Join(t.TempDir(), "does-not-exist.json")) - if err == nil || !strings.Contains(err.Error(), "read blueprint") { - t.Fatalf("want 'read blueprint' err, got %v", err) - } -} - -func TestLoadBlueprint_BadJSON(t *testing.T) { - t.Parallel() - dir := t.TempDir() - path := filepath.Join(dir, "bad.json") - if err := os.WriteFile(path, []byte(`{not json`), 0600); err != nil { - t.Fatalf("write: %v", err) - } - _, err := wire.LoadBlueprint(path) - if err == nil || !strings.Contains(err.Error(), "parse blueprint") { - t.Fatalf("want 'parse blueprint' err, got %v", err) - } -} - -func TestLoadBlueprint_MissingName(t *testing.T) { - t.Parallel() - dir := t.TempDir() - path := filepath.Join(dir, "noname.json") - if err := os.WriteFile(path, []byte(`{}`), 0600); err != nil { - t.Fatalf("write: %v", err) - } - _, err := wire.LoadBlueprint(path) - if err == nil || !strings.Contains(err.Error(), "name is required") { - t.Fatalf("want 'name is required' err, got %v", err) - } -} - -func TestValidateBlueprint_HappyPath(t *testing.T) { - t.Parallel() - bp := &wire.NetworkBlueprint{Name: "net", JoinRule: "open"} - if err := wire.ValidateBlueprint(bp); err != nil { - t.Fatalf("ValidateBlueprint: %v", err) - } -} - -func TestValidateBlueprint_NameRequired(t *testing.T) { - t.Parallel() - err := wire.ValidateBlueprint(&wire.NetworkBlueprint{}) - if err == nil || !strings.Contains(err.Error(), "name is required") { - t.Fatalf("want 'name is required', got %v", err) - } -} - -func TestValidateBlueprint_AllJoinRules(t *testing.T) { - t.Parallel() - for _, jr := range []string{"", "open", "token", "invite"} { - bp := &wire.NetworkBlueprint{Name: "n", JoinRule: jr} - if jr == "token" { - bp.JoinToken = "tok" - } - if err := wire.ValidateBlueprint(bp); err != nil { - t.Errorf("JoinRule=%q: %v", jr, err) - } - } -} - -func TestValidateBlueprint_InvalidJoinRule(t *testing.T) { - t.Parallel() - bp := &wire.NetworkBlueprint{Name: "n", JoinRule: "weird"} - err := wire.ValidateBlueprint(bp) - if err == nil || !strings.Contains(err.Error(), "invalid join_rule") { - t.Fatalf("got %v", err) - } -} - -func TestValidateBlueprint_TokenRuleNeedsToken(t *testing.T) { - t.Parallel() - bp := &wire.NetworkBlueprint{Name: "n", JoinRule: "token"} - err := wire.ValidateBlueprint(bp) - if err == nil || !strings.Contains(err.Error(), "join_token is required") { - t.Fatalf("got %v", err) - } -} - -func TestValidateBlueprint_RoleRequiresExternalID(t *testing.T) { - t.Parallel() - bp := &wire.NetworkBlueprint{ - Name: "n", - Roles: []wire.BlueprintRole{{Role: "admin"}}, - } - err := wire.ValidateBlueprint(bp) - if err == nil || !strings.Contains(err.Error(), "external_id is required") { - t.Fatalf("got %v", err) - } -} - -func TestValidateBlueprint_RoleValidAndInvalid(t *testing.T) { - t.Parallel() - for _, r := range []string{"owner", "admin", "member"} { - bp := &wire.NetworkBlueprint{ - Name: "n", - Roles: []wire.BlueprintRole{{ExternalID: "u", Role: r}}, - } - if err := wire.ValidateBlueprint(bp); err != nil { - t.Errorf("role=%q: %v", r, err) - } - } - bp := &wire.NetworkBlueprint{ - Name: "n", - Roles: []wire.BlueprintRole{{ExternalID: "u", Role: "superadmin"}}, - } - if err := wire.ValidateBlueprint(bp); err == nil || - !strings.Contains(err.Error(), "invalid role") { - t.Fatalf("got %v", err) - } -} - -func TestValidateBlueprint_IdentityProvider(t *testing.T) { - t.Parallel() - // missing URL - err := wire.ValidateBlueprint(&wire.NetworkBlueprint{ - Name: "n", - IdentityProvider: &wire.BlueprintIdentityProvider{Type: "oidc"}, - }) - if err == nil || !strings.Contains(err.Error(), "identity_provider.url is required") { - t.Fatalf("got %v", err) - } - // invalid type - err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ - Name: "n", - IdentityProvider: &wire.BlueprintIdentityProvider{Type: "weird", URL: "https://a.b"}, - }) - if err == nil || !strings.Contains(err.Error(), "invalid identity_provider type") { - t.Fatalf("got %v", err) - } - // happy path for each valid type - for _, typ := range []string{"oidc", "saml", "webhook", "entra_id", "ldap"} { - err := wire.ValidateBlueprint(&wire.NetworkBlueprint{ - Name: "n", - IdentityProvider: &wire.BlueprintIdentityProvider{ - Type: typ, - URL: "https://example.com/auth", - }, - }) - if err != nil { - t.Errorf("type=%q: %v", typ, err) - } - } -} - -func TestValidateBlueprint_AuditExport(t *testing.T) { - t.Parallel() - // invalid format - err := wire.ValidateBlueprint(&wire.NetworkBlueprint{ - Name: "n", - AuditExport: &wire.BlueprintAuditExport{ - Format: "weird", - Endpoint: "https://a.b", - }, - }) - if err == nil || !strings.Contains(err.Error(), "invalid audit_export format") { - t.Fatalf("got %v", err) - } - // missing endpoint - err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ - Name: "n", - AuditExport: &wire.BlueprintAuditExport{Format: "json"}, - }) - if err == nil || !strings.Contains(err.Error(), "audit_export.endpoint is required") { - t.Fatalf("got %v", err) - } - // happy path syslog_cef (no URL validation) - err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ - Name: "n", - AuditExport: &wire.BlueprintAuditExport{Format: "syslog_cef", Endpoint: "1.2.3.4:514"}, - }) - if err != nil { - t.Errorf("syslog_cef: %v", err) - } -} - -func TestValidateBlueprint_ExprPolicy(t *testing.T) { - t.Parallel() - // invalid JSON - err := wire.ValidateBlueprint(&wire.NetworkBlueprint{ - Name: "n", - ExprPolicy: json.RawMessage(`{not json`), - }) - if err == nil || !strings.Contains(err.Error(), "expr_policy: invalid JSON") { - t.Fatalf("got %v", err) - } - // wrong version - err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ - Name: "n", - ExprPolicy: json.RawMessage(`{"version":2,"rules":[1]}`), - }) - if err == nil || !strings.Contains(err.Error(), "unsupported version") { - t.Fatalf("got %v", err) - } - // no rules - err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ - Name: "n", - ExprPolicy: json.RawMessage(`{"version":1,"rules":null}`), - }) - if err == nil || !strings.Contains(err.Error(), "at least one rule") { - t.Fatalf("got %v", err) - } - // happy path - err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ - Name: "n", - ExprPolicy: json.RawMessage(`{"version":1,"rules":[{"on":"connect","match":"true"}]}`), - }) - if err != nil { - t.Errorf("happy: %v", err) - } -} diff --git a/pkg/registry/wire/zz_decode_edge_test.go b/pkg/registry/wire/zz_decode_edge_test.go deleted file mode 100644 index 15bbeca2..00000000 --- a/pkg/registry/wire/zz_decode_edge_test.go +++ /dev/null @@ -1,110 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package wire_test - -import ( - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -func TestDecodeLookupReq_Truncated(t *testing.T) { - t.Parallel() - for _, n := range []int{0, 1, 2, 3} { - _, err := wire.DecodeLookupReq(make([]byte, n)) - if err == nil || !strings.Contains(err.Error(), "too short") { - t.Errorf("len=%d: want 'too short' err, got %v", n, err) - } - } -} - -func TestDecodeLookupReq_HappyPath(t *testing.T) { - t.Parallel() - got, err := wire.DecodeLookupReq([]byte{0x00, 0x00, 0xCA, 0xFE}) - if err != nil { - t.Fatalf("DecodeLookupReq: %v", err) - } - if got != 0xCAFE { - t.Errorf("got %x, want CAFE", got) - } -} - -func TestEncodeLookupResp_RoundTripWithAllFields(t *testing.T) { - t.Parallel() - // Build a fully-populated lookup response, then decode it. - encoded := wire.EncodeLookupResp( - 0xABCD, // nodeID - true, // public - true, // taskExec - []uint16{1, 2, 3}, // networks - []byte("0123456789012345"), // pubkey (16 bytes) - "host.example", // hostname - []string{"tag1", "tag2"}, // tags - "1.2.3.4:4000", // realAddr (only if public) - "ext-id-xyz", // externalID - ) - if len(encoded) == 0 { - t.Fatal("EncodeLookupResp returned empty") - } - resp, err := wire.DecodeLookupResp(encoded) - if err != nil { - t.Fatalf("DecodeLookupResp: %v", err) - } - if resp.NodeID != 0xABCD { - t.Errorf("NodeID = %x, want ABCD", resp.NodeID) - } - if !resp.Public { - t.Errorf("Public = false, want true") - } - if resp.Hostname != "host.example" { - t.Errorf("Hostname = %q, want host.example", resp.Hostname) - } - if resp.ExternalID != "ext-id-xyz" { - t.Errorf("ExternalID = %q", resp.ExternalID) - } - if len(resp.Networks) != 3 { - t.Errorf("Networks len = %d, want 3", len(resp.Networks)) - } - if len(resp.Tags) != 2 { - t.Errorf("Tags len = %d, want 2", len(resp.Tags)) - } -} - -func TestEncodeLookupResp_PrivateNodeNoAddr(t *testing.T) { - t.Parallel() - // Private node: realAddr is encoded but should not be revealed by - // post-decode contract. - encoded := wire.EncodeLookupResp( - 1, false, false, []uint16{}, []byte{}, "host", []string{}, "", "", - ) - resp, err := wire.DecodeLookupResp(encoded) - if err != nil { - t.Fatalf("DecodeLookupResp: %v", err) - } - if resp.Public { - t.Errorf("Public = true, want false") - } -} - -func TestDecodeError_Truncated(t *testing.T) { - t.Parallel() - // DecodeError returns a string. Truncated → fallback string. - for _, n := range []int{0, 1} { - got := wire.DecodeError(make([]byte, n)) - if got != "unknown error" { - t.Errorf("len=%d: got %q, want 'unknown error'", n, got) - } - } -} - -func TestDecodeError_HappyPath(t *testing.T) { - t.Parallel() - // 2-byte length prefix + body - msg := "internal error" - buf := []byte{byte(len(msg) >> 8), byte(len(msg))} - buf = append(buf, msg...) - if got := wire.DecodeError(buf); got != msg { - t.Errorf("got %q, want %q", got, msg) - } -} diff --git a/pkg/registry/wire/zz_decode_truncation_test.go b/pkg/registry/wire/zz_decode_truncation_test.go deleted file mode 100644 index 05c210f8..00000000 --- a/pkg/registry/wire/zz_decode_truncation_test.go +++ /dev/null @@ -1,108 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package wire_test - -import ( - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -// TestDecodeLookupResp_TruncationCascade walks every truncation -// boundary in DecodeLookupResp by encoding a fully-populated response -// then progressively trimming the payload from the back. -func TestDecodeLookupResp_TruncationCascade(t *testing.T) { - t.Parallel() - full := wire.EncodeLookupResp( - 0x1234, - true, - false, - []uint16{1, 2}, - []byte("pubkey-32-bytes-AAAAAAAAAAAAAAAA"), - "hostname", - []string{"tagA", "tagB"}, - "10.0.0.1:4000", - "ext-id", - ) - // Truncate to every shorter length and ensure decode either succeeds - // (only happens at exact length boundaries) or returns a truncation - // error. This exercises every "if off >= len(payload)" branch. - for i := 0; i < len(full); i++ { - _, err := wire.DecodeLookupResp(full[:i]) - if err == nil { - continue // accidental valid prefix is fine - } - // Every error should contain "truncated" or "too short". - if !strings.Contains(err.Error(), "truncated") && - !strings.Contains(err.Error(), "too short") { - t.Errorf("len=%d: unexpected err %v", i, err) - } - } -} - -// TestDecodeResolveResp_TruncationCascade does the same for ResolveResp. -func TestDecodeResolveResp_TruncationCascade(t *testing.T) { - t.Parallel() - full := wire.EncodeResolveResp(0x1234, "10.0.0.5:4000", []string{"192.168.1.10:4000"}, 7) - for i := 0; i < len(full); i++ { - _, err := wire.DecodeResolveResp(full[:i]) - if err == nil { - continue - } - if !strings.Contains(err.Error(), "truncated") && - !strings.Contains(err.Error(), "too short") && - !strings.Contains(err.Error(), "decode") { - t.Errorf("len=%d: unexpected err %v", i, err) - } - } -} - -// TestDecodeResolveReq_Truncation drills the short-buffer branches. -func TestDecodeResolveReq_Truncation(t *testing.T) { - t.Parallel() - for _, n := range []int{0, 1, 4, 8, 16, 32, 64, 71} { - _, _, _, err := wire.DecodeResolveReq(make([]byte, n)) - if err == nil { - t.Errorf("len=%d: want error, got nil", n) - } - } -} - -// TestDecodeHeartbeatResp_Truncation exercises the small response decoder. -func TestDecodeHeartbeatResp_Truncation(t *testing.T) { - t.Parallel() - for _, n := range []int{0, 1, 2, 3, 4, 5, 8} { - _, _, err := wire.DecodeHeartbeatResp(make([]byte, n)) - if err == nil { - t.Errorf("len=%d: want error, got nil", n) - } - } -} - -// TestDecodeError_LengthExceedsBuffer covers the clamping branch where the -// length prefix lies about how much data follows. -func TestDecodeError_LengthExceedsBuffer(t *testing.T) { - t.Parallel() - // Length prefix says 100 bytes, but buffer only has 5 bytes of body. - buf := []byte{0x00, 0x64, 'h', 'e', 'l', 'l', 'o'} // 0x0064 = 100 - got := wire.DecodeError(buf) - if got != "hello" { - t.Errorf("got %q, want 'hello' (clamped)", got) - } -} - -// TestEncodeError_OverlongMessageTruncated covers EncodeError's 65000-byte cap. -func TestEncodeError_OverlongMessageTruncated(t *testing.T) { - t.Parallel() - long := strings.Repeat("x", 70000) - encoded := wire.EncodeError(long) - // Encoded payload = 2-byte length + body. Body should be 65000 chars. - if len(encoded) != 2+65000 { - t.Errorf("encoded length = %d, want %d", len(encoded), 2+65000) - } - // Decode and ensure round-trip is the truncated form. - if got := wire.DecodeError(encoded); len(got) != 65000 { - t.Errorf("decoded length = %d, want 65000", len(got)) - } -} diff --git a/pkg/registry/wire/zz_frame_test.go b/pkg/registry/wire/zz_frame_test.go deleted file mode 100644 index 80ff8b4b..00000000 --- a/pkg/registry/wire/zz_frame_test.go +++ /dev/null @@ -1,110 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package wire_test - -import ( - "bytes" - "encoding/binary" - "errors" - "io" - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -func TestReadFrameWriteFrame_RoundTrip(t *testing.T) { - t.Parallel() - var buf bytes.Buffer - payload := []byte("hello-frame-body") - if err := wire.WriteFrame(&buf, 0x42, payload); err != nil { - t.Fatalf("WriteFrame: %v", err) - } - msgType, got, err := wire.ReadFrame(&buf) - if err != nil { - t.Fatalf("ReadFrame: %v", err) - } - if msgType != 0x42 { - t.Errorf("msgType = %x, want 0x42", msgType) - } - if !bytes.Equal(got, payload) { - t.Errorf("payload = %q, want %q", got, payload) - } -} - -func TestReadFrame_EmptyPayload(t *testing.T) { - t.Parallel() - var buf bytes.Buffer - if err := wire.WriteFrame(&buf, 0x01, nil); err != nil { - t.Fatalf("WriteFrame: %v", err) - } - msgType, payload, err := wire.ReadFrame(&buf) - if err != nil { - t.Fatalf("ReadFrame: %v", err) - } - if msgType != 0x01 { - t.Errorf("msgType = %x, want 0x01", msgType) - } - if len(payload) != 0 { - t.Errorf("payload len = %d, want 0", len(payload)) - } -} - -func TestReadFrame_HeaderTruncated(t *testing.T) { - t.Parallel() - _, _, err := wire.ReadFrame(bytes.NewReader([]byte{0x00, 0x01})) // 2 bytes, need 5 - if !errors.Is(err, io.ErrUnexpectedEOF) && err != io.EOF { - t.Errorf("want EOF/ErrUnexpectedEOF, got %v", err) - } -} - -func TestReadFrame_LengthZero(t *testing.T) { - t.Parallel() - var hdr [5]byte - binary.BigEndian.PutUint32(hdr[:4], 0) // length = 0 → too short - hdr[4] = 0x01 - _, _, err := wire.ReadFrame(bytes.NewReader(hdr[:])) - if err == nil || !strings.Contains(err.Error(), "too short") { - t.Errorf("want 'too short', got %v", err) - } -} - -func TestReadFrame_LengthExceedsMax(t *testing.T) { - t.Parallel() - var hdr [5]byte - binary.BigEndian.PutUint32(hdr[:4], wire.MaxMessageSize+1) - _, _, err := wire.ReadFrame(bytes.NewReader(hdr[:])) - if err == nil || !strings.Contains(err.Error(), "too large") { - t.Errorf("want 'too large', got %v", err) - } -} - -func TestReadFrame_PayloadTruncated(t *testing.T) { - t.Parallel() - var hdr [5]byte - binary.BigEndian.PutUint32(hdr[:4], 100) // claims 99 bytes of payload - hdr[4] = 0x01 - _, _, err := wire.ReadFrame(bytes.NewReader(append(hdr[:], []byte("short")...))) - if !errors.Is(err, io.ErrUnexpectedEOF) { - t.Errorf("want ErrUnexpectedEOF, got %v", err) - } -} - -func TestWriteFrame_HeaderWriteError(t *testing.T) { - t.Parallel() - bang := errors.New("hdr-fail") - err := wire.WriteFrame(&failingWriter{err: bang}, 0x01, []byte("xx")) - if !errors.Is(err, bang) { - t.Errorf("want hdr-fail, got %v", err) - } -} - -func TestWriteFrame_PayloadWriteError(t *testing.T) { - t.Parallel() - bang := errors.New("payload-fail") - // allow 5-byte header, fail body - err := wire.WriteFrame(&shortWriter{allow: 5, err: bang}, 0x01, []byte("hello")) - if !errors.Is(err, bang) { - t.Errorf("want payload-fail, got %v", err) - } -} diff --git a/pkg/registry/wire/zz_fuzz_wire_test.go b/pkg/registry/wire/zz_fuzz_wire_test.go deleted file mode 100644 index c0a98781..00000000 --- a/pkg/registry/wire/zz_fuzz_wire_test.go +++ /dev/null @@ -1,216 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package wire_test - -import ( - "bytes" - "encoding/binary" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -// FuzzReadFrame exercises the binary frame reader. -// Wire format: [4B length][1B type][payload]. The length field is -// length-prefixed; a malicious or buggy peer could send a 4-byte header -// that claims gigabytes — the MaxMessageSize cap should keep it bounded, -// but fuzzing confirms no panic / OOM regression slips in. -func FuzzReadFrame(f *testing.F) { - // Seed: valid empty-payload JSON frame. - { - var buf bytes.Buffer - wire.WriteFrame(&buf, wire.MsgJSON, []byte("{}")) - f.Add(buf.Bytes()) - } - // Seed: valid heartbeat req. - { - var buf bytes.Buffer - wire.WriteFrame(&buf, wire.MsgHeartbeat, wire.EncodeHeartbeatReq(42, make([]byte, 64))) - f.Add(buf.Bytes()) - } - // Seed: lookup req. - { - var buf bytes.Buffer - wire.WriteFrame(&buf, wire.MsgLookup, wire.EncodeLookupReq(0xDEADBEEF)) - f.Add(buf.Bytes()) - } - // Adversarial: huge length field, no body. - { - var hdr [5]byte - binary.BigEndian.PutUint32(hdr[:4], 0xFFFFFFFF) - hdr[4] = wire.MsgJSON - f.Add(hdr[:]) - } - // Adversarial: length=0 (below the "must include type byte" minimum). - { - var hdr [5]byte - binary.BigEndian.PutUint32(hdr[:4], 0) - hdr[4] = wire.MsgJSON - f.Add(hdr[:]) - } - f.Add([]byte{}) - f.Add(make([]byte, 4)) - - f.Fuzz(func(t *testing.T, data []byte) { - defer func() { - if r := recover(); r != nil { - t.Errorf("panic on input %x: %v", data, r) - } - }() - r := bytes.NewReader(data) - _, _, _ = wire.ReadFrame(r) - }) -} - -// FuzzDecodeLookupResp targets the wire-controlled allocation path -// flagged in PILOT-131. The decoder pulls counts (network count, tag -// count, length-prefixed fields) directly from the input — a 16-bit -// network count or 8-bit tag count drives `make([]uint16, n)` / -// `make([]string, n)`. Truncated inputs must surface as errors, not -// panics, and not unbounded allocations. -func FuzzDecodeLookupResp(f *testing.F) { - f.Add(wire.EncodeLookupResp(1, false, false, nil, nil, "", nil, "", "")) - f.Add(wire.EncodeLookupResp(0xDEADBEEF, true, true, - []uint16{1, 2, 3}, []byte("pubkey"), "host", []string{"a", "b"}, - "1.2.3.4:5", "extid")) - f.Add(wire.EncodeLookupResp(7, true, false, - []uint16{42}, bytes.Repeat([]byte{0x55}, 255), "h", []string{"tag"}, - "", "")) - - // Adversarial: header claims many networks but no body follows. - { - buf := make([]byte, 11) - binary.BigEndian.PutUint32(buf[:4], 1) - buf[4] = 0 - // reserved (4) zero - binary.BigEndian.PutUint16(buf[9:11], 0xFFFF) // claim 65535 networks - f.Add(buf) - } - // Adversarial: pubkey_len > remaining bytes. - { - buf := make([]byte, 12) - binary.BigEndian.PutUint32(buf[:4], 1) - // reserved + netcount = 0 - buf[11] = 0xFF // pubkey_len = 255 - f.Add(buf) - } - // Minimum-size buffer. - f.Add(make([]byte, 11)) - f.Add([]byte{}) - - f.Fuzz(func(t *testing.T, data []byte) { - defer func() { - if r := recover(); r != nil { - t.Errorf("panic on input %x: %v", data, r) - } - }() - _, _ = wire.DecodeLookupResp(data) - }) -} - -// FuzzDecodeResolveResp covers the resolve response decoder which has -// the same wire-controlled allocation shape (count + length-prefixed -// LAN addrs). -func FuzzDecodeResolveResp(f *testing.F) { - f.Add(wire.EncodeResolveResp(1, "1.2.3.4:5", nil, 0)) - f.Add(wire.EncodeResolveResp(2, "10.0.0.1:9000", - []string{"192.168.1.1", "10.0.0.5"}, 30)) - f.Add(wire.EncodeResolveResp(3, "", nil, -1)) - f.Add(make([]byte, 12)) - f.Add([]byte{}) - - // Adversarial: LAN count overflow. - { - buf := make([]byte, 8) - binary.BigEndian.PutUint32(buf[:4], 1) - binary.BigEndian.PutUint16(buf[4:6], 0) // addr_len = 0 - binary.BigEndian.PutUint16(buf[6:8], 0xFFFF) // 65535 LAN addrs - f.Add(buf) - } - - f.Fuzz(func(t *testing.T, data []byte) { - defer func() { - if r := recover(); r != nil { - t.Errorf("panic on input %x: %v", data, r) - } - }() - _, _ = wire.DecodeResolveResp(data) - }) -} - -// FuzzDecodeHeartbeatReq / Resp / LookupReq / Error are simple -// fixed-shape decoders — fuzz them anyway since they're entry points. -func FuzzDecodeHeartbeatReq(f *testing.F) { - f.Add(wire.EncodeHeartbeatReq(1, make([]byte, 64))) - f.Add(wire.EncodeHeartbeatReq(0xFFFFFFFF, bytes.Repeat([]byte{0xAA}, 64))) - f.Add([]byte{}) - f.Add(make([]byte, 67)) // one byte short - - f.Fuzz(func(t *testing.T, data []byte) { - defer func() { - if r := recover(); r != nil { - t.Errorf("panic on input %x: %v", data, r) - } - }() - _, _ = wire.DecodeHeartbeatReq(data) - }) -} - -func FuzzDecodeError(f *testing.F) { - f.Add(wire.EncodeError("oh no")) - f.Add(wire.EncodeError("")) - f.Add([]byte{0x00}) - - f.Fuzz(func(t *testing.T, data []byte) { - defer func() { - if r := recover(); r != nil { - t.Errorf("panic on input %x: %v", data, r) - } - }() - _ = wire.DecodeError(data) - }) -} - -// FuzzReadMessage exercises the JSON length-prefixed message reader. -// The 4-byte length is wire-controlled; the MaxMessageSize check is the -// only guard against `make([]byte, hugeLength)`. Verify no panic and no -// OOM-by-allocation regression. -func FuzzReadMessage(f *testing.F) { - // Seed: valid 2-byte JSON `{}`. - { - var buf bytes.Buffer - _ = wire.WriteMessage(&buf, map[string]interface{}{}) - f.Add(buf.Bytes()) - } - { - var buf bytes.Buffer - _ = wire.WriteMessage(&buf, map[string]interface{}{ - "op": "lookup", "node_id": float64(42), - }) - f.Add(buf.Bytes()) - } - // Adversarial: header claims big payload. - { - hdr := make([]byte, 4) - binary.BigEndian.PutUint32(hdr, 0xFFFFFFFF) - f.Add(hdr) - } - // Length declares 4GB but no body follows. - { - hdr := make([]byte, 4) - binary.BigEndian.PutUint32(hdr, 0x7FFFFFFF) - f.Add(hdr) - } - f.Add([]byte{}) - f.Add(make([]byte, 3)) - - f.Fuzz(func(t *testing.T, data []byte) { - defer func() { - if r := recover(); r != nil { - t.Errorf("panic on input %x: %v", data, r) - } - }() - r := bytes.NewReader(data) - _, _ = wire.ReadMessage(r) - }) -} diff --git a/pkg/registry/wire/zz_message_framing_test.go b/pkg/registry/wire/zz_message_framing_test.go deleted file mode 100644 index fe1a20de..00000000 --- a/pkg/registry/wire/zz_message_framing_test.go +++ /dev/null @@ -1,151 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package wire_test - -import ( - "bytes" - "encoding/binary" - "errors" - "io" - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -// failingWriter returns the supplied error on every Write call. Used to -// exercise the early-return branch of WriteRawMessage. -type failingWriter struct{ err error } - -func (f *failingWriter) Write(p []byte) (int, error) { return 0, f.err } - -// shortWriter accepts the first N bytes then errors. Used to fail the -// SECOND Write inside WriteRawMessage. -type shortWriter struct { - allow int - err error -} - -func (s *shortWriter) Write(p []byte) (int, error) { - if s.allow >= len(p) { - s.allow -= len(p) - return len(p), nil - } - return 0, s.err -} - -func TestWriteReadMessageRoundTrip(t *testing.T) { - t.Parallel() - msg := map[string]interface{}{ - "op": "register", - "email": "a@b.co", - "port": float64(4000), // json.Unmarshal turns numbers into float64 - } - var buf bytes.Buffer - if err := wire.WriteMessage(&buf, msg); err != nil { - t.Fatalf("WriteMessage: %v", err) - } - got, err := wire.ReadMessage(&buf) - if err != nil { - t.Fatalf("ReadMessage: %v", err) - } - for k, v := range msg { - if got[k] != v { - t.Errorf("key %q: got %v (%T), want %v (%T)", k, got[k], got[k], v, v) - } - } -} - -func TestWriteMessageJSONEncodeError(t *testing.T) { - t.Parallel() - // channels can't be JSON-encoded → json.Marshal fails. - bad := map[string]interface{}{"ch": make(chan int)} - err := wire.WriteMessage(&bytes.Buffer{}, bad) - if err == nil || !strings.Contains(err.Error(), "json encode") { - t.Fatalf("want 'json encode' err, got %v", err) - } -} - -func TestReadMessageTooLarge(t *testing.T) { - t.Parallel() - // Synthesise a length prefix > MaxMessageSize without writing the body. - var lenBuf [4]byte - binary.BigEndian.PutUint32(lenBuf[:], wire.MaxMessageSize+1) - r := bytes.NewReader(lenBuf[:]) - _, err := wire.ReadMessage(r) - if err == nil || !strings.Contains(err.Error(), "too large") { - t.Fatalf("want 'too large' err, got %v", err) - } -} - -func TestReadMessageEOFOnPrefix(t *testing.T) { - t.Parallel() - // Empty reader → io.EOF on the length prefix read. - _, err := wire.ReadMessage(bytes.NewReader(nil)) - if !errors.Is(err, io.EOF) { - t.Fatalf("want io.EOF, got %v", err) - } -} - -func TestReadMessageTruncatedBody(t *testing.T) { - t.Parallel() - // 100-byte prefix but only 5 bytes of body → ErrUnexpectedEOF. - var lenBuf [4]byte - binary.BigEndian.PutUint32(lenBuf[:], 100) - r := bytes.NewReader(append(lenBuf[:], []byte("short")...)) - _, err := wire.ReadMessage(r) - if !errors.Is(err, io.ErrUnexpectedEOF) { - t.Fatalf("want ErrUnexpectedEOF, got %v", err) - } -} - -func TestReadMessageBadJSON(t *testing.T) { - t.Parallel() - body := []byte("{not json") - var lenBuf [4]byte - binary.BigEndian.PutUint32(lenBuf[:], uint32(len(body))) - r := bytes.NewReader(append(lenBuf[:], body...)) - _, err := wire.ReadMessage(r) - if err == nil || !strings.Contains(err.Error(), "json decode") { - t.Fatalf("want 'json decode' err, got %v", err) - } -} - -func TestWriteRawMessageHappyPath(t *testing.T) { - t.Parallel() - body := []byte(`{"ok":true}`) - var buf bytes.Buffer - if err := wire.WriteRawMessage(&buf, body); err != nil { - t.Fatalf("WriteRawMessage: %v", err) - } - // Verify prefix - if buf.Len() != 4+len(body) { - t.Fatalf("buf len %d, want %d", buf.Len(), 4+len(body)) - } - gotLen := binary.BigEndian.Uint32(buf.Bytes()[:4]) - if int(gotLen) != len(body) { - t.Errorf("length prefix %d, want %d", gotLen, len(body)) - } - if !bytes.Equal(buf.Bytes()[4:], body) { - t.Errorf("body mismatch: got %q", buf.Bytes()[4:]) - } -} - -func TestWriteRawMessageErrorOnPrefix(t *testing.T) { - t.Parallel() - bang := errors.New("boom") - err := wire.WriteRawMessage(&failingWriter{err: bang}, []byte(`{}`)) - if !errors.Is(err, bang) { - t.Fatalf("want boom, got %v", err) - } -} - -func TestWriteRawMessageErrorOnBody(t *testing.T) { - t.Parallel() - bang := errors.New("boom") - // allow the 4-byte prefix, fail on the body write - err := wire.WriteRawMessage(&shortWriter{allow: 4, err: bang}, []byte(`{"x":1}`)) - if !errors.Is(err, bang) { - t.Fatalf("want boom on body write, got %v", err) - } -} diff --git a/pkg/registry/wire/zz_rules_test.go b/pkg/registry/wire/zz_rules_test.go deleted file mode 100644 index 64f5f924..00000000 --- a/pkg/registry/wire/zz_rules_test.go +++ /dev/null @@ -1,334 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package wire_test - -import ( - "encoding/json" - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -// --- ValidateRules error branches ---------------------------------------- - -func TestValidateRulesNilReturnsNil(t *testing.T) { - t.Parallel() - if err := wire.ValidateRules(nil); err != nil { - t.Fatalf("nil rules: %v", err) - } -} - -func TestValidateRulesLinksRequired(t *testing.T) { - t.Parallel() - cases := []int{0, -5} - for _, l := range cases { - r := &wire.NetworkRules{Links: l, Cycle: "1h", PruneBy: "score", FillHow: "random"} - err := wire.ValidateRules(r) - if err == nil || !strings.Contains(err.Error(), "links must be >= 1") { - t.Fatalf("links=%d: %v", l, err) - } - } -} - -func TestValidateRulesCycleRequired(t *testing.T) { - t.Parallel() - r := &wire.NetworkRules{Links: 5, Cycle: "", PruneBy: "score", FillHow: "random"} - err := wire.ValidateRules(r) - if err == nil || !strings.Contains(err.Error(), "cycle is required") { - t.Fatalf("expected cycle-required error, got %v", err) - } -} - -func TestValidateRulesCycleInvalidDuration(t *testing.T) { - t.Parallel() - r := &wire.NetworkRules{Links: 5, Cycle: "not-a-duration", PruneBy: "score", FillHow: "random"} - err := wire.ValidateRules(r) - if err == nil || !strings.Contains(err.Error(), "invalid cycle duration") { - t.Fatalf("%v", err) - } -} - -func TestValidateRulesCycleTooShort(t *testing.T) { - t.Parallel() - r := &wire.NetworkRules{Links: 5, Cycle: "30s", PruneBy: "score", FillHow: "random"} - err := wire.ValidateRules(r) - if err == nil || !strings.Contains(err.Error(), "cycle must be >= 1m") { - t.Fatalf("%v", err) - } -} - -func TestValidateRulesPruneFillNegativeOrOverflow(t *testing.T) { - t.Parallel() - base := wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "score", FillHow: "random"} - // Prune < 0 - r := base - r.Prune = -1 - if err := wire.ValidateRules(&r); err == nil || !strings.Contains(err.Error(), "prune must be >= 0") { - t.Fatalf("prune<0: %v", err) - } - // Fill < 0 - r = base - r.Fill = -1 - if err := wire.ValidateRules(&r); err == nil || !strings.Contains(err.Error(), "fill must be >= 0") { - t.Fatalf("fill<0: %v", err) - } - // Prune > Links - r = base - r.Prune = 10 - if err := wire.ValidateRules(&r); err == nil || !strings.Contains(err.Error(), "cannot exceed links") { - t.Fatalf("prune>links: %v", err) - } - // Fill > Links - r = base - r.Fill = 10 - if err := wire.ValidateRules(&r); err == nil || !strings.Contains(err.Error(), "fill (10) cannot exceed links") { - t.Fatalf("fill>links: %v", err) - } -} - -func TestValidateRulesPruneByAllValidValues(t *testing.T) { - t.Parallel() - for _, pb := range []string{"score", "age", "activity"} { - r := &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: pb, FillHow: "random"} - if err := wire.ValidateRules(r); err != nil { - t.Fatalf("prune_by=%q: %v", pb, err) - } - } -} - -func TestValidateRulesPruneByRequiredAndUnknown(t *testing.T) { - t.Parallel() - // Empty - r := &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "", FillHow: "random"} - if err := wire.ValidateRules(r); err == nil || !strings.Contains(err.Error(), "prune_by is required") { - t.Fatalf("empty prune_by: %v", err) - } - // Unknown - r = &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "lottery", FillHow: "random"} - if err := wire.ValidateRules(r); err == nil || !strings.Contains(err.Error(), "unknown prune_by strategy") { - t.Fatalf("unknown prune_by: %v", err) - } -} - -func TestValidateRulesFillHowRequiredAndUnknown(t *testing.T) { - t.Parallel() - r := &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "score", FillHow: ""} - if err := wire.ValidateRules(r); err == nil || !strings.Contains(err.Error(), "fill_how is required") { - t.Fatalf("empty fill_how: %v", err) - } - r = &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "score", FillHow: "roundrobin"} - if err := wire.ValidateRules(r); err == nil || !strings.Contains(err.Error(), "unknown fill_how strategy") { - t.Fatalf("unknown fill_how: %v", err) - } -} - -func TestValidateRulesGraceInvalid(t *testing.T) { - t.Parallel() - r := &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "score", FillHow: "random", Grace: "not-a-duration"} - if err := wire.ValidateRules(r); err == nil || !strings.Contains(err.Error(), "invalid grace duration") { - t.Fatalf("bad grace: %v", err) - } - // Note: time.ParseDuration rejects literal negatives like "-1m" for some inputs. - // We rely on the `g < 0` branch being effectively unreachable via parsing in practice, - // but verify parseable non-negative grace succeeds. -} - -func TestValidateRulesGraceEmptyOrValid(t *testing.T) { - t.Parallel() - r := &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "score", FillHow: "random", Grace: ""} - if err := wire.ValidateRules(r); err != nil { - t.Fatalf("empty grace: %v", err) - } - r.Grace = "10m" - if err := wire.ValidateRules(r); err != nil { - t.Fatalf("valid grace: %v", err) - } -} - -func TestValidateRulesHappyPath(t *testing.T) { - t.Parallel() - r := &wire.NetworkRules{Links: 10, Cycle: "1h", Prune: 2, PruneBy: "score", Fill: 2, FillHow: "random", Grace: "5m"} - if err := wire.ValidateRules(r); err != nil { - t.Fatalf("happy: %v", err) - } -} - -// --- ParseRules ----------------------------------------------------------- - -func TestParseRulesBadJSON(t *testing.T) { - t.Parallel() - _, err := wire.ParseRules(`{not json`) - if err == nil || !strings.Contains(err.Error(), "invalid JSON") { - t.Fatalf("%v", err) - } -} - -func TestParseRulesInvalidRules(t *testing.T) { - t.Parallel() - _, err := wire.ParseRules(`{"links":0,"cycle":"1h","prune_by":"score","fill_how":"random"}`) - if err == nil || !strings.Contains(err.Error(), "links must be >= 1") { - t.Fatalf("%v", err) - } -} - -func TestParseRulesHappyPath(t *testing.T) { - t.Parallel() - r, err := wire.ParseRules(`{"links":5,"cycle":"1h","prune":1,"prune_by":"age","fill":1,"fill_how":"random"}`) - if err != nil { - t.Fatalf("%v", err) - } - if r.Links != 5 || r.Cycle != "1h" || r.Prune != 1 || r.PruneBy != "age" || r.Fill != 1 || r.FillHow != "random" { - t.Fatalf("parsed: %+v", r) - } -} - -// --- RulesToPolicy -------------------------------------------------------- - -func TestRulesToPolicyNilReturnsNilNil(t *testing.T) { - t.Parallel() - raw, err := wire.RulesToPolicy(nil) - if err != nil { - t.Fatalf("%v", err) - } - if raw != nil { - t.Fatalf("expected nil json.RawMessage for nil rules, got %s", string(raw)) - } -} - -func TestRulesToPolicyShapeAndContentWithoutGrace(t *testing.T) { - t.Parallel() - r := &wire.NetworkRules{Links: 7, Cycle: "2h", Prune: 3, PruneBy: "age", Fill: 2, FillHow: "random"} - raw, err := wire.RulesToPolicy(r) - if err != nil { - t.Fatalf("%v", err) - } - var doc map[string]interface{} - if err := json.Unmarshal(raw, &doc); err != nil { - t.Fatalf("%v", err) - } - if doc["version"].(float64) != 1 { - t.Fatalf("version: %v", doc["version"]) - } - cfg := doc["config"].(map[string]interface{}) - if cfg["max_peers"].(float64) != 7 { - t.Fatalf("max_peers: %v", cfg["max_peers"]) - } - if cfg["cycle"].(string) != "2h" { - t.Fatalf("cycle: %v", cfg["cycle"]) - } - if _, hasGrace := cfg["grace"]; hasGrace { - t.Fatalf("grace should be absent when Grace=\"\"") - } - rules := doc["rules"].([]interface{}) - if len(rules) != 1 { - t.Fatalf("rules count: %d", len(rules)) - } - // rule[0] = cycle-prune-fill; prune action first, fill action second - r1 := rules[0].(map[string]interface{}) - if r1["name"].(string) != "cycle-prune-fill" || r1["on"].(string) != "cycle" { - t.Fatalf("rule 0: %+v", r1) - } - actions := r1["actions"].([]interface{}) - pruneA := actions[0].(map[string]interface{}) - if pruneA["type"].(string) != "prune" { - t.Fatalf("first action: %+v", pruneA) - } - params := pruneA["params"].(map[string]interface{}) - if params["count"].(float64) != 3 || params["by"].(string) != "age" { - t.Fatalf("prune params: %+v", params) - } - fillA := actions[1].(map[string]interface{}) - if fillA["type"].(string) != "fill" { - t.Fatalf("second action: %+v", fillA) - } - fillP := fillA["params"].(map[string]interface{}) - if fillP["count"].(float64) != 2 || fillP["how"].(string) != "random" { - t.Fatalf("fill params: %+v", fillP) - } -} - -func TestRulesToPolicyIncludesGraceWhenSet(t *testing.T) { - t.Parallel() - r := &wire.NetworkRules{Links: 5, Cycle: "1h", Prune: 1, PruneBy: "score", Fill: 1, FillHow: "random", Grace: "15m"} - raw, err := wire.RulesToPolicy(r) - if err != nil { - t.Fatalf("%v", err) - } - var doc map[string]interface{} - _ = json.Unmarshal(raw, &doc) - cfg := doc["config"].(map[string]interface{}) - if cfg["grace"].(string) != "15m" { - t.Fatalf("grace: %v", cfg["grace"]) - } -} - -// --- AllowedPortsToPolicy ------------------------------------------------- - -func TestAllowedPortsToPolicyEmptyReturnsNilNil(t *testing.T) { - t.Parallel() - raw, err := wire.AllowedPortsToPolicy(nil) - if err != nil || raw != nil { - t.Fatalf("nil ports: raw=%v err=%v", raw, err) - } - raw, err = wire.AllowedPortsToPolicy([]uint16{}) - if err != nil || raw != nil { - t.Fatalf("empty ports: raw=%v err=%v", raw, err) - } -} - -func TestAllowedPortsToPolicyMatchExpressionAndRules(t *testing.T) { - t.Parallel() - raw, err := wire.AllowedPortsToPolicy([]uint16{80, 443, 7001}) - if err != nil { - t.Fatalf("%v", err) - } - // Raw text contains the exact match expression. - s := string(raw) - if !strings.Contains(s, `"port in [80, 443, 7001]"`) { - t.Fatalf("match expr not formatted as expected:\n%s", s) - } - var doc map[string]interface{} - if err := json.Unmarshal(raw, &doc); err != nil { - t.Fatalf("%v", err) - } - if doc["version"].(float64) != 1 { - t.Fatalf("version: %v", doc["version"]) - } - rules := doc["rules"].([]interface{}) - if len(rules) != 6 { - t.Fatalf("rules count: %d, want 6 (3 allow + 3 deny)", len(rules)) - } - // Expected names in order. - wantNames := []string{"allow-ports", "allow-ports-dg", "allow-ports-dial", "deny-rest", "deny-rest-dg", "deny-rest-dial"} - for i, want := range wantNames { - r := rules[i].(map[string]interface{}) - if r["name"].(string) != want { - t.Fatalf("rule[%d].name = %q, want %q", i, r["name"], want) - } - } - // Allow rules use the built match expr; deny rules use "true". - for i := 0; i < 3; i++ { - r := rules[i].(map[string]interface{}) - if r["match"].(string) != "port in [80, 443, 7001]" { - t.Fatalf("allow rule[%d] match: %q", i, r["match"]) - } - } - for i := 3; i < 6; i++ { - r := rules[i].(map[string]interface{}) - if r["match"].(string) != "true" { - t.Fatalf("deny rule[%d] match: %q", i, r["match"]) - } - } -} - -func TestAllowedPortsToPolicySinglePort(t *testing.T) { - t.Parallel() - raw, err := wire.AllowedPortsToPolicy([]uint16{7}) - if err != nil { - t.Fatalf("%v", err) - } - if !strings.Contains(string(raw), `"port in [7]"`) { - t.Fatalf("single-port match expr:\n%s", string(raw)) - } -} diff --git a/pkg/registry/wire/zz_wire_test.go b/pkg/registry/wire/zz_wire_test.go deleted file mode 100644 index e91014b8..00000000 --- a/pkg/registry/wire/zz_wire_test.go +++ /dev/null @@ -1,415 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package wire_test - -import ( - "bytes" - "math" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" -) - -func TestWireFrameRoundTrip(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - msgType byte - payload []byte - }{ - {"empty payload", wire.MsgHeartbeat, nil}, - {"small payload", wire.MsgLookup, []byte{1, 2, 3, 4}}, - {"max type", wire.MsgError, []byte("test error")}, - {"json passthrough", wire.MsgJSON, []byte(`{"type":"heartbeat","node_id":42}`)}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var buf bytes.Buffer - if err := wire.WriteFrame(&buf, tt.msgType, tt.payload); err != nil { - t.Fatalf("write frame: %v", err) - } - - gotType, gotPayload, err := wire.ReadFrame(&buf) - if err != nil { - t.Fatalf("read frame: %v", err) - } - if gotType != tt.msgType { - t.Fatalf("type: got 0x%02x, want 0x%02x", gotType, tt.msgType) - } - if !bytes.Equal(gotPayload, tt.payload) { - t.Fatalf("payload: got %v, want %v", gotPayload, tt.payload) - } - }) - } -} - -func TestWireFrameMultipleMessages(t *testing.T) { - t.Parallel() - - var buf bytes.Buffer - for i := 0; i < 10; i++ { - payload := []byte{byte(i), byte(i + 1)} - if err := wire.WriteFrame(&buf, byte(i), payload); err != nil { - t.Fatalf("write frame %d: %v", i, err) - } - } - - for i := 0; i < 10; i++ { - gotType, gotPayload, err := wire.ReadFrame(&buf) - if err != nil { - t.Fatalf("read frame %d: %v", i, err) - } - if gotType != byte(i) { - t.Fatalf("frame %d type: got 0x%02x, want 0x%02x", i, gotType, byte(i)) - } - if len(gotPayload) != 2 || gotPayload[0] != byte(i) || gotPayload[1] != byte(i+1) { - t.Fatalf("frame %d payload mismatch", i) - } - } -} - -func TestWireFrameTooLarge(t *testing.T) { - t.Parallel() - - // Write a frame claiming a payload larger than MaxMessageSize - var buf bytes.Buffer - wire.WriteFrame(&buf, wire.MsgJSON, make([]byte, wire.MaxMessageSize+1)) - - _, _, err := wire.ReadFrame(&buf) - if err == nil { - t.Fatal("expected error for oversized frame") - } -} - -func TestHeartbeatReqRoundTrip(t *testing.T) { - t.Parallel() - - var sig [64]byte - for i := range sig { - sig[i] = byte(i) - } - - payload := wire.EncodeHeartbeatReq(42, sig[:]) - req, err := wire.DecodeHeartbeatReq(payload) - if err != nil { - t.Fatalf("decode: %v", err) - } - if req.NodeID != 42 { - t.Fatalf("nodeID: got %d, want 42", req.NodeID) - } - if req.Signature != sig { - t.Fatal("signature mismatch") - } -} - -func TestHeartbeatReqTooShort(t *testing.T) { - t.Parallel() - _, err := wire.DecodeHeartbeatReq([]byte{1, 2, 3}) - if err == nil { - t.Fatal("expected error for short payload") - } -} - -func TestHeartbeatRespRoundTrip(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - unixTime int64 - keyExpiryWarning bool - }{ - {"no warning", 1700000000, false}, - {"with warning", 1700000000, true}, - {"zero time", 0, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - payload := wire.EncodeHeartbeatResp(tt.unixTime, tt.keyExpiryWarning) - gotTime, gotWarning, err := wire.DecodeHeartbeatResp(payload) - if err != nil { - t.Fatalf("decode: %v", err) - } - if gotTime != tt.unixTime { - t.Fatalf("time: got %d, want %d", gotTime, tt.unixTime) - } - if gotWarning != tt.keyExpiryWarning { - t.Fatalf("warning: got %v, want %v", gotWarning, tt.keyExpiryWarning) - } - }) - } -} - -func TestLookupReqRoundTrip(t *testing.T) { - t.Parallel() - - payload := wire.EncodeLookupReq(12345) - nodeID, err := wire.DecodeLookupReq(payload) - if err != nil { - t.Fatalf("decode: %v", err) - } - if nodeID != 12345 { - t.Fatalf("nodeID: got %d, want 12345", nodeID) - } -} - -func TestLookupRespRoundTrip(t *testing.T) { - t.Parallel() - - pubKey := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} - - payload := wire.EncodeLookupResp( - 42, // nodeID - true, // public - true, // taskExec - []uint16{1, 2, 3}, // networks - pubKey, // pubKey - "test-host", // hostname - []string{"svc", "primary"}, // tags - "10.0.0.1:4000", // realAddr - "ext-123", // externalID - ) - - result, err := wire.DecodeLookupResp(payload) - if err != nil { - t.Fatalf("decode: %v", err) - } - - if result.NodeID != 42 { - t.Fatalf("NodeID: got %d, want 42", result.NodeID) - } - if !result.Public { - t.Fatal("expected Public=true") - } - if !result.TaskExec { - t.Fatal("expected TaskExec=true") - } - if len(result.Networks) != 3 || result.Networks[0] != 1 || result.Networks[2] != 3 { - t.Fatalf("Networks: got %v, want [1,2,3]", result.Networks) - } - if !bytes.Equal(result.PubKey, pubKey) { - t.Fatal("PubKey mismatch") - } - if result.Hostname != "test-host" { - t.Fatalf("Hostname: got %q, want %q", result.Hostname, "test-host") - } - if len(result.Tags) != 2 || result.Tags[0] != "svc" || result.Tags[1] != "primary" { - t.Fatalf("Tags: got %v", result.Tags) - } - if result.RealAddr != "10.0.0.1:4000" { - t.Fatalf("RealAddr: got %q", result.RealAddr) - } - if result.ExternalID != "ext-123" { - t.Fatalf("ExternalID: got %q", result.ExternalID) - } -} - -func TestLookupRespMinimal(t *testing.T) { - t.Parallel() - - payload := wire.EncodeLookupResp(1, false, false, nil, nil, "", nil, "", "") - result, err := wire.DecodeLookupResp(payload) - if err != nil { - t.Fatalf("decode: %v", err) - } - if result.NodeID != 1 { - t.Fatalf("NodeID: got %d, want 1", result.NodeID) - } - if result.Public || result.TaskExec { - t.Fatal("expected both flags false") - } - if len(result.Networks) != 0 { - t.Fatal("expected empty networks") - } - if len(result.PubKey) != 0 { - t.Fatal("expected empty pubkey") - } -} - -func TestResolveReqRoundTrip(t *testing.T) { - t.Parallel() - - sig := make([]byte, 64) - for i := range sig { - sig[i] = byte(i + 100) - } - - payload := wire.EncodeResolveReq(10, 20, sig) - nodeID, requesterID, gotSig, err := wire.DecodeResolveReq(payload) - if err != nil { - t.Fatalf("decode: %v", err) - } - if nodeID != 10 { - t.Fatalf("nodeID: got %d, want 10", nodeID) - } - if requesterID != 20 { - t.Fatalf("requesterID: got %d, want 20", requesterID) - } - if !bytes.Equal(gotSig, sig) { - t.Fatal("signature mismatch") - } -} - -func TestResolveRespRoundTrip(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - nodeID uint32 - realAddr string - lanAddrs []string - keyAgeDays int - }{ - {"basic", 42, "10.0.0.1:4000", nil, 30}, - {"with LANs", 42, "10.0.0.1:4000", []string{"192.168.1.1:4000", "192.168.2.1:4000"}, 30}, - {"unknown key age", 42, "10.0.0.1:4000", nil, -1}, - {"zero key age", 42, "10.0.0.1:4000", nil, 0}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - payload := wire.EncodeResolveResp(tt.nodeID, tt.realAddr, tt.lanAddrs, tt.keyAgeDays) - result, err := wire.DecodeResolveResp(payload) - if err != nil { - t.Fatalf("decode: %v", err) - } - if result.NodeID != tt.nodeID { - t.Fatalf("NodeID: got %d, want %d", result.NodeID, tt.nodeID) - } - if result.RealAddr != tt.realAddr { - t.Fatalf("RealAddr: got %q, want %q", result.RealAddr, tt.realAddr) - } - if len(result.LANAddrs) != len(tt.lanAddrs) { - t.Fatalf("LANAddrs length: got %d, want %d", len(result.LANAddrs), len(tt.lanAddrs)) - } - for i, la := range result.LANAddrs { - if la != tt.lanAddrs[i] { - t.Fatalf("LANAddrs[%d]: got %q, want %q", i, la, tt.lanAddrs[i]) - } - } - if result.KeyAgeDays != tt.keyAgeDays { - t.Fatalf("KeyAgeDays: got %d, want %d", result.KeyAgeDays, tt.keyAgeDays) - } - }) - } -} - -func TestResolveRespMaxKeyAge(t *testing.T) { - t.Parallel() - - // Verify math.MaxUint32 maps to -1 - payload := wire.EncodeResolveResp(1, "addr", nil, -1) - result, err := wire.DecodeResolveResp(payload) - if err != nil { - t.Fatalf("decode: %v", err) - } - if result.KeyAgeDays != -1 { - t.Fatalf("KeyAgeDays: got %d, want -1", result.KeyAgeDays) - } - - // Verify large positive value round-trips - payload = wire.EncodeResolveResp(1, "addr", nil, int(math.MaxUint32-1)) - result, err = wire.DecodeResolveResp(payload) - if err != nil { - t.Fatalf("decode: %v", err) - } - if result.KeyAgeDays != int(math.MaxUint32-1) { - t.Fatalf("KeyAgeDays: got %d, want %d", result.KeyAgeDays, math.MaxUint32-1) - } -} - -func TestWireErrorRoundTrip(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - msg string - }{ - {"simple", "not found"}, - {"empty", ""}, - {"long", string(make([]byte, 1000))}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - payload := wire.EncodeError(tt.msg) - got := wire.DecodeError(payload) - if got != tt.msg { - t.Fatalf("got %q, want %q", got, tt.msg) - } - }) - } -} - -func TestWireErrorTruncation(t *testing.T) { - t.Parallel() - - // Messages > 65000 are truncated - longMsg := string(make([]byte, 70000)) - payload := wire.EncodeError(longMsg) - got := wire.DecodeError(payload) - if len(got) != 65000 { - t.Fatalf("expected truncated to 65000, got %d", len(got)) - } -} - -func TestWireProtocolNegotiationMagic(t *testing.T) { - t.Parallel() - - // Verify the magic bytes are correct - if wire.Magic != [4]byte{0x50, 0x49, 0x4C, 0x54} { - t.Fatalf("magic: got %v, want PILT", wire.Magic) - } - // Verify magic != any valid JSON length prefix (which must be < MaxMessageSize) - magicAsLen := uint32(wire.Magic[0])<<24 | uint32(wire.Magic[1])<<16 | uint32(wire.Magic[2])<<8 | uint32(wire.Magic[3]) - if magicAsLen <= wire.MaxMessageSize { - t.Fatalf("magic as length (%d) must be > MaxMessageSize (%d) for protocol detection", magicAsLen, wire.MaxMessageSize) - } -} - -func BenchmarkEncodeHeartbeatReq(b *testing.B) { - sig := make([]byte, 64) - for i := 0; i < b.N; i++ { - wire.EncodeHeartbeatReq(42, sig) - } -} - -func BenchmarkDecodeHeartbeatReq(b *testing.B) { - sig := make([]byte, 64) - payload := wire.EncodeHeartbeatReq(42, sig) - for i := 0; i < b.N; i++ { - wire.DecodeHeartbeatReq(payload) - } -} - -func BenchmarkEncodeLookupResp(b *testing.B) { - pubKey := make([]byte, 32) - networks := []uint16{1, 2, 3} - tags := []string{"svc", "primary"} - for i := 0; i < b.N; i++ { - wire.EncodeLookupResp(42, true, true, networks, pubKey, "test-host", tags, "10.0.0.1:4000", "ext-123") - } -} - -func BenchmarkDecodeLookupResp(b *testing.B) { - pubKey := make([]byte, 32) - networks := []uint16{1, 2, 3} - tags := []string{"svc", "primary"} - payload := wire.EncodeLookupResp(42, true, true, networks, pubKey, "test-host", tags, "10.0.0.1:4000", "ext-123") - for i := 0; i < b.N; i++ { - wire.DecodeLookupResp(payload) - } -} - -func BenchmarkWireFrameRoundTrip(b *testing.B) { - payload := make([]byte, 68) // heartbeat size - var buf bytes.Buffer - for i := 0; i < b.N; i++ { - buf.Reset() - wire.WriteFrame(&buf, wire.MsgHeartbeat, payload) - wire.ReadFrame(&buf) - } -} diff --git a/pkg/secure/client.go b/pkg/secure/client.go deleted file mode 100644 index 0b0de977..00000000 --- a/pkg/secure/client.go +++ /dev/null @@ -1,24 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package secure - -import ( - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" -) - -// Dial connects to a remote agent's secure port and performs the handshake. -// Returns an encrypted connection that implements net.Conn. -func Dial(d *driver.Driver, addr protocol.Addr, auth ...*HandshakeConfig) (*SecureConn, error) { - conn, err := d.DialAddr(addr, protocol.PortSecure) - if err != nil { - return nil, err - } - - sc, err := Handshake(conn, false, auth...) - if err != nil { - conn.Close() - return nil, err - } - return sc, nil -} diff --git a/pkg/secure/secure.go b/pkg/secure/secure.go deleted file mode 100644 index f3d94dc3..00000000 --- a/pkg/secure/secure.go +++ /dev/null @@ -1,773 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package secure - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/ecdh" - "crypto/ed25519" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "encoding/binary" - "fmt" - "io" - "net" - "sync" - "time" -) - -// MaxEncryptedMessageLen limits the maximum decrypted message size to prevent -// memory exhaustion from a malicious peer advertising a huge msgLen. -const MaxEncryptedMessageLen = 16 * 1024 * 1024 // 16 MB - -// HandshakeTimeout is the maximum time allowed for the ECDH handshake. -const HandshakeTimeout = 10 * time.Second - -// AuthFrameLen is the total size of an auth frame: -// nodeID(4) + timestamp(8) + nonce(16) + ed25519_signature(64) = 92 bytes. -const AuthFrameLen = 4 + 8 + 16 + 64 - -// authTimestampSkew is the maximum allowed time difference for auth timestamps. -const authTimestampSkew = 5 * time.Second - -// replayCacheExpiry is how long nonces are kept in the replay cache. -const replayCacheExpiry = 1 * time.Hour - -// HandshakeConfig holds identity authentication parameters for the secure -// channel handshake. If nil is passed to Handshake, authentication is skipped -// (backward compatibility for tests and unauthenticated channels). -type HandshakeConfig struct { - NodeID uint32 - Signer ed25519.PrivateKey - PeerPubKey ed25519.PublicKey -} - -// PeerPubKeyLookup returns the Ed25519 public key for a given node ID. -// Used by the server to look up a connecting client's identity for auth -// verification. Returns nil if the node is unknown. -// -// Definitive declaration of PeerPubKeyLookup; do not duplicate in this package. -type PeerPubKeyLookup func(nodeID uint32) ed25519.PublicKey - -// replayCache prevents reuse of auth nonces within a 1-hour window. -var replayCache = struct { - sync.Mutex - nonces map[[16]byte]time.Time -}{nonces: make(map[[16]byte]time.Time)} - -func init() { - go replayCacheCleaner() -} - -// replayCacheCleaner periodically removes expired nonce entries. -func replayCacheCleaner() { - ticker := time.NewTicker(5 * time.Minute) - for range ticker.C { - now := time.Now() - replayCache.Lock() - for k, t := range replayCache.nonces { - if now.Sub(t) > replayCacheExpiry { - delete(replayCache.nonces, k) - } - } - replayCache.Unlock() - } -} - -// maxReplayCacheEntries caps the replay cache to prevent memory exhaustion (M1 fix). -const maxReplayCacheEntries = 100000 - -// CheckAndRecordNonce returns an error if the nonce was already seen within -// the replay window, otherwise records it and returns nil. -func CheckAndRecordNonce(nonce [16]byte) error { - replayCache.Lock() - defer replayCache.Unlock() - if _, exists := replayCache.nonces[nonce]; exists { - return fmt.Errorf("auth nonce replay detected") - } - if len(replayCache.nonces) >= maxReplayCacheEntries { - return fmt.Errorf("auth replay cache full") - } - replayCache.nonces[nonce] = time.Now() - return nil -} - -// ResetReplayCache clears the replay cache. Exported for testing only. -func ResetReplayCache() { - replayCache.Lock() - defer replayCache.Unlock() - replayCache.nonces = make(map[[16]byte]time.Time) -} - -// InjectReplayNonce adds a nonce to the replay cache. Exported for testing only. -func InjectReplayNonce(nonce [16]byte) { - replayCache.Lock() - defer replayCache.Unlock() - replayCache.nonces[nonce] = time.Now() -} - -// CheckReplayNonce checks if a nonce is in the replay cache without recording it. -// Exported for testing only. -func CheckReplayNonce(nonce [16]byte) error { - replayCache.Lock() - defer replayCache.Unlock() - if _, exists := replayCache.nonces[nonce]; exists { - return fmt.Errorf("auth nonce replay detected") - } - return nil -} - -// HandshakeWithTimestampOffset performs an authenticated handshake but shifts -// the auth frame timestamp by the given offset. Exported for testing only. -func HandshakeWithTimestampOffset(conn net.Conn, isServer bool, cfg *HandshakeConfig, offset time.Duration) (*SecureConn, error) { - conn.SetDeadline(time.Now().Add(HandshakeTimeout)) - defer conn.SetDeadline(time.Time{}) - - curve := ecdh.X25519() - privKey, err := curve.GenerateKey(rand.Reader) - if err != nil { - return nil, fmt.Errorf("generate key: %w", err) - } - localPub := privKey.PublicKey().Bytes() - - var remotePub []byte - - if isServer { - remotePub, err = ReadExact(conn, 32) - if err != nil { - return nil, fmt.Errorf("read client key: %w", err) - } - if _, err := conn.Write(localPub); err != nil { - return nil, fmt.Errorf("send server key: %w", err) - } - } else { - if _, err := conn.Write(localPub); err != nil { - return nil, fmt.Errorf("send client key: %w", err) - } - remotePub, err = ReadExact(conn, 32) - if err != nil { - return nil, fmt.Errorf("read server key: %w", err) - } - } - - peerKey, err := curve.NewPublicKey(remotePub) - if err != nil { - return nil, fmt.Errorf("parse peer key: %w", err) - } - - shared, err := privKey.ECDH(peerKey) - if err != nil { - return nil, fmt.Errorf("ecdh: %w", err) - } - - // HKDF-SHA256 key derivation (H1 fix) - mac := hmac.New(sha256.New, nil) - mac.Write(shared) - prk := mac.Sum(nil) - mac = hmac.New(sha256.New, prk) - mac.Write([]byte("pilot-secure-v1")) - mac.Write([]byte{0x01}) - key := mac.Sum(nil) - - block, err := aes.NewCipher(key) - if err != nil { - return nil, fmt.Errorf("aes: %w", err) - } - aead, err := cipher.NewGCM(block) - if err != nil { - return nil, fmt.Errorf("gcm: %w", err) - } - - // Zero intermediate key material (H4 fix) - for i := range shared { - shared[i] = 0 - } - for i := range key { - key[i] = 0 - } - for i := range prk { - prk[i] = 0 - } - - sc := &SecureConn{raw: conn, aead: aead} - if isServer { - sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x01} - } else { - sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x02} - } - - if cfg != nil && cfg.Signer != nil { - if err := performAuthWithOffset(sc, cfg, localPub, remotePub, isServer, offset); err != nil { - sc.Close() - return nil, fmt.Errorf("auth: %w", err) - } - } - - return sc, nil -} - -// performAuthWithOffset is like performAuth but applies a timestamp offset. -func performAuthWithOffset(sc *SecureConn, cfg *HandshakeConfig, localX25519Pub, remoteX25519Pub []byte, isServer bool, offset time.Duration) error { - // Use shifted timestamp - ts := uint64(time.Now().Add(offset).Unix()) - - var authNonce [16]byte - if _, err := rand.Read(authNonce[:]); err != nil { - return fmt.Errorf("generate auth nonce: %w", err) - } - - sigMsg := BuildAuthSignMessage(cfg.NodeID, localX25519Pub, ts, authNonce) - signature := ed25519.Sign(cfg.Signer, sigMsg) - - frame := make([]byte, AuthFrameLen) - binary.BigEndian.PutUint32(frame[0:4], cfg.NodeID) - binary.BigEndian.PutUint64(frame[4:12], ts) - copy(frame[12:28], authNonce[:]) - copy(frame[28:92], signature) - - now := time.Now() // verifier uses current time - - if isServer { - if _, err := sc.Write(frame); err != nil { - return fmt.Errorf("send auth frame: %w", err) - } - peerFrame, err := readAuthFrame(sc) - if err != nil { - return fmt.Errorf("read peer auth frame: %w", err) - } - peerNodeID, err := VerifyAuthFrame(peerFrame, cfg.PeerPubKey, remoteX25519Pub, now) - if err != nil { - return err - } - sc.PeerNodeID = peerNodeID - } else { - peerFrame, err := readAuthFrame(sc) - if err != nil { - return fmt.Errorf("read peer auth frame: %w", err) - } - peerNodeID, err := VerifyAuthFrame(peerFrame, cfg.PeerPubKey, remoteX25519Pub, now) - if err != nil { - return err - } - sc.PeerNodeID = peerNodeID - if _, err := sc.Write(frame); err != nil { - return fmt.Errorf("send auth frame: %w", err) - } - } - - return nil -} - -// SecureConn wraps a net.Conn with AES-256-GCM encryption. -// After a successful ECDH handshake, all reads and writes are encrypted. -type SecureConn struct { - raw net.Conn - aead cipher.AEAD - rmu sync.Mutex - wmu sync.Mutex - nonce uint64 // monotonic counter for nonces - noncePrefix [4]byte // role-based prefix for nonce domain separation - readBuf []byte // leftover plaintext from a previous Read - PeerNodeID uint32 // authenticated peer node ID (0 if unauthenticated) -} - -// Handshake performs an ECDH key exchange over the connection. -// isServer determines which side reads first. -// An optional HandshakeConfig enables mutual Ed25519 authentication inside the -// encrypted channel after the ECDH exchange. Pass nil or omit for unauthenticated -// mode (backward compatible). -// A deadline is set to prevent indefinite blocking (M14 fix). -func Handshake(conn net.Conn, isServer bool, auth ...*HandshakeConfig) (*SecureConn, error) { - // Set handshake deadline to prevent indefinite blocking (M14 fix) - conn.SetDeadline(time.Now().Add(HandshakeTimeout)) - defer conn.SetDeadline(time.Time{}) // clear deadline after handshake - - // Generate ephemeral X25519 key pair - curve := ecdh.X25519() - privKey, err := curve.GenerateKey(rand.Reader) - if err != nil { - return nil, fmt.Errorf("generate key: %w", err) - } - localPub := privKey.PublicKey().Bytes() // 32 bytes - - var remotePub []byte - - if isServer { - // Server: read client's public key first, then send ours - remotePub, err = ReadExact(conn, 32) - if err != nil { - return nil, fmt.Errorf("read client key: %w", err) - } - if _, err := conn.Write(localPub); err != nil { - return nil, fmt.Errorf("send server key: %w", err) - } - } else { - // Client: send our public key first, then read server's - if _, err := conn.Write(localPub); err != nil { - return nil, fmt.Errorf("send client key: %w", err) - } - remotePub, err = ReadExact(conn, 32) - if err != nil { - return nil, fmt.Errorf("read server key: %w", err) - } - } - - // Parse remote public key - peerKey, err := curve.NewPublicKey(remotePub) - if err != nil { - return nil, fmt.Errorf("parse peer key: %w", err) - } - - // Compute shared secret - shared, err := privKey.ECDH(peerKey) - if err != nil { - return nil, fmt.Errorf("ecdh: %w", err) - } - - // HKDF-SHA256 key derivation (H1 fix) - mac := hmac.New(sha256.New, nil) - mac.Write(shared) - prk := mac.Sum(nil) - mac = hmac.New(sha256.New, prk) - mac.Write([]byte("pilot-secure-v1")) - mac.Write([]byte{0x01}) - key := mac.Sum(nil) - - // Create AES-GCM cipher - block, err := aes.NewCipher(key) - if err != nil { - return nil, fmt.Errorf("aes: %w", err) - } - aead, err := cipher.NewGCM(block) - if err != nil { - return nil, fmt.Errorf("gcm: %w", err) - } - - // Zero intermediate key material (H4 fix) - for i := range shared { - shared[i] = 0 - } - for i := range key { - key[i] = 0 - } - for i := range prk { - prk[i] = 0 - } - - sc := &SecureConn{raw: conn, aead: aead} - // Use role-based nonce prefix to prevent nonce collision (C3 fix). - // Both sides share the same AES-GCM key; using deterministic prefixes - // based on role ensures the nonce spaces never overlap. - if isServer { - sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x01} // server prefix - } else { - sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x02} // client prefix - } - - // Perform mutual Ed25519 authentication if config provided. - // This happens INSIDE the encrypted channel (after ECDH). - var cfg *HandshakeConfig - if len(auth) > 0 { - cfg = auth[0] - } - if cfg != nil && cfg.Signer != nil { - if err := performAuth(sc, cfg, localPub, remotePub, isServer); err != nil { - sc.Close() - return nil, fmt.Errorf("auth: %w", err) - } - } - - return sc, nil -} - -// HandshakeWithLookup is like Handshake with auth, but uses a lookup function -// to resolve the peer's Ed25519 public key by nodeID. This is used by servers -// that don't know the peer's identity until they read the auth frame. -func HandshakeWithLookup(conn net.Conn, isServer bool, cfg *HandshakeConfig, lookup PeerPubKeyLookup) (*SecureConn, error) { - // Set handshake deadline to prevent indefinite blocking (M14 fix) - conn.SetDeadline(time.Now().Add(HandshakeTimeout)) - defer conn.SetDeadline(time.Time{}) // clear deadline after handshake - - // Generate ephemeral X25519 key pair - curve := ecdh.X25519() - privKey, err := curve.GenerateKey(rand.Reader) - if err != nil { - return nil, fmt.Errorf("generate key: %w", err) - } - localPub := privKey.PublicKey().Bytes() // 32 bytes - - var remotePub []byte - - if isServer { - remotePub, err = ReadExact(conn, 32) - if err != nil { - return nil, fmt.Errorf("read client key: %w", err) - } - if _, err := conn.Write(localPub); err != nil { - return nil, fmt.Errorf("send server key: %w", err) - } - } else { - if _, err := conn.Write(localPub); err != nil { - return nil, fmt.Errorf("send client key: %w", err) - } - remotePub, err = ReadExact(conn, 32) - if err != nil { - return nil, fmt.Errorf("read server key: %w", err) - } - } - - peerKey, err := curve.NewPublicKey(remotePub) - if err != nil { - return nil, fmt.Errorf("parse peer key: %w", err) - } - - shared, err := privKey.ECDH(peerKey) - if err != nil { - return nil, fmt.Errorf("ecdh: %w", err) - } - - // HKDF-SHA256 key derivation (H1 fix) - mac := hmac.New(sha256.New, nil) - mac.Write(shared) - prk := mac.Sum(nil) - mac = hmac.New(sha256.New, prk) - mac.Write([]byte("pilot-secure-v1")) - mac.Write([]byte{0x01}) - key := mac.Sum(nil) - - block, err := aes.NewCipher(key) - if err != nil { - return nil, fmt.Errorf("aes: %w", err) - } - aead, err := cipher.NewGCM(block) - if err != nil { - return nil, fmt.Errorf("gcm: %w", err) - } - - // Zero intermediate key material (H4 fix) - for i := range shared { - shared[i] = 0 - } - for i := range key { - key[i] = 0 - } - for i := range prk { - prk[i] = 0 - } - - sc := &SecureConn{raw: conn, aead: aead} - if isServer { - sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x01} - } else { - sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x02} - } - - if cfg != nil && cfg.Signer != nil { - if err := performAuthWithLookup(sc, cfg, localPub, remotePub, isServer, lookup); err != nil { - sc.Close() - return nil, fmt.Errorf("auth: %w", err) - } - } - - return sc, nil -} - -// performAuthWithLookup is like performAuth but resolves the peer's Ed25519 -// pubkey via a lookup function after reading the peer's auth frame. -func performAuthWithLookup(sc *SecureConn, cfg *HandshakeConfig, localX25519Pub, remoteX25519Pub []byte, isServer bool, lookup PeerPubKeyLookup) error { - now := time.Now() - ts := uint64(now.Unix()) - - var authNonce [16]byte - if _, err := rand.Read(authNonce[:]); err != nil { - return fmt.Errorf("generate auth nonce: %w", err) - } - - sigMsg := BuildAuthSignMessage(cfg.NodeID, localX25519Pub, ts, authNonce) - signature := ed25519.Sign(cfg.Signer, sigMsg) - - frame := make([]byte, AuthFrameLen) - binary.BigEndian.PutUint32(frame[0:4], cfg.NodeID) - binary.BigEndian.PutUint64(frame[4:12], ts) - copy(frame[12:28], authNonce[:]) - copy(frame[28:92], signature) - - if isServer { - if _, err := sc.Write(frame); err != nil { - return fmt.Errorf("send auth frame: %w", err) - } - peerFrame, err := readAuthFrame(sc) - if err != nil { - return fmt.Errorf("read peer auth frame: %w", err) - } - // Extract peer's nodeID to look up their pubkey - peerNodeID := binary.BigEndian.Uint32(peerFrame[0:4]) - peerPubKey := lookup(peerNodeID) - if peerPubKey == nil { - return fmt.Errorf("unknown peer node %d: no public key found", peerNodeID) - } - peerNodeID, err = VerifyAuthFrame(peerFrame, peerPubKey, remoteX25519Pub, now) - if err != nil { - return err - } - sc.PeerNodeID = peerNodeID - } else { - peerFrame, err := readAuthFrame(sc) - if err != nil { - return fmt.Errorf("read peer auth frame: %w", err) - } - peerNodeID := binary.BigEndian.Uint32(peerFrame[0:4]) - peerPubKey := lookup(peerNodeID) - if peerPubKey == nil { - return fmt.Errorf("unknown peer node %d: no public key found", peerNodeID) - } - peerNodeID, err = VerifyAuthFrame(peerFrame, peerPubKey, remoteX25519Pub, now) - if err != nil { - return err - } - sc.PeerNodeID = peerNodeID - if _, err := sc.Write(frame); err != nil { - return fmt.Errorf("send auth frame: %w", err) - } - } - - return nil -} - -// performAuth executes the mutual Ed25519 authentication protocol inside the -// already-encrypted SecureConn. Both sides send an auth frame and verify the -// peer's frame. -// -// Auth frame format (92 bytes): -// -// [nodeID(4)][timestamp(8)][nonce(16)][ed25519_signature(64)] -// -// Signature covers: -// -// "pilot-secure-auth:" + nodeID(4) + X25519_ephemeral_pubkey(32) + timestamp(8) + nonce(16) -// -// Each side signs its OWN X25519 ephemeral pubkey (localPub). The verifier -// reconstructs the signed message using the peer's X25519 pubkey (remotePub, -// which was received during the ECDH exchange). This binds the ephemeral ECDH -// key to the long-term Ed25519 identity: a MITM cannot substitute their own -// X25519 key because they cannot produce a valid Ed25519 signature for it. -func performAuth(sc *SecureConn, cfg *HandshakeConfig, localX25519Pub, remoteX25519Pub []byte, isServer bool) error { - // Build our auth frame - now := time.Now() - ts := uint64(now.Unix()) - - var authNonce [16]byte - if _, err := rand.Read(authNonce[:]); err != nil { - return fmt.Errorf("generate auth nonce: %w", err) - } - - // Sign over our own X25519 pubkey to bind our identity to this ECDH session - sigMsg := BuildAuthSignMessage(cfg.NodeID, localX25519Pub, ts, authNonce) - signature := ed25519.Sign(cfg.Signer, sigMsg) - - // Build the wire frame: nodeID(4) + timestamp(8) + nonce(16) + signature(64) - frame := make([]byte, AuthFrameLen) - binary.BigEndian.PutUint32(frame[0:4], cfg.NodeID) - binary.BigEndian.PutUint64(frame[4:12], ts) - copy(frame[12:28], authNonce[:]) - copy(frame[28:92], signature) - - // Exchange auth frames. Server sends first, then reads. - // Client reads first, then sends. This prevents deadlock on net.Pipe. - if isServer { - if _, err := sc.Write(frame); err != nil { - return fmt.Errorf("send auth frame: %w", err) - } - peerFrame, err := readAuthFrame(sc) - if err != nil { - return fmt.Errorf("read peer auth frame: %w", err) - } - peerNodeID, err := VerifyAuthFrame(peerFrame, cfg.PeerPubKey, remoteX25519Pub, now) - if err != nil { - return err - } - sc.PeerNodeID = peerNodeID - } else { - peerFrame, err := readAuthFrame(sc) - if err != nil { - return fmt.Errorf("read peer auth frame: %w", err) - } - peerNodeID, err := VerifyAuthFrame(peerFrame, cfg.PeerPubKey, remoteX25519Pub, now) - if err != nil { - return err - } - sc.PeerNodeID = peerNodeID - if _, err := sc.Write(frame); err != nil { - return fmt.Errorf("send auth frame: %w", err) - } - } - - return nil -} - -// readAuthFrame reads exactly AuthFrameLen bytes from the SecureConn. -// The data is already decrypted by SecureConn.Read. -func readAuthFrame(sc *SecureConn) ([]byte, error) { - frame := make([]byte, AuthFrameLen) - n := 0 - for n < AuthFrameLen { - nn, err := sc.Read(frame[n:]) - if err != nil { - return nil, err - } - n += nn - } - return frame, nil -} - -// VerifyAuthFrame validates a peer's auth frame. The peer signed over their own -// X25519 ephemeral pubkey (peerX25519Pub), which we received during the ECDH -// exchange. We reconstruct the signed message and verify against the peer's -// Ed25519 public key from the registry. -func VerifyAuthFrame(frame []byte, peerEdPubKey ed25519.PublicKey, peerX25519Pub []byte, now time.Time) (uint32, error) { - if len(frame) != AuthFrameLen { - return 0, fmt.Errorf("auth frame wrong size: %d", len(frame)) - } - - peerNodeID := binary.BigEndian.Uint32(frame[0:4]) - peerTS := binary.BigEndian.Uint64(frame[4:12]) - var peerNonce [16]byte - copy(peerNonce[:], frame[12:28]) - peerSig := frame[28:92] - - // Check timestamp within skew window - peerTime := time.Unix(int64(peerTS), 0) - diff := now.Sub(peerTime) - if diff < 0 { - diff = -diff - } - if diff > authTimestampSkew { - return 0, fmt.Errorf("auth timestamp expired: skew %v exceeds %v", diff, authTimestampSkew) - } - - // Check nonce replay - if err := CheckAndRecordNonce(peerNonce); err != nil { - return 0, err - } - - // Reconstruct the message the peer signed: domain + nodeID + peerX25519Pub + timestamp + nonce - sigMsg := BuildAuthSignMessage(peerNodeID, peerX25519Pub, peerTS, peerNonce) - - // Verify Ed25519 signature - if !ed25519.Verify(peerEdPubKey, sigMsg, peerSig) { - return 0, fmt.Errorf("auth signature verification failed") - } - - return peerNodeID, nil -} - -// BuildAuthSignMessage constructs the message that is signed in the auth frame. -// Format: "pilot-secure-auth:" + nodeID(4) + X25519_ephemeral_pubkey(32) + timestamp(8) + nonce(16) -func BuildAuthSignMessage(nodeID uint32, x25519Pub []byte, timestamp uint64, nonce [16]byte) []byte { - domain := []byte("pilot-secure-auth:") - msg := make([]byte, len(domain)+4+32+8+16) - copy(msg, domain) - off := len(domain) - binary.BigEndian.PutUint32(msg[off:off+4], nodeID) - off += 4 - copy(msg[off:off+32], x25519Pub) - off += 32 - binary.BigEndian.PutUint64(msg[off:off+8], timestamp) - off += 8 - copy(msg[off:off+16], nonce[:]) - return msg -} - -// Read decrypts and reads data from the connection. -// Leftover plaintext from a previous decryption is returned first (H14 fix). -func (sc *SecureConn) Read(b []byte) (int, error) { - sc.rmu.Lock() - defer sc.rmu.Unlock() - - // Return buffered leftover data first (H14 fix — prevents silent truncation) - if len(sc.readBuf) > 0 { - n := copy(b, sc.readBuf) - sc.readBuf = sc.readBuf[n:] - return n, nil - } - - // Read 4-byte length prefix - lenBuf, err := ReadExact(sc.raw, 4) - if err != nil { - return 0, err - } - msgLen := binary.BigEndian.Uint32(lenBuf) - if msgLen < uint32(sc.aead.NonceSize()) { - return 0, fmt.Errorf("encrypted message too short") - } - // Reject unreasonably large messages to prevent OOM (M13 fix) - if msgLen > uint32(MaxEncryptedMessageLen) { - return 0, fmt.Errorf("encrypted message too large: %d bytes", msgLen) - } - - // Read nonce + ciphertext - ciphertext, err := ReadExact(sc.raw, int(msgLen)) - if err != nil { - return 0, err - } - - nonce := ciphertext[:sc.aead.NonceSize()] - encrypted := ciphertext[sc.aead.NonceSize():] - - // Decrypt with sender's nonce prefix as AAD (H3 fix) - peerAAD := nonce[:4] - plaintext, err := sc.aead.Open(nil, nonce, encrypted, peerAAD) - if err != nil { - return 0, fmt.Errorf("decrypt: %w", err) - } - - n := copy(b, plaintext) - // Buffer any remaining plaintext for subsequent Read calls (H14 fix) - if n < len(plaintext) { - sc.readBuf = make([]byte, len(plaintext)-n) - copy(sc.readBuf, plaintext[n:]) - } - return n, nil -} - -// Write encrypts and writes data to the connection. -func (sc *SecureConn) Write(b []byte) (int, error) { - sc.wmu.Lock() - defer sc.wmu.Unlock() - - // Generate nonce from prefix + counter - nonce := make([]byte, sc.aead.NonceSize()) - copy(nonce[0:4], sc.noncePrefix[:]) - sc.nonce++ - binary.BigEndian.PutUint64(nonce[sc.aead.NonceSize()-8:], sc.nonce) - - // Encrypt with nonce prefix as AAD (H3 fix) - ciphertext := sc.aead.Seal(nil, nonce, b, sc.noncePrefix[:]) - - // Write: [4-byte length][nonce][ciphertext] - total := len(nonce) + len(ciphertext) - msg := make([]byte, 4+total) - binary.BigEndian.PutUint32(msg[0:4], uint32(total)) - copy(msg[4:], nonce) - copy(msg[4+len(nonce):], ciphertext) - - if _, err := sc.raw.Write(msg); err != nil { - return 0, err - } - return len(b), nil -} - -func (sc *SecureConn) Close() error { return sc.raw.Close() } -func (sc *SecureConn) LocalAddr() net.Addr { return sc.raw.LocalAddr() } -func (sc *SecureConn) RemoteAddr() net.Addr { return sc.raw.RemoteAddr() } -func (sc *SecureConn) SetDeadline(t time.Time) error { return sc.raw.SetDeadline(t) } -func (sc *SecureConn) SetReadDeadline(t time.Time) error { return sc.raw.SetReadDeadline(t) } -func (sc *SecureConn) SetWriteDeadline(t time.Time) error { return sc.raw.SetWriteDeadline(t) } - -func ReadExact(r io.Reader, n int) ([]byte, error) { - buf := make([]byte, n) - _, err := io.ReadFull(r, buf) - return buf, err -} diff --git a/pkg/secure/server.go b/pkg/secure/server.go deleted file mode 100644 index 8fec7b9f..00000000 --- a/pkg/secure/server.go +++ /dev/null @@ -1,102 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package secure - -import ( - "crypto/ed25519" - "log/slog" - "net" - - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" -) - -// Handler is called for each new secure connection. -type Handler func(conn net.Conn) - -// Server listens on port 443 and upgrades connections to encrypted channels. -type Server struct { - driver *driver.Driver - handler Handler - authNodeID uint32 - authSigner ed25519.PrivateKey - peerLookup PeerPubKeyLookup -} - -// NewServer creates a secure channel server (unauthenticated ECDH). -func NewServer(d *driver.Driver, handler Handler) *Server { - return &Server{driver: d, handler: handler} -} - -// NewAuthServer creates a secure channel server with Ed25519 authentication. -// The server authenticates itself and verifies connecting clients using the -// lookup function to obtain each client's expected Ed25519 public key. -func NewAuthServer(d *driver.Driver, handler Handler, nodeID uint32, signer ed25519.PrivateKey, lookup PeerPubKeyLookup) *Server { - return &Server{ - driver: d, - handler: handler, - authNodeID: nodeID, - authSigner: signer, - peerLookup: lookup, - } -} - -// Driver returns the underlying packet driver. Exposed for tests. -func (s *Server) Driver() *driver.Driver { return s.driver } - -// Handler returns the per-connection handler callback. Exposed for tests. -func (s *Server) Handler() Handler { return s.handler } - -// AuthNodeID returns the authenticated node id (zero when unauth). -// Exposed for tests. -func (s *Server) AuthNodeID() uint32 { return s.authNodeID } - -// AuthSigner returns the server's Ed25519 signing key (nil when unauth). -// Exposed for tests. -func (s *Server) AuthSigner() ed25519.PrivateKey { return s.authSigner } - -// PeerLookup returns the per-peer pubkey lookup (nil when unauth). -// Exposed for tests. -func (s *Server) PeerLookup() PeerPubKeyLookup { return s.peerLookup } - -// ListenAndServe binds port 443 and starts accepting secure connections. -func (s *Server) ListenAndServe() error { - ln, err := s.driver.Listen(protocol.PortSecure) - if err != nil { - return err - } - - slog.Info("secure server listening", "port", protocol.PortSecure) - - for { - conn, err := ln.Accept() - if err != nil { - return err - } - go s.handleConn(conn) - } -} - -func (s *Server) handleConn(conn net.Conn) { - var sc *SecureConn - var err error - - if s.authSigner != nil { - // Use lookup-based handshake: the peer's nodeID is extracted from - // their auth frame, then the lookup function resolves their Ed25519 - // pubkey for signature verification. - sc, err = HandshakeWithLookup(conn, true, &HandshakeConfig{ - NodeID: s.authNodeID, - Signer: s.authSigner, - }, s.peerLookup) - } else { - sc, err = Handshake(conn, true) - } - - if err != nil { - slog.Warn("secure handshake failed", "err", err) - conn.Close() - return - } - s.handler(sc) -} diff --git a/pkg/secure/zz_extra_coverage_test.go b/pkg/secure/zz_extra_coverage_test.go deleted file mode 100644 index ec2eb607..00000000 --- a/pkg/secure/zz_extra_coverage_test.go +++ /dev/null @@ -1,1496 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -// Package secure_test — extra coverage tests targeting error paths in -// Handshake / HandshakeWithLookup / HandshakeWithTimestampOffset, the -// SecureConn Read/Write framing edges, performAuth* error branches, -// and the Dial/ListenAndServe surfaces that require a minimal IPC -// daemon mock to reach. -// -// Goal: bring pkg/secure from ~80% to ≥95% statement coverage. These -// tests use only public APIs (or test-only exported helpers already in -// the package). -package secure_test - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/ed25519" - "crypto/rand" - "encoding/binary" - "encoding/hex" - "errors" - "io" - "net" - "os" - "path/filepath" - "strings" - "sync" - "testing" - "time" - - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" - "github.com/TeoSlayer/pilotprotocol/pkg/secure" -) - -// --------------------------------------------------------------------------- -// Fake daemon — minimal IPC peer that drives the public driver.Driver API -// so we can reach secure.Dial, Server.ListenAndServe, and Server.handleConn -// through their real call paths. -// -// Wire format: each IPC frame is a length-prefixed buffer (ipcutil.Read / -// Write), and the first byte is the command opcode (see pkg/driver/ipc.go). -// We hard-code the opcodes here because they are private to driver, but -// the wire format is stable. -// --------------------------------------------------------------------------- - -const ( - cmdBind byte = 0x01 - cmdBindOK byte = 0x02 - cmdDial byte = 0x03 - cmdDialOK byte = 0x04 - cmdAccept byte = 0x05 - cmdSend byte = 0x06 - cmdRecv byte = 0x07 - cmdClose byte = 0x08 - cmdCloseOK byte = 0x09 -) - -// shortSocketPath returns a /tmp path short enough for macOS unix socket -// length limit (~104 chars). -func shortSocketPath(t *testing.T) string { - t.Helper() - var b [6]byte - if _, err := rand.Read(b[:]); err != nil { - t.Fatal(err) - } - p := filepath.Join("/tmp", "ss-"+hex.EncodeToString(b[:])+".sock") - t.Cleanup(func() { _ = os.Remove(p) }) - return p -} - -// fakeDaemon implements just enough of the daemon IPC contract for one -// connection at a time. It runs handlers per opcode; the test sets these -// up before connecting via driver.Connect. -type fakeDaemon struct { - t *testing.T - ln net.Listener - path string - mu sync.Mutex - conn net.Conn - connSet chan struct{} - handlers map[byte]func(frame []byte) [][]byte - // per-connID send forwarders — used by Dial/Accept happy-paths - bridges map[uint32]chan<- []byte -} - -func newFakeDaemon(t *testing.T) *fakeDaemon { - t.Helper() - p := shortSocketPath(t) - ln, err := net.Listen("unix", p) - if err != nil { - t.Fatalf("listen: %v", err) - } - d := &fakeDaemon{ - t: t, - ln: ln, - path: p, - connSet: make(chan struct{}), - handlers: make(map[byte]func(frame []byte) [][]byte), - bridges: make(map[uint32]chan<- []byte), - } - go d.loop() - return d -} - -func (d *fakeDaemon) loop() { - conn, err := d.ln.Accept() - if err != nil { - return - } - d.mu.Lock() - d.conn = conn - d.mu.Unlock() - close(d.connSet) - for { - frame, err := ipcutil.Read(conn) - if err != nil { - return - } - if len(frame) == 0 { - continue - } - cmd := frame[0] - d.mu.Lock() - h := d.handlers[cmd] - var bridgeCh chan<- []byte - var payload []byte - if cmd == cmdSend && len(frame) >= 5 { - id := binary.BigEndian.Uint32(frame[1:5]) - bridgeCh = d.bridges[id] - payload = append([]byte(nil), frame[5:]...) - } - d.mu.Unlock() - if bridgeCh != nil { - bridgeCh <- payload - continue - } - if h == nil { - continue - } - for _, r := range h(frame) { - _ = ipcutil.Write(conn, r) - } - } -} - -func (d *fakeDaemon) push(frame []byte) { - d.mu.Lock() - c := d.conn - d.mu.Unlock() - if c == nil { - <-d.connSet - d.mu.Lock() - c = d.conn - d.mu.Unlock() - } - _ = ipcutil.Write(c, frame) -} - -func (d *fakeDaemon) onCmd(cmd byte, h func(frame []byte) [][]byte) { - d.mu.Lock() - defer d.mu.Unlock() - d.handlers[cmd] = h -} - -func (d *fakeDaemon) registerBridge(connID uint32, toDriver chan<- []byte) { - d.mu.Lock() - defer d.mu.Unlock() - d.bridges[connID] = toDriver -} - -func (d *fakeDaemon) close() { - _ = d.ln.Close() - select { - case <-d.connSet: - case <-time.After(100 * time.Millisecond): - } - d.mu.Lock() - c := d.conn - d.mu.Unlock() - if c != nil { - _ = c.Close() - } -} - -// pumpRecv emits cmdRecv frames carrying `data` for connID to the driver. -func (d *fakeDaemon) pumpRecv(connID uint32, data []byte) { - frame := make([]byte, 1+4+len(data)) - frame[0] = cmdRecv - binary.BigEndian.PutUint32(frame[1:5], connID) - copy(frame[5:], data) - d.push(frame) -} - -// pumpAccept emits a cmdAccept frame for port `port` with the given conn. -func (d *fakeDaemon) pumpAccept(port uint16, connID uint32) { - addrSize := protocol.AddrSize - frame := make([]byte, 1+2+4+addrSize+2) - frame[0] = cmdAccept - binary.BigEndian.PutUint16(frame[1:3], port) - binary.BigEndian.PutUint32(frame[3:7], connID) - binary.BigEndian.PutUint16(frame[3+4+addrSize:], 0) - d.push(frame) -} - -// bridgeDriverToPipe wires a fakeDaemon conn-side to one half of a net.Pipe. -// Bytes the driver writes via cmdSend(connID, ...) are forwarded to the pipe -// (writeable end) `peer`. Bytes that arrive on `peer` are pushed back to the -// driver as cmdRecv(connID, ...). This lets us run secure.Handshake on both -// the driver-Conn side and a raw secure.Handshake on the peer side against -// each other, exercising secure.Dial and ListenAndServe happy paths. -func (d *fakeDaemon) bridgeDriverToPipe(connID uint32, peer net.Conn) { - toPeer := make(chan []byte, 32) - d.registerBridge(connID, toPeer) - go func() { - for data := range toPeer { - if _, err := peer.Write(data); err != nil { - return - } - } - }() - go func() { - buf := make([]byte, 4096) - for { - n, err := peer.Read(buf) - if n > 0 { - d.pumpRecv(connID, buf[:n]) - } - if err != nil { - return - } - } - }() -} - -// failAfterNWrites wraps a net.Conn and returns errClosedPipe after the Nth -// write. Used to surgically trigger sc.Write failures at specific points in -// the authenticated handshake protocol. -type failAfterNWrites struct { - net.Conn - mu sync.Mutex - writes int - failAfter int -} - -func (f *failAfterNWrites) Write(b []byte) (int, error) { - f.mu.Lock() - f.writes++ - n := f.writes - f.mu.Unlock() - if n > f.failAfter { - return 0, io.ErrClosedPipe - } - return f.Conn.Write(b) -} - -// --------------------------------------------------------------------------- -// secure.Dial -// --------------------------------------------------------------------------- - -func TestDialDialAddrErrorPropagates(t *testing.T) { - d := newFakeDaemon(t) - drv, err := driver.Connect(d.path) - if err != nil { - t.Fatalf("connect: %v", err) - } - go func() { - time.Sleep(50 * time.Millisecond) - d.close() - drv.Close() - }() - _, err = secure.Dial(drv, protocol.Addr{Network: 1, Node: 1}) - if err == nil { - t.Fatal("expected dial error after daemon close") - } -} - -func TestDialHandshakeErrorClosesConn(t *testing.T) { - d := newFakeDaemon(t) - defer d.close() - - const connID uint32 = 0xCAFEBABE - d.onCmd(cmdDial, func(frame []byte) [][]byte { - resp := make([]byte, 1+4) - resp[0] = cmdDialOK - binary.BigEndian.PutUint32(resp[1:5], connID) - return [][]byte{resp} - }) - d.onCmd(cmdClose, func(frame []byte) [][]byte { - resp := make([]byte, 5) - resp[0] = cmdCloseOK - binary.BigEndian.PutUint32(resp[1:5], connID) - return [][]byte{resp} - }) - - drv, err := driver.Connect(d.path) - if err != nil { - t.Fatal(err) - } - defer drv.Close() - - go func() { - time.Sleep(20 * time.Millisecond) - d.pumpRecv(connID, []byte{0x01, 0x02, 0x03}) - time.Sleep(20 * time.Millisecond) - d.close() - }() - - _, err = secure.Dial(drv, protocol.Addr{Network: 1, Node: 1}) - if err == nil { - t.Fatal("expected handshake error after daemon close") - } -} - -func TestDialHappyPathWithBridge(t *testing.T) { - d := newFakeDaemon(t) - defer d.close() - - const connID uint32 = 0xAA00AA00 - d.onCmd(cmdDial, func(frame []byte) [][]byte { - resp := make([]byte, 5) - resp[0] = cmdDialOK - binary.BigEndian.PutUint32(resp[1:5], connID) - return [][]byte{resp} - }) - d.onCmd(cmdClose, func(frame []byte) [][]byte { - resp := make([]byte, 5) - resp[0] = cmdCloseOK - binary.BigEndian.PutUint32(resp[1:5], connID) - return [][]byte{resp} - }) - - drv, err := driver.Connect(d.path) - if err != nil { - t.Fatal(err) - } - defer drv.Close() - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - d.bridgeDriverToPipe(connID, pa) - - peerDone := make(chan *secure.SecureConn, 1) - peerErr := make(chan error, 1) - go func() { - sc, err := secure.Handshake(pb, true) - if err != nil { - peerErr <- err - return - } - peerDone <- sc - }() - - sc, err := secure.Dial(drv, protocol.Addr{Network: 1, Node: 1}) - if err != nil { - t.Fatalf("Dial: %v", err) - } - select { - case <-peerDone: - case err := <-peerErr: - t.Fatalf("peer handshake: %v", err) - case <-time.After(3 * time.Second): - t.Fatal("peer handshake timed out") - } - if sc == nil { - t.Fatal("Dial returned nil conn") - } - sc.Close() -} - -// --------------------------------------------------------------------------- -// Server: ListenAndServe + handleConn paths -// --------------------------------------------------------------------------- - -func TestListenAndServeBindError(t *testing.T) { - d := newFakeDaemon(t) - defer d.close() - drv, err := driver.Connect(d.path) - if err != nil { - t.Fatal(err) - } - defer drv.Close() - - s := secure.NewServer(drv, func(_ net.Conn) {}) - errCh := make(chan error, 1) - go func() { errCh <- s.ListenAndServe() }() - time.Sleep(50 * time.Millisecond) - d.close() - drv.Close() - select { - case err := <-errCh: - if err == nil { - t.Fatal("expected error from ListenAndServe") - } - case <-time.After(5 * time.Second): - t.Fatal("ListenAndServe never returned") - } -} - -func TestListenAndServeAcceptErrorReturns(t *testing.T) { - d := newFakeDaemon(t) - defer d.close() - - d.onCmd(cmdBind, func(frame []byte) [][]byte { - resp := make([]byte, 3) - resp[0] = cmdBindOK - binary.BigEndian.PutUint16(resp[1:3], protocol.PortSecure) - return [][]byte{resp} - }) - - drv, err := driver.Connect(d.path) - if err != nil { - t.Fatal(err) - } - defer drv.Close() - - s := secure.NewServer(drv, func(_ net.Conn) {}) - errCh := make(chan error, 1) - go func() { errCh <- s.ListenAndServe() }() - time.Sleep(50 * time.Millisecond) - d.close() - drv.Close() - select { - case <-errCh: - case <-time.After(5 * time.Second): - t.Fatal("ListenAndServe never returned after daemon close") - } -} - -func TestListenAndServeHandshakeFailsUnauthBranch(t *testing.T) { - d := newFakeDaemon(t) - defer d.close() - - const connID uint32 = 0x77777777 - d.onCmd(cmdBind, func(frame []byte) [][]byte { - resp := make([]byte, 3) - resp[0] = cmdBindOK - binary.BigEndian.PutUint16(resp[1:3], protocol.PortSecure) - return [][]byte{resp} - }) - d.onCmd(cmdClose, func(frame []byte) [][]byte { - resp := make([]byte, 5) - resp[0] = cmdCloseOK - binary.BigEndian.PutUint32(resp[1:5], connID) - return [][]byte{resp} - }) - - drv, err := driver.Connect(d.path) - if err != nil { - t.Fatal(err) - } - defer drv.Close() - - handlerCalled := make(chan struct{}, 1) - s := secure.NewServer(drv, func(_ net.Conn) { handlerCalled <- struct{}{} }) - - go func() { _ = s.ListenAndServe() }() - time.Sleep(80 * time.Millisecond) - d.pumpAccept(protocol.PortSecure, connID) - time.Sleep(40 * time.Millisecond) - d.pumpRecv(connID, []byte{0x01}) - time.Sleep(40 * time.Millisecond) - d.close() - drv.Close() - - select { - case <-handlerCalled: - t.Fatal("handler should not be called on handshake failure") - case <-time.After(200 * time.Millisecond): - } -} - -func TestListenAndServeAuthBranchHandshakeFails(t *testing.T) { - d := newFakeDaemon(t) - defer d.close() - - const connID uint32 = 0x88888888 - d.onCmd(cmdBind, func(frame []byte) [][]byte { - resp := make([]byte, 3) - resp[0] = cmdBindOK - binary.BigEndian.PutUint16(resp[1:3], protocol.PortSecure) - return [][]byte{resp} - }) - d.onCmd(cmdClose, func(frame []byte) [][]byte { - resp := make([]byte, 5) - resp[0] = cmdCloseOK - binary.BigEndian.PutUint32(resp[1:5], connID) - return [][]byte{resp} - }) - - drv, err := driver.Connect(d.path) - if err != nil { - t.Fatal(err) - } - defer drv.Close() - - _, signer, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } - lookup := func(_ uint32) ed25519.PublicKey { return nil } - - handlerCalled := make(chan struct{}, 1) - s := secure.NewAuthServer(drv, func(_ net.Conn) { handlerCalled <- struct{}{} }, 42, signer, lookup) - - go func() { _ = s.ListenAndServe() }() - time.Sleep(80 * time.Millisecond) - d.pumpAccept(protocol.PortSecure, connID) - time.Sleep(40 * time.Millisecond) - d.pumpRecv(connID, []byte{0x01}) - time.Sleep(40 * time.Millisecond) - d.close() - drv.Close() - - select { - case <-handlerCalled: - t.Fatal("handler must not be called on failed handshake") - case <-time.After(200 * time.Millisecond): - } -} - -func TestListenAndServeHandlerInvokedOnSuccess(t *testing.T) { - d := newFakeDaemon(t) - defer d.close() - - const connID uint32 = 0xBB00BB00 - d.onCmd(cmdBind, func(frame []byte) [][]byte { - resp := make([]byte, 3) - resp[0] = cmdBindOK - binary.BigEndian.PutUint16(resp[1:3], protocol.PortSecure) - return [][]byte{resp} - }) - d.onCmd(cmdClose, func(frame []byte) [][]byte { - resp := make([]byte, 5) - resp[0] = cmdCloseOK - binary.BigEndian.PutUint32(resp[1:5], connID) - return [][]byte{resp} - }) - - drv, err := driver.Connect(d.path) - if err != nil { - t.Fatal(err) - } - defer drv.Close() - - handlerCh := make(chan struct{}, 1) - s := secure.NewServer(drv, func(_ net.Conn) { handlerCh <- struct{}{} }) - - go func() { _ = s.ListenAndServe() }() - time.Sleep(80 * time.Millisecond) - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - d.bridgeDriverToPipe(connID, pa) - - go func() { _, _ = secure.Handshake(pb, false) }() - d.pumpAccept(protocol.PortSecure, connID) - - select { - case <-handlerCh: - case <-time.After(5 * time.Second): - t.Fatal("handler never invoked") - } -} - -// --------------------------------------------------------------------------- -// Handshake error branches — server/client sides closing mid-flight -// --------------------------------------------------------------------------- - -func TestHandshakeServerReadClientKeyFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - defer a.Close() - defer b.Close() - b.Close() - _, err := secure.Handshake(a, true) - if err == nil { - t.Fatal("expected read-client-key error") - } - if !strings.Contains(err.Error(), "read client key") { - t.Errorf("err = %v", err) - } -} - -func TestHandshakeClientSendKeyFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - a.Close() - b.Close() - _, err := secure.Handshake(a, false) - if err == nil { - t.Fatal("expected send-client-key error") - } -} - -func TestHandshakeClientReadServerKeyFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - defer a.Close() - defer b.Close() - go func() { - buf := make([]byte, 32) - _, _ = io.ReadFull(b, buf) - b.Close() - }() - _, err := secure.Handshake(a, false) - if err == nil { - t.Fatal("expected read-server-key error") - } - if !strings.Contains(err.Error(), "read server key") { - t.Errorf("err = %v", err) - } -} - -func TestHandshakeServerSendKeyFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - defer a.Close() - defer b.Close() - go func() { - junk := make([]byte, 32) - _, _ = rand.Read(junk) - _, _ = b.Write(junk) - b.Close() - }() - _, err := secure.Handshake(a, true) - if err == nil { - t.Fatal("expected send-server-key error") - } -} - -func TestHandshakeWithLookupServerReadKeyFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - a.Close() - b.Close() - _, err := secure.HandshakeWithLookup(a, true, nil, nil) - if err == nil { - t.Fatal("expected error") - } -} - -func TestHandshakeWithLookupClientWriteFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - a.Close() - b.Close() - _, err := secure.HandshakeWithLookup(a, false, nil, nil) - if err == nil { - t.Fatal("expected error") - } -} - -func TestHandshakeWithLookupClientReadFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - defer a.Close() - defer b.Close() - go func() { - buf := make([]byte, 32) - _, _ = io.ReadFull(b, buf) - b.Close() - }() - _, err := secure.HandshakeWithLookup(a, false, nil, nil) - if err == nil { - t.Fatal("expected error") - } -} - -func TestHandshakeWithLookupServerWriteKeyFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - defer a.Close() - defer b.Close() - go func() { - junk := make([]byte, 32) - _, _ = rand.Read(junk) - _, _ = b.Write(junk) - b.Close() - }() - _, err := secure.HandshakeWithLookup(a, true, nil, nil) - if err == nil { - t.Fatal("expected error") - } -} - -// --------------------------------------------------------------------------- -// HandshakeWithTimestampOffset error branches -// --------------------------------------------------------------------------- - -func TestHandshakeWithTimestampOffsetServerReadKeyFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - a.Close() - b.Close() - _, err := secure.HandshakeWithTimestampOffset(a, true, nil, 0) - if err == nil { - t.Fatal("expected error") - } -} - -func TestHandshakeWithTimestampOffsetClientWriteFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - a.Close() - b.Close() - _, err := secure.HandshakeWithTimestampOffset(a, false, nil, 0) - if err == nil { - t.Fatal("expected error") - } -} - -func TestHandshakeWithTimestampOffsetServerWriteFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - defer a.Close() - defer b.Close() - go func() { - junk := make([]byte, 32) - _, _ = rand.Read(junk) - _, _ = b.Write(junk) - b.Close() - }() - _, err := secure.HandshakeWithTimestampOffset(a, true, nil, 0) - if err == nil { - t.Fatal("expected error") - } -} - -func TestHandshakeWithTimestampOffsetClientReadFails(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - defer a.Close() - defer b.Close() - go func() { - buf := make([]byte, 32) - _, _ = io.ReadFull(b, buf) - b.Close() - }() - _, err := secure.HandshakeWithTimestampOffset(a, false, nil, 0) - if err == nil { - t.Fatal("expected error") - } -} - -func TestHandshakeWithTimestampOffsetMutual(t *testing.T) { - secure.ResetReplayCache() - _, srvPriv, _ := ed25519.GenerateKey(rand.Reader) - _, cliPriv, _ := ed25519.GenerateKey(rand.Reader) - srvPub := srvPriv.Public().(ed25519.PublicKey) - cliPub := cliPriv.Public().(ed25519.PublicKey) - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - type res struct { - sc *secure.SecureConn - err error - } - chA := make(chan res, 1) - chB := make(chan res, 1) - go func() { sc, err := secure.HandshakeWithTimestampOffset(pa, true, cfgServer, 0); chA <- res{sc, err} }() - go func() { sc, err := secure.HandshakeWithTimestampOffset(pb, false, cfgClient, 0); chB <- res{sc, err} }() - rA := <-chA - rB := <-chB - if rA.err != nil { - t.Fatalf("server: %v", rA.err) - } - if rB.err != nil { - t.Fatalf("client: %v", rB.err) - } - if rA.sc.PeerNodeID != 2 || rB.sc.PeerNodeID != 1 { - t.Errorf("peer IDs wrong: %d, %d", rA.sc.PeerNodeID, rB.sc.PeerNodeID) - } -} - -func TestHandshakeWithTimestampOffsetUnauthSkipsAuth(t *testing.T) { - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - type res struct { - sc *secure.SecureConn - err error - } - chA := make(chan res, 1) - chB := make(chan res, 1) - go func() { sc, err := secure.HandshakeWithTimestampOffset(pa, true, nil, 0); chA <- res{sc, err} }() - go func() { sc, err := secure.HandshakeWithTimestampOffset(pb, false, nil, 0); chB <- res{sc, err} }() - rA := <-chA - rB := <-chB - if rA.err != nil || rB.err != nil { - t.Fatalf("handshake errors: %v / %v", rA.err, rB.err) - } - rA.sc.Close() - rB.sc.Close() -} - -// --------------------------------------------------------------------------- -// Handshake — ECDH low-order pubkey hits the "ecdh:" branch -// --------------------------------------------------------------------------- - -func TestHandshakeECDHFailsOnLowOrderPoint(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - defer a.Close() - defer b.Close() - go func() { - zeros := make([]byte, 32) - _, _ = b.Write(zeros) - buf := make([]byte, 32) - _, _ = io.ReadFull(b, buf) - }() - _, err := secure.Handshake(a, true) - if err == nil { - t.Fatal("expected ecdh low-order error") - } - if !strings.Contains(err.Error(), "ecdh") { - t.Errorf("err = %v", err) - } -} - -func TestHandshakeWithLookupECDHFailsOnLowOrderPoint(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - defer a.Close() - defer b.Close() - go func() { - zeros := make([]byte, 32) - _, _ = b.Write(zeros) - buf := make([]byte, 32) - _, _ = io.ReadFull(b, buf) - }() - _, err := secure.HandshakeWithLookup(a, true, nil, nil) - if err == nil { - t.Fatal("expected ecdh error") - } - if !strings.Contains(err.Error(), "ecdh") { - t.Errorf("err = %v", err) - } -} - -func TestHandshakeWithTimestampOffsetECDHFailsOnLowOrderPoint(t *testing.T) { - t.Parallel() - a, b := net.Pipe() - defer a.Close() - defer b.Close() - go func() { - zeros := make([]byte, 32) - _, _ = b.Write(zeros) - buf := make([]byte, 32) - _, _ = io.ReadFull(b, buf) - }() - _, err := secure.HandshakeWithTimestampOffset(a, true, nil, 0) - if err == nil { - t.Fatal("expected ecdh error") - } - if !strings.Contains(err.Error(), "ecdh") { - t.Errorf("err = %v", err) - } -} - -// --------------------------------------------------------------------------- -// SecureConn.Read framing error paths -// --------------------------------------------------------------------------- - -func TestSecureConnReadRejectsMessageTooShort(t *testing.T) { - t.Parallel() - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - type res struct { - sc *secure.SecureConn - err error - } - chA := make(chan res, 1) - chB := make(chan res, 1) - go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() - go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() - rA := <-chA - rB := <-chB - if rA.err != nil || rB.err != nil { - t.Fatalf("handshake: %v %v", rA.err, rB.err) - } - defer rA.sc.Close() - defer rB.sc.Close() - - go func() { - var hdr [4]byte - binary.BigEndian.PutUint32(hdr[:], 4) // too short - _, _ = pb.Write(hdr[:]) - _, _ = pb.Write([]byte{0x00, 0x00, 0x00, 0x00}) - }() - - buf := make([]byte, 16) - _, err := rA.sc.Read(buf) - if err == nil { - t.Fatal("expected error on too-short message") - } - if !strings.Contains(err.Error(), "too short") { - t.Errorf("err = %v", err) - } -} - -func TestSecureConnReadRejectsMessageTooLarge(t *testing.T) { - t.Parallel() - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - type res struct { - sc *secure.SecureConn - err error - } - chA := make(chan res, 1) - chB := make(chan res, 1) - go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() - go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() - rA := <-chA - rB := <-chB - if rA.err != nil || rB.err != nil { - t.Fatalf("handshake: %v %v", rA.err, rB.err) - } - defer rA.sc.Close() - defer rB.sc.Close() - - go func() { - var hdr [4]byte - binary.BigEndian.PutUint32(hdr[:], secure.MaxEncryptedMessageLen+1) - _, _ = pb.Write(hdr[:]) - }() - - buf := make([]byte, 16) - _, err := rA.sc.Read(buf) - if err == nil { - t.Fatal("expected too-large error") - } - if !strings.Contains(err.Error(), "too large") { - t.Errorf("err = %v", err) - } -} - -func TestSecureConnReadDecryptFails(t *testing.T) { - t.Parallel() - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - type res struct { - sc *secure.SecureConn - err error - } - chA := make(chan res, 1) - chB := make(chan res, 1) - go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() - go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() - rA := <-chA - rB := <-chB - if rA.err != nil || rB.err != nil { - t.Fatalf("handshake: %v %v", rA.err, rB.err) - } - defer rA.sc.Close() - defer rB.sc.Close() - - go func() { - const total = 32 - var hdr [4]byte - binary.BigEndian.PutUint32(hdr[:], total) - _, _ = pb.Write(hdr[:]) - payload := make([]byte, total) - _, _ = rand.Read(payload) - _, _ = pb.Write(payload) - }() - - buf := make([]byte, 16) - _, err := rA.sc.Read(buf) - if err == nil { - t.Fatal("expected decrypt error") - } - if !strings.Contains(err.Error(), "decrypt") { - t.Errorf("err = %v", err) - } -} - -func TestSecureConnReadLengthPrefixError(t *testing.T) { - t.Parallel() - pa, pb := net.Pipe() - type res struct { - sc *secure.SecureConn - err error - } - chA := make(chan res, 1) - chB := make(chan res, 1) - go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() - go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() - rA := <-chA - rB := <-chB - if rA.err != nil || rB.err != nil { - t.Fatalf("handshake: %v %v", rA.err, rB.err) - } - pb.Close() - rB.sc.Close() - defer pa.Close() - defer rA.sc.Close() - _, err := rA.sc.Read(make([]byte, 16)) - if err == nil { - t.Fatal("expected error reading length") - } -} - -func TestSecureConnReadCiphertextReadError(t *testing.T) { - t.Parallel() - pa, pb := net.Pipe() - type res struct { - sc *secure.SecureConn - err error - } - chA := make(chan res, 1) - chB := make(chan res, 1) - go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() - go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() - rA := <-chA - rB := <-chB - if rA.err != nil || rB.err != nil { - t.Fatalf("handshake: %v %v", rA.err, rB.err) - } - - go func() { - var hdr [4]byte - binary.BigEndian.PutUint32(hdr[:], 32) - _, _ = pb.Write(hdr[:]) - pb.Close() - }() - defer pa.Close() - defer rA.sc.Close() - _, err := rA.sc.Read(make([]byte, 16)) - if err == nil { - t.Fatal("expected error reading ciphertext body") - } - rB.sc.Close() -} - -// --------------------------------------------------------------------------- -// SecureConn.Write error paths -// --------------------------------------------------------------------------- - -func TestSecureConnWriteErrorOnClosedConn(t *testing.T) { - t.Parallel() - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - type res struct { - sc *secure.SecureConn - err error - } - chA := make(chan res, 1) - chB := make(chan res, 1) - go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() - go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() - rA := <-chA - rB := <-chB - if rA.err != nil || rB.err != nil { - t.Fatalf("handshake: %v %v", rA.err, rB.err) - } - pa.Close() - _, err := rA.sc.Write([]byte("oops")) - if err == nil { - t.Fatal("expected write error after raw conn close") - } - rB.sc.Close() -} - -// --------------------------------------------------------------------------- -// performAuth* error paths — VerifyAuthFrame failures on each side -// --------------------------------------------------------------------------- - -func TestAuthServerVerifyFailsClientPasses(t *testing.T) { - secure.ResetReplayCache() - srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) - _, cliPriv, _ := ed25519.GenerateKey(rand.Reader) - wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: wrongPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { _, err := secure.Handshake(pa, true, cfgServer); errA <- err }() - go func() { _, err := secure.Handshake(pb, false, cfgClient); errB <- err }() - <-errA - <-errB -} - -func TestAuthClientVerifyFailsServerPasses(t *testing.T) { - secure.ResetReplayCache() - _, srvPriv, _ := ed25519.GenerateKey(rand.Reader) - cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) - wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: wrongPub} - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { _, err := secure.Handshake(pa, true, cfgServer); errA <- err }() - go func() { _, err := secure.Handshake(pb, false, cfgClient); errB <- err }() - <-errA - <-errB -} - -func TestAuthOffsetClientVerifyFails(t *testing.T) { - secure.ResetReplayCache() - _, srvPriv, _ := ed25519.GenerateKey(rand.Reader) - cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) - wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: wrongPub} - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { - _, err := secure.HandshakeWithTimestampOffset(pa, true, cfgServer, 0) - errA <- err - }() - go func() { - _, err := secure.HandshakeWithTimestampOffset(pb, false, cfgClient, 0) - errB <- err - }() - <-errA - <-errB -} - -func TestAuthOffsetServerVerifyFails(t *testing.T) { - secure.ResetReplayCache() - srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) - _, cliPriv, _ := ed25519.GenerateKey(rand.Reader) - wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: wrongPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { - _, err := secure.HandshakeWithTimestampOffset(pa, true, cfgServer, 0) - errA <- err - }() - go func() { - _, err := secure.HandshakeWithTimestampOffset(pb, false, cfgClient, 0) - errB <- err - }() - <-errA - <-errB -} - -func TestAuthLookupServerVerifyFails(t *testing.T) { - secure.ResetReplayCache() - serverPub, serverPriv, _ := ed25519.GenerateKey(rand.Reader) - _, clientPriv, _ := ed25519.GenerateKey(rand.Reader) - wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) - const srvID, cliID = uint32(101), uint32(202) - - srvLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == cliID { - return wrongPub - } - return nil - } - cliLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == srvID { - return serverPub - } - return nil - } - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { - _, err := secure.HandshakeWithLookup(pa, true, &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, srvLookup) - errA <- err - }() - go func() { - _, err := secure.HandshakeWithLookup(pb, false, &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, cliLookup) - errB <- err - }() - <-errA - <-errB -} - -func TestAuthLookupClientVerifyFails(t *testing.T) { - secure.ResetReplayCache() - _, serverPriv, _ := ed25519.GenerateKey(rand.Reader) - clientPub, clientPriv, _ := ed25519.GenerateKey(rand.Reader) - wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) - const srvID, cliID = uint32(901), uint32(902) - - srvLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == cliID { - return clientPub - } - return nil - } - cliLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == srvID { - return wrongPub - } - return nil - } - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { - _, err := secure.HandshakeWithLookup(pa, true, &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, srvLookup) - errA <- err - }() - go func() { - _, err := secure.HandshakeWithLookup(pb, false, &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, cliLookup) - errB <- err - }() - <-errA - cliErr := <-errB - if cliErr == nil { - t.Fatal("expected client verify error") - } -} - -// --------------------------------------------------------------------------- -// performAuth* — post-ECDH auth-frame write failures -// --------------------------------------------------------------------------- - -func TestAuthServerPostECDHWriteFails(t *testing.T) { - secure.ResetReplayCache() - srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) - cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} - - wrapped := &failAfterNWrites{Conn: pa, failAfter: 1} - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { _, err := secure.Handshake(wrapped, true, cfgServer); errA <- err }() - go func() { _, err := secure.Handshake(pb, false, cfgClient); errB <- err }() - srvErr := <-errA - <-errB - if srvErr == nil { - t.Fatal("expected server auth-write error") - } -} - -func TestAuthClientPostVerifyWriteFails(t *testing.T) { - secure.ResetReplayCache() - srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) - cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} - - wrapped := &failAfterNWrites{Conn: pb, failAfter: 1} - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { _, err := secure.Handshake(pa, true, cfgServer); errA <- err }() - go func() { _, err := secure.Handshake(wrapped, false, cfgClient); errB <- err }() - <-errA - cliErr := <-errB - if cliErr == nil { - t.Fatal("expected client post-verify write error") - } -} - -func TestAuthOffsetServerPostECDHWriteFails(t *testing.T) { - secure.ResetReplayCache() - srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) - cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} - - wrapped := &failAfterNWrites{Conn: pa, failAfter: 1} - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { - _, err := secure.HandshakeWithTimestampOffset(wrapped, true, cfgServer, 0) - errA <- err - }() - go func() { - _, err := secure.HandshakeWithTimestampOffset(pb, false, cfgClient, 0) - errB <- err - }() - srvErr := <-errA - <-errB - if srvErr == nil { - t.Fatal("expected server auth-write error") - } -} - -func TestAuthOffsetClientPostVerifyWriteFails(t *testing.T) { - secure.ResetReplayCache() - srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) - cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} - - wrapped := &failAfterNWrites{Conn: pb, failAfter: 1} - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { - _, err := secure.HandshakeWithTimestampOffset(pa, true, cfgServer, 0) - errA <- err - }() - go func() { - _, err := secure.HandshakeWithTimestampOffset(wrapped, false, cfgClient, 0) - errB <- err - }() - <-errA - cliErr := <-errB - if cliErr == nil { - t.Fatal("expected client post-verify write error") - } -} - -func TestAuthLookupServerPostECDHWriteFails(t *testing.T) { - secure.ResetReplayCache() - serverPub, serverPriv, _ := ed25519.GenerateKey(rand.Reader) - clientPub, clientPriv, _ := ed25519.GenerateKey(rand.Reader) - const srvID, cliID = uint32(701), uint32(702) - srvLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == cliID { - return clientPub - } - return nil - } - cliLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == srvID { - return serverPub - } - return nil - } - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - wrapped := &failAfterNWrites{Conn: pa, failAfter: 1} - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { - _, err := secure.HandshakeWithLookup(wrapped, true, &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, srvLookup) - errA <- err - }() - go func() { - _, err := secure.HandshakeWithLookup(pb, false, &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, cliLookup) - errB <- err - }() - srvErr := <-errA - <-errB - if srvErr == nil { - t.Fatal("expected server auth-write error") - } -} - -func TestAuthLookupClientPostVerifyWriteFails(t *testing.T) { - secure.ResetReplayCache() - serverPub, serverPriv, _ := ed25519.GenerateKey(rand.Reader) - clientPub, clientPriv, _ := ed25519.GenerateKey(rand.Reader) - const srvID, cliID = uint32(801), uint32(802) - srvLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == cliID { - return clientPub - } - return nil - } - cliLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == srvID { - return serverPub - } - return nil - } - - pa, pb := net.Pipe() - defer pa.Close() - defer pb.Close() - - wrapped := &failAfterNWrites{Conn: pb, failAfter: 1} - - errA := make(chan error, 1) - errB := make(chan error, 1) - go func() { - _, err := secure.HandshakeWithLookup(pa, true, &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, srvLookup) - errA <- err - }() - go func() { - _, err := secure.HandshakeWithLookup(wrapped, false, &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, cliLookup) - errB <- err - }() - <-errA - cliErr := <-errB - if cliErr == nil { - t.Fatal("expected client post-verify write error") - } -} - -// --------------------------------------------------------------------------- -// CheckAndRecordNonce cap — fill the cache to maxReplayCacheEntries then -// confirm the next insert errors. -// --------------------------------------------------------------------------- - -func TestCheckAndRecordNonceCacheFull(t *testing.T) { - secure.ResetReplayCache() - for i := 0; i < 100000; i++ { - var n [16]byte - binary.BigEndian.PutUint64(n[:8], uint64(i)) - secure.InjectReplayNonce(n) - } - var fresh [16]byte - binary.BigEndian.PutUint64(fresh[:8], 0xFFFFFFFFFFFFFFFF) - err := secure.CheckAndRecordNonce(fresh) - if err == nil || !strings.Contains(err.Error(), "cache full") { - t.Fatalf("expected cache-full error, got %v", err) - } - secure.ResetReplayCache() -} - -// --------------------------------------------------------------------------- -// Sanity: AES-GCM nonce size assumption matches. -// --------------------------------------------------------------------------- - -func TestAesGcmNonceSizeIs12(t *testing.T) { - t.Parallel() - key := make([]byte, 32) - block, err := aes.NewCipher(key) - if err != nil { - t.Fatal(err) - } - g, err := cipher.NewGCM(block) - if err != nil { - t.Fatal(err) - } - if g.NonceSize() != 12 { - t.Errorf("nonce size = %d, want 12", g.NonceSize()) - } -} - -var _ = errors.New -var _ = secure.AuthFrameLen diff --git a/pkg/secure/zz_handshake_lookup_test.go b/pkg/secure/zz_handshake_lookup_test.go deleted file mode 100644 index fe9571ee..00000000 --- a/pkg/secure/zz_handshake_lookup_test.go +++ /dev/null @@ -1,284 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package secure_test - -import ( - "crypto/ed25519" - "crypto/rand" - "errors" - "io" - "net" - "testing" - "time" - - "github.com/TeoSlayer/pilotprotocol/pkg/secure" -) - -// runLookupHandshake connects two net.Pipe ends, runs HandshakeWithLookup on -// each side concurrently, and returns both resulting SecureConns (or errors). -// NOTE: callers must NOT mark tests t.Parallel() since HandshakeWithLookup -// mutates the global replay cache (see iter 12 lesson). -func runLookupHandshake(t *testing.T, serverCfg, clientCfg *secure.HandshakeConfig, serverLookup, clientLookup secure.PeerPubKeyLookup) (*secure.SecureConn, error, *secure.SecureConn, error) { - t.Helper() - s, c := net.Pipe() - type result struct { - sc *secure.SecureConn - err error - } - srvCh := make(chan result, 1) - cliCh := make(chan result, 1) - go func() { - sc, err := secure.HandshakeWithLookup(s, true, serverCfg, serverLookup) - srvCh <- result{sc, err} - }() - go func() { - sc, err := secure.HandshakeWithLookup(c, false, clientCfg, clientLookup) - cliCh <- result{sc, err} - }() - select { - case <-time.After(5 * time.Second): - s.Close() - c.Close() - t.Fatal("handshake timed out") - default: - } - srv := <-srvCh - cli := <-cliCh - return srv.sc, srv.err, cli.sc, cli.err -} - -func newEd25519KeyPair(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) { - t.Helper() - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } - return pub, priv -} - -func TestHandshakeWithLookupHappyPath(t *testing.T) { - secure.ResetReplayCache() - serverPub, serverPriv := newEd25519KeyPair(t) - clientPub, clientPriv := newEd25519KeyPair(t) - const srvID, cliID = uint32(0x10001), uint32(0x20002) - - srvLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == cliID { - return clientPub - } - return nil - } - cliLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == srvID { - return serverPub - } - return nil - } - - srvSC, srvErr, cliSC, cliErr := runLookupHandshake(t, - &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, - &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, - srvLookup, cliLookup) - - if srvErr != nil { - t.Fatalf("server handshake: %v", srvErr) - } - if cliErr != nil { - t.Fatalf("client handshake: %v", cliErr) - } - if srvSC.PeerNodeID != cliID { - t.Errorf("server saw peer=%d, want %d", srvSC.PeerNodeID, cliID) - } - if cliSC.PeerNodeID != srvID { - t.Errorf("client saw peer=%d, want %d", cliSC.PeerNodeID, srvID) - } - - // End-to-end data exchange proves the derived keys match on both sides. - done := make(chan error, 1) - go func() { - buf := make([]byte, 5) - if _, err := io.ReadFull(cliSC, buf); err != nil { - done <- err - return - } - if string(buf) != "ping!" { - done <- errors.New("bad payload: " + string(buf)) - return - } - done <- nil - }() - if _, err := srvSC.Write([]byte("ping!")); err != nil { - t.Fatalf("write: %v", err) - } - select { - case err := <-done: - if err != nil { - t.Fatalf("read: %v", err) - } - case <-time.After(2 * time.Second): - t.Fatal("timeout reading from encrypted stream") - } - cliSC.Close() - srvSC.Close() -} - -func TestHandshakeWithLookupServerRejectsUnknownPeer(t *testing.T) { - secure.ResetReplayCache() - serverPub, serverPriv := newEd25519KeyPair(t) - _, clientPriv := newEd25519KeyPair(t) - const srvID, cliID = uint32(0x30003), uint32(0x40004) - - srvLookup := func(_ uint32) ed25519.PublicKey { return nil } // unknown - cliLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == srvID { - return serverPub - } - return nil - } - - _, srvErr, _, _ := runLookupHandshake(t, - &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, - &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, - srvLookup, cliLookup) - - // Server reads client's auth frame AFTER writing its own, then looks up - // the client's pubkey and rejects on nil. Client has already completed - // its side (read server frame, verified, wrote its own) so it returns - // without error — only the server-side error surfaces. - if srvErr == nil { - t.Fatal("server should have rejected unknown peer") - } -} - -func TestHandshakeWithLookupClientRejectsUnknownServer(t *testing.T) { - secure.ResetReplayCache() - _, serverPriv := newEd25519KeyPair(t) - clientPub, clientPriv := newEd25519KeyPair(t) - const srvID, cliID = uint32(0x50005), uint32(0x60006) - - srvLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == cliID { - return clientPub - } - return nil - } - cliLookup := func(_ uint32) ed25519.PublicKey { return nil } - - _, srvErr, _, cliErr := runLookupHandshake(t, - &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, - &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, - srvLookup, cliLookup) - - if cliErr == nil { - t.Fatal("client should have rejected unknown server") - } - if srvErr == nil { - t.Fatal("server should have failed after client closed") - } -} - -func TestHandshakeWithLookupBadSignatureRejected(t *testing.T) { - secure.ResetReplayCache() - serverPub, serverPriv := newEd25519KeyPair(t) - _, clientPriv := newEd25519KeyPair(t) - // Third unrelated pubkey — server will look up client by nodeID but - // get a key that doesn't match the client's actual signer. - wrongPub, _ := newEd25519KeyPair(t) - const srvID, cliID = uint32(0x70007), uint32(0x80008) - - srvLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == cliID { - return wrongPub // signature will fail to verify - } - return nil - } - cliLookup := func(nodeID uint32) ed25519.PublicKey { - if nodeID == srvID { - return serverPub - } - return nil - } - - _, srvErr, _, _ := runLookupHandshake(t, - &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, - &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, - srvLookup, cliLookup) - - if srvErr == nil { - t.Fatal("server should have rejected bad signature") - } -} - -func TestHandshakeWithLookupNoAuthSkipsLookup(t *testing.T) { - secure.ResetReplayCache() - s, c := net.Pipe() - srvCh := make(chan error, 1) - cliCh := make(chan error, 1) - // No signer in cfg — auth is skipped, lookup is never called. - go func() { - _, err := secure.HandshakeWithLookup(s, true, nil, nil) - srvCh <- err - }() - go func() { - _, err := secure.HandshakeWithLookup(c, false, nil, nil) - cliCh <- err - }() - select { - case err := <-srvCh: - if err != nil { - t.Fatalf("server: %v", err) - } - case <-time.After(2 * time.Second): - t.Fatal("server handshake timed out") - } - select { - case err := <-cliCh: - if err != nil { - t.Fatalf("client: %v", err) - } - case <-time.After(2 * time.Second): - t.Fatal("client handshake timed out") - } - s.Close() - c.Close() -} - -// ---------- Server constructors ---------- - -func TestNewServerSetsFields(t *testing.T) { - t.Parallel() - called := false - h := func(_ net.Conn) { called = true } - s := secure.NewServer(nil, h) - if s.Driver() != nil { - t.Error("driver should be nil") - } - if s.Handler() == nil { - t.Fatal("handler nil") - } - if s.AuthSigner() != nil || s.AuthNodeID() != 0 || s.PeerLookup() != nil { - t.Error("unauth server should not populate auth fields") - } - // Sanity: handler invocable. - s.Handler()(nil) - if !called { - t.Error("handler not invoked") - } -} - -func TestNewAuthServerSetsFields(t *testing.T) { - t.Parallel() - _, priv := newEd25519KeyPair(t) - lookup := func(_ uint32) ed25519.PublicKey { return nil } - h := func(_ net.Conn) {} - s := secure.NewAuthServer(nil, h, 0xABCD1234, priv, lookup) - if s.AuthNodeID() != 0xABCD1234 { - t.Errorf("authNodeID = %#x", s.AuthNodeID()) - } - if s.AuthSigner() == nil { - t.Error("authSigner nil") - } - if s.PeerLookup() == nil { - t.Error("peerLookup nil") - } -} diff --git a/pkg/secure/zz_secure_test.go b/pkg/secure/zz_secure_test.go deleted file mode 100644 index 4d1aae10..00000000 --- a/pkg/secure/zz_secure_test.go +++ /dev/null @@ -1,586 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package secure_test - -import ( - "bytes" - "crypto/ed25519" - "crypto/rand" - "encoding/binary" - "errors" - "net" - "strings" - "sync" - "testing" - "time" - - "github.com/TeoSlayer/pilotprotocol/pkg/secure" -) - -// pipePair returns two connected net.Conn endpoints (in-process pipe). -func pipePair() (net.Conn, net.Conn) { - return net.Pipe() -} - -// genIdentity returns an Ed25519 keypair. -func genIdentity(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) { - t.Helper() - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } - return pub, priv -} - -// handshakeBoth runs both ends of Handshake concurrently and returns errors. -func handshakeBoth(t *testing.T, a, b net.Conn, cfgA, cfgB *secure.HandshakeConfig) (*secure.SecureConn, *secure.SecureConn) { - t.Helper() - type result struct { - sc *secure.SecureConn - err error - } - chA := make(chan result, 1) - chB := make(chan result, 1) - go func() { - var sc *secure.SecureConn - var err error - if cfgA != nil { - sc, err = secure.Handshake(a, true, cfgA) - } else { - sc, err = secure.Handshake(a, true) - } - chA <- result{sc, err} - }() - go func() { - var sc *secure.SecureConn - var err error - if cfgB != nil { - sc, err = secure.Handshake(b, false, cfgB) - } else { - sc, err = secure.Handshake(b, false) - } - chB <- result{sc, err} - }() - rA := <-chA - rB := <-chB - if rA.err != nil { - t.Fatalf("server handshake: %v", rA.err) - } - if rB.err != nil { - t.Fatalf("client handshake: %v", rB.err) - } - return rA.sc, rB.sc -} - -// --------------------------------------------------------------------------- -// Unauthenticated handshake + Read/Write round-trip -// --------------------------------------------------------------------------- - -func TestUnauthenticatedHandshakeRoundTrip(t *testing.T) { - t.Parallel() - a, b := pipePair() - defer a.Close() - defer b.Close() - - server, client := handshakeBoth(t, a, b, nil, nil) - - msg := []byte("hello secure world") - go func() { _, _ = client.Write(msg) }() - - got := make([]byte, len(msg)) - n, err := server.Read(got) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(got[:n], msg) { - t.Errorf("got %q, want %q", got[:n], msg) - } -} - -func TestEncryptedReadBuffersLeftover(t *testing.T) { - t.Parallel() - a, b := pipePair() - defer a.Close() - defer b.Close() - server, client := handshakeBoth(t, a, b, nil, nil) - - msg := bytes.Repeat([]byte("X"), 1024) - go func() { _, _ = client.Write(msg) }() - - // Read with a small buffer to force leftover-buffering path - small := make([]byte, 100) - n, err := server.Read(small) - if err != nil { - t.Fatal(err) - } - if n != 100 { - t.Errorf("first read = %d, want 100", n) - } - // Drain remaining 924 bytes via subsequent reads (should hit readBuf) - rest := make([]byte, 1024) - off := 0 - for off < 924 { - k, err := server.Read(rest[off:]) - if err != nil { - t.Fatal(err) - } - off += k - } - if off != 924 { - t.Errorf("drained %d, want 924", off) - } -} - -func TestEncryptedHidesPlaintextOnWire(t *testing.T) { - // Use a tap to capture raw bytes, then verify plaintext is not present. - a, b := pipePair() - defer a.Close() - defer b.Close() - server, client := handshakeBoth(t, a, b, nil, nil) - - // Send a recognizable plaintext via client; ensure it doesn't appear - // directly on the wire by reading both server-side decrypted data and - // verifying decryption returns the same bytes (i.e., AEAD is exercised). - plain := []byte("PLAINTEXT-MARKER-12345") - go func() { _, _ = client.Write(plain) }() - got := make([]byte, len(plain)) - if _, err := server.Read(got); err != nil { - t.Fatal(err) - } - if !bytes.Equal(got, plain) { - t.Errorf("decryption mismatch: got %q want %q", got, plain) - } - // Indirect check: server has consumed all, no further bytes pending -} - -func TestNonceUniquenessAcrossWrites(t *testing.T) { - a, b := pipePair() - defer a.Close() - defer b.Close() - server, client := handshakeBoth(t, a, b, nil, nil) - - // Send N distinct messages from client; server reads them. Each Write - // increments the nonce counter so duplicates would be a SUT bug. - const N = 5 - go func() { - for i := 0; i < N; i++ { - _, _ = client.Write([]byte{byte(i)}) - } - }() - for i := 0; i < N; i++ { - buf := make([]byte, 1) - if _, err := server.Read(buf); err != nil { - t.Fatal(err) - } - if buf[0] != byte(i) { - t.Errorf("msg %d: got %d", i, buf[0]) - } - } -} - -// --------------------------------------------------------------------------- -// Authenticated handshake -// --------------------------------------------------------------------------- - -func TestAuthenticatedHandshakeMutual(t *testing.T) { - secure.ResetReplayCache() - - srvPub, srvPriv := genIdentity(t) - cliPub, cliPriv := genIdentity(t) - - a, b := pipePair() - defer a.Close() - defer b.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 100, Signer: srvPriv, PeerPubKey: cliPub} - cfgClient := &secure.HandshakeConfig{NodeID: 200, Signer: cliPriv, PeerPubKey: srvPub} - - server, client := handshakeBoth(t, a, b, cfgServer, cfgClient) - - if server.PeerNodeID != 200 { - t.Errorf("server PeerNodeID = %d, want 200", server.PeerNodeID) - } - if client.PeerNodeID != 100 { - t.Errorf("client PeerNodeID = %d, want 100", client.PeerNodeID) - } -} - -func TestAuthenticatedHandshakeWrongPeerKeyFails(t *testing.T) { - secure.ResetReplayCache() - - _, srvPriv := genIdentity(t) - _, cliPriv := genIdentity(t) - wrongPub, _ := genIdentity(t) // server expects this, client signs with different key - - a, b := pipePair() - defer a.Close() - defer b.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: wrongPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: wrongPub} - - type r struct{ err error } - chA := make(chan r, 1) - chB := make(chan r, 1) - go func() { _, err := secure.Handshake(a, true, cfgServer); chA <- r{err} }() - go func() { _, err := secure.Handshake(b, false, cfgClient); chB <- r{err} }() - - rA := <-chA - rB := <-chB - if rA.err == nil && rB.err == nil { - t.Fatal("expected at least one side to fail with wrong PeerPubKey") - } -} - -// --------------------------------------------------------------------------- -// Replay cache and timestamp skew -// --------------------------------------------------------------------------- - -func TestHandshakeWithTimestampOffsetExpiredFails(t *testing.T) { - secure.ResetReplayCache() - srvPub, srvPriv := genIdentity(t) - cliPub, cliPriv := genIdentity(t) - - a, b := pipePair() - defer a.Close() - defer b.Close() - - cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} - cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} - - // Server uses normal timestamp; client uses 10s offset → exceeds 5s skew. - chA := make(chan error, 1) - chB := make(chan error, 1) - go func() { - _, err := secure.HandshakeWithTimestampOffset(a, true, cfgServer, 0) - chA <- err - }() - go func() { - _, err := secure.HandshakeWithTimestampOffset(b, false, cfgClient, 10*time.Second) - chB <- err - }() - - errA := <-chA - errB := <-chB - // Server reads client's frame and finds it exceeds skew. - if errA == nil { - t.Fatal("expected server to reject expired auth") - } - if !strings.Contains(errA.Error(), "timestamp expired") && - !strings.Contains(errA.Error(), "skew") { - t.Errorf("unexpected err: %v", errA) - } - _ = errB // client side may also error or close -} - -func TestReplayCacheRejectsRepeat(t *testing.T) { - secure.ResetReplayCache() - - var nonce [16]byte - if _, err := rand.Read(nonce[:]); err != nil { - t.Fatal(err) - } - if err := secure.CheckAndRecordNonce(nonce); err != nil { - t.Fatalf("first record: %v", err) - } - // Same nonce again → replay - err := secure.CheckAndRecordNonce(nonce) - if err == nil || !strings.Contains(err.Error(), "replay") { - t.Fatalf("expected replay error, got %v", err) - } -} - -func TestCheckReplayNonceDoesNotRecord(t *testing.T) { - secure.ResetReplayCache() - - var nonce [16]byte - if _, err := rand.Read(nonce[:]); err != nil { - t.Fatal(err) - } - // CheckReplayNonce should report not-present (nil err) without inserting. - if err := secure.CheckReplayNonce(nonce); err != nil { - t.Fatalf("expected fresh nonce nil err, got %v", err) - } - // Then we can record it - if err := secure.CheckAndRecordNonce(nonce); err != nil { - t.Fatal(err) - } - // Now CheckReplayNonce should report replay - if err := secure.CheckReplayNonce(nonce); err == nil { - t.Fatal("expected replay error from CheckReplayNonce") - } -} - -func TestInjectReplayNonceTriggersReplay(t *testing.T) { - secure.ResetReplayCache() - - var nonce [16]byte - if _, err := rand.Read(nonce[:]); err != nil { - t.Fatal(err) - } - secure.InjectReplayNonce(nonce) - if err := secure.CheckAndRecordNonce(nonce); err == nil { - t.Fatal("expected replay after inject") - } -} - -// --------------------------------------------------------------------------- -// BuildAuthSignMessage -// --------------------------------------------------------------------------- - -func TestBuildAuthSignMessageStable(t *testing.T) { - x25519 := bytes.Repeat([]byte{0xAB}, 32) - var nonce [16]byte - for i := range nonce { - nonce[i] = byte(i) - } - got := secure.BuildAuthSignMessage(0xDEADBEEF, x25519, 0x1122334455667788, nonce) - // Layout: domain(18) + nodeID(4) + pub(32) + ts(8) + nonce(16) = 78 - if len(got) != 18+4+32+8+16 { - t.Errorf("len = %d, want 78", len(got)) - } - if !bytes.HasPrefix(got, []byte("pilot-secure-auth:")) { - t.Errorf("missing domain prefix: %q", got[:18]) - } - if id := binary.BigEndian.Uint32(got[18:22]); id != 0xDEADBEEF { - t.Errorf("nodeID encoding wrong: %x", id) - } - if !bytes.Equal(got[22:54], x25519) { - t.Errorf("pubkey not embedded correctly") - } - if ts := binary.BigEndian.Uint64(got[54:62]); ts != 0x1122334455667788 { - t.Errorf("timestamp encoding wrong: %x", ts) - } - if !bytes.Equal(got[62:78], nonce[:]) { - t.Errorf("nonce not embedded correctly") - } -} - -func TestBuildAuthSignMessageDifferentInputsDiffer(t *testing.T) { - x := bytes.Repeat([]byte{0x00}, 32) - var n1, n2 [16]byte - n2[0] = 1 - a := secure.BuildAuthSignMessage(1, x, 100, n1) - b := secure.BuildAuthSignMessage(1, x, 100, n2) - if bytes.Equal(a, b) { - t.Fatal("messages with different nonces should differ") - } -} - -// --------------------------------------------------------------------------- -// VerifyAuthFrame -// --------------------------------------------------------------------------- - -func TestVerifyAuthFrameWrongSize(t *testing.T) { - _, err := secure.VerifyAuthFrame(make([]byte, 10), nil, nil, time.Now()) - if err == nil || !strings.Contains(err.Error(), "wrong size") { - t.Fatalf("expected wrong-size err, got %v", err) - } -} - -func TestVerifyAuthFrameExpiredTimestamp(t *testing.T) { - secure.ResetReplayCache() - pub, priv := genIdentity(t) - x25519 := bytes.Repeat([]byte{0xAB}, 32) - expiredTS := uint64(time.Now().Add(-time.Hour).Unix()) - var nonce [16]byte - rand.Read(nonce[:]) - - frame := make([]byte, secure.AuthFrameLen) - binary.BigEndian.PutUint32(frame[0:4], 42) - binary.BigEndian.PutUint64(frame[4:12], expiredTS) - copy(frame[12:28], nonce[:]) - sig := ed25519.Sign(priv, secure.BuildAuthSignMessage(42, x25519, expiredTS, nonce)) - copy(frame[28:92], sig) - - _, err := secure.VerifyAuthFrame(frame, pub, x25519, time.Now()) - if err == nil || !strings.Contains(err.Error(), "expired") { - t.Fatalf("expected expired err, got %v", err) - } -} - -func TestVerifyAuthFrameReplayDetected(t *testing.T) { - secure.ResetReplayCache() - pub, priv := genIdentity(t) - x25519 := bytes.Repeat([]byte{0xAB}, 32) - now := time.Now() - ts := uint64(now.Unix()) - var nonce [16]byte - rand.Read(nonce[:]) - - build := func() []byte { - frame := make([]byte, secure.AuthFrameLen) - binary.BigEndian.PutUint32(frame[0:4], 42) - binary.BigEndian.PutUint64(frame[4:12], ts) - copy(frame[12:28], nonce[:]) - sig := ed25519.Sign(priv, secure.BuildAuthSignMessage(42, x25519, ts, nonce)) - copy(frame[28:92], sig) - return frame - } - - if _, err := secure.VerifyAuthFrame(build(), pub, x25519, now); err != nil { - t.Fatalf("first verify: %v", err) - } - // Second verify with the SAME nonce → replay - _, err := secure.VerifyAuthFrame(build(), pub, x25519, now) - if err == nil || !strings.Contains(err.Error(), "replay") { - t.Fatalf("expected replay error, got %v", err) - } -} - -func TestVerifyAuthFrameBadSignature(t *testing.T) { - secure.ResetReplayCache() - pub, _ := genIdentity(t) // verifier key - _, otherPriv := genIdentity(t) - x25519 := bytes.Repeat([]byte{0xAB}, 32) - now := time.Now() - ts := uint64(now.Unix()) - var nonce [16]byte - rand.Read(nonce[:]) - - frame := make([]byte, secure.AuthFrameLen) - binary.BigEndian.PutUint32(frame[0:4], 42) - binary.BigEndian.PutUint64(frame[4:12], ts) - copy(frame[12:28], nonce[:]) - // Sign with a DIFFERENT key → verification must fail - sig := ed25519.Sign(otherPriv, secure.BuildAuthSignMessage(42, x25519, ts, nonce)) - copy(frame[28:92], sig) - - _, err := secure.VerifyAuthFrame(frame, pub, x25519, now) - if err == nil || !strings.Contains(err.Error(), "verification failed") { - t.Fatalf("expected sig verify err, got %v", err) - } -} - -// --------------------------------------------------------------------------- -// ReadExact -// --------------------------------------------------------------------------- - -func TestReadExactSuccess(t *testing.T) { - t.Parallel() - got, err := secure.ReadExact(bytes.NewReader([]byte("hello world")), 5) - if err != nil { - t.Fatal(err) - } - if string(got) != "hello" { - t.Errorf("got %q", got) - } -} - -func TestReadExactShortFails(t *testing.T) { - t.Parallel() - _, err := secure.ReadExact(bytes.NewReader([]byte("hi")), 5) - if err == nil { - t.Fatal("expected error reading 5 from 2-byte source") - } - if !errors.Is(err, errors.New("")) && err.Error() == "" { - t.Fatalf("expected non-empty error, got %v", err) - } -} - -// --------------------------------------------------------------------------- -// secure.SecureConn passthrough methods -// --------------------------------------------------------------------------- - -func TestSecureConnAddrAndDeadlinePassthrough(t *testing.T) { - t.Parallel() - a, b := pipePair() - defer a.Close() - defer b.Close() - server, _ := handshakeBoth(t, a, b, nil, nil) - - if server.LocalAddr() == nil { - t.Error("LocalAddr nil") - } - if server.RemoteAddr() == nil { - t.Error("RemoteAddr nil") - } - - dl := time.Now().Add(time.Second) - if err := server.SetDeadline(dl); err != nil { - t.Errorf("SetDeadline: %v", err) - } - if err := server.SetReadDeadline(dl); err != nil { - t.Errorf("SetReadDeadline: %v", err) - } - if err := server.SetWriteDeadline(dl); err != nil { - t.Errorf("SetWriteDeadline: %v", err) - } -} - -func TestSecureConnCloseClosesUnderlying(t *testing.T) { - t.Parallel() - a, b := pipePair() - server, _ := handshakeBoth(t, a, b, nil, nil) - - if err := server.Close(); err != nil { - t.Errorf("Close: %v", err) - } - // After Close the underlying conn rejects further writes. - if _, err := a.Write([]byte("x")); err == nil { - t.Error("expected raw write to fail after Close") - } - b.Close() -} - -// --------------------------------------------------------------------------- -// Handshake error: unparseable peer key -// --------------------------------------------------------------------------- - -func TestHandshakeRejectsBadPeerKey(t *testing.T) { - t.Parallel() - a, b := pipePair() - defer a.Close() - defer b.Close() - - // Server expects 32-byte X25519 pub from client. Send 32 bytes of 0xFF - // which is an invalid (non-canonical) curve point. - go func() { - // Client side: write garbage instead of running Handshake - junk := bytes.Repeat([]byte{0xFF}, 32) - _, _ = b.Write(junk) - // Read server's pubkey to unblock its Write - buf := make([]byte, 32) - _, _ = b.Read(buf) - }() - - _, err := secure.Handshake(a, true) - // Either ECDH or NewPublicKey may reject — both valid. - if err == nil { - t.Skip("server accepted; some Go versions accept all-1s as pubkey — skip") - } -} - -// --------------------------------------------------------------------------- -// Concurrent writes serialise (no nonce reuse / corruption) -// --------------------------------------------------------------------------- - -func TestConcurrentWritesSerialise(t *testing.T) { - t.Parallel() - a, b := pipePair() - defer a.Close() - defer b.Close() - server, client := handshakeBoth(t, a, b, nil, nil) - - const N = 20 - var wg sync.WaitGroup - for i := 0; i < N; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - _, _ = client.Write([]byte{byte(i)}) - }(i) - } - - // Read all messages on server side; ensure decryption succeeds for each. - got := make(map[byte]bool) - for len(got) < N { - buf := make([]byte, 1) - _, err := server.Read(buf) - if err != nil { - t.Fatalf("decrypt err during concurrent writes: %v", err) - } - got[buf[0]] = true - } - wg.Wait() -} diff --git a/pkg/urlvalidate/validate.go b/pkg/urlvalidate/validate.go deleted file mode 100644 index 081a9c13..00000000 --- a/pkg/urlvalidate/validate.go +++ /dev/null @@ -1,68 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -// Package urlvalidate provides SSRF-prevention checks shared across packages -// that accept operator-supplied URLs (webhook endpoints, audit export sinks, -// identity provider verification callbacks, etc.). -// -// The rules are intentionally conservative: -// - Only http and https schemes are allowed. -// - Link-local addresses (IPv4 169.254.0.0/16, IPv6 fe80::/10) are blocked -// because they include cloud metadata services and host-local adjacencies. -// - A small allowlist of cloud metadata hostnames is blocked outright. DNS -// is case-insensitive, so the comparison lowercases the hostname before -// matching — "Metadata.Google.Internal" must not bypass the blocklist. -// -// Placing this in a neutral package lets both pkg/daemon and pkg/registry -// (which cannot import pkg/daemon) share exactly one implementation. -package urlvalidate - -import ( - "fmt" - "net" - "net/url" - "strings" -) - -// Validate returns nil if rawURL is an acceptable http(s) endpoint that does -// not point at a link-local or well-known cloud-metadata target. Callers are -// responsible for deciding whether an empty URL (which returns an error here) -// should be interpreted as "disable" before calling. -func Validate(rawURL string) error { - parsed, err := url.Parse(rawURL) - if err != nil { - return fmt.Errorf("invalid URL: %w", err) - } - if parsed.Scheme != "http" && parsed.Scheme != "https" { - return fmt.Errorf("URL must use http or https scheme, got %q", parsed.Scheme) - } - host := parsed.Hostname() - if host == "" { - return fmt.Errorf("URL must have a host") - } - // Strip IPv6 zone identifier (e.g. "fe80::1%eth0") before parsing. - // net.ParseIP does not handle zone suffixes, so without this a - // link-local address with a zone ID would pass the check unnoticed. - ipStr := host - if i := strings.IndexByte(ipStr, '%'); i != -1 { - ipStr = ipStr[:i] - } - if ip := net.ParseIP(ipStr); ip != nil { - if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { - return fmt.Errorf("URL cannot target link-local address %s", host) - } - } - switch strings.ToLower(host) { - case - // GCP - "metadata.google.internal", - "metadata.google.com", - // AWS (DNS names that reach the EC2 instance metadata service - // without traversing the link-local IP path) - "ec2.internal", - "instance-data-service.ec2.internal", - // Azure (IMDS DNS endpoint) - "metadata.azure.com": - return fmt.Errorf("URL cannot target cloud metadata endpoint %s", host) - } - return nil -} diff --git a/pkg/urlvalidate/zz_cloud_metadata_test.go b/pkg/urlvalidate/zz_cloud_metadata_test.go deleted file mode 100644 index 8af6493f..00000000 --- a/pkg/urlvalidate/zz_cloud_metadata_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package urlvalidate_test - -// Regression for SSRF allowlist gaps: the original implementation -// blocked GCP metadata.google.{internal,com} + link-local IPs (which -// covers 169.254.169.254 reaching the EC2/Azure metadata services by -// IP). But the AWS DNS-name path (ec2.internal, -// instance-data-service.ec2.internal) and Azure DNS-name path -// (metadata.azure.com) reached the metadata service without hitting -// the link-local check, leaving an SSRF vector for a webhook -// destination targeting `http://ec2.internal/...`. - -import ( - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/urlvalidate" -) - -func TestValidate_BlocksAWSMetadataHostnames(t *testing.T) { - t.Parallel() - - cases := []string{ - "http://ec2.internal/latest/meta-data/iam/security-credentials/", - "http://instance-data-service.ec2.internal/", - "http://EC2.Internal/", // case-insensitive - } - for _, in := range cases { - err := urlvalidate.Validate(in) - if err == nil { - t.Errorf("Validate(%q) returned nil — AWS metadata hostname not blocked", in) - continue - } - if !strings.Contains(err.Error(), "metadata") { - t.Errorf("Validate(%q) error %q does not mention 'metadata'", in, err.Error()) - } - } -} - -func TestValidate_BlocksAzureMetadataHostname(t *testing.T) { - t.Parallel() - - err := urlvalidate.Validate("http://metadata.azure.com/metadata/instance?api-version=2021-02-01") - if err == nil { - t.Fatal("Azure metadata.azure.com not blocked — SSRF vector") - } - if !strings.Contains(err.Error(), "metadata") { - t.Errorf("expected error to mention 'metadata', got: %v", err) - } -} - -func TestValidate_StillAllowsLegitimateHosts(t *testing.T) { - t.Parallel() - - for _, in := range []string{ - "https://example.com/webhook", - "https://hooks.slack.com/services/T00/B00/abc", - "https://internal-api.example.com/", - } { - if err := urlvalidate.Validate(in); err != nil { - t.Errorf("Validate(%q) wrongly rejected: %v", in, err) - } - } -} diff --git a/pkg/urlvalidate/zz_validate_edge_test.go b/pkg/urlvalidate/zz_validate_edge_test.go deleted file mode 100644 index d51dc839..00000000 --- a/pkg/urlvalidate/zz_validate_edge_test.go +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package urlvalidate_test - -import ( - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/urlvalidate" -) - -func TestValidate_ParseError(t *testing.T) { - t.Parallel() - // %ZZ is an invalid percent-encoding → url.Parse returns an error. - err := urlvalidate.Validate("http://example.com/%ZZ") - if err == nil || !strings.Contains(err.Error(), "invalid URL") { - t.Fatalf("want 'invalid URL', got %v", err) - } -} - -func TestValidate_NoHost(t *testing.T) { - t.Parallel() - // "http:" parses but Hostname() returns "". - err := urlvalidate.Validate("http:") - if err == nil || !strings.Contains(err.Error(), "URL must have a host") { - t.Fatalf("want 'URL must have a host', got %v", err) - } -} - -func TestValidate_LinkLocalIPv6WithZone(t *testing.T) { - t.Parallel() - // The code strips %zoneid before passing to net.ParseIP. Cover that branch. - err := urlvalidate.Validate("http://[fe80::1%25eth0]/") - if err == nil || !strings.Contains(err.Error(), "link-local") { - t.Fatalf("want 'link-local', got %v", err) - } -} - -func TestValidate_LinkLocalIPv4Multicast(t *testing.T) { - t.Parallel() - // 224.0.0.1 is in IPv4 link-local multicast block 224.0.0.0/24. - err := urlvalidate.Validate("http://224.0.0.1/") - if err == nil || !strings.Contains(err.Error(), "link-local") { - t.Fatalf("want 'link-local', got %v", err) - } -} - -func TestValidate_NormalPublicHostsAllowed(t *testing.T) { - t.Parallel() - // Spot-checks for non-error happy paths beyond what the table covers. - for _, u := range []string{ - "https://hooks.example.com/webhook", - "http://example.org:8080/path?x=1", - "https://api.example.io/audit", - } { - if err := urlvalidate.Validate(u); err != nil { - t.Errorf("%s: unexpected error: %v", u, err) - } - } -} diff --git a/pkg/urlvalidate/zz_validate_test.go b/pkg/urlvalidate/zz_validate_test.go deleted file mode 100644 index bb55839c..00000000 --- a/pkg/urlvalidate/zz_validate_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later - -package urlvalidate_test - -import ( - "strings" - "testing" - - "github.com/TeoSlayer/pilotprotocol/pkg/urlvalidate" -) - -func TestValidate(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - url string - wantErr bool - errMsg string - }{ - {"valid http", "http://example.com/hook", false, ""}, - {"valid https", "https://hooks.example.com/pilot", false, ""}, - {"valid with port", "https://example.com:8443/hook", false, ""}, - {"valid routable IPv4", "http://192.168.1.100:9000/hook", false, ""}, - - {"ftp scheme", "ftp://example.com/hook", true, "http or https"}, - {"file scheme", "file:///etc/passwd", true, "http or https"}, - {"no scheme", "example.com/hook", true, "http or https"}, - {"empty", "", true, "http or https"}, - - {"link-local ipv4", "http://169.254.169.254/metadata", true, "link-local"}, - {"link-local ipv6", "http://[fe80::1]/hook", true, "link-local"}, - - {"gcp metadata", "http://metadata.google.internal/", true, "cloud metadata"}, - {"gcp metadata alt", "http://metadata.google.com/", true, "cloud metadata"}, - {"gcp metadata mixed case", "http://Metadata.Google.Internal/", true, "cloud metadata"}, - {"gcp metadata upper case", "http://METADATA.GOOGLE.INTERNAL/", true, "cloud metadata"}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - err := urlvalidate.Validate(tc.url) - if tc.wantErr { - if err == nil { - t.Fatalf("expected error for URL %q", tc.url) - } - if tc.errMsg != "" && !strings.Contains(err.Error(), tc.errMsg) { - t.Fatalf("expected error containing %q, got: %v", tc.errMsg, err) - } - } else if err != nil { - t.Fatalf("unexpected error for URL %q: %v", tc.url, err) - } - }) - } -} diff --git a/tests/compat/zz_real_beacon_test.go b/tests/compat/zz_real_beacon_test.go index 92e007ff..c0b396bd 100644 --- a/tests/compat/zz_real_beacon_test.go +++ b/tests/compat/zz_real_beacon_test.go @@ -29,10 +29,10 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon/transport/wss" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/pilot-protocol/beacon" bwss "github.com/pilot-protocol/beacon/wss" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/protocol" ) // startRealBeacon brings up a real *beacon.Server with its compat WSS diff --git a/tests/daemon/zz_dup_ack_empty_unacked_cong_bug_test.go b/tests/daemon/zz_dup_ack_empty_unacked_cong_bug_test.go index e2f098cc..06f9f7a8 100644 --- a/tests/daemon/zz_dup_ack_empty_unacked_cong_bug_test.go +++ b/tests/daemon/zz_dup_ack_empty_unacked_cong_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestDupACKWithEmptyUnackedShrinksCongWin verifies that three duplicate ACKs diff --git a/tests/daemon/zz_find_conn_timewait_shadow_bug_test.go b/tests/daemon/zz_find_conn_timewait_shadow_bug_test.go index a57b6d43..56c866ac 100644 --- a/tests/daemon/zz_find_conn_timewait_shadow_bug_test.go +++ b/tests/daemon/zz_find_conn_timewait_shadow_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFindConnectionPrefersActiveOverTimeWait verifies that when two connections diff --git a/tests/daemon/zz_listener_closed_channel_bug_test.go b/tests/daemon/zz_listener_closed_channel_bug_test.go index 5d304253..f66a8350 100644 --- a/tests/daemon/zz_listener_closed_channel_bug_test.go +++ b/tests/daemon/zz_listener_closed_channel_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestListenerSendAfterUnbindSafe verifies that sending a connection to a diff --git a/tests/daemon/zz_zero_window_peerrecvwin_bug_test.go b/tests/daemon/zz_zero_window_peerrecvwin_bug_test.go index 9a424459..3ae71d9c 100644 --- a/tests/daemon/zz_zero_window_peerrecvwin_bug_test.go +++ b/tests/daemon/zz_zero_window_peerrecvwin_bug_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestZeroWindowAdvertisementNotHonored verifies that when the peer advertises diff --git a/tests/internal/policy/zz_policy_test.go b/tests/internal/policy/zz_policy_test.go index 154b165d..5fd27086 100644 --- a/tests/internal/policy/zz_policy_test.go +++ b/tests/internal/policy/zz_policy_test.go @@ -6,7 +6,7 @@ import ( "encoding/json" "testing" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" + registry "github.com/pilot-protocol/common/registry/wire" policy "github.com/pilot-protocol/policy/policylang" ) diff --git a/tests/pkg/config/zz_config_test.go b/tests/pkg/config/zz_config_test.go index 89a74b81..14762bdd 100644 --- a/tests/pkg/config/zz_config_test.go +++ b/tests/pkg/config/zz_config_test.go @@ -8,7 +8,7 @@ import ( "path/filepath" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/config" + "github.com/pilot-protocol/common/config" ) func TestLoadValidJSON(t *testing.T) { diff --git a/tests/pkg/coreapi/zz_lifecycle_test.go b/tests/pkg/coreapi/zz_lifecycle_test.go index b3bd9952..b3354932 100644 --- a/tests/pkg/coreapi/zz_lifecycle_test.go +++ b/tests/pkg/coreapi/zz_lifecycle_test.go @@ -8,7 +8,7 @@ import ( "fmt" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/coreapi" + "github.com/pilot-protocol/common/coreapi" ) type fakeService struct { diff --git a/tests/pkg/logging/zz_logging_test.go b/tests/pkg/logging/zz_logging_test.go index 01a49504..df01cf9a 100644 --- a/tests/pkg/logging/zz_logging_test.go +++ b/tests/pkg/logging/zz_logging_test.go @@ -9,7 +9,7 @@ import ( "strings" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/logging" + "github.com/pilot-protocol/common/logging" ) func TestSetupWriterJSONFormat(t *testing.T) { diff --git a/tests/pkg/registry/wire/zz_rules_test.go b/tests/pkg/registry/wire/zz_rules_test.go index 64f5f924..512ecd59 100644 --- a/tests/pkg/registry/wire/zz_rules_test.go +++ b/tests/pkg/registry/wire/zz_rules_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" + "github.com/pilot-protocol/common/registry/wire" ) // --- ValidateRules error branches ---------------------------------------- diff --git a/tests/pkg/registry/wire/zz_wire_test.go b/tests/pkg/registry/wire/zz_wire_test.go index e91014b8..1467eb60 100644 --- a/tests/pkg/registry/wire/zz_wire_test.go +++ b/tests/pkg/registry/wire/zz_wire_test.go @@ -7,7 +7,7 @@ import ( "math" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" + "github.com/pilot-protocol/common/registry/wire" ) func TestWireFrameRoundTrip(t *testing.T) { diff --git a/tests/pkg/secure/zz_handshake_lookup_test.go b/tests/pkg/secure/zz_handshake_lookup_test.go index fe9571ee..9d044a05 100644 --- a/tests/pkg/secure/zz_handshake_lookup_test.go +++ b/tests/pkg/secure/zz_handshake_lookup_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/secure" + "github.com/pilot-protocol/common/secure" ) // runLookupHandshake connects two net.Pipe ends, runs HandshakeWithLookup on diff --git a/tests/pkg/secure/zz_secure_test.go b/tests/pkg/secure/zz_secure_test.go index 4d1aae10..21d83e4e 100644 --- a/tests/pkg/secure/zz_secure_test.go +++ b/tests/pkg/secure/zz_secure_test.go @@ -14,7 +14,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/secure" + "github.com/pilot-protocol/common/secure" ) // pipePair returns two connected net.Conn endpoints (in-process pipe). diff --git a/tests/pkg/urlvalidate/zz_validate_test.go b/tests/pkg/urlvalidate/zz_validate_test.go index bb55839c..05bb50e9 100644 --- a/tests/pkg/urlvalidate/zz_validate_test.go +++ b/tests/pkg/urlvalidate/zz_validate_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/urlvalidate" + "github.com/pilot-protocol/common/urlvalidate" ) func TestValidate(t *testing.T) { diff --git a/tests/plugins/policy/zz_shipped_blueprints_test.go b/tests/plugins/policy/zz_shipped_blueprints_test.go index d6c740fc..b3183f54 100644 --- a/tests/plugins/policy/zz_shipped_blueprints_test.go +++ b/tests/plugins/policy/zz_shipped_blueprints_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/wire" + registry "github.com/pilot-protocol/common/registry/wire" "github.com/pilot-protocol/policy" ) diff --git a/tests/regtestutil/regtestutil.go b/tests/regtestutil/regtestutil.go index 745749ca..cabeafb4 100644 --- a/tests/regtestutil/regtestutil.go +++ b/tests/regtestutil/regtestutil.go @@ -15,7 +15,7 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/testenv.go b/tests/testenv.go index 504f765d..724b2800 100644 --- a/tests/testenv.go +++ b/tests/testenv.go @@ -13,10 +13,10 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/beacon" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/driver" + registryclient "github.com/pilot-protocol/common/registry/client" "github.com/pilot-protocol/dataexchange" "github.com/pilot-protocol/eventstream" "github.com/pilot-protocol/handshake" diff --git a/tests/zz_admin_token_test.go b/tests/zz_admin_token_test.go index fee89372..8484e9c7 100644 --- a/tests/zz_admin_token_test.go +++ b/tests/zz_admin_token_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_audit_test.go b/tests/zz_audit_test.go index 993de15a..ea07309d 100644 --- a/tests/zz_audit_test.go +++ b/tests/zz_audit_test.go @@ -17,9 +17,9 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/beacon" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_auto_join_test.go b/tests/zz_auto_join_test.go index 4b0e7b17..a4db3494 100644 --- a/tests/zz_auto_join_test.go +++ b/tests/zz_auto_join_test.go @@ -15,7 +15,7 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registryclient "github.com/pilot-protocol/common/registry/client" "github.com/pilot-protocol/webhook" ) diff --git a/tests/zz_beacon_registry_test.go b/tests/zz_beacon_registry_test.go index da76ac28..8fb75c96 100644 --- a/tests/zz_beacon_registry_test.go +++ b/tests/zz_beacon_registry_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_bench_helpers_test.go b/tests/zz_bench_helpers_test.go index 1455cf69..e93eab7d 100644 --- a/tests/zz_bench_helpers_test.go +++ b/tests/zz_bench_helpers_test.go @@ -13,8 +13,8 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" "github.com/pilot-protocol/beacon" + "github.com/pilot-protocol/common/driver" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_bench_latency_test.go b/tests/zz_bench_latency_test.go index 6bf6ae2d..4866bbd3 100644 --- a/tests/zz_bench_latency_test.go +++ b/tests/zz_bench_latency_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // BenchmarkLatencyUnderLoad measures p50/p95/p99 round-trip latency for small diff --git a/tests/zz_broadcast_test.go b/tests/zz_broadcast_test.go index fb6de3b9..1dd67a2d 100644 --- a/tests/zz_broadcast_test.go +++ b/tests/zz_broadcast_test.go @@ -8,7 +8,7 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registryclient "github.com/pilot-protocol/common/registry/client" ) func TestBroadcast(t *testing.T) { diff --git a/tests/zz_commands_test.go b/tests/zz_commands_test.go index c916e31e..64e808aa 100644 --- a/tests/zz_commands_test.go +++ b/tests/zz_commands_test.go @@ -11,9 +11,9 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" icrypto "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/protocol" + registryclient "github.com/pilot-protocol/common/registry/client" ) // ====================== diff --git a/tests/zz_compat_dial_test.go b/tests/zz_compat_dial_test.go index e734d4ed..b9cdbb21 100644 --- a/tests/zz_compat_dial_test.go +++ b/tests/zz_compat_dial_test.go @@ -24,9 +24,9 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/pilot-protocol/beacon" + "github.com/pilot-protocol/common/driver" + "github.com/pilot-protocol/common/protocol" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_compat_registry_tls_test.go b/tests/zz_compat_registry_tls_test.go index 0fdc075e..0363d722 100644 --- a/tests/zz_compat_registry_tls_test.go +++ b/tests/zz_compat_registry_tls_test.go @@ -28,8 +28,8 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/driver" + "github.com/pilot-protocol/common/protocol" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_config_test.go b/tests/zz_config_test.go index 5a05d5fb..56d83b79 100644 --- a/tests/zz_config_test.go +++ b/tests/zz_config_test.go @@ -8,7 +8,7 @@ import ( "path/filepath" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/config" + "github.com/pilot-protocol/common/config" ) // NOTE: These tests modify the global flag.CommandLine and cannot use t.Parallel(). diff --git a/tests/zz_dashboard_helper_test.go b/tests/zz_dashboard_helper_test.go index 75fa45f0..1367bfec 100644 --- a/tests/zz_dashboard_helper_test.go +++ b/tests/zz_dashboard_helper_test.go @@ -10,8 +10,8 @@ package tests import ( "testing" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" icrypto "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" ) func dashRegisterNode(t *testing.T, addr, hostname string) { diff --git a/tests/zz_data_exchange_policy_test.go b/tests/zz_data_exchange_policy_test.go index 17fd80b9..fabba3cf 100644 --- a/tests/zz_data_exchange_policy_test.go +++ b/tests/zz_data_exchange_policy_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + "github.com/pilot-protocol/common/protocol" + registryclient "github.com/pilot-protocol/common/registry/client" ) // TestDataExchangePolicy exercises the "data-exchange" network policy: diff --git a/tests/zz_datagram_test.go b/tests/zz_datagram_test.go index 47a48f45..2ae68baa 100644 --- a/tests/zz_datagram_test.go +++ b/tests/zz_datagram_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registryclient "github.com/pilot-protocol/common/registry/client" ) // TestUnicastDatagram verifies point-to-point datagram delivery. diff --git a/tests/zz_enterprise_gate_test.go b/tests/zz_enterprise_gate_test.go index a1d6242d..72b19645 100644 --- a/tests/zz_enterprise_gate_test.go +++ b/tests/zz_enterprise_gate_test.go @@ -18,8 +18,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_fin_ack_test.go b/tests/zz_fin_ack_test.go index 46a1ff0d..22a71d6c 100644 --- a/tests/zz_fin_ack_test.go +++ b/tests/zz_fin_ack_test.go @@ -8,7 +8,7 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestFinAckNoStorm — §4.2 FIN-ACK non-storm regression. diff --git a/tests/zz_fuzz_config_test.go b/tests/zz_fuzz_config_test.go index 9dab7231..21bf0bcd 100644 --- a/tests/zz_fuzz_config_test.go +++ b/tests/zz_fuzz_config_test.go @@ -10,7 +10,7 @@ import ( "path/filepath" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/config" + "github.com/pilot-protocol/common/config" ) // --------------------------------------------------------------------------- diff --git a/tests/zz_fuzz_daemon_test.go b/tests/zz_fuzz_daemon_test.go index feb3431d..5636f6cb 100644 --- a/tests/zz_fuzz_daemon_test.go +++ b/tests/zz_fuzz_daemon_test.go @@ -10,7 +10,7 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --------------------------------------------------------------------------- diff --git a/tests/zz_fuzz_ipc_test.go b/tests/zz_fuzz_ipc_test.go index da004ac1..2c0854ae 100644 --- a/tests/zz_fuzz_ipc_test.go +++ b/tests/zz_fuzz_ipc_test.go @@ -10,7 +10,7 @@ import ( "sync" "testing" - "github.com/TeoSlayer/pilotprotocol/internal/ipcutil" + "github.com/pilot-protocol/common/ipcutil" ) // --------------------------------------------------------------------------- diff --git a/tests/zz_fuzz_protocol_test.go b/tests/zz_fuzz_protocol_test.go index 58bd4e7f..e5618ea2 100644 --- a/tests/zz_fuzz_protocol_test.go +++ b/tests/zz_fuzz_protocol_test.go @@ -11,7 +11,7 @@ import ( "strings" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // --------------------------------------------------------------------------- diff --git a/tests/zz_fuzz_registry_server_test.go b/tests/zz_fuzz_registry_server_test.go index 5ff9b7c0..9555b09e 100644 --- a/tests/zz_fuzz_registry_server_test.go +++ b/tests/zz_fuzz_registry_server_test.go @@ -14,8 +14,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_fuzz_secure_test.go b/tests/zz_fuzz_secure_test.go index d8a05e36..5445849e 100644 --- a/tests/zz_fuzz_secure_test.go +++ b/tests/zz_fuzz_secure_test.go @@ -18,7 +18,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/secure" + "github.com/pilot-protocol/common/secure" ) // --------------------------------------------------------------------------- diff --git a/tests/zz_handshake_test.go b/tests/zz_handshake_test.go index 324d8c7f..45411fc4 100644 --- a/tests/zz_handshake_test.go +++ b/tests/zz_handshake_test.go @@ -10,8 +10,8 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + "github.com/pilot-protocol/common/driver" + registryclient "github.com/pilot-protocol/common/registry/client" ) func TestHandshakeMutualAutoApprove(t *testing.T) { diff --git a/tests/zz_health_endpoint_test.go b/tests/zz_health_endpoint_test.go index a09502c5..61f68984 100644 --- a/tests/zz_health_endpoint_test.go +++ b/tests/zz_health_endpoint_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_hostname_privacy_test.go b/tests/zz_hostname_privacy_test.go index 3bc5a49f..993a2cb2 100644 --- a/tests/zz_hostname_privacy_test.go +++ b/tests/zz_hostname_privacy_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registryclient "github.com/pilot-protocol/common/registry/client" ) // TestResolveHostnamePrivateNodeRequiresTrust verifies that resolving a private diff --git a/tests/zz_hostname_test.go b/tests/zz_hostname_test.go index aad4b13e..0264311b 100644 --- a/tests/zz_hostname_test.go +++ b/tests/zz_hostname_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_identity_test.go b/tests/zz_identity_test.go index 63fa4e9c..6741eb7b 100644 --- a/tests/zz_identity_test.go +++ b/tests/zz_identity_test.go @@ -11,9 +11,9 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/driver" + registryclient "github.com/pilot-protocol/common/registry/client" ) // waitForSocketRemoval polls until the given unix socket file is removed, diff --git a/tests/zz_integration_test.go b/tests/zz_integration_test.go index 221d007f..9dba6af1 100644 --- a/tests/zz_integration_test.go +++ b/tests/zz_integration_test.go @@ -22,8 +22,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" icrypto "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_invite_acceptance_test.go b/tests/zz_invite_acceptance_test.go index 7b4e8b22..c6e22483 100644 --- a/tests/zz_invite_acceptance_test.go +++ b/tests/zz_invite_acceptance_test.go @@ -14,7 +14,7 @@ import ( registry "github.com/pilot-protocol/rendezvous" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registryclient "github.com/pilot-protocol/common/registry/client" ) // TestInviteRequiresAcceptance verifies the full invite flow: diff --git a/tests/zz_ipc_ops_test.go b/tests/zz_ipc_ops_test.go index e3bbae48..3f09ddf8 100644 --- a/tests/zz_ipc_ops_test.go +++ b/tests/zz_ipc_ops_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registryclient "github.com/pilot-protocol/common/registry/client" ) // TestSetHostnameViaIPC verifies the driver → IPC → daemon → registry round-trip diff --git a/tests/zz_ipc_test.go b/tests/zz_ipc_test.go index 483d4caf..27fd529b 100644 --- a/tests/zz_ipc_test.go +++ b/tests/zz_ipc_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" + "github.com/pilot-protocol/common/driver" ) // TestIPCDisconnectRecovery tests that driver operations return errors when diff --git a/tests/zz_ipv6_test.go b/tests/zz_ipv6_test.go index 3de21154..d2fccdc1 100644 --- a/tests/zz_ipv6_test.go +++ b/tests/zz_ipv6_test.go @@ -10,8 +10,8 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" "github.com/pilot-protocol/beacon" + "github.com/pilot-protocol/common/driver" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_key_lifecycle_test.go b/tests/zz_key_lifecycle_test.go index fedbfe12..79c89f96 100644 --- a/tests/zz_key_lifecycle_test.go +++ b/tests/zz_key_lifecycle_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_lifecycle_test.go b/tests/zz_lifecycle_test.go index 17b25a07..bafc0edd 100644 --- a/tests/zz_lifecycle_test.go +++ b/tests/zz_lifecycle_test.go @@ -10,7 +10,7 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" + "github.com/pilot-protocol/common/driver" ) // TestDialClosedPort verifies that dialing a port with no listener returns an error (RST). diff --git a/tests/zz_limits_test.go b/tests/zz_limits_test.go index 8d735f9d..9c28cbca 100644 --- a/tests/zz_limits_test.go +++ b/tests/zz_limits_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" + "github.com/pilot-protocol/common/driver" ) func TestAcceptQueueNoOrphan(t *testing.T) { diff --git a/tests/zz_member_tags_test.go b/tests/zz_member_tags_test.go index 89acd608..0f3fed35 100644 --- a/tests/zz_member_tags_test.go +++ b/tests/zz_member_tags_test.go @@ -10,8 +10,8 @@ import ( "path/filepath" "testing" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" "github.com/pilot-protocol/policy" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_metrics_test.go b/tests/zz_metrics_test.go index 36568840..609c5e71 100644 --- a/tests/zz_metrics_test.go +++ b/tests/zz_metrics_test.go @@ -12,8 +12,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" icrypto "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_multi_beacon_test.go b/tests/zz_multi_beacon_test.go index 1cf61630..eb96dd5e 100644 --- a/tests/zz_multi_beacon_test.go +++ b/tests/zz_multi_beacon_test.go @@ -10,9 +10,9 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/pilot-protocol/beacon" + "github.com/pilot-protocol/common/driver" + "github.com/pilot-protocol/common/protocol" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_nameserver_test.go b/tests/zz_nameserver_test.go index 90e21f50..d5294cae 100644 --- a/tests/zz_nameserver_test.go +++ b/tests/zz_nameserver_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/driver" + "github.com/pilot-protocol/common/protocol" "github.com/pilot-protocol/nameserver" ) diff --git a/tests/zz_nat_traversal_test.go b/tests/zz_nat_traversal_test.go index ded3aa4b..a7ed22e0 100644 --- a/tests/zz_nat_traversal_test.go +++ b/tests/zz_nat_traversal_test.go @@ -13,8 +13,8 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/pilot-protocol/beacon" + "github.com/pilot-protocol/common/protocol" ) // TestBeaconPunchRequest verifies that the beacon correctly handles diff --git a/tests/zz_network_policy_test.go b/tests/zz_network_policy_test.go index 25e7761d..1a26d3c7 100644 --- a/tests/zz_network_policy_test.go +++ b/tests/zz_network_policy_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_network_test.go b/tests/zz_network_test.go index 663fa344..eb27ae98 100644 --- a/tests/zz_network_test.go +++ b/tests/zz_network_test.go @@ -5,8 +5,8 @@ package tests import ( "testing" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_peer_resilience_test.go b/tests/zz_peer_resilience_test.go index 429132f4..bc7ece45 100644 --- a/tests/zz_peer_resilience_test.go +++ b/tests/zz_peer_resilience_test.go @@ -9,9 +9,9 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/beacon" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_persistence_test.go b/tests/zz_persistence_test.go index 94133efb..a9f8b302 100644 --- a/tests/zz_persistence_test.go +++ b/tests/zz_persistence_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_pilotctl_network_test.go b/tests/zz_pilotctl_network_test.go index cb0b472f..3ce5df3a 100644 --- a/tests/zz_pilotctl_network_test.go +++ b/tests/zz_pilotctl_network_test.go @@ -17,7 +17,7 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registryclient "github.com/pilot-protocol/common/registry/client" "github.com/pilot-protocol/webhook" ) diff --git a/tests/zz_port_leak_test.go b/tests/zz_port_leak_test.go index 267583eb..bc31e05b 100644 --- a/tests/zz_port_leak_test.go +++ b/tests/zz_port_leak_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestEphemeralPortLeakOnSendDatagramError — §4.4 port-leak regression. diff --git a/tests/zz_privacy_test.go b/tests/zz_privacy_test.go index 5a4fa9a5..e5b91de1 100644 --- a/tests/zz_privacy_test.go +++ b/tests/zz_privacy_test.go @@ -12,7 +12,7 @@ import ( "testing" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registryclient "github.com/pilot-protocol/common/registry/client" ) // TestPrivateNodeResolveBlocked verifies that a private node cannot be resolved diff --git a/tests/zz_protocol_test.go b/tests/zz_protocol_test.go index 883011ff..02800f4a 100644 --- a/tests/zz_protocol_test.go +++ b/tests/zz_protocol_test.go @@ -7,7 +7,7 @@ import ( "encoding/binary" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) func TestAddrString(t *testing.T) { diff --git a/tests/zz_protocol_version_test.go b/tests/zz_protocol_version_test.go index 7ba1351c..f5a89de1 100644 --- a/tests/zz_protocol_version_test.go +++ b/tests/zz_protocol_version_test.go @@ -11,9 +11,9 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/protocol" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_rbac_test.go b/tests/zz_rbac_test.go index 23f43f4b..c3a5f255 100644 --- a/tests/zz_rbac_test.go +++ b/tests/zz_rbac_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" ) // TestRBACOwnerRole verifies that the creator of a network gets the owner role. diff --git a/tests/zz_rc6_peer_restart_recovery_test.go b/tests/zz_rc6_peer_restart_recovery_test.go index ddfc1ee9..14d5d8ef 100644 --- a/tests/zz_rc6_peer_restart_recovery_test.go +++ b/tests/zz_rc6_peer_restart_recovery_test.go @@ -40,9 +40,9 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" "github.com/pilot-protocol/beacon" + "github.com/pilot-protocol/common/driver" + "github.com/pilot-protocol/common/protocol" registry "github.com/pilot-protocol/rendezvous" pluginsruntime "github.com/pilot-protocol/runtime" ) diff --git a/tests/zz_registry_hardening_test.go b/tests/zz_registry_hardening_test.go index d526a7b5..4f8e5b83 100644 --- a/tests/zz_registry_hardening_test.go +++ b/tests/zz_registry_hardening_test.go @@ -12,8 +12,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_registry_loadtest_test.go b/tests/zz_registry_loadtest_test.go index bb9d1f13..667a07e2 100644 --- a/tests/zz_registry_loadtest_test.go +++ b/tests/zz_registry_loadtest_test.go @@ -18,8 +18,8 @@ import ( "testing" "time" - registry "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registry "github.com/pilot-protocol/common/registry/client" ) // --------------------------------------------------------------------------- diff --git a/tests/zz_replication_test.go b/tests/zz_replication_test.go index c1b36647..25ec8678 100644 --- a/tests/zz_replication_test.go +++ b/tests/zz_replication_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_reregistration_test.go b/tests/zz_reregistration_test.go index 81d1174a..d2202847 100644 --- a/tests/zz_reregistration_test.go +++ b/tests/zz_reregistration_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/beacon" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_scaling_test.go b/tests/zz_scaling_test.go index 69e897d9..d9e998bb 100644 --- a/tests/zz_scaling_test.go +++ b/tests/zz_scaling_test.go @@ -14,8 +14,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_secure_auth_test.go b/tests/zz_secure_auth_test.go index 2d2f7fb3..d851714e 100644 --- a/tests/zz_secure_auth_test.go +++ b/tests/zz_secure_auth_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/secure" + "github.com/pilot-protocol/common/secure" ) // generateTestIdentity creates a random Ed25519 keypair for testing. diff --git a/tests/zz_secure_test.go b/tests/zz_secure_test.go index 39b7b235..5f824540 100644 --- a/tests/zz_secure_test.go +++ b/tests/zz_secure_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/secure" + "github.com/pilot-protocol/common/secure" ) func TestSecureChannel(t *testing.T) { diff --git a/tests/zz_secure_unit_test.go b/tests/zz_secure_unit_test.go index 7cbfaaf1..0adb1983 100644 --- a/tests/zz_secure_unit_test.go +++ b/tests/zz_secure_unit_test.go @@ -8,7 +8,7 @@ import ( "sync" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/secure" + "github.com/pilot-protocol/common/secure" ) func TestSecureHandshakeAndRoundTrip(t *testing.T) { diff --git a/tests/zz_security_fixes_test.go b/tests/zz_security_fixes_test.go index 2a250ccb..40995c96 100644 --- a/tests/zz_security_fixes_test.go +++ b/tests/zz_security_fixes_test.go @@ -17,9 +17,9 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" "github.com/pilot-protocol/common/fsutil" + registryclient "github.com/pilot-protocol/common/registry/client" "github.com/pilot-protocol/dataexchange" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_security_phase2_test.go b/tests/zz_security_phase2_test.go index 0e7bd1d6..2646d053 100644 --- a/tests/zz_security_phase2_test.go +++ b/tests/zz_security_phase2_test.go @@ -9,9 +9,9 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/protocol" + registryclient "github.com/pilot-protocol/common/registry/client" "github.com/pilot-protocol/policy" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_shutdown_test.go b/tests/zz_shutdown_test.go index d5d176b5..c09706e5 100644 --- a/tests/zz_shutdown_test.go +++ b/tests/zz_shutdown_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + "github.com/pilot-protocol/common/driver" + registryclient "github.com/pilot-protocol/common/registry/client" ) // TestGracefulShutdown verifies the documented shutdown contract: Stop() diff --git a/tests/zz_snapshot_test.go b/tests/zz_snapshot_test.go index 7a214d56..ec52cbd3 100644 --- a/tests/zz_snapshot_test.go +++ b/tests/zz_snapshot_test.go @@ -12,8 +12,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_stop_idempotent_test.go b/tests/zz_stop_idempotent_test.go index fcf29cd1..1b5c35de 100644 --- a/tests/zz_stop_idempotent_test.go +++ b/tests/zz_stop_idempotent_test.go @@ -12,9 +12,9 @@ import ( "testing" "time" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" "github.com/pilot-protocol/common/crypto" + "github.com/pilot-protocol/common/driver" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_stress_test.go b/tests/zz_stress_test.go index 2d9a0ff3..b7e062aa 100644 --- a/tests/zz_stress_test.go +++ b/tests/zz_stress_test.go @@ -10,7 +10,7 @@ import ( "sync/atomic" "testing" - "github.com/TeoSlayer/pilotprotocol/pkg/driver" + "github.com/pilot-protocol/common/driver" ) func TestStressConcurrentConnections(t *testing.T) { diff --git a/tests/zz_syn_trust_gate_test.go b/tests/zz_syn_trust_gate_test.go index 1269fa2a..3645833f 100644 --- a/tests/zz_syn_trust_gate_test.go +++ b/tests/zz_syn_trust_gate_test.go @@ -8,7 +8,7 @@ import ( "time" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" + registryclient "github.com/pilot-protocol/common/registry/client" ) // localUDPAddr converts a daemon's tunnel address to a localhost UDPAddr diff --git a/tests/zz_tags_test.go b/tests/zz_tags_test.go index a92c92a5..5125fbd0 100644 --- a/tests/zz_tags_test.go +++ b/tests/zz_tags_test.go @@ -14,8 +14,8 @@ import ( "testing" "time" - registryclient "github.com/TeoSlayer/pilotprotocol/pkg/registry/client" icrypto "github.com/pilot-protocol/common/crypto" + registryclient "github.com/pilot-protocol/common/registry/client" registry "github.com/pilot-protocol/rendezvous" ) diff --git a/tests/zz_wire_golden_test.go b/tests/zz_wire_golden_test.go index a6c20242..b707a276 100644 --- a/tests/zz_wire_golden_test.go +++ b/tests/zz_wire_golden_test.go @@ -10,7 +10,7 @@ import ( "testing" "github.com/TeoSlayer/pilotprotocol/pkg/daemon" - "github.com/TeoSlayer/pilotprotocol/pkg/protocol" + "github.com/pilot-protocol/common/protocol" ) // TestWireFormatGolden is the P6 wire-protocol invariance gate. It reads each