diff --git a/.dockerignore b/.dockerignore index c66da84c..c86b6356 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - target dev *.mp4 diff --git a/.editorconfig b/.editorconfig index ecab13e5..6ac865e4 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - root = true [*] diff --git a/.github/workflows/fly.yml b/.github/workflows/fly.yml index 912141a8..debb4731 100644 --- a/.github/workflows/fly.yml +++ b/.github/workflows/fly.yml @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - name: Fly Deploy on: push: diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index f5283296..83250781 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - name: pr on: @@ -12,15 +8,6 @@ env: CARGO_TERM_COLOR: always jobs: - reuse: - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - uses: actions/checkout@v4 - - name: REUSE Compliance Check - uses: fsfe/reuse-action@v5 - build: runs-on: ubuntu-latest diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f533c5f4..f4137343 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - name: release permissions: diff --git a/.gitignore b/.gitignore index 2dc04329..91334f38 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - .DS_Store target/ logs/ diff --git a/.rustfmt.toml b/.rustfmt.toml index f550ecc0..b37b2250 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,6 +1,2 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - max_width = 100 hard_tabs = false diff --git a/Cargo.lock b/Cargo.lock index e44e6b2c..8efcbf7d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18,16 +18,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] -name = "ahash" -version = "0.8.12" +name = "adler2" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" -dependencies = [ - "cfg-if", - "once_cell", - "version_check", - "zerocopy", -] +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "aho-corasick" @@ -237,11 +231,17 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.3", "object", "rustc-demangle", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -456,10 +456,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] -name = "crossbeam-epoch" -version = "0.9.18" +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" dependencies = [ "crossbeam-utils", ] @@ -533,6 +542,29 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" +[[package]] +name = "env_filter" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -574,16 +606,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0399f9d26e5191ce32c498bebd31e7a3ceabc2745f0ac54af3f335126c3f24b3" [[package]] -name = "fnv" -version = "1.0.7" +name = "flate2" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide 0.8.9", +] [[package]] -name = "foldhash" -version = "0.1.5" +name = "fnv" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" @@ -771,18 +807,23 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.15.5" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" -dependencies = [ - "foldhash", -] +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" [[package]] -name = "hashbrown" -version = "0.16.1" +name = "hdrhistogram" +version = "7.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" +dependencies = [ + "base64 0.21.7", + "byteorder", + "crossbeam-channel", + "flate2", + "nom", + "num-traits", +] [[package]] name = "heck" @@ -851,6 +892,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "hyper" version = "1.4.1" @@ -889,23 +936,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "hyper-rustls" -version = "0.27.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" -dependencies = [ - "http", - "hyper", - "hyper-util", - "rustls 0.23.31", - "rustls-native-certs 0.8.1", - "rustls-pki-types", - "tokio", - "tokio-rustls 0.26.0", - "tower-service", -] - [[package]] name = "hyper-serve" version = "0.6.2" @@ -1072,9 +1102,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.85" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" dependencies = [ "once_cell", "wasm-bindgen", @@ -1163,53 +1193,6 @@ version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" -[[package]] -name = "metrics" -version = "0.24.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d5312e9ba3771cfa961b585728215e3d972c950a3eed9252aa093d6301277e8" -dependencies = [ - "ahash", - "portable-atomic", -] - -[[package]] -name = "metrics-exporter-prometheus" -version = "0.16.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd7399781913e5393588a8d8c6a2867bf85fb38eaf2502fdce465aad2dc6f034" -dependencies = [ - "base64", - "http-body-util", - "hyper", - "hyper-rustls 0.27.7", - "hyper-util", - "indexmap 2.13.0", - "ipnet", - "metrics", - "metrics-util", - "quanta", - "thiserror 1.0.61", - "tokio", - "tracing", -] - -[[package]] -name = "metrics-util" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8496cc523d1f94c1385dd8f0f0c2c480b2b8aeccb5b7e4485ad6365523ae376" -dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", - "hashbrown 0.15.5", - "metrics", - "quanta", - "rand 0.9.2", - "rand_xoshiro", - "sketches-ddsketch", -] - [[package]] name = "mime" version = "0.3.17" @@ -1231,6 +1214,16 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + [[package]] name = "mio" version = "1.1.0" @@ -1244,36 +1237,39 @@ dependencies = [ [[package]] name = "moq-api" -version = "0.2.10" +version = "0.2.7" dependencies = [ "axum", "clap", + "env_logger", "hyper", + "log", "redis", "reqwest", "serde", "serde_json", "thiserror 1.0.61", "tokio", - "tracing", - "tracing-subscriber", "url", ] [[package]] name = "moq-catalog" -version = "0.2.3" +version = "0.2.2" dependencies = [ "serde", ] [[package]] name = "moq-clock-ietf" -version = "0.6.14" +version = "0.6.8" dependencies = [ "anyhow", "chrono", "clap", + "env_logger", + "futures", + "log", "moq-native-ietf", "moq-transport", "tokio", @@ -1284,12 +1280,13 @@ dependencies = [ [[package]] name = "moq-native-ietf" -version = "0.8.0" +version = "0.7.1" dependencies = [ "anyhow", "clap", "futures", "hex", + "log", "moq-transport", "quinn", "rand 0.8.5", @@ -1297,9 +1294,7 @@ dependencies = [ "rustls 0.23.31", "rustls-native-certs 0.7.0", "rustls-pemfile", - "socket2 0.5.7", "tokio", - "tracing", "url", "web-transport", "web-transport-quinn", @@ -1308,11 +1303,14 @@ dependencies = [ [[package]] name = "moq-pub" -version = "0.8.14" +version = "0.8.8" dependencies = [ "anyhow", "bytes", "clap", + "env_logger", + "futures", + "log", "moq-catalog", "moq-native-ietf", "moq-transport", @@ -1327,18 +1325,19 @@ dependencies = [ [[package]] name = "moq-relay-ietf" -version = "0.7.18" +version = "0.7.10" dependencies = [ "anyhow", + "arc-swap", "async-trait", "axum", "clap", + "env_logger", "fs2", "futures", "hex", "hyper-serve", - "metrics", - "metrics-exporter-prometheus", + "log", "moq-api", "moq-native-ietf", "moq-transport", @@ -1346,7 +1345,6 @@ dependencies = [ "serde_json", "thiserror 2.0.17", "tokio", - "tokio-util", "tower-http", "tracing", "tracing-subscriber", @@ -1356,10 +1354,12 @@ dependencies = [ [[package]] name = "moq-sub" -version = "0.4.8" +version = "0.4.2" dependencies = [ "anyhow", "clap", + "env_logger", + "log", "moq-catalog", "moq-native-ietf", "moq-transport", @@ -1373,12 +1373,35 @@ dependencies = [ [[package]] name = "moq-test-client" -version = "0.1.6" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "env_logger", + "log", + "moq-native-ietf", + "moq-transport", + "tokio", + "tracing", + "tracing-subscriber", + "url", + "web-transport", +] + +[[package]] +name = "moq-topn-test" +version = "0.1.0" dependencies = [ "anyhow", + "bytes", "clap", + "hdrhistogram", "moq-native-ietf", + "moq-relay-ietf", "moq-transport", + "rand 0.8.5", + "serde", + "serde_json", "tokio", "tracing", "tracing-subscriber", @@ -1388,17 +1411,17 @@ dependencies = [ [[package]] name = "moq-transport" -version = "0.14.2" +version = "0.12.2" dependencies = [ "bytes", "futures", + "log", "paste", "serde", "serde_json", "serde_with", "thiserror 1.0.61", "tokio", - "tracing", "uuid", "web-transport", ] @@ -1464,9 +1487,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.2" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "521739c6d2bac4aa25192232afe6841231376b2b26d4d9fae5ecf8ca5772e441" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "num-integer" @@ -1592,12 +1615,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "portable-atomic" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" - [[package]] name = "powerfmt" version = "0.2.0" @@ -1641,21 +1658,6 @@ dependencies = [ "smallvec", ] -[[package]] -name = "quanta" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" -dependencies = [ - "crossbeam-utils", - "libc", - "once_cell", - "raw-cpuid", - "wasi 0.11.0+wasi-snapshot-preview1", - "web-sys", - "winapi", -] - [[package]] name = "quinn" version = "0.11.9" @@ -1789,24 +1791,6 @@ dependencies = [ "getrandom 0.3.3", ] -[[package]] -name = "rand_xoshiro" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f703f4665700daf5512dcca5f43afa6af89f09db47fb56be587f80636bda2d41" -dependencies = [ - "rand_core 0.9.3", -] - -[[package]] -name = "raw-cpuid" -version = "11.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" -dependencies = [ - "bitflags", -] - [[package]] name = "redis" version = "0.32.5" @@ -1914,7 +1898,7 @@ version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "566cafdd92868e0939d3fb961bd0dc25fcfaaed179291093b3d43e6b3150ea10" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "futures-core", "futures-util", @@ -1922,7 +1906,7 @@ dependencies = [ "http-body", "http-body-util", "hyper", - "hyper-rustls 0.26.0", + "hyper-rustls", "hyper-util", "ipnet", "js-sys", @@ -2065,7 +2049,7 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" dependencies = [ - "base64", + "base64 0.22.1", "rustls-pki-types", ] @@ -2297,7 +2281,7 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2c45cd61fefa9db6f254525d46e392b852e0e61d9a1fd36e5bd183450a556d5" dependencies = [ - "base64", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", @@ -2329,7 +2313,7 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d471eaefb14f4b30032525bdb124b36e55ba9cb1292080e06f1a236cd10fe87" dependencies = [ - "base64", + "base64 0.22.1", "indexmap 2.13.0", "ref-cast", ] @@ -2365,16 +2349,16 @@ dependencies = [ ] [[package]] -name = "siphasher" -version = "1.0.1" +name = "simd-adler32" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" [[package]] -name = "sketches-ddsketch" -version = "0.3.0" +name = "siphasher" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1e9a774a6c28142ac54bb25d25562e6bcf957493a184f15ad4eebccb23e410a" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" [[package]] name = "slab" @@ -2501,30 +2485,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.47" +version = "0.3.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde_core", + "serde", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.8" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-macros" -version = "0.2.27" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" dependencies = [ "num-conv", "time-core", @@ -2782,12 +2766,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - [[package]] name = "walkdir" version = "2.5.0" @@ -2833,9 +2811,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.108" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" dependencies = [ "cfg-if", "once_cell", @@ -2858,9 +2836,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.108" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2868,9 +2846,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.108" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" dependencies = [ "bumpalo", "proc-macro2", @@ -2881,18 +2859,18 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.108" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" dependencies = [ "unicode-ident", ] [[package]] name = "web-streams" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15c4d5dbf19463c4b65e974303d453cc11991873c7a4a4953214f791d73303a2" +checksum = "48465a648c14f53f6d8319b95bc336a44627f6aa6bd94270463777af8ed65deb" dependencies = [ "thiserror 2.0.17", "wasm-bindgen", @@ -2902,9 +2880,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.85" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "312e32e551d92129218ea9a2452120f4aabc03529ef03e4d0d82fb2780608598" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" dependencies = [ "js-sys", "wasm-bindgen", @@ -2922,9 +2900,9 @@ dependencies = [ [[package]] name = "web-transport" -version = "0.10.1" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb0d7e0f5ab200ec6b36fef6f3eee02d41d561fe73b85d2ab822af1b3fb4c680" +checksum = "23c3f78eca5afa10eb7b8ab64b4e5e521a006f0cbd88de09e44d55ef37e8855a" dependencies = [ "bytes", "thiserror 2.0.17", @@ -2935,9 +2913,9 @@ dependencies = [ [[package]] name = "web-transport-proto" -version = "0.5.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17633ea7058419f87cbb7f341ab75ac5c1d6d187c154b0bd4c87539e66f4c4e4" +checksum = "0225d295c8ac00a2e9a498aefeaf3f3c6186da12a251c938189b15b82ea22808" dependencies = [ "bytes", "http", @@ -2949,9 +2927,9 @@ dependencies = [ [[package]] name = "web-transport-quinn" -version = "0.11.4" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b195557749e84091d7b912a25e190e9606283b5121d041faf538b0b55f40d7" +checksum = "82e77c81fe4cf56c1049e07c6ed9c00862a967010fe9da4f5e02dc7f4d71fdac" dependencies = [ "bytes", "futures", @@ -2969,18 +2947,18 @@ dependencies = [ [[package]] name = "web-transport-trait" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "802d6aa508f2c63c9050ceabc17265bbf90ed4d6f4e4357e987583883628e79c" +checksum = "cb67841c4a481ca3c1412ee4c9f463987401991e1ddc000903df2124f3dc85e9" dependencies = [ "bytes", ] [[package]] name = "web-transport-wasm" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03261961ff4d65f873dd0521909b6795e5d7fe40581df2b7897db05e62db9620" +checksum = "6816176def6e8df1636c8fc2c401f37add41ccad1518705e209d9a7ada3d144c" dependencies = [ "bytes", "js-sys", @@ -3317,26 +3295,6 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" -[[package]] -name = "zerocopy" -version = "0.8.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" -dependencies = [ - "zerocopy-derive", -] - -[[package]] -name = "zerocopy-derive" -version = "0.8.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "zeroize" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index 61a90c1d..27c8279c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - [workspace] members = [ "moq-transport", @@ -13,13 +9,14 @@ members = [ "moq-native-ietf", "moq-catalog", "moq-test-client", + "moq-topn-test", ] resolver = "2" [workspace.dependencies] web-transport = "0.10" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } +env_logger = "0.11" +log = { version = "0.4", features = ["std"] } # Use debug symbols in production until things are more stable [profile.release] diff --git a/Dockerfile b/Dockerfile index 679b84c2..59059ca7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - FROM rust:bookworm as builder # Create a build directory and copy over all of the files diff --git a/LICENSES/Apache-2.0.txt b/LICENSE-APACHE similarity index 94% rename from LICENSES/Apache-2.0.txt rename to LICENSE-APACHE index 55b36619..d6456956 100644 --- a/LICENSES/Apache-2.0.txt +++ b/LICENSE-APACHE @@ -176,9 +176,18 @@ END OF TERMS AND CONDITIONS - - Copyright (c) 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors - Copyright (c) 2023-2024 Luke Curley and contributors + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/LICENSES/MIT.txt b/LICENSE-MIT similarity index 88% rename from LICENSES/MIT.txt rename to LICENSE-MIT index a6443463..fbd437cf 100644 --- a/LICENSES/MIT.txt +++ b/LICENSE-MIT @@ -1,7 +1,6 @@ MIT License -Copyright (c) 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -Copyright (c) 2023-2024 Luke Curley and contributors +Copyright (c) 2023 Luke Curley Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Makefile b/Makefile index bba48b86..dfb4e1c6 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - export CAROOT ?= $(shell cd dev ; go run filippo.io/mkcert -CAROOT) .PHONY: run diff --git a/REUSE.toml b/REUSE.toml deleted file mode 100644 index 39f45dc3..00000000 --- a/REUSE.toml +++ /dev/null @@ -1,49 +0,0 @@ -version = 1 - -# Pre-a84afe3 files never edited after: Luke Curley only -[[annotations]] -path = [ - "dev/README.md", - "moq-api/README.md", - "moq-pub/README.md", - "moq-sub/README.md", - "moq-transport/README.md", - "dev/go.sum", - ".vscode/extensions.json", -] -SPDX-FileCopyrightText = "2023-2024 Luke Curley and contributors" -SPDX-License-Identifier = "MIT OR Apache-2.0" - -# Pre-a84afe3 files edited after: both copyrights -[[annotations]] -path = [ - "README.md", - "moq-api/CHANGELOG.md", - "moq-pub/CHANGELOG.md", - "moq-sub/CHANGELOG.md", - "moq-transport/CHANGELOG.md", - "Cargo.lock", - ".github/logo.svg", -] -SPDX-FileCopyrightText = [ - "2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors", - "2023-2024 Luke Curley and contributors", -] -SPDX-License-Identifier = "MIT OR Apache-2.0" - -# Post-a84afe3 files (new): Cloudflare only -[[annotations]] -path = [ - "moq-catalog/CHANGELOG.md", - "moq-clock-ietf/CHANGELOG.md", - "moq-native-ietf/CHANGELOG.md", - "moq-relay-ietf/CHANGELOG.md", - "moq-relay-ietf/README.md", - "moq-test-client/CHANGELOG.md", - "moq-test-client/README.md", - "deploy/MLOG_SETUP.md", - "deploy/QLOG_SETUP.md", - "flake.lock", -] -SPDX-FileCopyrightText = "2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors" -SPDX-License-Identifier = "MIT OR Apache-2.0" diff --git a/default.nix b/default.nix index 0e30cbf2..4a6d2a3c 100644 --- a/default.nix +++ b/default.nix @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - { pkgs ? import { }, }: diff --git a/deploy/MLOG_SETUP.md b/deploy/MLOG_SETUP.md deleted file mode 100644 index 92c37ca3..00000000 --- a/deploy/MLOG_SETUP.md +++ /dev/null @@ -1,358 +0,0 @@ -# MoQ Transport Logging (mlog) - -mlog is structured logging for MoQ Transport protocol events. It records every control message and data stream event the relay processes for a given connection, using [JSON-SEQ](https://www.rfc-editor.org/rfc/rfc7464) (RFC 7464) so each record is prefixed with ASCII Record Separator (`0x1e`) and ends with a newline. The events are [qlog](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-main-schema/)-compatible, following the [`draft-pardue-moq-qlog-moq-events`](https://datatracker.ietf.org/doc/draft-pardue-moq-qlog-moq-events/) IETF draft. - -Where QUIC qlog (see [QLOG_SETUP.md](QLOG_SETUP.md)) captures transport-layer events like handshakes, congestion, and packet loss, mlog captures the MoQ application layer: SETUP exchanges, namespace announcements, subscriptions, and object delivery. - -## Quick Start: Verify Your Messages Reached the Relay - -This walkthrough uses `moq-test-client` against the public interop relay to demonstrate the basics. You can adapt the same technique for your own client. - -### Step 1: Run a test - -```bash -cargo run --bin moq-test-client -- \ - --relay https://interop-relay.cloudflare.mediaoverquic.com:443 \ - --tls-disable-verify \ - --test setup-only -``` - -### Step 2: Find your Connection ID - -The test client prints the Connection ID in its TAP output: - -``` -ok 1 - setup-only - --- - duration_ms: 521 - connection_id: 22c73802597dcd91ef662a3cd67a03e0 - ... -``` - -### Step 3: Fetch your mlog - -```bash -curl https://interop-relay.cloudflare.mediaoverquic.com:443/mlog/22c73802597dcd91ef662a3cd67a03e0 -``` - -### Step 4: Read the results - -The output is JSON-SEQ: each record starts with ASCII Record Separator (`0x1e`) and contains a JSON object. The first record is a header; subsequent records are events: - -```json -{"time":0.179,"name":"moqt:control_message_parsed","data":{"message_type":"client_setup","supported_versions":["DRAFT_14"],...}} -{"time":0.216,"name":"moqt:control_message_created","data":{"message_type":"server_setup","selected_version":"DRAFT_14",...}} -``` - -- `control_message_parsed` = the relay **received** a message from your client -- `control_message_created` = the relay **sent** a message to your client -- `time` = milliseconds since connection start - -If you see `client_setup` parsed and `server_setup` created, your client is speaking valid MoQ Transport and the relay understood it. - -## Worked Examples - -### Example 1: SETUP handshake (`setup-only`) - -The simplest test: connect, exchange CLIENT_SETUP/SERVER_SETUP, disconnect. - -```bash -cargo run --bin moq-test-client -- \ - --relay https://interop-relay.cloudflare.mediaoverquic.com:443 \ - --tls-disable-verify \ - --test setup-only -``` - -**mlog output** (reformatted for readability): - -```json -{ - "time": 0.179, - "name": "moqt:control_message_parsed", - "data": { - "event_type": "control_message_parsed", - "stream_id": 0, - "message_type": "client_setup", - "number_of_supported_versions": 1, - "supported_versions": ["DRAFT_14"], - "parameters": [["2", "100"]] - } -} -``` -The relay parsed your CLIENT_SETUP. It offered version DRAFT_14. The `parameters` array contains SETUP parameters as `[id, value]` pairs (here, PATH with max length 100). - -```json -{ - "time": 0.216, - "name": "moqt:control_message_created", - "data": { - "event_type": "control_message_created", - "stream_id": 0, - "message_type": "server_setup", - "selected_version": "DRAFT_14", - "parameters": [["2", "100"]] - } -} -``` -The relay responded with SERVER_SETUP, selecting DRAFT_14. The handshake is complete. - -**What to look for:** -- If you see `client_setup` parsed but no `server_setup` created, the relay rejected your version or parameters -- If you see nothing at all, your connection didn't reach the MoQ layer (check QUIC connectivity) - -### Example 2: Publishing a namespace (`announce-only`) - -After SETUP, announce a namespace and verify the relay accepts it. - -```bash -cargo run --bin moq-test-client -- \ - --relay https://interop-relay.cloudflare.mediaoverquic.com:443 \ - --tls-disable-verify \ - --test announce-only -``` - -**mlog output** (after the SETUP exchange): - -```json -{ - "time": 64.779, - "name": "moqt:control_message_parsed", - "data": { - "event_type": "control_message_parsed", - "stream_id": 0, - "message_type": "publish_namespace", - "request_id": 0, - "track_namespace": "/moq-test/interop", - "parameters": [] - } -} -``` -The relay received your PUBLISH_NAMESPACE for `/moq-test/interop`. - -```json -{ - "time": 65.526, - "name": "moqt:control_message_created", - "data": { - "event_type": "control_message_created", - "stream_id": 0, - "message_type": "publish_namespace_ok", - "request_id": 0 - } -} -``` -The relay accepted the namespace with PUBLISH_NAMESPACE_OK. Your client is now registered as a publisher for this namespace, and the relay will route incoming subscriptions to you. - -**What to look for:** -- `publish_namespace` parsed confirms the relay received your announcement -- `publish_namespace_ok` created confirms it was accepted -- The `request_id` ties the response to the request - -### Example 3: Full publish-subscribe flow (`announce-subscribe`) - -This test uses two connections: a publisher and a subscriber. The test client reports both Connection IDs: - -```bash -cargo run --bin moq-test-client -- \ - --relay https://interop-relay.cloudflare.mediaoverquic.com:443 \ - --tls-disable-verify \ - --test announce-subscribe -``` - -**TAP output:** - -``` -ok 1 - announce-subscribe - --- - duration_ms: 3436 - publisher_connection_id: 71d4b5eb1a807779af03331c330d5fa9 - subscriber_connection_id: 08d0b03ede133f0839435bff64ed2fc5 - ... -``` - -Fetch both mlogs to see each side of the exchange: - -```bash -# Publisher's view -curl https://interop-relay.cloudflare.mediaoverquic.com:443/mlog/71d4b5eb1a807779af03331c330d5fa9 - -# Subscriber's view -curl https://interop-relay.cloudflare.mediaoverquic.com:443/mlog/08d0b03ede133f0839435bff64ed2fc5 -``` - -**Publisher mlog** (after SETUP): - -```json -{"time":231.481,"name":"moqt:control_message_parsed","data":{"message_type":"publish_namespace","request_id":0,"track_namespace":"/moq-test/interop",...}} -{"time":233.084,"name":"moqt:control_message_created","data":{"message_type":"publish_namespace_ok","request_id":0}} -``` - -The publisher announced `/moq-test/interop` and the relay accepted it. - -**Subscriber mlog** (after SETUP): - -```json -{"time":47.169,"name":"moqt:control_message_parsed","data":{"message_type":"subscribe","subscribe_id":0,"track_namespace":"/moq-test/interop","track_name":"test-track",...}} -{"time":48.622,"name":"moqt:control_message_created","data":{"message_type":"subscribe_ok","subscribe_id":0,"track_alias":0,...}} -``` - -The subscriber sent a SUBSCRIBE for track `test-track` under namespace `/moq-test/interop`, and the relay responded with SUBSCRIBE_OK. This confirms the relay successfully routed the subscription to the publisher's namespace. - -**When data flows**, you'll also see data plane events in the mlog. In the publisher's mlog, `*_parsed` events show data the relay received: - -```json -{"time":395.872,"name":"moqt:subgroup_header_parsed","data":{"header_type":"SubgroupIdExt","track_alias":0,"group_id":1,"publisher_priority":128}} -{"time":397.166,"name":"moqt:subgroup_object_parsed","data":{"group_id":1,"subgroup_id":0,"object_id":0,"object_payload_length":1024}} -``` - -In the subscriber's mlog, `*_created` events show data the relay forwarded: - -```json -{"time":395.872,"name":"moqt:subgroup_header_created","data":{"header_type":"SubgroupIdExt","track_alias":0,"group_id":1,"publisher_priority":128,"subgroup_id":0}} -{"time":397.166,"name":"moqt:subgroup_object_created","data":{"group_id":1,"subgroup_id":0,"object_id":0,"object_payload_length":1024}} -``` - -- `object_payload_length` tells you the size of the payload the relay handled - -## mlog Format Reference - -### File format - -mlog uses [JSON-SEQ (RFC 7464)](https://www.rfc-editor.org/rfc/rfc7464): each record is prefixed with ASCII Record Separator (`0x1e`) and ends with a newline. - -**Line 1 — Header:** - -```json -{ - "qlog_version": "0.3", - "qlog_format": "JSON-SEQ", - "title": "moq-relay", - "description": "MoQ Transport events", - "trace": { - "vantage_point": { "type": "server" }, - "event_schemas": [ - "urn:ietf:params:qlog:events:loglevel", - "urn:ietf:params:qlog:events:moqt" - ] - } -} -``` - -**Subsequent lines — Events:** - -```json -{ - "time": , - "name": "", - "data": { ... } -} -``` - -### Event types - -| Event Name | Direction | Description | -|------------|-----------|-------------| -| `moqt:control_message_parsed` | Received | Relay received a control message from the client | -| `moqt:control_message_created` | Sent | Relay sent a control message to the client | -| `moqt:subgroup_header_parsed` | Received | Relay received a subgroup stream header | -| `moqt:subgroup_header_created` | Sent | Relay sent (forwarded) a subgroup stream header | -| `moqt:subgroup_object_parsed` | Received | Relay received an object within a subgroup | -| `moqt:subgroup_object_created` | Sent | Relay sent (forwarded) an object within a subgroup | -| `moqt:object_datagram_parsed` | Received | Relay received a datagram-delivered object | -| `moqt:object_datagram_created` | Sent | Relay sent (forwarded) a datagram-delivered object | - -**Naming convention:** `*_parsed` = the relay received and parsed this from the wire. `*_created` = the relay constructed and sent this on the wire. - -### Control message types - -These appear in the `message_type` field of control message events: - -| Message Type | Protocol Reference | -|-------------|-------------------| -| `client_setup` / `server_setup` | MoQT §3.3 | -| `publish_namespace` / `publish_namespace_ok` | MoQT §6.2 | -| `subscribe` / `subscribe_ok` / `subscribe_error` | MoQT §5.1 | -| `unsubscribe` | MoQT §5.1 | - -### Known limitations - -- **`stream_id` is currently a placeholder.** The `stream_id` field in all events is `0`. Actual QUIC stream IDs are not yet plumbed through to the mlog layer. Don't use this field for stream correlation — use `subscribe_id`, `track_alias`, or `group_id` instead. - -### File naming - -mlog files are named `_server.mlog`, where the connection ID is the hex-encoded QUIC Connection ID (typically 32 hex characters). - -## Running Your Own Relay with mlog - -### Using the dev script - -The simplest way to run a local relay with mlog enabled: - -```bash -MLOG_DIR=/tmp/mlog ./dev/relay -``` - -This starts the relay with: -- mlog writing to `/tmp/mlog/` -- mlog HTTP serving enabled at `https://localhost:4443/mlog/` -- Self-signed TLS certificates (auto-generated) - -Then run tests against it: - -```bash -cargo run --bin moq-test-client -- \ - --relay https://localhost:4443 \ - --tls-disable-verify \ - --test setup-only -``` - -And fetch the mlog: - -```bash -curl -k https://localhost:4443/mlog/ -``` - -You can also pass extra arguments directly: - -```bash -./dev/relay --mlog-dir /tmp/mlog --mlog-serve -``` - -### Relay flags - -| Flag | Description | -|------|-------------| -| `--mlog-dir ` | Directory to write mlog files to | -| `--mlog-serve` | Serve mlog files over HTTPS at `/mlog/` | -| `--dev` | Required for `--mlog-serve` (enables the HTTP endpoint; already set by `./dev/relay`) | - -Both `--mlog-dir` and `--mlog-serve` can be used independently: you can write mlog files to disk without serving them over HTTP, or you might only need the HTTP endpoint in some setups. - -## Troubleshooting - -**"Connection not found" when fetching mlog** -- mlog files on the interop relay use ephemeral storage and are lost on restart/redeploy -- If you get a 404, the relay may have been restarted since your test. Run the test again. - -**Parsing JSON-SEQ with jq** -- Use `jq --seq` (jq 1.6+) to parse JSON-SEQ directly, or strip the Record Separator first: `tr -d '\x1e' < file.mlog | jq '.'` - -**mlog is empty (header only, no events)** -- The QUIC connection was established but no MoQ messages were exchanged -- Check that your client is sending CLIENT_SETUP on the control stream after connecting - -**No SERVER_SETUP in response to CLIENT_SETUP** -- The relay may not support the MoQ Transport version(s) your client offered -- Check the `supported_versions` array in your CLIENT_SETUP against what the relay supports - -**Correlating with QUIC qlog** -- Both mlog and qlog use the same Connection ID -- qlog captures QUIC-layer events (handshake, streams, congestion); mlog captures MoQ-layer events -- For a complete picture, fetch both: `/qlog/` and `/mlog/` - -## Further Reading - -- [draft-pardue-moq-qlog-moq-events](https://datatracker.ietf.org/doc/draft-pardue-moq-qlog-moq-events/) — The IETF draft defining the mlog event schema -- [QLOG_SETUP.md](QLOG_SETUP.md) — QUIC qlog setup for the relay -- [moq-test-client README](../moq-test-client/README.md) — Test client documentation and available test cases -- [draft-ietf-moq-transport](https://datatracker.ietf.org/doc/draft-ietf-moq-transport/) — The MoQ Transport protocol specification diff --git a/deploy/fly-relay.sh b/deploy/fly-relay.sh index dd6ab34c..197c4fa3 100755 --- a/deploy/fly-relay.sh +++ b/deploy/fly-relay.sh @@ -1,7 +1,4 @@ #!/usr/bin/env sh -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 export PORT=${PORT:-4443} export RUST_LOG=${RUST_LOG:-info} diff --git a/deploy/publish b/deploy/publish index 63d374cb..9507def0 100755 --- a/deploy/publish +++ b/deploy/publish @@ -1,8 +1,4 @@ #!/bin/bash -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - set -euo pipefail ADDR=${ADDR:-"https://relay.quic.video"} diff --git a/dev/.gitignore b/dev/.gitignore index f8fe410e..07539d66 100644 --- a/dev/.gitignore +++ b/dev/.gitignore @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - *.crt *.key *.hex diff --git a/dev/cert b/dev/cert index 4daa234a..5e90ca51 100755 --- a/dev/cert +++ b/dev/cert @@ -1,8 +1,4 @@ #!/bin/bash -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - set -euo pipefail cd "$(dirname "${BASH_SOURCE[0]}")" diff --git a/dev/clock b/dev/clock index 08cadef7..88286bf7 100755 --- a/dev/clock +++ b/dev/clock @@ -1,8 +1,4 @@ #!/bin/bash -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - set -euo pipefail # Change directory to the root of the project diff --git a/dev/go.mod b/dev/go.mod index 5efcd202..ac3c3d05 100644 --- a/dev/go.mod +++ b/dev/go.mod @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - module github.com/kixelated/warp/cert go 1.18 diff --git a/dev/pub b/dev/pub index f7314ca7..4d038ef0 100755 --- a/dev/pub +++ b/dev/pub @@ -1,8 +1,4 @@ #!/bin/bash -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - set -euo pipefail # Change directory to the root of the project diff --git a/dev/pub_multi_track b/dev/pub_multi_track index c9ced8c1..7fde8113 100755 --- a/dev/pub_multi_track +++ b/dev/pub_multi_track @@ -1,7 +1,4 @@ #!/bin/bash -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - set -euo pipefail # Change directory to the root of the project diff --git a/dev/relay b/dev/relay index 6ec72f57..a4137026 100755 --- a/dev/relay +++ b/dev/relay @@ -1,8 +1,4 @@ #!/bin/bash -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - set -euo pipefail # Change directory to the root of the project @@ -31,13 +27,6 @@ else QLOG="" fi -# Enable mlog logging if MLOG_DIR is set -MLOG_DIR="${MLOG_DIR:-}" -if [ -n "$MLOG_DIR" ]; then - MLOG="--mlog-dir $MLOG_DIR --mlog-serve" -else - MLOG="" -fi # A list of optional args ARGS="" @@ -56,4 +45,4 @@ fi echo "Publish URL: https://quic.video/publish/?server=localhost:$PORT" # Run the relay and forward any arguments -cargo run --bin moq-relay-ietf -- --bind "$BIND" --tls-cert "$CERT" --tls-key "$KEY" --dev $QLOG $MLOG $ARGS "$@" +cargo run --bin moq-relay-ietf -- --bind "$BIND" --tls-cert "$CERT" --tls-key "$KEY" --dev $QLOG $ARGS "$@" diff --git a/dev/sub b/dev/sub index 1b04863b..f5c8bde2 100755 --- a/dev/sub +++ b/dev/sub @@ -1,8 +1,4 @@ #!/bin/bash -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - set -euo pipefail # Change directory to the root of the project @@ -20,6 +16,6 @@ ADDR="${ADDR:-$HOST:$PORT}" NAME="${NAME:-bbb}" # Combine the host and name into a URL. -URL="${URL:-"https://$ADDR"}" +URL="${URL:-"https://$ADDR/$NAME"}" cargo run --bin moq-sub -- --name "$NAME" "$URL" "$@" | ffplay - diff --git a/docker-compose.yml b/docker-compose.yml index ad2032a5..509c207b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - x-moq: &x-moq build: . environment: diff --git a/flake.nix b/flake.nix index 396a1a13..85587f17 100644 --- a/flake.nix +++ b/flake.nix @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - { inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; outputs = diff --git a/fly.toml b/fly.toml index 9a108245..e1d0d12d 100644 --- a/fly.toml +++ b/fly.toml @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - app = "moq-rs-interop-relay" kill_signal = "SIGINT" kill_timeout = 5 diff --git a/moq-api/CHANGELOG.md b/moq-api/CHANGELOG.md index 55fcd17d..ae35a0fa 100644 --- a/moq-api/CHANGELOG.md +++ b/moq-api/CHANGELOG.md @@ -6,25 +6,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.2.10](https://github.com/cloudflare/moq-rs/compare/moq-api-v0.2.9...moq-api-v0.2.10) - 2026-03-31 - -### Other - -- Make repo REUSE v3.3 compliant -- Bring copyright notices, license docs up to date - -## [0.2.9](https://github.com/cloudflare/moq-rs/compare/moq-api-v0.2.8...moq-api-v0.2.9) - 2026-02-18 - -### Other - -- update Cargo.lock dependencies - -## [0.2.8](https://github.com/cloudflare/moq-rs/compare/moq-api-v0.2.7...moq-api-v0.2.8) - 2026-02-18 - -### Other - -- migrate from log crate to tracing - ## [0.2.7](https://github.com/cloudflare/moq-rs/compare/moq-api-v0.2.6...moq-api-v0.2.7) - 2025-12-18 ### Other diff --git a/moq-api/Cargo.toml b/moq-api/Cargo.toml index b9d1cb28..57fb5b60 100644 --- a/moq-api/Cargo.toml +++ b/moq-api/Cargo.toml @@ -1,15 +1,11 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - [package] name = "moq-api" description = "Media over QUIC" -authors = ["moq-rs contributors"] -repository = "https://github.com/cloudflare/moq-rs" +authors = ["Luke Curley"] +repository = "https://github.com/englishm/moq-rs" license = "MIT OR Apache-2.0" -version = "0.2.10" +version = "0.2.7" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] @@ -42,6 +38,6 @@ redis = { version = "0.32", features = [ url = { version = "2", features = ["serde"] } # Error handling -tracing = { workspace = true } -tracing-subscriber = { workspace = true } +log = { workspace = true } +env_logger = { workspace = true } thiserror = "1" diff --git a/moq-api/src/client.rs b/moq-api/src/client.rs index 53bda5eb..224392f7 100644 --- a/moq-api/src/client.rs +++ b/moq-api/src/client.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use url::Url; use crate::{ApiError, Origin}; diff --git a/moq-api/src/error.rs b/moq-api/src/error.rs index 8e38b34b..60e02406 100644 --- a/moq-api/src/error.rs +++ b/moq-api/src/error.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use thiserror::Error; #[derive(Error, Debug)] diff --git a/moq-api/src/lib.rs b/moq-api/src/lib.rs index 6672e0b8..be117a02 100644 --- a/moq-api/src/lib.rs +++ b/moq-api/src/lib.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - mod client; mod error; mod model; diff --git a/moq-api/src/main.rs b/moq-api/src/main.rs index 477f87d0..d29832ec 100644 --- a/moq-api/src/main.rs +++ b/moq-api/src/main.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use clap::Parser; mod server; @@ -10,13 +6,7 @@ use server::{Server, ServerConfig}; #[tokio::main] async fn main() -> Result<(), ApiError> { - // Initialize tracing with env filter (respects RUST_LOG environment variable) - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), - ) - .init(); + env_logger::init(); let config = ServerConfig::parse(); let server = Server::new(config); diff --git a/moq-api/src/model.rs b/moq-api/src/model.rs index 6bb24304..d6f0c4fc 100644 --- a/moq-api/src/model.rs +++ b/moq-api/src/model.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use serde::{Deserialize, Serialize}; use url::Url; diff --git a/moq-api/src/server.rs b/moq-api/src/server.rs index 5d11d1a4..98f407fc 100644 --- a/moq-api/src/server.rs +++ b/moq-api/src/server.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::net; use axum::{ @@ -41,7 +37,7 @@ impl Server { } pub async fn run(self) -> Result<(), ApiError> { - tracing::info!("connecting to redis: url={}", self.config.redis); + log::info!("connecting to redis: url={}", self.config.redis); // Create the redis client. let redis = redis::Client::open(self.config.redis)?; @@ -57,7 +53,7 @@ impl Server { ) .with_state(redis); - tracing::info!("serving requests: bind={}", self.config.bind); + log::info!("serving requests: bind={}", self.config.bind); let listener = tokio::net::TcpListener::bind(&self.config.bind).await?; axum::serve(listener, app.into_make_service()).await?; diff --git a/moq-catalog/CHANGELOG.md b/moq-catalog/CHANGELOG.md index f2537c0c..b48982f1 100644 --- a/moq-catalog/CHANGELOG.md +++ b/moq-catalog/CHANGELOG.md @@ -7,12 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.2.3](https://github.com/cloudflare/moq-rs/compare/moq-catalog-v0.2.2...moq-catalog-v0.2.3) - 2026-03-31 - -### Other - -- Make repo REUSE v3.3 compliant - ## [0.2.2](https://github.com/englishm/moq-rs/compare/moq-catalog-v0.2.1...moq-catalog-v0.2.2) - 2025-01-16 ### Other diff --git a/moq-catalog/Cargo.toml b/moq-catalog/Cargo.toml index 63058ada..2bd0a75e 100644 --- a/moq-catalog/Cargo.toml +++ b/moq-catalog/Cargo.toml @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - [package] name = "moq-catalog" description = "Media over QUIC" @@ -9,7 +5,7 @@ authors = ["Luke Curley"] repository = "https://github.com/englishm/moq-rs" license = "MIT OR Apache-2.0" -version = "0.2.3" +version = "0.2.2" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] diff --git a/moq-catalog/src/lib.rs b/moq-catalog/src/lib.rs index 0c3b7e03..25fd6908 100644 --- a/moq-catalog/src/lib.rs +++ b/moq-catalog/src/lib.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! This module contains the structs and functions for the MoQ catalog format /// The catalog format is a JSON file that describes the tracks available in a broadcast. /// diff --git a/moq-clock-ietf/CHANGELOG.md b/moq-clock-ietf/CHANGELOG.md index f65c19cc..77731d81 100644 --- a/moq-clock-ietf/CHANGELOG.md +++ b/moq-clock-ietf/CHANGELOG.md @@ -6,46 +6,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.6.14](https://github.com/cloudflare/moq-rs/compare/moq-clock-ietf-v0.6.13...moq-clock-ietf-v0.6.14) - 2026-05-20 - -### Other - -- update Cargo.lock dependencies - -## [0.6.13](https://github.com/cloudflare/moq-rs/compare/moq-clock-ietf-v0.6.12...moq-clock-ietf-v0.6.13) - 2026-04-10 - -### Fixed - -- cross-platform dual-stack binding for IPv6 sockets - -### Other - -- Merge pull request #151 from englishm-cloudflare/me/ipv6-dual-stack-binding - -## [0.6.12](https://github.com/cloudflare/moq-rs/compare/moq-clock-ietf-v0.6.11...moq-clock-ietf-v0.6.12) - 2026-03-31 - -### Other - -- Make repo REUSE v3.3 compliant - -## [0.6.11](https://github.com/cloudflare/moq-rs/compare/moq-clock-ietf-v0.6.10...moq-clock-ietf-v0.6.11) - 2026-03-27 - -### Added - -- add Transport enum and connection path extraction - -## [0.6.10](https://github.com/cloudflare/moq-rs/compare/moq-clock-ietf-v0.6.9...moq-clock-ietf-v0.6.10) - 2026-02-18 - -### Other - -- update Cargo.lock dependencies - -## [0.6.9](https://github.com/cloudflare/moq-rs/compare/moq-clock-ietf-v0.6.8...moq-clock-ietf-v0.6.9) - 2026-02-18 - -### Other - -- migrate from log crate to tracing - ## [0.6.8](https://github.com/cloudflare/moq-rs/compare/moq-clock-ietf-v0.6.7...moq-clock-ietf-v0.6.8) - 2026-02-03 ### Other diff --git a/moq-clock-ietf/Cargo.toml b/moq-clock-ietf/Cargo.toml index bca4f26e..854687a7 100644 --- a/moq-clock-ietf/Cargo.toml +++ b/moq-clock-ietf/Cargo.toml @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - [package] name = "moq-clock-ietf" description = "CLOCK over QUIC" @@ -8,7 +5,7 @@ authors = ["Luke Curley"] repository = "https://github.com/englishm/moq-rs" license = "MIT OR Apache-2.0" -version = "0.6.14" +version = "0.6.8" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] @@ -17,20 +14,23 @@ categories = ["multimedia", "network-programming", "web-programming"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -moq-native-ietf = { path = "../moq-native-ietf", version = "0.8" } -moq-transport = { path = "../moq-transport", version = "0.14" } +moq-native-ietf = { path = "../moq-native-ietf", version = "0.7" } +moq-transport = { path = "../moq-transport", version = "0.12" } # QUIC url = "2" # Async stuff tokio = { version = "1", features = ["full"] } +futures = "0.3" # CLI, logging, error handling clap = { version = "4", features = ["derive"] } -tracing = { workspace = true } -tracing-subscriber = { workspace = true } +log = { workspace = true } +env_logger = { workspace = true } anyhow = { version = "1", features = ["backtrace"] } +tracing = "0.1" +tracing-subscriber = "0.3" # CLOCK STUFF chrono = "0.4" diff --git a/moq-clock-ietf/src/cli.rs b/moq-clock-ietf/src/cli.rs index fd08590b..558dfd03 100644 --- a/moq-clock-ietf/src/cli.rs +++ b/moq-clock-ietf/src/cli.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use clap::Parser; use std::net; use url::Url; diff --git a/moq-clock-ietf/src/clock.rs b/moq-clock-ietf/src/clock.rs index ebedcab9..ada2bbd6 100644 --- a/moq-clock-ietf/src/clock.rs +++ b/moq-clock-ietf/src/clock.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use anyhow::Context; use moq_transport::serve::{ Datagram, DatagramsReader, DatagramsWriter, StreamReader, Subgroup, SubgroupWriter, @@ -48,13 +45,14 @@ impl Publisher { group_id: next_group_id as u64, subgroup_id: 0, priority: 0, + header_type: None, }) .context("failed to create minute segment")?; // Spawn a new task to handle sending the object every second tokio::spawn(async move { if let Err(err) = Self::send_subgroup_objects(subgroup_writer, now).await { - tracing::warn!("failed to send minute: {:?}", err); + log::warn!("failed to send minute: {:?}", err); } }); @@ -69,6 +67,7 @@ impl Publisher { priority: 127, payload: time_str.clone().into_bytes().into(), extension_headers: Default::default(), + status: None, }) .context("failed to write datagram")?; diff --git a/moq-clock-ietf/src/main.rs b/moq-clock-ietf/src/main.rs index ac8a1ef6..eab274b7 100644 --- a/moq-clock-ietf/src/main.rs +++ b/moq-clock-ietf/src/main.rs @@ -1,9 +1,7 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use moq_native_ietf::quic; use anyhow::Context; +use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; mod cli; mod clock; @@ -13,34 +11,58 @@ use cli::Cli; use moq_transport::{ coding::TrackNamespace, - serve, - session::{Publisher, Subscriber}, + serve::{self, TracksReader}, + session::{Publisher, SessionError, Subscriber}, }; -/// The main entry point for the MoQ Clock IETF example. +async fn serve_subscriptions( + mut publisher: Publisher, + tracks: TracksReader, +) -> Result<(), SessionError> { + let mut tasks: FuturesUnordered> = + FuturesUnordered::new(); + + loop { + tokio::select! { + Some(subscribed) = publisher.subscribed() => { + let info = subscribed.info.clone(); + let tracks = tracks.clone(); + log::info!("serving subscribe: {:?}", info); + + tasks.push(async move { + if let Err(err) = Publisher::serve_subscribe(subscribed, tracks).await { + log::warn!("failed serving subscribe: {:?}, error: {}", info, err); + } + }.boxed()); + } + _ = tasks.next(), if !tasks.is_empty() => {} + else => return Ok(()), + } + } +} + #[tokio::main] async fn main() -> anyhow::Result<()> { - // Initialize tracing with env filter (respects RUST_LOG environment variable) - // Default to info level, but suppress quinn's verbose output - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info,quinn=warn")), - ) - .init(); + env_logger::init(); + + // Disable tracing so we don't get a bunch of Quinn spam. + let tracer = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::WARN) + .finish(); + tracing::subscriber::set_global_default(tracer).unwrap(); let config = Cli::parse(); let tls = config.tls.load()?; // Create the QUIC endpoint - let quic = quic::Endpoint::new(quic::Config::new(config.bind, None, tls)?)?; + let quic = quic::Endpoint::new(quic::Config::new(config.bind, None, tls))?; - tracing::info!("connecting to server: url={}", config.url); + log::info!("connecting to server: url={}", config.url); // Connect to the server - let (session, connection_id, transport) = quic.client.connect(&config.url, None).await?; + let (session, connection_id) = quic.client.connect(&config.url, None).await?; - tracing::info!( + log::info!( "connected with CID: {} (use this to look up qlog/mlog on server)", connection_id ); @@ -48,12 +70,12 @@ async fn main() -> anyhow::Result<()> { // Depending on whether we are publishing or subscribing, create the appropriate session if config.publish { // Create the publisher session - let (session, mut publisher) = Publisher::connect(session, transport) + let (session, mut publisher) = Publisher::connect(session) .await .context("failed to create MoQ Transport session")?; if config.datagrams { - tracing::info!("publishing clock via datagrams"); + log::info!("publishing clock via datagrams"); let (mut tracks_writer, _, tracks_reader) = serve::Tracks { namespace: TrackNamespace::from_utf8_path(&config.namespace), @@ -63,13 +85,19 @@ async fn main() -> anyhow::Result<()> { let track_writer = tracks_writer.create(&config.track).unwrap(); let clock_publisher = clock::Publisher::new_datagram(track_writer.datagrams()?); + let publish_ns = publisher + .publish_namespace(tracks_reader.namespace.clone()) + .await + .context("failed to register namespace")?; + tokio::select! { res = session.run() => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, - res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, + res = serve_subscriptions(publisher, tracks_reader) => res.context("failed to serve tracks")?, + res = publish_ns.closed() => res.context("namespace closed")?, } } else { - tracing::info!("publishing clock via streams"); + log::info!("publishing clock via streams"); let (mut tracks_writer, _, tracks_reader) = serve::Tracks { namespace: TrackNamespace::from_utf8_path(&config.namespace), @@ -79,15 +107,21 @@ async fn main() -> anyhow::Result<()> { let track_writer = tracks_writer.create(&config.track).unwrap(); let clock_publisher = clock::Publisher::new(track_writer.subgroups()?); + let publish_ns = publisher + .publish_namespace(tracks_reader.namespace.clone()) + .await + .context("failed to register namespace")?; + tokio::select! { res = session.run() => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, - res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, + res = serve_subscriptions(publisher, tracks_reader) => res.context("failed to serve tracks")?, + res = publish_ns.closed() => res.context("namespace closed")?, } } } else { // Create the subscriber session - let (session, mut subscriber) = Subscriber::connect(session, transport) + let (session, mut subscriber) = Subscriber::connect(session) .await .context("failed to create MoQ Transport session")?; diff --git a/moq-native-ietf/CHANGELOG.md b/moq-native-ietf/CHANGELOG.md index 67a21a43..f70f6ce2 100644 --- a/moq-native-ietf/CHANGELOG.md +++ b/moq-native-ietf/CHANGELOG.md @@ -6,40 +6,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.8.0](https://github.com/cloudflare/moq-rs/compare/moq-native-ietf-v0.7.5...moq-native-ietf-v0.8.0) - 2026-04-10 - -### Fixed - -- cross-platform dual-stack binding for IPv6 sockets - -### Other - -- Merge pull request #151 from englishm-cloudflare/me/ipv6-dual-stack-binding - -## [0.7.5](https://github.com/cloudflare/moq-rs/compare/moq-native-ietf-v0.7.4...moq-native-ietf-v0.7.5) - 2026-03-31 - -### Other - -- Make repo REUSE v3.3 compliant - -## [0.7.4](https://github.com/cloudflare/moq-rs/compare/moq-native-ietf-v0.7.3...moq-native-ietf-v0.7.4) - 2026-03-27 - -### Added - -- add Transport enum and connection path extraction - -## [0.7.3](https://github.com/cloudflare/moq-rs/compare/moq-native-ietf-v0.7.2...moq-native-ietf-v0.7.3) - 2026-02-18 - -### Other - -- Upgrade web-transport crates to v0.10.1 - -## [0.7.2](https://github.com/cloudflare/moq-rs/compare/moq-native-ietf-v0.7.1...moq-native-ietf-v0.7.2) - 2026-02-18 - -### Other - -- migrate from log crate to tracing - ## [0.7.1](https://github.com/cloudflare/moq-rs/compare/moq-native-ietf-v0.7.0...moq-native-ietf-v0.7.1) - 2026-02-03 ### Other diff --git a/moq-native-ietf/Cargo.toml b/moq-native-ietf/Cargo.toml index 9bc8efab..41eeb1c9 100644 --- a/moq-native-ietf/Cargo.toml +++ b/moq-native-ietf/Cargo.toml @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - [package] name = "moq-native-ietf" description = "Media over QUIC - Helper library for native applications" @@ -8,14 +5,14 @@ authors = ["Luke Curley"] repository = "https://github.com/englishm/moq-rs" license = "MIT OR Apache-2.0" -version = "0.8.0" +version = "0.7.1" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] categories = ["multimedia", "network-programming", "web-programming"] [dependencies] -moq-transport = { path = "../moq-transport", version = "0.14" } +moq-transport = { path = "../moq-transport", version = "0.12" } web-transport = { workspace = true } web-transport-quinn = { version = "0.11", default-features = false, features = ["ring"] } @@ -32,8 +29,7 @@ rand = "0.8" tokio = { version = "1", features = ["full"] } futures = "0.3" -socket2 = "0.5" anyhow = { version = "1", features = ["backtrace"] } clap = { version = "4", features = ["derive", "env"] } -tracing = { workspace = true } +log = { version = "0.4", features = ["std"] } diff --git a/moq-native-ietf/src/lib.rs b/moq-native-ietf/src/lib.rs index fa7ac3f3..dc71b400 100644 --- a/moq-native-ietf/src/lib.rs +++ b/moq-native-ietf/src/lib.rs @@ -1,5 +1,2 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - pub mod quic; pub mod tls; diff --git a/moq-native-ietf/src/quic.rs b/moq-native-ietf/src/quic.rs index 59effe88..e5d6ac8d 100644 --- a/moq-native-ietf/src/quic.rs +++ b/moq-native-ietf/src/quic.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{ collections::HashSet, fmt, @@ -14,11 +11,8 @@ use std::{ use anyhow::Context; use clap::Parser; -use socket2::{Domain, Protocol, Socket, Type}; use url::Url; -use moq_transport::session::Transport; - use crate::tls; use futures::future::BoxFuture; @@ -30,7 +24,7 @@ use futures::FutureExt; pub enum AddressFamily { Ipv4, Ipv6, - /// IPv6 with dual-stack support (IPV6_V6ONLY=false) + /// IPv6 with dual-stack support (Linux) Ipv6DualStack, } @@ -49,67 +43,6 @@ impl fmt::Display for AddressFamily { } } -/// Bind a UDP socket, attempting dual-stack if the address is IPv6. -/// -/// For IPv6 addresses, attempts to set `IPV6_V6ONLY = false` to enable -/// dual-stack operation (accepting both IPv4 and IPv6 traffic). This is -/// the default on Linux but must be explicitly requested on macOS/Windows. -/// -/// Returns `(socket, is_dual_stack)` where `is_dual_stack` indicates -/// whether the socket can handle both IPv4 and IPv6 destinations. -fn bind_smart(addr: net::SocketAddr) -> anyhow::Result<(net::UdpSocket, bool)> { - let domain = if addr.is_ipv6() { - Domain::IPV6 - } else { - Domain::IPV4 - }; - let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP)) - .context("failed to create UDP socket")?; - - let mut is_dual_stack = false; - - if addr.is_ipv6() { - match socket.set_only_v6(false) { - Ok(()) => { - is_dual_stack = true; - tracing::debug!(addr = %addr, "IPv6 dual-stack enabled (IPV6_V6ONLY=false)"); - } - Err(e) => { - tracing::warn!( - addr = %addr, - error = %e, - "Could not enable dual-stack on IPv6 socket; \ - IPv4-only destinations may be unreachable" - ); - } - } - } - - socket - .bind(&addr.into()) - .with_context(|| format!("failed to bind UDP socket to {}", addr))?; - - let local_addr = match socket.local_addr() { - Ok(a) => a - .as_socket() - .map(|s| s.to_string()) - .unwrap_or_else(|| "".to_string()), - Err(e) => { - tracing::warn!(error = %e, "failed to get local address after successful bind"); - "".to_string() - } - }; - - tracing::info!( - bind = %addr, - local = %local_addr, - dual_stack = is_dual_stack, - "UDP socket bound" - ); - - Ok((socket.into(), is_dual_stack)) -} - /// Build a TransportConfig with our standard settings /// /// This is used both for the base endpoint config and when creating @@ -126,11 +59,7 @@ fn build_transport_config() -> quinn::TransportConfig { #[derive(Parser, Clone)] pub struct Args { /// Listen for UDP packets on the given address. - /// - /// Defaults to [::]:0 (IPv6 with dual-stack). If the default IPv6 bind - /// fails, automatically falls back to 0.0.0.0 (IPv4-only) with a warning. - /// Explicitly provided IPv6 addresses will not fall back. - #[arg(long, default_value = Args::DEFAULT_BIND)] + #[arg(long, default_value = "[::]:0")] pub bind: net::SocketAddr, /// Directory to write qlog files (one per connection) @@ -144,7 +73,7 @@ pub struct Args { impl Default for Args { fn default() -> Self { Self { - bind: Self::DEFAULT_BIND.parse().unwrap(), + bind: "[::]:0".parse().unwrap(), qlog_dir: None, tls: Default::default(), } @@ -152,62 +81,31 @@ impl Default for Args { } impl Args { - /// The default bind address used when `--bind` is not explicitly provided. - const DEFAULT_BIND: &str = "[::]:0"; - pub fn load(&self) -> anyhow::Result { let tls = self.tls.load()?; - - match Config::new(self.bind, self.qlog_dir.clone(), tls.clone()) { - Ok(config) => Ok(config), - Err(e) if self.bind.to_string() == Self::DEFAULT_BIND => { - // IPv6 default bind failed -- try falling back to IPv4. - // Only do this for the default; if the user explicitly - // requested an IPv6 address, respect that and propagate - // the error. - let fallback = net::SocketAddr::new( - net::IpAddr::V4(net::Ipv4Addr::UNSPECIFIED), - self.bind.port(), - ); - tracing::warn!( - requested = %self.bind, - fallback = %fallback, - error = %e, - "IPv6 bind failed, falling back to IPv4" - ); - Config::new(fallback, self.qlog_dir.clone(), tls).with_context(|| { - format!("IPv4 fallback also failed (original IPv6 error: {})", e) - }) - } - Err(e) => Err(e), - } + Ok(Config::new(self.bind, self.qlog_dir.clone(), tls)) } } pub struct Config { pub bind: Option, pub socket: net::UdpSocket, - pub is_dual_stack: bool, pub qlog_dir: Option, pub tls: tls::Config, pub tags: HashSet, } impl Config { - pub fn new( - bind: net::SocketAddr, - qlog_dir: Option, - tls: tls::Config, - ) -> anyhow::Result { - let (socket, is_dual_stack) = bind_smart(bind)?; - Ok(Self { + pub fn new(bind: net::SocketAddr, qlog_dir: Option, tls: tls::Config) -> Self { + Self { bind: Some(bind), - socket, - is_dual_stack, + socket: net::UdpSocket::bind(bind) + .context("failed to bind socket") + .unwrap(), qlog_dir, tls, tags: HashSet::new(), - }) + } } pub fn with_socket( @@ -215,18 +113,9 @@ impl Config { qlog_dir: Option, tls: tls::Config, ) -> Self { - // Probe the socket to detect dual-stack capability rather than assuming. - let is_dual_stack = socket.local_addr().is_ok_and(|addr| { - addr.is_ipv6() && { - let sock_ref = socket2::SockRef::from(&socket); - sock_ref.only_v6().map(|v6only| !v6only).unwrap_or(false) - } - }); - Self { bind: None, socket, - is_dual_stack, qlog_dir, tls, tags: HashSet::new(), @@ -261,7 +150,7 @@ impl Endpoint { if !qlog_dir.is_dir() { anyhow::bail!("qlog path is not a directory: {}", qlog_dir.display()); } - tracing::info!("qlog output enabled: {}", qlog_dir.display()); + log::info!("qlog output enabled: {}", qlog_dir.display()); } // Build transport config with our standard settings @@ -303,7 +192,6 @@ impl Endpoint { quic, config: config.tls.client, transport, - is_dual_stack: config.is_dual_stack, }; Ok(Self { @@ -316,15 +204,13 @@ impl Endpoint { pub struct Server { quic: quinn::Endpoint, - accept: FuturesUnordered< - BoxFuture<'static, anyhow::Result<(web_transport::Session, String, Transport)>>, - >, + accept: FuturesUnordered>>, qlog_dir: Option>, base_server_config: Arc, } impl Server { - pub async fn accept(&mut self) -> Option<(web_transport::Session, String, Transport)> { + pub async fn accept(&mut self) -> Option<(web_transport::Session, String)> { loop { tokio::select! { res = self.quic.accept() => { @@ -337,7 +223,7 @@ impl Server { match res? { Ok(result) => return Some(result), Err(err) => { - tracing::warn!("failed to accept QUIC connection: {}", err.root_cause()); + log::warn!("failed to accept QUIC connection: {}", err.root_cause()); continue; } } @@ -350,7 +236,7 @@ impl Server { conn: quinn::Incoming, qlog_dir: Option>, base_server_config: Arc, - ) -> anyhow::Result<(web_transport::Session, String, Transport)> { + ) -> anyhow::Result<(web_transport::Session, String)> { // Capture the original destination connection ID BEFORE accepting // This is the actual QUIC CID that can be used for qlog/mlog correlation let orig_dst_cid = conn.orig_dst_cid(); @@ -376,7 +262,7 @@ impl Server { let mut server_config = (*base_server_config).clone(); server_config.transport_config(Arc::new(transport)); - tracing::debug!( + log::debug!( "qlog enabled: cid={} path={}", connection_id_hex, qlog_path.display() @@ -399,7 +285,7 @@ impl Server { let alpn = String::from_utf8_lossy(&alpn); let server_name = handshake.server_name.unwrap_or_default(); - tracing::debug!( + log::debug!( "received QUIC handshake: cid={} ip={} alpn={} server={}", connection_id_hex, conn.remote_address(), @@ -410,7 +296,7 @@ impl Server { // Wait for the QUIC connection to be established. let conn = conn.await.context("failed to establish QUIC connection")?; - tracing::debug!( + log::debug!( "established QUIC connection: cid={} stable_id={} ip={} alpn={} server={}", connection_id_hex, conn.stable_id(), @@ -420,32 +306,26 @@ impl Server { ); let alpn_bytes = alpn.as_bytes(); - let (session, transport) = if alpn_bytes == web_transport_quinn::ALPN.as_bytes() { + let session = if alpn_bytes == web_transport_quinn::ALPN.as_bytes() { // Wait for the WebTransport CONNECT request (includes H3 SETTINGS exchange). let request = web_transport_quinn::Request::accept(conn) .await .context("failed to receive WebTransport request")?; // Accept the CONNECT request. - let session = request + request .ok() .await - .context("failed to respond to WebTransport request")?; - (session, Transport::WebTransport) + .context("failed to respond to WebTransport request")? } else if alpn_bytes == moq_transport::setup::ALPN { - // Raw QUIC mode — create a "fake" WebTransport session with no H3 framing. + // Raw QUIC mode — create a session with no H3 framing. let request = url::Url::parse("moqt://localhost").unwrap(); - let session = web_transport_quinn::Session::raw( - conn, - request, - web_transport_quinn::proto::ConnectResponse::default(), - ); - (session, Transport::RawQuic) + web_transport_quinn::Session::raw(conn, request, web_transport_quinn::proto::ConnectResponse::default()) } else { anyhow::bail!("unsupported ALPN: {}", alpn) }; - Ok((session.into(), connection_id_hex, transport)) + Ok((session.into(), connection_id_hex)) } pub fn local_addr(&self) -> anyhow::Result { @@ -460,7 +340,6 @@ pub struct Client { quic: quinn::Endpoint, config: rustls::ClientConfig, transport: Arc, - is_dual_stack: bool, } impl Client { @@ -472,9 +351,6 @@ impl Client { } /// Returns the address family of the local QUIC socket. - /// - /// Uses the dual-stack state determined at bind time rather than - /// compile-time platform assumptions. pub fn address_family(&self) -> anyhow::Result { let local_addr = self .quic @@ -483,7 +359,7 @@ impl Client { if local_addr.is_ipv4() { Ok(AddressFamily::Ipv4) - } else if self.is_dual_stack { + } else if cfg!(target_os = "linux") { Ok(AddressFamily::Ipv6DualStack) } else { Ok(AddressFamily::Ipv6) @@ -494,7 +370,7 @@ impl Client { &self, url: &Url, socket_addr: Option, - ) -> anyhow::Result<(web_transport::Session, String, Transport)> { + ) -> anyhow::Result<(web_transport::Session, String)> { let mut config = self.config.clone(); // TODO support connecting to both ALPNs at the same time @@ -551,23 +427,20 @@ impl Client { .context("CID not captured")? .to_string(); - let (session, transport) = match url.scheme() { - "https" => ( - web_transport_quinn::Session::connect(connection, url.clone()).await?, - Transport::WebTransport, - ), - "moqt" => ( - web_transport_quinn::Session::raw( - connection, - url.clone(), - web_transport_quinn::proto::ConnectResponse::default(), - ), - Transport::RawQuic, - ), + let session = match url.scheme() { + "https" => { + // Build a ConnectRequest with the MoQT version as the WebTransport subprotocol. + // Per draft-15+, version negotiation uses ALPN (raw QUIC) or + // wt-available-protocols (WebTransport) instead of CLIENT_SETUP versions. + let request = web_transport_quinn::proto::ConnectRequest::new(url.clone()) + .with_protocol(std::str::from_utf8(moq_transport::setup::ALPN).unwrap()); + web_transport_quinn::Session::connect(connection, request).await? + } + "moqt" => web_transport_quinn::Session::raw(connection, url.clone(), web_transport_quinn::proto::ConnectResponse::default()), _ => unreachable!(), }; - Ok((session.into(), connection_id_hex, transport)) + Ok((session.into(), connection_id_hex)) } /// Default DNS resolution logic that filters results by address family. @@ -595,14 +468,14 @@ impl Client { } // Log all DNS results for debugging - tracing::debug!( + log::debug!( "DNS lookup for {}, family {:?}: found {} results", host, address_family, addrs.len() ); for (i, addr) in addrs.iter().enumerate() { - tracing::debug!( + log::debug!( " DNS[{}]: {} ({})", i, addr, @@ -624,12 +497,15 @@ impl Client { ))? } AddressFamily::Ipv6DualStack => { - // Dual-stack socket: any address family works, use first result - tracing::debug!("Using first DNS result (IPv6 dual-stack): {}", addrs[0]); + // IPv6 socket on Linux: dual-stack, use first result + log::debug!( + "Using first DNS result (Linux IPv6 dual-stack): {}", + addrs[0] + ); addrs[0] } AddressFamily::Ipv6 => { - // IPv6-only socket: filter to IPv6 addresses + // IPv6 socket non-Linux: filter to IPv6 addresses addrs .iter() .find(|a| a.is_ipv6()) @@ -641,7 +517,7 @@ impl Client { } }; - tracing::debug!( + log::debug!( "Connecting from {} to {} (selected from {} DNS results)", local_addr, compatible_addr, diff --git a/moq-native-ietf/src/tls.rs b/moq-native-ietf/src/tls.rs index 1f8e5ec4..5f9a018f 100644 --- a/moq-native-ietf/src/tls.rs +++ b/moq-native-ietf/src/tls.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use anyhow::Context; use clap::Parser; use ring::digest::{digest, SHA256}; diff --git a/moq-pub/CHANGELOG.md b/moq-pub/CHANGELOG.md index c4f629dd..92a70f16 100644 --- a/moq-pub/CHANGELOG.md +++ b/moq-pub/CHANGELOG.md @@ -6,51 +6,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.8.14](https://github.com/cloudflare/moq-rs/compare/moq-pub-v0.8.13...moq-pub-v0.8.14) - 2026-05-20 - -### Other - -- update Cargo.lock dependencies - -## [0.8.13](https://github.com/cloudflare/moq-rs/compare/moq-pub-v0.8.12...moq-pub-v0.8.13) - 2026-04-10 - -### Fixed - -- cross-platform dual-stack binding for IPv6 sockets - -### Other - -- Merge pull request #151 from englishm-cloudflare/me/ipv6-dual-stack-binding - -## [0.8.12](https://github.com/cloudflare/moq-rs/compare/moq-pub-v0.8.11...moq-pub-v0.8.12) - 2026-03-31 - -### Other - -- Make repo REUSE v3.3 compliant -- Bring copyright notices, license docs up to date - -## [0.8.11](https://github.com/cloudflare/moq-rs/compare/moq-pub-v0.8.10...moq-pub-v0.8.11) - 2026-03-27 - -### Added - -- add Transport enum and connection path extraction - -### Fixed - -- *(moq-pub)* combine moof+mdat into single MoQ Object - -## [0.8.10](https://github.com/cloudflare/moq-rs/compare/moq-pub-v0.8.9...moq-pub-v0.8.10) - 2026-02-18 - -### Other - -- update Cargo.lock dependencies - -## [0.8.9](https://github.com/cloudflare/moq-rs/compare/moq-pub-v0.8.8...moq-pub-v0.8.9) - 2026-02-18 - -### Other - -- migrate from log crate to tracing - ## [0.8.8](https://github.com/cloudflare/moq-rs/compare/moq-pub-v0.8.7...moq-pub-v0.8.8) - 2026-01-29 ### Other diff --git a/moq-pub/Cargo.toml b/moq-pub/Cargo.toml index 52344353..d37ebbaf 100644 --- a/moq-pub/Cargo.toml +++ b/moq-pub/Cargo.toml @@ -1,15 +1,11 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - [package] name = "moq-pub" description = "Media over QUIC" -authors = ["moq-rs contributors"] -repository = "https://github.com/cloudflare/moq-rs" +authors = ["Mike English", "Luke Curley"] +repository = "https://github.com/englishm/moq-rs" license = "MIT OR Apache-2.0" -version = "0.8.14" +version = "0.8.8" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] @@ -18,8 +14,8 @@ categories = ["multimedia", "network-programming", "web-programming"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -moq-native-ietf = { path = "../moq-native-ietf", version = "0.8" } -moq-transport = { path = "../moq-transport", version = "0.14" } +moq-native-ietf = { path = "../moq-native-ietf", version = "0.7" } +moq-transport = { path = "../moq-transport", version = "0.12" } moq-catalog = { path = "../moq-catalog", version = "0.2" } url = "2" @@ -27,12 +23,15 @@ bytes = "1" # Async stuff tokio = { version = "1", features = ["full"] } +futures = "0.3" # CLI, logging, error handling clap = { version = "4", features = ["derive"] } -tracing = { workspace = true } -tracing-subscriber = { workspace = true } +log = { workspace = true } +env_logger = { workspace = true } mp4 = "0.14" anyhow = { version = "1", features = ["backtrace"] } serde_json = "1" rfc6381-codec = "0.2" +tracing = "0.1" +tracing-subscriber = "0.3" diff --git a/moq-pub/src/cli.rs b/moq-pub/src/cli.rs index 5db370eb..a39ed412 100644 --- a/moq-pub/src/cli.rs +++ b/moq-pub/src/cli.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use clap::Parser; use std::{net, path}; use url::Url; diff --git a/moq-pub/src/lib.rs b/moq-pub/src/lib.rs index d54dda2a..4e6530ba 100644 --- a/moq-pub/src/lib.rs +++ b/moq-pub/src/lib.rs @@ -1,6 +1,2 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - mod media; pub use media::*; diff --git a/moq-pub/src/main.rs b/moq-pub/src/main.rs index 17d6403b..c259b440 100644 --- a/moq-pub/src/main.rs +++ b/moq-pub/src/main.rs @@ -1,18 +1,19 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use bytes::BytesMut; use std::net; use url::Url; use anyhow::Context; use clap::Parser; +use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use tokio::io::AsyncReadExt; use moq_native_ietf::quic; use moq_pub::Media; -use moq_transport::{coding::TrackNamespace, serve, session::Publisher}; +use moq_transport::{ + coding::TrackNamespace, + serve::{self, TracksReader}, + session::{Publisher, SessionError}, +}; #[derive(Parser, Clone)] pub struct Cli { @@ -43,16 +44,41 @@ pub struct Cli { pub tls: moq_native_ietf::tls::Args, } +async fn serve_subscriptions( + mut publisher: Publisher, + tracks: TracksReader, +) -> Result<(), SessionError> { + let mut tasks: FuturesUnordered> = + FuturesUnordered::new(); + + loop { + tokio::select! { + Some(subscribed) = publisher.subscribed() => { + let info = subscribed.info.clone(); + let tracks = tracks.clone(); + log::info!("serving subscribe: {:?}", info); + + tasks.push(async move { + if let Err(err) = Publisher::serve_subscribe(subscribed, tracks).await { + log::warn!("failed serving subscribe: {:?}, error: {}", info, err); + } + }.boxed()); + } + _ = tasks.next(), if !tasks.is_empty() => {} + else => return Ok(()), + } + } +} + #[tokio::main] async fn main() -> anyhow::Result<()> { - // Initialize tracing with env filter (respects RUST_LOG environment variable) - // Default to info level, but suppress quinn's verbose output - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info,quinn=warn")), - ) - .init(); + env_logger::init(); + + // Disable tracing so we don't get a bunch of Quinn spam. + let tracer = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::WARN) + .finish(); + tracing::subscriber::set_global_default(tracer).unwrap(); let cli = Cli::parse(); @@ -66,26 +92,35 @@ async fn main() -> anyhow::Result<()> { cli.bind, None, tls.clone(), - )?)?; + ))?; - tracing::info!("connecting to relay: url={}", cli.url); - let (session, connection_id, transport) = quic.client.connect(&cli.url, None).await?; + log::info!("connecting to relay: url={}", cli.url); + let (session, connection_id) = quic.client.connect(&cli.url, None).await?; - tracing::info!( + log::info!( "connected with CID: {} (use this to look up qlog/mlog on server)", connection_id ); - let (session, mut publisher) = Publisher::connect(session, transport) + let (session, publisher) = Publisher::connect(session) .await .context("failed to create MoQ Transport publisher")?; + let namespace = reader.namespace.clone(); + + let publish_ns = publisher + .clone() + .publish_namespace(namespace) + .await + .context("failed to register namespace")?; + + log::info!("namespace registered, starting media and subscription handling"); + tokio::select! { res = session.run() => res.context("session error")?, - res = run_media(media) => { - res.context("media error")? - }, - res = publisher.announce(reader) => res.context("publisher error")?, + res = run_media(media) => res.context("media error")?, + res = serve_subscriptions(publisher, reader) => res.context("publisher error")?, + res = publish_ns.closed() => res.context("publisher error")?, } Ok(()) diff --git a/moq-pub/src/media.rs b/moq-pub/src/media.rs index 3cb0d9e3..f1952781 100644 --- a/moq-pub/src/media.rs +++ b/moq-pub/src/media.rs @@ -1,9 +1,5 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use anyhow::{self, Context}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{Buf, Bytes}; use moq_transport::serve::{SubgroupWriter, SubgroupsWriter, TrackWriter, TracksWriter}; use mp4::{self, ReadBox, TrackType}; use std::cmp::max; @@ -266,7 +262,7 @@ impl Media { let catalog_str = serde_json::to_string_pretty(&catalog)?; - tracing::info!("catalog: {}", catalog_str); + log::info!("catalog: {}", catalog_str); // Create a single fragment for the segment. self.catalog.append(0)?.write(catalog_str.into())?; @@ -327,11 +323,6 @@ struct Track { // The current segment current: Option, - // Pending moof header bytes, waiting to be combined with mdat. - // Per CMSF (draft-ietf-moq-cmsf-00 §3.3), each MoQ Object must contain - // at least one complete CMAF Chunk (moof+mdat pair). - pending_moof: Option, - // The number of units per second. timescale: u64, @@ -344,17 +335,15 @@ impl Track { Self { track: track.subgroups().unwrap(), current: None, - pending_moof: None, timescale, handler, } } pub fn header(&mut self, raw: Bytes, fragment: Fragment) -> anyhow::Result<()> { - if self.current.is_some() { - // Use the existing segment — just stash the moof for now. - debug_assert!(self.pending_moof.is_none(), "overwriting pending moof"); - self.pending_moof = Some(raw); + if let Some(current) = self.current.as_mut() { + // Use the existing segment + current.write(raw)?; return Ok(()); } @@ -371,15 +360,15 @@ impl Track { let priority: u8 = 127; // Create a new segment. - let segment = self.track.append(priority)?; + let mut segment = self.track.append(priority)?; println!( "timestamp: {:?} segment: {:?}:{:?} priority: {:?}", fragment.timestamp, segment.info.group_id, segment.info.subgroup_id, priority ); - // Stash the moof — it will be combined with mdat in data(). - self.pending_moof = Some(raw); + // Write the fragment in it's own object. + segment.write(raw)?; // Save for the next iteration self.current = Some(segment); @@ -388,20 +377,19 @@ impl Track { } pub fn data(&mut self, raw: Bytes) -> anyhow::Result<()> { - let moof = self.pending_moof.take().context("missing pending moof")?; let segment = self.current.as_mut().context("missing current fragment")?; - // Combine moof+mdat into a single MoQ Object (CMSF §3.3 compliance). - let mut combined = BytesMut::with_capacity(moof.len() + raw.len()); - combined.put_slice(&moof); - combined.put_slice(&raw); - segment.write(combined.freeze())?; + segment.write(raw)?; Ok(()) } pub fn end_group(&mut self) { - self.current = None; - self.pending_moof = None; + // Send EndOfGroup marker before dropping the writer + if let Some(mut current) = self.current.take() { + if let Err(e) = current.end_of_group() { + log::warn!("failed to send EndOfGroup marker: {}", e); + } + } } } diff --git a/moq-relay-ietf/CHANGELOG.md b/moq-relay-ietf/CHANGELOG.md index 19e037df..e219d1ab 100644 --- a/moq-relay-ietf/CHANGELOG.md +++ b/moq-relay-ietf/CHANGELOG.md @@ -6,92 +6,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.7.18](https://github.com/cloudflare/moq-rs/compare/moq-relay-ietf-v0.7.17...moq-relay-ietf-v0.7.18) - 2026-05-20 - -### Fixed - -- tokio utils use default features -- suggestions from opencode reviewers -- apply suggestions from opencode review - -### Other - -- check for cancelled of cancellation token when waiting for subscribe open -- keep the comments for readability purpose -- Merge branch 'main' of github.com:itzmanish/moq-rs into feat/remote-manager-rewrite - -## [0.7.17](https://github.com/cloudflare/moq-rs/compare/moq-relay-ietf-v0.7.16...moq-relay-ietf-v0.7.17) - 2026-04-13 - -### Fixed - -- always register in coordinator after registering in local - -### Other - -- Merge branch 'main' of github.com:itzmanish/moq-rs into fix-register-order-namespace - -## [0.7.16](https://github.com/cloudflare/moq-rs/compare/moq-relay-ietf-v0.7.15...moq-relay-ietf-v0.7.16) - 2026-04-10 - -### Fixed - -- cross-platform dual-stack binding for IPv6 sockets - -### Other - -- Merge pull request #151 from englishm-cloudflare/me/ipv6-dual-stack-binding - -## [0.7.15](https://github.com/cloudflare/moq-rs/compare/moq-relay-ietf-v0.7.14...moq-relay-ietf-v0.7.15) - 2026-04-09 - -### Fixed - -- include destination address in upstream connection cache key - -## [0.7.14](https://github.com/cloudflare/moq-rs/compare/moq-relay-ietf-v0.7.13...moq-relay-ietf-v0.7.14) - 2026-03-31 - -### Other - -- Make repo REUSE v3.3 compliant -- Bring copyright notices, license docs up to date - -## [0.7.13](https://github.com/cloudflare/moq-rs/compare/moq-relay-ietf-v0.7.12...moq-relay-ietf-v0.7.13) - 2026-03-27 - -### Added - -- actively reject unauthorized control messages on permission-gated sessions -- add scope-aware namespace isolation to ApiCoordinator -- add Coordinator stubs for SUBSCRIBE_NAMESPACE, track PUBLISH, and lingering subscriber support -- add resolve_scope() to Coordinator trait with permission-gated sessions -- add scope parameter to Coordinator trait and thread through relay -- add Transport enum and connection path extraction - -## [0.7.12](https://github.com/cloudflare/moq-rs/compare/moq-relay-ietf-v0.7.11...moq-relay-ietf-v0.7.12) - 2026-02-18 - -### Other - -- update Cargo.toml dependencies - -## [0.7.11](https://github.com/cloudflare/moq-rs/compare/moq-relay-ietf-v0.7.10...moq-relay-ietf-v0.7.11) - 2026-02-18 - -### Added - -- add additional debug logging for troubleshooting -- add structured fields to high-value log messages -- *(metrics)* add describe_metrics() for Prometheus HELP text -- *(metrics)* distinguish graceful close from connection errors -- *(moq-relay-ietf)* add optional prometheus exporter for metrics validation -- *(moq-relay-ietf)* add metrics instrumentation via metrics crate facade - -### Fixed - -- cargo fmt and clippy lints -- *(metrics)* move upstream_connections gauge after successful connect -- *(metrics)* address review feedback for metrics instrumentation - -### Other - -- migrate from log crate to tracing -- *(metrics)* make metrics always-on, remove feature gate - ## [0.7.10](https://github.com/cloudflare/moq-rs/compare/moq-relay-ietf-v0.7.9...moq-relay-ietf-v0.7.10) - 2026-01-29 ### Other diff --git a/moq-relay-ietf/Cargo.toml b/moq-relay-ietf/Cargo.toml index 2821e634..00a1764f 100644 --- a/moq-relay-ietf/Cargo.toml +++ b/moq-relay-ietf/Cargo.toml @@ -1,14 +1,11 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - [package] name = "moq-relay-ietf" description = "Media over QUIC" -authors = ["moq-rs contributors"] +authors = ["Luke Curley", "Manish Kumar Pandit"] repository = "https://github.com/cloudflare/moq-rs" license = "MIT OR Apache-2.0" -version = "0.7.18" +version = "0.7.10" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] @@ -23,8 +20,8 @@ name = "moq-relay-ietf" path = "src/bin/moq-relay-ietf/main.rs" [dependencies] -moq-transport = { path = "../moq-transport", version = "0.14" } -moq-native-ietf = { path = "../moq-native-ietf", version = "0.8" } +moq-transport = { path = "../moq-transport", version = "0.12" } +moq-native-ietf = { path = "../moq-native-ietf", version = "0.7" } moq-api = { path = "../moq-api", version = "0.2" } web-transport = { workspace = true } @@ -33,7 +30,7 @@ url = "2" # Async stuff tokio = { version = "1", features = ["full"] } -tokio-util = "0.7" +# tokio-util = "0.7" futures = "0.3" async-trait = "0.1" @@ -60,16 +57,12 @@ clap = { version = "4", features = ["derive"] } # Logging -tracing = { workspace = true } -tracing-subscriber = { workspace = true } +log = { workspace = true } +env_logger = { workspace = true } +tracing = "0.1" +tracing-subscriber = "0.3" thiserror = "2.0.17" -# Metrics — always compiled in; the metrics crate is effectively zero-cost -# when no recorder is installed (similar to how the log crate works). -# The metrics-prometheus feature adds the optional Prometheus exporter. -metrics = "0.24" -metrics-exporter-prometheus = { version = "0.16", optional = true } - -[features] -default = [] -metrics-prometheus = ["dep:metrics-exporter-prometheus"] +# misc +#once_cell = "1.21.3" +arc-swap = "1" diff --git a/moq-relay-ietf/README.md b/moq-relay-ietf/README.md index 4cbd1f41..04d4c2e3 100644 --- a/moq-relay-ietf/README.md +++ b/moq-relay-ietf/README.md @@ -6,7 +6,7 @@ All subscriptions are deduplicated and cached, so that a single publisher can se ## Usage The publisher must choose a unique name for their broadcast, sent as the WebTransport path when connecting to the server. -Connection paths are normalized and validated: trailing slashes are trimmed, dot segments and percent-encoded characters are rejected, and empty segments are not allowed. Capitalization matters. +We currently do a dumb string comparison, so capatilization matters as do slashes. For example: `CONNECT https://relay.quic.video/BigBuckBunny` diff --git a/moq-relay-ietf/src/api.rs b/moq-relay-ietf/src/api.rs index 3a69bf12..ddafa443 100644 --- a/moq-relay-ietf/src/api.rs +++ b/moq-relay-ietf/src/api.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use url::Url; /// API client for moq-api. @@ -59,7 +56,7 @@ impl Refresh { /// Update the origin registration in moq-api. async fn update(&self) -> Result<(), moq_api::ApiError> { - tracing::debug!( + log::debug!( "registering origin: namespace={} url={}", self.namespace, self.origin.url @@ -85,7 +82,7 @@ impl Drop for Refresh { // TODO this is really lazy let namespace = self.namespace.clone(); let client = self.client.clone(); - tracing::debug!("removing origin: namespace={}", namespace,); + log::debug!("removing origin: namespace={}", namespace,); tokio::spawn(async move { client.delete_origin(&namespace).await }); } } diff --git a/moq-relay-ietf/src/bin/moq-relay-ietf/api_coordinator.rs b/moq-relay-ietf/src/bin/moq-relay-ietf/api_coordinator.rs index 0375fa17..1e2ccb1b 100644 --- a/moq-relay-ietf/src/bin/moq-relay-ietf/api_coordinator.rs +++ b/moq-relay-ietf/src/bin/moq-relay-ietf/api_coordinator.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! API-based coordinator for multi-relay deployments. //! //! This coordinator uses the moq-api HTTP server as a centralized registry @@ -63,7 +60,7 @@ impl ApiCoordinatorConfig { /// Handle that unregisters a namespace when dropped and manages TTL refresh struct NamespaceUnregisterHandle { - namespace_key: String, + namespace: TrackNamespace, client: Client, /// Channel to signal the refresh task to stop (wrapped in Option so we can take it in drop) shutdown_tx: Option>, @@ -76,24 +73,25 @@ impl Drop for NamespaceUnregisterHandle { let _ = tx.send(()); } - let namespace_key = self.namespace_key.clone(); + let namespace = self.namespace.clone(); let client = self.client.clone(); // Spawn a task to unregister since we can't do async in drop tokio::spawn(async move { - if let Err(err) = unregister_namespace_async(&client, &namespace_key).await { - tracing::warn!(namespace = %namespace_key, error = %err, "failed to unregister namespace on drop: {}", err); + if let Err(err) = unregister_namespace_async(&client, &namespace).await { + log::warn!("failed to unregister namespace on drop: {}", err); } }); } } /// Async helper for unregistering a namespace -async fn unregister_namespace_async(client: &Client, namespace_key: &str) -> Result<()> { - tracing::debug!(namespace = %namespace_key, "unregistering namespace from API: {}", namespace_key); +async fn unregister_namespace_async(client: &Client, namespace: &TrackNamespace) -> Result<()> { + let namespace_str = namespace.to_utf8_path(); + log::debug!("unregistering namespace from API: {}", namespace_str); client - .delete_origin(namespace_key) + .delete_origin(&namespace_str) .await .context("failed to delete namespace from API")?; @@ -108,13 +106,6 @@ async fn unregister_namespace_async(client: &Client, namespace_key: &str) -> Res /// - HTTP-based registration and lookup /// - TTL-based automatic expiration of stale registrations /// - Background refresh tasks to maintain registrations -/// -/// # Scope handling -/// -/// Registry keys encode the scope and namespace into a single collision-free -/// string. Namespace tuple fields are hex-encoded to handle arbitrary bytes -/// (MoQT namespaces are tuples of byte arrays, not strings). See -/// [`ApiCoordinator::registry_key()`] for format details. pub struct ApiCoordinator { /// moq-api client client: Client, @@ -123,42 +114,6 @@ pub struct ApiCoordinator { } impl ApiCoordinator { - /// Build the moq-api registry key for a namespace, scoped if applicable. - /// - /// The key unambiguously encodes `(scope, namespace)` into a single string - /// that can be used as an opaque key in the moq-api HTTP registry. - /// - /// ## Format - /// - /// Each namespace tuple field is hex-encoded and fields are joined with `.`. - /// The scope (if present) is prepended with a `:` separator: - /// - /// - Scoped: `"{scope}:{hex_field0}.{hex_field1}..."` - /// - Unscoped: `":{hex_field0}.{hex_field1}..."` - /// - /// ## Why this is collision-free - /// - /// - Hex encoding (`[0-9a-f]`) preserves arbitrary bytes without ambiguity - /// - `.` separates tuple fields (can't appear in hex output) - /// - `:` separates scope from namespace (can't appear in hex output, and - /// scopes are validated connection paths that don't contain `:`) - /// - The leading `:` on unscoped keys prevents collision with scoped keys - /// (scopes always start with `/` per `normalize_connection_path`) - /// - Different tuple field counts produce different keys (e.g., one field - /// `"ab"` → `"6162"` vs two fields `"a","b"` → `"61.62"`) - fn registry_key(scope: Option<&str>, namespace: &TrackNamespace) -> String { - let ns_hex: String = namespace - .fields - .iter() - .map(|f| hex::encode(&f.value)) - .collect::>() - .join("."); - match scope { - Some(s) => format!("{s}:{ns_hex}"), - None => format!(":{ns_hex}"), - } - } - /// Create a new API-based coordinator. /// /// # Arguments @@ -175,7 +130,7 @@ impl ApiCoordinator { /// Start a background task to refresh namespace registration fn start_refresh_task( client: Client, - namespace_key: String, + namespace: TrackNamespace, relay_url: Url, refresh_interval: Duration, mut shutdown_rx: tokio::sync::oneshot::Receiver<()>, @@ -187,19 +142,20 @@ impl ApiCoordinator { loop { tokio::select! { _ = interval.tick() => { + let namespace_str = namespace.to_utf8_path(); let origin = Origin { url: relay_url.clone() }; - match client.patch_origin(&namespace_key, origin).await { + match client.patch_origin(&namespace_str, origin).await { Ok(()) => { - tracing::trace!(namespace = %namespace_key, "refreshed namespace registration: {}", namespace_key); + log::trace!("refreshed namespace registration: {}", namespace_str); } Err(err) => { - tracing::warn!(namespace = %namespace_key, error = %err, "failed to refresh namespace registration: {}", err); + log::warn!("failed to refresh namespace registration: {}", err); } } } _ = &mut shutdown_rx => { - tracing::debug!("namespace refresh task shutting down"); + log::debug!("namespace refresh task shutting down"); break; } } @@ -212,17 +168,14 @@ impl ApiCoordinator { impl Coordinator for ApiCoordinator { async fn register_namespace( &self, - scope: Option<&str>, namespace: &TrackNamespace, ) -> CoordinatorResult { - let namespace_str = Self::registry_key(scope, namespace); + let namespace_str = namespace.to_utf8_path(); let origin = Origin { url: self.config.relay_url.clone(), }; - tracing::info!( - namespace = %namespace_str, - relay_url = %self.config.relay_url, + log::info!( "registering namespace in API: {} -> {}", namespace_str, self.config.relay_url @@ -241,14 +194,14 @@ impl Coordinator for ApiCoordinator { // Start background refresh task Self::start_refresh_task( self.client.clone(), - namespace_str.clone(), + namespace.clone(), self.config.relay_url.clone(), Duration::from_secs(self.config.refresh_interval_secs), shutdown_rx, ); let handle = NamespaceUnregisterHandle { - namespace_key: namespace_str, + namespace: namespace.clone(), client: self.client.clone(), shutdown_tx: Some(shutdown_tx), }; @@ -256,13 +209,9 @@ impl Coordinator for ApiCoordinator { Ok(NamespaceRegistration::new(handle)) } - async fn unregister_namespace( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - ) -> CoordinatorResult<()> { - let namespace_str = Self::registry_key(scope, namespace); - tracing::info!(namespace = %namespace_str, "unregistering namespace from API: {}", namespace_str); + async fn unregister_namespace(&self, namespace: &TrackNamespace) -> CoordinatorResult<()> { + let namespace_str = namespace.to_utf8_path(); + log::info!("unregistering namespace from API: {}", namespace_str); self.client .delete_origin(&namespace_str) @@ -275,11 +224,10 @@ impl Coordinator for ApiCoordinator { async fn lookup( &self, - scope: Option<&str>, namespace: &TrackNamespace, ) -> CoordinatorResult<(NamespaceOrigin, Option)> { - let namespace_str = Self::registry_key(scope, namespace); - tracing::debug!(scope = scope.unwrap_or(""), namespace = %namespace_str, "looking up namespace in API: {}", namespace_str); + let namespace_str = namespace.to_utf8_path(); + log::debug!("looking up namespace in API: {}", namespace_str); // Query the API for the namespace let result = self @@ -291,21 +239,21 @@ impl Coordinator for ApiCoordinator { match result { Some(origin) => { - tracing::debug!(namespace = %namespace_str, origin_url = %origin.url, "found namespace {} at {}", namespace_str, origin.url); + log::debug!("found namespace {} at {}", namespace_str, origin.url); Ok(( NamespaceOrigin::new(namespace.clone(), origin.url, None), None, )) } None => { - tracing::debug!(namespace = %namespace_str, "namespace not found: {}", namespace_str); + log::debug!("namespace not found: {}", namespace_str); Err(CoordinatorError::NamespaceNotFound) } } } async fn shutdown(&self) -> CoordinatorResult<()> { - tracing::info!("shutting down API coordinator"); + log::info!("shutting down API coordinator"); // The moq-api client uses reqwest which handles connection cleanup internally Ok(()) } diff --git a/moq-relay-ietf/src/bin/moq-relay-ietf/file_coordinator.rs b/moq-relay-ietf/src/bin/moq-relay-ietf/file_coordinator.rs index ca26f067..bebb9c10 100644 --- a/moq-relay-ietf/src/bin/moq-relay-ietf/file_coordinator.rs +++ b/moq-relay-ietf/src/bin/moq-relay-ietf/file_coordinator.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! File-based coordinator for multi-relay deployments. //! //! This coordinator uses a shared JSON file with file locking to coordinate @@ -27,15 +24,11 @@ use moq_relay_ietf::{ /// Data stored in the shared file #[derive(Debug, Default, Serialize, Deserialize)] struct CoordinatorData { - /// Maps connection scope to namespace map - namespaces: HashMap>, + /// Maps namespace path (e.g., "/foo/bar") to relay URL + namespaces: HashMap, } impl CoordinatorData { - fn scope_key(scope: Option<&str>) -> String { - scope.unwrap_or("").to_string() - } - fn namespace_key(namespace: &TrackNamespace) -> String { namespace.to_utf8_path() } @@ -43,23 +36,20 @@ impl CoordinatorData { /// Handle that unregisters a namespace when dropped struct NamespaceUnregisterHandle { - scope_key: String, - namespace_key: String, + namespace: TrackNamespace, file_path: PathBuf, } impl Drop for NamespaceUnregisterHandle { fn drop(&mut self) { - if let Err(err) = - unregister_namespace_sync(&self.file_path, &self.scope_key, &self.namespace_key) - { - tracing::warn!(namespace = %self.namespace_key, error = %err, "failed to unregister namespace on drop: {}", err); + if let Err(err) = unregister_namespace_sync(&self.file_path, &self.namespace) { + log::warn!("failed to unregister namespace on drop: {}", err); } } } /// Synchronous helper for unregistering namespace (used in Drop) -fn unregister_namespace_sync(file_path: &Path, scope_key: &str, namespace_key: &str) -> Result<()> { +fn unregister_namespace_sync(file_path: &Path, namespace: &TrackNamespace) -> Result<()> { let file = OpenOptions::new() .read(true) .write(true) @@ -70,13 +60,10 @@ fn unregister_namespace_sync(file_path: &Path, scope_key: &str, namespace_key: & file.lock_exclusive()?; let mut data = read_data(&file)?; - tracing::debug!(namespace = %namespace_key, scope = %scope_key, "unregistering namespace: {}", namespace_key); - if let Some(bucket) = data.namespaces.get_mut(scope_key) { - bucket.remove(namespace_key); - if bucket.is_empty() { - data.namespaces.remove(scope_key); - } - } + let key = CoordinatorData::namespace_key(namespace); + + log::debug!("unregistering namespace: {}", key); + data.namespaces.remove(&key); write_data(&file, &data)?; file.unlock()?; @@ -96,9 +83,7 @@ fn read_data(file: &File) -> Result { return Ok(CoordinatorData::default()); } - let data: CoordinatorData = - serde_json::from_str(&contents).context("failed to parse coordinator data")?; - Ok(data) + serde_json::from_str(&contents).context("failed to parse coordinator data") } /// Write coordinator data to file @@ -143,17 +128,14 @@ impl FileCoordinator { impl Coordinator for FileCoordinator { async fn register_namespace( &self, - scope: Option<&str>, namespace: &TrackNamespace, ) -> CoordinatorResult { - let scope_key = CoordinatorData::scope_key(scope); - let namespace_key = CoordinatorData::namespace_key(namespace); + let namespace = namespace.clone(); let relay_url = self.relay_url.to_string(); let file_path = self.file_path.clone(); // Run blocking file I/O in a separate thread - let scope_clone = scope_key.clone(); - let key_clone = namespace_key.clone(); + let ns_clone = namespace.clone(); tokio::task::spawn_blocking(move || { let file = OpenOptions::new() .read(true) @@ -165,12 +147,10 @@ impl Coordinator for FileCoordinator { file.lock_exclusive()?; let mut data = read_data(&file)?; - tracing::info!(namespace = %key_clone, scope = %scope_clone, relay_url = %relay_url, "registering namespace: {} -> {}", key_clone, relay_url); - data - .namespaces - .entry(scope_clone) - .or_default() - .insert(key_clone, relay_url); + let key = CoordinatorData::namespace_key(&ns_clone); + + log::info!("registering namespace: {} -> {}", key, relay_url); + data.namespaces.insert(key, relay_url); write_data(&file, &data)?; file.unlock()?; @@ -180,8 +160,7 @@ impl Coordinator for FileCoordinator { .await??; let handle = NamespaceUnregisterHandle { - scope_key, - namespace_key, + namespace, file_path: self.file_path.clone(), }; @@ -190,31 +169,21 @@ impl Coordinator for FileCoordinator { // FIXME(itzmanish): Not being called currently but we need to call this on publish_namespace_done // currently unregister happens on drop of namespace - async fn unregister_namespace( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - ) -> CoordinatorResult<()> { - let scope_key = CoordinatorData::scope_key(scope); - let namespace_key = CoordinatorData::namespace_key(namespace); + async fn unregister_namespace(&self, namespace: &TrackNamespace) -> CoordinatorResult<()> { + let namespace = namespace.clone(); let file_path = self.file_path.clone(); - tokio::task::spawn_blocking(move || { - unregister_namespace_sync(&file_path, &scope_key, &namespace_key) - }) - .await??; + tokio::task::spawn_blocking(move || unregister_namespace_sync(&file_path, &namespace)) + .await??; Ok(()) } async fn lookup( &self, - scope: Option<&str>, namespace: &TrackNamespace, ) -> CoordinatorResult<(NamespaceOrigin, Option)> { let namespace = namespace.clone(); - let scope_key = CoordinatorData::scope_key(scope); - let namespace_key = CoordinatorData::namespace_key(&namespace); let file_path = self.file_path.clone(); let result = tokio::task::spawn_blocking( @@ -229,15 +198,12 @@ impl Coordinator for FileCoordinator { file.lock_shared()?; let data = read_data(&file)?; - tracing::debug!(namespace = %namespace_key, scope = %scope_key, "looking up namespace: {}", namespace_key); + let key = CoordinatorData::namespace_key(&namespace); - let Some(bucket) = data.namespaces.get(&scope_key) else { - file.unlock()?; - return Ok(None); - }; + log::debug!("looking up namespace: {}", key); // Try exact match first - if let Some(relay_url) = bucket.get(&namespace_key) { + if let Some(relay_url) = data.namespaces.get(&key) { file.unlock()?; let url = Url::parse(relay_url)?; return Ok(Some((NamespaceOrigin::new(namespace, url, None), None))); @@ -245,12 +211,12 @@ impl Coordinator for FileCoordinator { // Try prefix matching (find longest matching prefix) let mut best_match: Option<(String, String)> = None; - for (registered_key, url) in bucket { + for (registered_key, url) in &data.namespaces { // FIXME(itzmanish): it would be much better to compare on TupleField // instead of working on strings let is_prefix = registered_key .split('/') - .zip(namespace_key.split('/')) + .zip(key.split('/')) .all(|(a, b)| a == b); match best_match { Some((ns, _)) if is_prefix && ns.len() < registered_key.len() => { diff --git a/moq-relay-ietf/src/bin/moq-relay-ietf/main.rs b/moq-relay-ietf/src/bin/moq-relay-ietf/main.rs index 9ccfe462..faea6141 100644 --- a/moq-relay-ietf/src/bin/moq-relay-ietf/main.rs +++ b/moq-relay-ietf/src/bin/moq-relay-ietf/main.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - mod api_coordinator; mod file_coordinator; @@ -12,7 +9,7 @@ use url::Url; use api_coordinator::{ApiCoordinator, ApiCoordinatorConfig}; use file_coordinator::FileCoordinator; -use moq_relay_ietf::{Coordinator, Relay, RelayConfig, Web, WebConfig}; +use moq_relay_ietf::{Coordinator, Relay, RelayConfig, TieBreakPolicy, Web, WebConfig}; #[derive(Parser, Clone)] pub struct Cli { @@ -82,64 +79,31 @@ pub struct Cli { #[arg(long, default_value = "600")] pub api_ttl: u64, - /// Address to expose Prometheus metrics on (e.g., "127.0.0.1:9090"). - /// Requires the `metrics-prometheus` feature to be enabled. - /// When set, serves metrics at http:///metrics + /// Enable TopN event logging for visualization. + /// Writes JSON events to stdout that can be parsed to generate timeline SVGs. + /// Use with: grep TOPN_EVENT relay.log > events.log + /// Then: cargo run -p moq-topn-test --bin topn-log-to-svg -- events.log timeline.svg #[arg(long)] - pub metrics_addr: Option, + pub topn_log: bool, + + /// Tie-break policy for top-N filtering when values are equal. + /// "oldest" = first registered track wins (default) + /// "recent" = most recently updated track wins + #[arg(long, default_value = "oldest")] + pub tie_break: String, } #[tokio::main] async fn main() -> anyhow::Result<()> { - // Initialize tracing with env filter (respects RUST_LOG environment variable) - // Default to info level, but suppress quinn's verbose output - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info,quinn=warn")), - ) - .init(); - - let cli = Cli::parse(); - - // Initialize Prometheus metrics exporter if --metrics-addr is provided - #[cfg(feature = "metrics-prometheus")] - if let Some(metrics_addr) = cli.metrics_addr { - use metrics_exporter_prometheus::PrometheusBuilder; - - // Configure histogram buckets for subscribe latency (1ms to 10s) - let subscribe_latency_buckets = vec![ - 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 5.0, 10.0, - ]; - - PrometheusBuilder::new() - .with_http_listener(metrics_addr) - .set_buckets_for_metric( - metrics_exporter_prometheus::Matcher::Full( - "moq_relay_subscribe_latency_seconds".to_string(), - ), - &subscribe_latency_buckets, - )? - .install() - .expect("failed to install Prometheus metrics exporter"); - - // Register metric descriptions (shows as # HELP in Prometheus output) - moq_relay_ietf::metrics::describe_metrics(); - - tracing::info!( - "metrics exporter listening on http://{}/metrics", - metrics_addr - ); - } + env_logger::init(); - #[cfg(not(feature = "metrics-prometheus"))] - if cli.metrics_addr.is_some() { - tracing::warn!( - "--metrics-addr was provided but the metrics-prometheus feature is not enabled. \ - Rebuild with --features metrics-prometheus to enable the Prometheus exporter." - ); - } + // Disable tracing so we don't get a bunch of Quinn spam. + let tracer = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::WARN) + .finish(); + tracing::subscriber::set_global_default(tracer).unwrap(); + let cli = Cli::parse(); let tls = cli.tls.load()?; if tls.server.is_none() { @@ -173,13 +137,20 @@ async fn main() -> anyhow::Result<()> { let coordinator: Arc = if let Some(api_url) = &cli.api_url { let config = ApiCoordinatorConfig::new(api_url.clone(), relay_url).with_ttl(cli.api_ttl); let api_coordinator = ApiCoordinator::new(config); - tracing::info!("using API coordinator: {}", api_url); + log::info!("using API coordinator: {}", api_url); Arc::new(api_coordinator) } else { - tracing::info!("using file coordinator: {}", cli.coordinator_file.display()); + log::info!("using file coordinator: {}", cli.coordinator_file.display()); Arc::new(FileCoordinator::new(&cli.coordinator_file, relay_url)) }; + // Parse tie-break policy + let tie_break_policy = match cli.tie_break.as_str() { + "oldest" => TieBreakPolicy::OldestWins, + "recent" => TieBreakPolicy::MostRecentWins, + other => anyhow::bail!("invalid tie-break policy '{}': must be 'oldest' or 'recent'", other), + }; + // Create a QUIC server for media. let relay = Relay::new(RelayConfig { tls: tls.clone(), @@ -190,6 +161,8 @@ async fn main() -> anyhow::Result<()> { node: cli.node, announce: cli.announce, coordinator, + topn_log: cli.topn_log, + tie_break_policy, })?; if cli.dev { diff --git a/moq-relay-ietf/src/consumer.rs b/moq-relay-ietf/src/consumer.rs index 7b25ae78..7118bc7b 100644 --- a/moq-relay-ietf/src/consumer.rs +++ b/moq-relay-ietf/src/consumer.rs @@ -1,16 +1,16 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; use anyhow::Context; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_transport::{ - serve::Tracks, - session::{Announced, SessionError, Subscriber}, + coding::KeyValuePairs, + message::PublishOk, + serve::{ServeError, Tracks}, + session::{PublishNamespaceReceived, PublishReceived, SessionError, Subscriber}, }; -use crate::{metrics::GaugeGuard, Coordinator, Locals, Producer}; +use crate::{Coordinator, Locals, Producer, SubscriberRegistry}; /// Consumer of tracks from a remote Publisher #[derive(Clone)] @@ -19,10 +19,8 @@ pub struct Consumer { locals: Locals, coordinator: Arc, forward: Option, // Forward all announcements to this subscriber - /// The resolved scope identity for this session, if any. - /// Produced by `Coordinator::resolve_scope()` from the connection path. - /// Passed to coordinator register/lookup calls to isolate namespaces. - scope: Option, + subscriber_registry: Option, + session_id: u64, } impl Consumer { @@ -31,131 +29,155 @@ impl Consumer { locals: Locals, coordinator: Arc, forward: Option, - scope: Option, ) -> Self { Self { subscriber, locals, coordinator, forward, - scope, + subscriber_registry: None, + session_id: 0, + } + } + + /// Creates a consumer with a subscriber registry for PUBLISH notifications. + pub fn with_registry( + subscriber: Subscriber, + locals: Locals, + coordinator: Arc, + forward: Option, + subscriber_registry: SubscriberRegistry, + session_id: u64, + ) -> Self { + Self { + subscriber, + locals, + coordinator, + forward, + subscriber_registry: Some(subscriber_registry), + session_id, } } - /// Run the consumer to serve announce requests. - pub async fn run(mut self) -> Result<(), SessionError> { - let mut tasks = FuturesUnordered::new(); + /// Run the consumer to serve announce requests and track-level publish messages. + pub async fn run(self) -> Result<(), SessionError> { + let mut tasks: FuturesUnordered> = + FuturesUnordered::new(); + + log::debug!("[CONSUMER] run: starting main loop"); loop { + let mut subscriber_ns = self.subscriber.clone(); + let mut subscriber_publish = self.subscriber.clone(); + + log::trace!("[CONSUMER] run: waiting on select (tasks={})", tasks.len()); + tokio::select! { - // Handle a new announce request - Some(announce) = self.subscriber.announced() => { - metrics::counter!("moq_relay_publishers_total").increment(1); + Some(publish_ns) = subscriber_ns.publish_ns_recvd() => { + let this = self.clone(); + tasks.push(async move { + let info = publish_ns.clone(); + log::info!("serving publish_namespace: {:?}", info); + + if let Err(err) = this.serve_publish_namespace(publish_ns).await { + log::warn!("failed serving publish_namespace: {:?}, error: {}", info, err) + } + }.boxed()); + }, + Some(publish) = subscriber_publish.publish_received() => { + log::debug!("[CONSUMER] run: received track-level PUBLISH"); let this = self.clone(); tasks.push(async move { - let info = announce.clone(); - let namespace = info.namespace.to_utf8_path(); - tracing::info!(namespace = %namespace, "serving announce: {:?}", info); - - // Serve the announce request - if let Err(err) = this.serve(announce).await { - tracing::warn!(namespace = %namespace, error = %err, "failed serving announce: {:?}, error: {}", info, err); - // Note: phase-specific error counters are incremented in serve() + let info = publish.info.clone(); + log::info!("serving publish (track-level): {:?}", info); + + if let Err(err) = this.serve_publish(publish).await { + log::warn!("failed serving publish: {:?}, error: {}", info, err) } - }); + }.boxed()); + }, + _ = tasks.next(), if !tasks.is_empty() => { + log::trace!("[CONSUMER] run: a task completed"); + }, + else => { + log::debug!("[CONSUMER] run: else branch triggered, returning"); + return Ok(()); }, - _ = tasks.next(), if !tasks.is_empty() => {}, - else => return Ok(()), }; } } - /// Serve an announce request. - async fn serve(mut self, mut announce: Announced) -> Result<(), anyhow::Error> { - // Track active publishers - decrements when this function returns - let _publisher_guard = GaugeGuard::new("moq_relay_active_publishers"); + async fn serve_publish_namespace( + mut self, + mut publish_ns: PublishNamespaceReceived, + ) -> Result<(), anyhow::Error> { + let mut tasks: FuturesUnordered>> = + FuturesUnordered::new(); - let mut tasks = FuturesUnordered::new(); + let (writer, mut request, reader) = Tracks::new(publish_ns.namespace.clone()).produce(); - // Produce the tracks for this announce and return the reader - let (_, mut request, reader) = Tracks::new(announce.namespace.clone()).produce(); + // NOTE(mpandit): once the track is pulled from origin, internally it will be relayed + // from this metal only, because now coordinator will have entry for the namespace. // should we allow the same namespace being served from multiple relays?? - // Manish: NO. - let ns = reader.namespace.to_utf8_path(); + // Register namespace with the coordinator + let _namespace_registration = self + .coordinator + .register_namespace(&reader.namespace) + .await?; // Register the local tracks, unregister on drop - tracing::debug!(namespace = %ns, "registering namespace in locals"); - let _register = match self - .locals - .register(self.scope.as_deref(), reader.clone()) - .await - { - Ok(reg) => reg, - Err(err) => { - metrics::counter!("moq_relay_announce_errors_total", "phase" => "local_register") - .increment(1); - return Err(err); - } - }; - tracing::debug!(namespace = %ns, "namespace registered in locals"); + let _register = self.locals.register(reader.clone(), writer).await?; - // NOTE(mpandit): once the track is pulled from origin, internally it will be relayed - // from this metal only, because now coordinator will have entry for the namespace. + publish_ns.ok()?; - // Register namespace with the coordinator - tracing::debug!(namespace = %ns, "registering namespace with coordinator"); - let _namespace_registration = match self - .coordinator - .register_namespace(self.scope.as_deref(), &reader.namespace) - .await - { - Ok(reg) => reg, - Err(err) => { - metrics::counter!("moq_relay_announce_errors_total", "phase" => "coordinator_register") - .increment(1); - return Err(err.into()); + // Notify subscriber registry of the new PUBLISH_NAMESPACE + // This will trigger forwarding to matching SUBSCRIBE_NAMESPACE subscriptions + // Uses session_id for self-exclusion + if let Some(ref registry) = self.subscriber_registry { + let notified = registry.notify_publish_namespace(&publish_ns.namespace, self.session_id); + if notified > 0 { + log::info!( + "notified {} SUBSCRIBE_NAMESPACE subscriptions of PUBLISH_NAMESPACE {:?}", + notified, + publish_ns.namespace + ); } - }; - tracing::debug!(namespace = %ns, "namespace registered with coordinator"); - - // Accept the announce with an OK response - if let Err(err) = announce.ok() { - metrics::counter!("moq_relay_announce_errors_total", "phase" => "send_ok").increment(1); - return Err(err.into()); } - tracing::debug!(namespace = %ns, "sent ANNOUNCE_OK"); - - // Successfully sent ANNOUNCE_OK - metrics::counter!("moq_relay_announce_ok_total").increment(1); - - // Forward the announce, if needed - if let Some(mut forward) = self.forward { - tasks.push( - async move { - let namespace = reader.namespace.to_utf8_path(); - tracing::info!(namespace = %namespace, "forwarding announce: {:?}", reader.info); - forward - .announce(reader) - .await - .context("failed forwarding announce") + + // Forward publish_namespace upstream - keep handle alive in this scope + let _forwarded_publish_ns = if let Some(mut forward) = self.forward.clone() { + let reader_clone = reader.clone(); + log::info!("forwarding publish_namespace: {:?}", reader_clone.info); + match forward.publish_namespace(reader_clone).await { + Ok(publish_ns) => { + if let Err(e) = publish_ns.ok().await { + log::warn!("publish_namespace not accepted: {}", e); + None + } else { + log::info!( + "publish_namespace forwarded and accepted: {:?}", + publish_ns.info.namespace + ); + Some(publish_ns) + } } - .boxed(), - ); - } + Err(e) => { + log::warn!("failed forwarding publish_namespace: {}", e); + None + } + } + } else { + None + }; // Serve subscribe requests loop { tokio::select! { - // If the announce is closed, return the error - Err(err) = announce.closed() => { - let ns = announce.namespace.to_utf8_path(); - tracing::info!(namespace = %ns, error = %err, "announce closed"); - return Err(err.into()); - }, + Err(err) = publish_ns.closed() => return Err(err.into()), // Wait for the next subscriber and serve the track. Some(track) = request.next() => { @@ -164,13 +186,11 @@ impl Consumer { // Spawn a new task to handle the subscribe tasks.push(async move { let info = track.clone(); - let namespace = info.namespace.to_utf8_path(); - let track_name = info.name.clone(); - tracing::info!(namespace = %namespace, track = %track_name, "forwarding subscribe: {:?}", info); + log::info!("forwarding subscribe: {:?}", info); // Forward the subscribe request if let Err(err) = subscriber.subscribe(track).await { - tracing::warn!(namespace = %namespace, track = %track_name, error = %err, "failed forwarding subscribe: {:?}, error: {}", info, err) + log::warn!("failed forwarding subscribe: {:?}, error: {}", info, err) } Ok(()) @@ -181,4 +201,251 @@ impl Consumer { } } } + + async fn serve_publish(mut self, publish: PublishReceived) -> Result<(), anyhow::Error> { + let namespace = publish.info.track_namespace.clone(); + let track_name = publish.info.track_name.clone(); + let track_alias = publish.info.track_alias; + let initial_forward = publish.info.forward; + let publish_request_id = publish.info.id; + let track_extensions = publish.info.track_extensions.clone(); + + log::info!( + "received PUBLISH for track: {}/{} (forward={}, extensions={:?})", + namespace, + track_name, + initial_forward, + track_extensions + ); + + // Use auto-register variant to support SUBSCRIBE_NAMESPACE flow + // where PUBLISH can arrive without prior PUBLISH_NAMESPACE + let track_info = self + .locals + .get_or_create_track_info_auto_register(&namespace, &track_name); + + let writer = match track_info.publish_arrived() { + Ok(w) => w, + Err(ServeError::Uninterested) => { + log::info!( + "PUBLISH rejected: already subscribed to {}/{}", + namespace, + track_name + ); + publish.reject(ServeError::Uninterested.code(), "Already subscribed")?; + return Err(ServeError::Uninterested.into()); + } + Err(ServeError::Duplicate) => { + log::info!( + "PUBLISH rejected: already publishing {}/{}", + namespace, + track_name + ); + publish.reject(ServeError::Duplicate.code(), "Already publishing")?; + return Err(ServeError::Duplicate.into()); + } + Err(e) => { + publish.reject(e.code(), &e.to_string())?; + return Err(e.into()); + } + }; + + let reader = track_info.get_reader(); + + self.locals + .insert_track(&namespace, reader) + .context("failed to insert track into namespace")?; + + // Store publish info for forward state management + track_info.set_publish_info(publish_request_id, initial_forward); + + // Store track extensions for forwarding to subscribers + track_info.set_track_extensions(track_extensions); + + // Include forward=1 in PUBLISH_OK to tell publisher to start sending immediately + let mut params = KeyValuePairs::default(); + params.set_intvalue(0x10, 1); // Forward = 1 + + let msg = PublishOk { + id: publish.info.id, + params, + }; + + publish.accept(writer, msg)?; + + log::info!( + "PUBLISH accepted, track {}/{} now in Publishing state (forward={})", + namespace, + track_name, + initial_forward + ); + + // Register track with TopN tracker if track_extensions contain property values + // This enables top-N filtering for SUBSCRIBE_NAMESPACE with TRACK_FILTER + if let Some(ref registry) = self.subscriber_registry { + // Check for known property types in track_extensions + // AUDIO_LEVEL_EXT = 0x12 (18) - audio level for active speaker detection + const AUDIO_LEVEL_EXT: u64 = 0x12; + + if let Some(track_exts) = track_info.track_extensions() { + if let Some(kvp) = track_exts.get(AUDIO_LEVEL_EXT) { + if let moq_transport::coding::Value::IntValue(audio_level) = kvp.value { + registry.register_track( + &namespace, + &track_name, + AUDIO_LEVEL_EXT, + audio_level, + self.session_id, + ); + log::info!( + "registered track {}/{} with TopN tracker (audio_level={})", + namespace, + track_name, + audio_level + ); + } + } + } + + // Spawn ingest observer: single task that reads objects and calls + // update_track_value once per value change (removes 799/800 redundant + // mutex locks from subscriber observer path) + let reg = registry.clone(); + let ingest_ns = namespace.clone(); + let ingest_name = track_name.clone(); + let ingest_session_id = self.session_id; + let ingest_reader = track_info.get_reader(); + tokio::spawn(async move { + Self::run_ingest_observer( + ingest_reader, + reg, + ingest_ns, + ingest_name, + track_alias, + ingest_session_id, + ) + .await; + }); + } + + // Notify subscriber registry of the new PUBLISH + // This will trigger forwarding to matching SUBSCRIBE_NAMESPACE subscriptions + // Uses session_id for self-exclusion (don't notify the same session that sent the PUBLISH) + if let Some(ref registry) = self.subscriber_registry { + let notified = registry.notify_publish(&namespace, &track_name, track_alias, self.session_id); + if notified > 0 { + log::info!( + "notified {} SUBSCRIBE_NAMESPACE subscriptions of PUBLISH {}/{}", + notified, + namespace, + track_name + ); + } + } + + // If forward=0 (paused), wait for subscribers to request forwarding + // When forward state changes to 1, send REQUEST_UPDATE to publisher + if !initial_forward { + let forward_rx = track_info.forward_receiver(); + if let Some(mut rx) = forward_rx { + log::info!( + "track {}/{} is paused (forward=0), waiting for subscriber to request forwarding", + namespace, + track_name + ); + + // Wait for forward state to change to true + loop { + rx.changed().await.ok(); + if *rx.borrow() { + // Forward state changed to true, send REQUEST_UPDATE + log::info!( + "subscriber arrived for paused track {}/{}, sending REQUEST_UPDATE with forward=1", + namespace, + track_name + ); + + let mut params = KeyValuePairs::default(); + params.set_intvalue(0x10, 1); // Forward = 1 + + let request_update = moq_transport::message::SubscribeUpdate { + id: self.subscriber.get_next_request_id(), + existing_request_id: publish_request_id, + params, + }; + + self.subscriber.send_message(request_update); + log::info!( + "sent REQUEST_UPDATE for track {}/{} (existing_request_id={})", + namespace, + track_name, + publish_request_id + ); + break; + } + } + } + } + + Ok(()) + } + + async fn run_ingest_observer( + reader: moq_transport::serve::TrackReader, + registry: SubscriberRegistry, + namespace: moq_transport::coding::TrackNamespace, + track_name: String, + track_alias: u64, + session_id: u64, + ) { + const AUDIO_LEVEL_EXT: u64 = 0x12; + let last_value = AtomicU64::new(u64::MAX); + + let mode = match reader.mode().await { + Ok(mode) => mode, + Err(_) => return, + }; + + match mode { + moq_transport::serve::TrackReaderMode::Subgroups(mut subgroups) => { + while let Ok(Some(mut subgroup)) = subgroups.next().await { + while let Ok(Some(object)) = subgroup.next().await { + if let Some(kvp) = object.extension_headers.get(AUDIO_LEVEL_EXT) { + if let moq_transport::coding::Value::IntValue(value) = kvp.value { + if last_value.swap(value, Ordering::Relaxed) != value { + registry.update_track_value( + &namespace, + &track_name, + AUDIO_LEVEL_EXT, + value, + track_alias, + session_id, + ); + } + } + } + } + } + } + moq_transport::serve::TrackReaderMode::Datagrams(mut datagrams) => { + while let Ok(Some(datagram)) = datagrams.read().await { + if let Some(kvp) = datagram.extension_headers.get(AUDIO_LEVEL_EXT) { + if let moq_transport::coding::Value::IntValue(value) = kvp.value { + if last_value.swap(value, Ordering::Relaxed) != value { + registry.update_track_value( + &namespace, + &track_name, + AUDIO_LEVEL_EXT, + value, + track_alias, + session_id, + ); + } + } + } + } + } + _ => {} + } + } } diff --git a/moq-relay-ietf/src/coordinator.rs b/moq-relay-ietf/src/coordinator.rs index 532af962..4315b438 100644 --- a/moq-relay-ietf/src/coordinator.rs +++ b/moq-relay-ietf/src/coordinator.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::net::SocketAddr; use async_trait::async_trait; @@ -125,243 +122,9 @@ impl NamespaceOrigin { } } -/// Information about the resolved scope for a connection. -/// -/// Returned by [`Coordinator::resolve_scope()`] to tell the relay: -/// - Which scope this connection belongs to (for routing and namespace isolation) -/// - What the connection is allowed to do (for permission enforcement) -/// -/// Multiple connection paths can map to the same `scope_id` — for example, -/// a publisher path and a subscriber path that share a scope but have -/// different permissions. -#[derive(Debug, Clone)] -pub struct ScopeInfo { - /// The resolved scope identity. Used as the key for namespace - /// registration and lookup in all subsequent coordinator operations. - /// - /// Multiple connection paths can map to the same `scope_id`. - pub scope_id: String, - - /// What this connection is allowed to do within the scope. - pub permissions: ScopePermissions, -} - -/// Permissions granted to a connection within its scope. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ScopePermissions { - /// Can both publish (PUBLISH_NAMESPACE) and subscribe (SUBSCRIBE/FETCH). - ReadWrite, - /// Can subscribe/fetch only. Publishing attempts will be rejected - /// by the relay (the Consumer side of the session will not be created). - ReadOnly, -} - -impl ScopePermissions { - /// Whether this permission level allows publishing (PUBLISH_NAMESPACE). - pub fn can_publish(&self) -> bool { - matches!(self, Self::ReadWrite) - } - - /// Whether this permission level allows subscribing (SUBSCRIBE/FETCH). - /// - /// Always returns `true` — both `ReadWrite` and `ReadOnly` connections - /// can subscribe. This is intentional: the asymmetry with [`can_publish()`] - /// reflects that subscribing is the baseline capability, while publishing - /// requires elevated permissions. If a future permission level needs to - /// deny subscribing, a new variant should be added. - /// - /// [`can_publish()`]: ScopePermissions::can_publish - pub fn can_subscribe(&self) -> bool { - true - } -} - -// ============================================================================ -// Types for extended Coordinator functionality -// ============================================================================ - -/// Per-scope configuration retrieved from the coordinator. -/// -/// Called after [`Coordinator::resolve_scope()`] to get operational parameters -/// for the scope. This configuration applies to all sessions within the scope. -#[derive(Debug, Clone, Default)] -#[non_exhaustive] -pub struct ScopeConfig { - /// Origin server to fall back to when namespace not found locally or on - /// other relays. The relay will attempt to subscribe from this origin - /// before returning "not found" to the subscriber. - pub origin_fallback: Option, - - /// Whether to pre-register subscriber interest for tracks that don't exist - /// yet. When true, enables "subscriber-first" workflows where subscribers - /// can wait for publishers that haven't connected yet. - /// - /// This corresponds to the "rendezvous" concept in the MoQT specification - /// (`RENDEZVOUS_TIMEOUT` parameter, see moq-transport PR #1447). The - /// "lingering subscribe" terminology from moq-transport issue #1402 is - /// used here for consistency with existing implementations. - /// - /// Future: A `rendezvous_timeout` field may be added to control how long - /// the relay waits for a publisher before giving up. - pub lingering_subscribe: bool, -} - -/// Result of subscribing to a namespace prefix via SUBSCRIBE_NAMESPACE. -/// -/// The subscription remains active until this handle is dropped. -/// On drop, cleanup is performed (e.g., unregistering from the coordinator). -pub struct NamespaceSubscription { - /// Namespaces that currently match the subscribed prefix. - /// The relay should send PUBLISH_NAMESPACE for each of these. - pub existing_namespaces: Vec, - - /// RAII handle — drop triggers unsubscription cleanup. - _registration: Box, -} - -impl Default for NamespaceSubscription { - fn default() -> Self { - Self { - existing_namespaces: vec![], - _registration: Box::new(()), - } - } -} - -impl NamespaceSubscription { - /// Create a new subscription with existing namespaces and a cleanup handle. - pub fn new(existing: Vec, inner: T) -> Self { - Self { - existing_namespaces: existing, - _registration: Box::new(inner), - } - } -} - -/// Information about a registered namespace. -/// -/// Returned in [`NamespaceSubscription`] to describe namespaces matching -/// a SUBSCRIBE_NAMESPACE prefix. -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct NamespaceInfo { - /// The namespace identity. - pub namespace: TrackNamespace, -} - -impl NamespaceInfo { - /// Create a new NamespaceInfo. - pub fn new(namespace: TrackNamespace) -> Self { - Self { namespace } - } -} - -/// Information about a relay to forward messages to. -/// -/// Returned by [`Coordinator::lookup_namespace_subscribers()`] and -/// [`Coordinator::lookup_track_subscribers()`] to tell the relay where -/// to forward PUBLISH_NAMESPACE or track availability notifications. -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct RelayInfo { - /// Relay URL (used for TLS SNI and connection establishment). - pub url: Url, - - /// Optional direct socket address (bypasses DNS resolution). - pub addr: Option, -} - -impl RelayInfo { - /// Create a new RelayInfo with URL only. - pub fn new(url: Url) -> Self { - Self { url, addr: None } - } - - /// Create a new RelayInfo with URL and direct socket address. - pub fn with_addr(url: Url, addr: SocketAddr) -> Self { - Self { - url, - addr: Some(addr), - } - } -} - -/// Handle returned when a track is registered with the coordinator. -/// -/// Dropping this handle automatically unregisters the track. -/// This provides RAII-based cleanup for track-level PUBLISH. -pub struct TrackRegistration { - _registration: Box, -} - -impl Default for TrackRegistration { - fn default() -> Self { - Self { - _registration: Box::new(()), - } - } -} - -impl TrackRegistration { - /// Create a new track registration handle wrapping any Send + Sync type. - pub fn new(inner: T) -> Self { - Self { - _registration: Box::new(inner), - } - } -} - -/// Information about a registered track. -/// -/// Returned by [`Coordinator::list_tracks()`] to describe tracks -/// registered under a namespace. -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct TrackEntry { - /// The namespace this track belongs to. - pub namespace: TrackNamespace, - - /// The track name within the namespace. - pub name: String, -} - -impl TrackEntry { - /// Create a new TrackEntry. - pub fn new(namespace: TrackNamespace, name: String) -> Self { - Self { namespace, name } - } -} - -/// Handle returned when subscribing to a track for rendezvous/lingering subscriber support. -/// -/// Dropping this handle automatically unregisters the track subscription. -/// This provides RAII-based cleanup for pre-registered subscriber interest -/// (the "rendezvous" concept from MoQT's `RENDEZVOUS_TIMEOUT` parameter). -pub struct TrackSubscription { - _registration: Box, -} - -impl Default for TrackSubscription { - fn default() -> Self { - Self { - _registration: Box::new(()), - } - } -} - -impl TrackSubscription { - /// Create a new track subscription handle wrapping any Send + Sync type. - pub fn new(inner: T) -> Self { - Self { - _registration: Box::new(inner), - } - } -} - /// Coordinator handles namespace registration/discovery across relays. /// /// Implementations are responsible for: -/// - Resolving connection paths to scopes (identity + permissions) /// - Tracking which namespaces are served locally /// - Caching remote namespace lookups /// - Communicating with external registries (HTTP API, Redis, etc.) @@ -372,66 +135,8 @@ impl TrackSubscription { /// /// All methods take `&self` and implementations must be thread-safe. /// Multiple tasks will call these methods concurrently. -/// -/// ## Scope Resolution -/// -/// When a new session is accepted, the relay calls [`resolve_scope()`] with -/// the raw connection path (from WebTransport URL or CLIENT_SETUP PATH -/// parameter). The coordinator returns a [`ScopeInfo`] containing: -/// -/// - **`scope_id`**: The resolved scope identity, used as the key for all -/// subsequent `register_namespace()` and `lookup()` calls. This is -/// intentionally separate from the raw connection path — multiple paths -/// can map to the same scope. -/// -/// - **`permissions`**: What the connection is allowed to do. The relay -/// enforces permissions by selectively enabling the publish and/or -/// subscribe sides of the session. -/// -/// If `resolve_scope()` returns `None`, the session is unscoped — all -/// subsequent operations use `scope: None` and both publish and subscribe -/// are allowed. -/// -/// [`resolve_scope()`]: Coordinator::resolve_scope #[async_trait] pub trait Coordinator: Send + Sync { - /// Resolve a connection path to scope information. - /// - /// Called once per accepted session, before any register/lookup calls. - /// The relay uses the returned [`ScopeInfo`] to: - /// - Scope all subsequent coordinator operations to `scope_id` - /// - Enforce permissions (e.g., skip creating the publish side for - /// `ReadOnly` connections) - /// - /// # Arguments - /// - /// * `connection_path` - The raw connection path from the WebTransport - /// URL or CLIENT_SETUP PATH parameter. `None` if no path was present. - /// - /// # Returns - /// - /// - `Ok(Some(ScopeInfo))` - Connection is scoped with the given - /// identity and permissions. - /// - `Ok(None)` - Connection is unscoped. The relay will pass - /// `scope: None` to all subsequent coordinator calls and allow - /// both publish and subscribe. - /// - `Err(...)` - Connection should be rejected (e.g., unrecognized - /// path, unauthorized). - /// - /// # Default Implementation - /// - /// Passes through the connection path as the `scope_id` with - /// `ReadWrite` permissions. Connections without a path are unscoped. - async fn resolve_scope( - &self, - connection_path: Option<&str>, - ) -> CoordinatorResult> { - Ok(connection_path.map(|path| ScopeInfo { - scope_id: path.to_string(), - permissions: ScopePermissions::ReadWrite, - })) - } - /// Register a namespace as locally available on this relay. /// /// Called when a publisher sends PUBLISH_NAMESPACE. @@ -443,21 +148,14 @@ pub trait Coordinator: Send + Sync { /// /// # Arguments /// - /// * `scope` - The resolved scope identity from [`resolve_scope()`], - /// or `None` for unscoped sessions. Used to isolate namespace - /// registrations — the same namespace in different scopes may - /// route independently. /// * `namespace` - The namespace being registered /// /// # Returns /// /// A `NamespaceRegistration` handle. The namespace remains registered /// as long as this handle is held. Dropping it unregisters the namespace. - /// - /// [`resolve_scope()`]: Coordinator::resolve_scope async fn register_namespace( &self, - scope: Option<&str>, namespace: &TrackNamespace, ) -> CoordinatorResult; @@ -469,13 +167,8 @@ pub trait Coordinator: Send + Sync { /// /// # Arguments /// - /// * `scope` - The resolved scope identity, or `None` for unscoped sessions. /// * `namespace` - The namespace to unregister - async fn unregister_namespace( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - ) -> CoordinatorResult<()>; + async fn unregister_namespace(&self, namespace: &TrackNamespace) -> CoordinatorResult<()>; /// Lookup where a namespace is served from. /// @@ -487,9 +180,6 @@ pub trait Coordinator: Send + Sync { /// /// # Arguments /// - /// * `scope` - The resolved scope identity, or `None` for unscoped - /// sessions. Coordinators use this to scope lookups (e.g., to route - /// to the correct origin for a particular application). /// * `namespace` - The namespace to look up /// /// # Returns @@ -498,7 +188,6 @@ pub trait Coordinator: Send + Sync { /// - `Err` - Namespace not found anywhere async fn lookup( &self, - scope: Option<&str>, namespace: &TrackNamespace, ) -> CoordinatorResult<(NamespaceOrigin, Option)>; @@ -511,1552 +200,4 @@ pub trait Coordinator: Send + Sync { async fn shutdown(&self) -> CoordinatorResult<()> { Ok(()) } - - // ======================================================================== - // Scope configuration - // ======================================================================== - - /// Get configuration for a resolved scope. - /// - /// Called after [`resolve_scope()`] to retrieve operational parameters - /// for the scope, such as origin fallback URLs and lingering subscriber - /// settings. - /// - /// # Arguments - /// - /// * `scope` - The resolved scope identity from [`resolve_scope()`], - /// or `None` for unscoped sessions. - /// - /// # Default Implementation - /// - /// Returns default configuration (no origin fallback, lingering subscribe - /// disabled). - /// - /// [`resolve_scope()`]: Coordinator::resolve_scope - async fn get_scope_config(&self, _scope: Option<&str>) -> CoordinatorResult { - Ok(ScopeConfig::default()) - } - - // ======================================================================== - // SUBSCRIBE_NAMESPACE support - // ======================================================================== - - /// Register interest in a namespace prefix (SUBSCRIBE_NAMESPACE). - /// - /// Called when a subscriber sends SUBSCRIBE_NAMESPACE. The coordinator - /// should: - /// 1. Record that this relay is interested in the prefix - /// 2. Return currently-matching namespaces - /// 3. Return an RAII handle for cleanup on disconnect - /// - /// When publishers later register namespaces matching this prefix, the - /// relay uses [`lookup_namespace_subscribers()`] to find interested relays - /// and forward PUBLISH_NAMESPACE to them. - /// - /// # Arguments - /// - /// * `scope` - The resolved scope identity, or `None` for unscoped sessions. - /// * `prefix` - The namespace prefix to subscribe to. - /// - /// # Default Implementation - /// - /// Returns an empty subscription (no existing namespaces, no-op cleanup). - /// - /// [`lookup_namespace_subscribers()`]: Coordinator::lookup_namespace_subscribers - async fn subscribe_namespace( - &self, - _scope: Option<&str>, - _prefix: &TrackNamespace, - ) -> CoordinatorResult { - Ok(NamespaceSubscription::default()) - } - - /// Unregister interest in a namespace prefix (UNSUBSCRIBE_NAMESPACE). - /// - /// Called when a subscriber sends UNSUBSCRIBE_NAMESPACE or disconnects. - /// This is an explicit unregistration — the subscription handle may still - /// exist but interest should be removed from the registry. - /// - /// # Arguments - /// - /// * `scope` - The resolved scope identity, or `None` for unscoped sessions. - /// * `prefix` - The namespace prefix to unsubscribe from. - /// - /// # Default Implementation - /// - /// No-op (returns success). - async fn unsubscribe_namespace( - &self, - _scope: Option<&str>, - _prefix: &TrackNamespace, - ) -> CoordinatorResult<()> { - Ok(()) - } - - /// Find relays interested in a namespace (reverse lookup). - /// - /// Called when a publisher registers a new namespace. The relay uses this - /// to find other relays that have active SUBSCRIBE_NAMESPACE subscriptions - /// matching this namespace, then forwards PUBLISH_NAMESPACE to them. - /// - /// # Arguments - /// - /// * `scope` - The resolved scope identity, or `None` for unscoped sessions. - /// * `namespace` - The newly-registered namespace. - /// - /// # Returns - /// - /// List of relay endpoints to forward PUBLISH_NAMESPACE to. - /// - /// # Default Implementation - /// - /// Returns an empty list (no subscribers). - async fn lookup_namespace_subscribers( - &self, - _scope: Option<&str>, - _namespace: &TrackNamespace, - ) -> CoordinatorResult> { - Ok(vec![]) - } - - // ======================================================================== - // Track-level PUBLISH support - // ======================================================================== - - /// Register a track as available on this relay (track-level PUBLISH). - /// - /// Called when a publisher sends PUBLISH for a specific track. This - /// provides finer-grained routing than namespace-level registration. - /// - /// # Arguments - /// - /// * `scope` - The resolved scope identity, or `None` for unscoped sessions. - /// * `namespace` - The namespace the track belongs to. - /// * `track` - The track name within the namespace. - /// - /// # Returns - /// - /// A `TrackRegistration` handle. The track remains registered as long as - /// this handle is held. Dropping it unregisters the track. - /// - /// # Default Implementation - /// - /// Returns a no-op registration handle. - async fn register_track( - &self, - _scope: Option<&str>, - _namespace: &TrackNamespace, - _track: &str, - ) -> CoordinatorResult { - Ok(TrackRegistration::default()) - } - - /// Unregister a track. - /// - /// Called when a publisher sends PUBLISH_DONE or disconnects. - /// - /// # Arguments - /// - /// * `scope` - The resolved scope identity, or `None` for unscoped sessions. - /// * `namespace` - The namespace the track belongs to. - /// * `track` - The track name to unregister. - /// - /// # Default Implementation - /// - /// No-op (returns success). - async fn unregister_track( - &self, - _scope: Option<&str>, - _namespace: &TrackNamespace, - _track: &str, - ) -> CoordinatorResult<()> { - Ok(()) - } - - /// List tracks registered under a namespace. - /// - /// Used for track discovery within a namespace, supporting SUBSCRIBE_NAMESPACE - /// workflows where subscribers need to know what tracks are available. - /// - /// # Arguments - /// - /// * `scope` - The resolved scope identity, or `None` for unscoped sessions. - /// * `namespace` - The namespace to list tracks from. - /// - /// # Default Implementation - /// - /// Returns an empty list. - async fn list_tracks( - &self, - _scope: Option<&str>, - _namespace: &TrackNamespace, - ) -> CoordinatorResult> { - Ok(vec![]) - } - - // ======================================================================== - // Lingering subscriber / rendezvous support - // ======================================================================== - // - // These methods implement the "rendezvous" concept from the MoQT - // specification (RENDEZVOUS_TIMEOUT parameter, moq-transport PR #1447), - // also known as "lingering subscribe" (moq-transport issue #1402) or - // "early media" (Cisco's original framing at IETF 122). - // - // The relay uses these to pre-register subscriber interest before a - // publisher exists, enabling subscriber-first workflows. - // - // Future: timeout handling for rendezvous — how long to wait before - // giving up on a publisher. - - /// Pre-register interest in a track that may not exist yet (rendezvous). - /// - /// Enables "subscriber-first" workflows where a subscriber can wait for a - /// publisher that hasn't connected yet. Called when - /// [`ScopeConfig::lingering_subscribe`] is true and a subscriber requests - /// a track that doesn't exist. - /// - /// Also known as "lingering subscribe" (moq-transport issue #1402) or - /// "rendezvous" (MoQT spec's `RENDEZVOUS_TIMEOUT` parameter). - /// - /// When a publisher later registers the track, the relay uses - /// [`lookup_track_subscribers()`] to find waiting subscribers. - /// - /// # Arguments - /// - /// * `scope` - The resolved scope identity, or `None` for unscoped sessions. - /// * `namespace` - The namespace the track would belong to. - /// * `track` - The track name to pre-register interest in. - /// - /// # Returns - /// - /// A `TrackSubscription` handle. Interest remains registered as long as - /// this handle is held. Dropping it removes the interest. - /// - /// # Default Implementation - /// - /// Returns a no-op subscription handle. - /// - /// [`lookup_track_subscribers()`]: Coordinator::lookup_track_subscribers - async fn subscribe_track( - &self, - _scope: Option<&str>, - _namespace: &TrackNamespace, - _track: &str, - ) -> CoordinatorResult { - Ok(TrackSubscription::default()) - } - - /// Unregister track subscription interest. - /// - /// Called when a subscriber disconnects or no longer needs the track. - /// - /// # Arguments - /// - /// * `scope` - The resolved scope identity, or `None` for unscoped sessions. - /// * `namespace` - The namespace the track belongs to. - /// * `track` - The track name to unsubscribe from. - /// - /// # Default Implementation - /// - /// No-op (returns success). - async fn unsubscribe_track( - &self, - _scope: Option<&str>, - _namespace: &TrackNamespace, - _track: &str, - ) -> CoordinatorResult<()> { - Ok(()) - } - - /// Find relays with subscribers waiting for a track (reverse lookup). - /// - /// Called when a publisher registers a track, to notify lingering - /// subscribers that the track is now available. - /// - /// # Arguments - /// - /// * `scope` - The resolved scope identity, or `None` for unscoped sessions. - /// * `namespace` - The namespace the track belongs to. - /// * `track` - The track name. - /// - /// # Returns - /// - /// List of relay endpoints with waiting subscribers. - /// - /// # Default Implementation - /// - /// Returns an empty list. - async fn lookup_track_subscribers( - &self, - _scope: Option<&str>, - _namespace: &TrackNamespace, - _track: &str, - ) -> CoordinatorResult> { - Ok(vec![]) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::collections::HashMap; - use std::sync::Mutex; - - // ======================================================================== - // Test helpers and fixtures - // ======================================================================== - - /// Helper to build a TrackNamespace from slash-separated path segments. - fn ns(path: &str) -> TrackNamespace { - TrackNamespace::from_utf8_path(path) - } - - /// Returns true if `namespace` starts with all the fields in `prefix`. - fn ns_has_prefix(namespace: &TrackNamespace, prefix: &TrackNamespace) -> bool { - namespace.fields.len() >= prefix.fields.len() - && prefix - .fields - .iter() - .zip(namespace.fields.iter()) - .all(|(p, n)| p == n) - } - - // -------------------------------------------------------------------- - // MockCoordinator — a fully in-memory Coordinator for testing - // - // The in-tree FileCoordinator and ApiCoordinator only implement the - // three required methods (register_namespace, unregister_namespace, - // lookup) and rely on defaults for all new stub methods. This mock - // provides a complete reference implementation of the full trait, - // including SUBSCRIBE_NAMESPACE, track-level PUBLISH, and lingering - // subscriber support. - // - // It serves two purposes: - // 1. Executable documentation of intended method semantics for - // external implementors who can't see the bin-only coordinators - // 2. A test fixture that exercises non-trivial behavior for the new - // methods (which no existing coordinator implements yet) - // - // Test data models a broadcast/live-streaming scenario: - // - Scopes represent content providers or tenants - // (e.g., "content-provider-123") - // - Namespaces represent broadcast events or channels - // (e.g., "sports/football/match-42", "sports/football/match-42/camera-1") - // - Tracks represent individual media renditions - // (e.g., "video-1080p", "video-480p", "audio-en", "audio-es") - // - Multiple relays in a CDN cluster subscribe to namespace prefixes - // to discover new broadcasts and forward them to edge viewers - // -------------------------------------------------------------------- - - /// In-memory state for the mock coordinator. - struct MockState { - /// Maps scope → registered namespaces (keyed by TrackNamespace) - /// Value is the relay URL that registered it. - namespaces: HashMap>, - - /// Maps scope → registered tracks → relay URL - /// Key: (namespace, track_name) - tracks: HashMap>, - - /// Maps scope → SUBSCRIBE_NAMESPACE prefixes → list of relay URLs - namespace_subscribers: HashMap>>, - - /// Maps scope → subscribed tracks → list of relay URLs - /// Key: (namespace, track_name) - track_subscribers: HashMap>>, - - /// Maps scope → ScopeConfig - scope_configs: HashMap, - - /// Maps raw connection path → ScopeInfo (for resolve_scope) - path_to_scope: HashMap, - } - - impl MockState { - fn scope_key(scope: Option<&str>) -> String { - scope.unwrap_or("").to_string() - } - - fn track_key(namespace: &TrackNamespace, track: &str) -> (TrackNamespace, String) { - (namespace.clone(), track.to_string()) - } - } - - /// Drop-based handle for namespace unregistration. - struct MockNamespaceHandle { - state: std::sync::Arc>, - scope_key: String, - namespace: TrackNamespace, - } - - impl Drop for MockNamespaceHandle { - fn drop(&mut self) { - let mut state = self.state.lock().unwrap(); - if let Some(bucket) = state.namespaces.get_mut(&self.scope_key) { - bucket.remove(&self.namespace); - } - } - } - - /// Drop-based handle for track unregistration. - struct MockTrackHandle { - state: std::sync::Arc>, - scope_key: String, - track_key: (TrackNamespace, String), - } - - impl Drop for MockTrackHandle { - fn drop(&mut self) { - let mut state = self.state.lock().unwrap(); - if let Some(bucket) = state.tracks.get_mut(&self.scope_key) { - bucket.remove(&self.track_key); - } - } - } - - /// Drop-based handle for namespace subscription cleanup. - struct MockNamespaceSubHandle { - state: std::sync::Arc>, - scope_key: String, - prefix: TrackNamespace, - relay_url: String, - } - - impl Drop for MockNamespaceSubHandle { - fn drop(&mut self) { - let mut state = self.state.lock().unwrap(); - if let Some(bucket) = state.namespace_subscribers.get_mut(&self.scope_key) { - if let Some(relays) = bucket.get_mut(&self.prefix) { - relays.retain(|r| r != &self.relay_url); - } - } - } - } - - /// Drop-based handle for track subscription cleanup. - struct MockTrackSubHandle { - state: std::sync::Arc>, - scope_key: String, - track_key: (TrackNamespace, String), - relay_url: String, - } - - impl Drop for MockTrackSubHandle { - fn drop(&mut self) { - let mut state = self.state.lock().unwrap(); - if let Some(bucket) = state.track_subscribers.get_mut(&self.scope_key) { - if let Some(relays) = bucket.get_mut(&self.track_key) { - relays.retain(|r| r != &self.relay_url); - } - } - } - } - - /// A mock coordinator that stores all state in memory. - /// - /// Provides a complete reference implementation of the Coordinator trait - /// including all new stub methods. Useful for testing relay integration - /// and as executable documentation of the intended method semantics. - struct MockCoordinator { - state: std::sync::Arc>, - /// URL of "this" relay (used when registering namespaces/tracks) - relay_url: Url, - } - - impl MockCoordinator { - fn new(relay_url: &str) -> Self { - Self { - state: std::sync::Arc::new(Mutex::new(MockState { - namespaces: HashMap::new(), - tracks: HashMap::new(), - namespace_subscribers: HashMap::new(), - track_subscribers: HashMap::new(), - scope_configs: HashMap::new(), - path_to_scope: HashMap::new(), - })), - relay_url: Url::parse(relay_url).unwrap(), - } - } - - /// Configure scope resolution: connection path → ScopeInfo. - fn add_path_mapping(&self, path: &str, scope_id: &str, permissions: ScopePermissions) { - let mut state = self.state.lock().unwrap(); - state.path_to_scope.insert( - path.to_string(), - ScopeInfo { - scope_id: scope_id.to_string(), - permissions, - }, - ); - } - - /// Configure per-scope settings. - fn set_scope_config(&self, scope: &str, config: ScopeConfig) { - let mut state = self.state.lock().unwrap(); - state.scope_configs.insert(scope.to_string(), config); - } - } - - #[async_trait] - impl Coordinator for MockCoordinator { - async fn resolve_scope( - &self, - connection_path: Option<&str>, - ) -> CoordinatorResult> { - let state = self.state.lock().unwrap(); - match connection_path { - Some(path) => { - state - .path_to_scope - .get(path) - .cloned() - .map(Some) - .ok_or(CoordinatorError::Other(anyhow::anyhow!( - "unknown path: {}", - path - ))) - } - None => Ok(None), - } - } - - async fn register_namespace( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - ) -> CoordinatorResult { - let scope_key = MockState::scope_key(scope); - let relay_url = self.relay_url.to_string(); - - { - let mut state = self.state.lock().unwrap(); - let bucket = state.namespaces.entry(scope_key.clone()).or_default(); - if bucket.contains_key(namespace) { - return Err(CoordinatorError::NamespaceAlreadyRegistered); - } - bucket.insert(namespace.clone(), relay_url); - } - - let handle = MockNamespaceHandle { - state: self.state.clone(), - scope_key, - namespace: namespace.clone(), - }; - Ok(NamespaceRegistration::new(handle)) - } - - async fn unregister_namespace( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - ) -> CoordinatorResult<()> { - let scope_key = MockState::scope_key(scope); - let mut state = self.state.lock().unwrap(); - if let Some(bucket) = state.namespaces.get_mut(&scope_key) { - bucket.remove(namespace); - } - Ok(()) - } - - async fn lookup( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - ) -> CoordinatorResult<(NamespaceOrigin, Option)> { - let scope_key = MockState::scope_key(scope); - let state = self.state.lock().unwrap(); - - let bucket = state - .namespaces - .get(&scope_key) - .ok_or(CoordinatorError::NamespaceNotFound)?; - - // Exact match first - if let Some(relay_url) = bucket.get(namespace) { - let url = Url::parse(relay_url).unwrap(); - return Ok((NamespaceOrigin::new(namespace.clone(), url, None), None)); - } - - // Prefix match (longest wins) - let mut best: Option<(&TrackNamespace, &String)> = None; - for (registered, url) in bucket { - if ns_has_prefix(namespace, registered) { - match &best { - Some((prev, _)) if registered.fields.len() > prev.fields.len() => { - best = Some((registered, url)); - } - None => { - best = Some((registered, url)); - } - _ => {} - } - } - } - - match best { - Some((matched_ns, relay_url)) => { - let url = Url::parse(relay_url).unwrap(); - Ok((NamespaceOrigin::new(matched_ns.clone(), url, None), None)) - } - None => Err(CoordinatorError::NamespaceNotFound), - } - } - - async fn get_scope_config(&self, scope: Option<&str>) -> CoordinatorResult { - let state = self.state.lock().unwrap(); - let scope_key = MockState::scope_key(scope); - Ok(state - .scope_configs - .get(&scope_key) - .cloned() - .unwrap_or_default()) - } - - async fn subscribe_namespace( - &self, - scope: Option<&str>, - prefix: &TrackNamespace, - ) -> CoordinatorResult { - let scope_key = MockState::scope_key(scope); - let relay_url = self.relay_url.to_string(); - - let mut state = self.state.lock().unwrap(); - - // Find currently-registered namespaces that match the prefix - let existing: Vec = state - .namespaces - .get(&scope_key) - .map(|bucket| { - bucket - .keys() - .filter(|ns| ns_has_prefix(ns, prefix)) - .map(|ns| NamespaceInfo::new(ns.clone())) - .collect() - }) - .unwrap_or_default(); - - // Register this relay as interested in the prefix - state - .namespace_subscribers - .entry(scope_key.clone()) - .or_default() - .entry(prefix.clone()) - .or_default() - .push(relay_url.clone()); - - let handle = MockNamespaceSubHandle { - state: self.state.clone(), - scope_key, - prefix: prefix.clone(), - relay_url, - }; - - Ok(NamespaceSubscription::new(existing, handle)) - } - - async fn unsubscribe_namespace( - &self, - scope: Option<&str>, - prefix: &TrackNamespace, - ) -> CoordinatorResult<()> { - let scope_key = MockState::scope_key(scope); - let relay_url = self.relay_url.to_string(); - let mut state = self.state.lock().unwrap(); - if let Some(bucket) = state.namespace_subscribers.get_mut(&scope_key) { - if let Some(relays) = bucket.get_mut(prefix) { - relays.retain(|r| r != &relay_url); - } - } - Ok(()) - } - - async fn lookup_namespace_subscribers( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - ) -> CoordinatorResult> { - let scope_key = MockState::scope_key(scope); - let state = self.state.lock().unwrap(); - - let mut relays = Vec::new(); - if let Some(subs) = state.namespace_subscribers.get(&scope_key) { - for (prefix, relay_urls) in subs { - if ns_has_prefix(namespace, prefix) { - for url_str in relay_urls { - relays.push(RelayInfo::new(Url::parse(url_str).unwrap())); - } - } - } - } - Ok(relays) - } - - async fn register_track( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - track: &str, - ) -> CoordinatorResult { - let scope_key = MockState::scope_key(scope); - let track_key = MockState::track_key(namespace, track); - - { - let mut state = self.state.lock().unwrap(); - state - .tracks - .entry(scope_key.clone()) - .or_default() - .insert(track_key.clone(), self.relay_url.to_string()); - } - - let handle = MockTrackHandle { - state: self.state.clone(), - scope_key, - track_key, - }; - Ok(TrackRegistration::new(handle)) - } - - async fn unregister_track( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - track: &str, - ) -> CoordinatorResult<()> { - let scope_key = MockState::scope_key(scope); - let track_key = MockState::track_key(namespace, track); - let mut state = self.state.lock().unwrap(); - if let Some(bucket) = state.tracks.get_mut(&scope_key) { - bucket.remove(&track_key); - } - Ok(()) - } - - async fn list_tracks( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - ) -> CoordinatorResult> { - let scope_key = MockState::scope_key(scope); - let state = self.state.lock().unwrap(); - - let entries = state - .tracks - .get(&scope_key) - .map(|bucket| { - bucket - .keys() - .filter_map(|(ns, track_name)| { - if ns == namespace { - Some(TrackEntry::new(ns.clone(), track_name.clone())) - } else { - None - } - }) - .collect() - }) - .unwrap_or_default(); - - Ok(entries) - } - - async fn subscribe_track( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - track: &str, - ) -> CoordinatorResult { - let scope_key = MockState::scope_key(scope); - let track_key = MockState::track_key(namespace, track); - let relay_url = self.relay_url.to_string(); - - { - let mut state = self.state.lock().unwrap(); - state - .track_subscribers - .entry(scope_key.clone()) - .or_default() - .entry(track_key.clone()) - .or_default() - .push(relay_url.clone()); - } - - let handle = MockTrackSubHandle { - state: self.state.clone(), - scope_key, - track_key, - relay_url, - }; - Ok(TrackSubscription::new(handle)) - } - - async fn unsubscribe_track( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - track: &str, - ) -> CoordinatorResult<()> { - let scope_key = MockState::scope_key(scope); - let track_key = MockState::track_key(namespace, track); - let relay_url = self.relay_url.to_string(); - let mut state = self.state.lock().unwrap(); - if let Some(bucket) = state.track_subscribers.get_mut(&scope_key) { - if let Some(relays) = bucket.get_mut(&track_key) { - relays.retain(|r| r != &relay_url); - } - } - Ok(()) - } - - async fn lookup_track_subscribers( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - track: &str, - ) -> CoordinatorResult> { - let scope_key = MockState::scope_key(scope); - let track_key = MockState::track_key(namespace, track); - let state = self.state.lock().unwrap(); - - let relays = state - .track_subscribers - .get(&scope_key) - .and_then(|bucket| bucket.get(&track_key)) - .map(|urls| { - urls.iter() - .map(|u| RelayInfo::new(Url::parse(u).unwrap())) - .collect() - }) - .unwrap_or_default(); - - Ok(relays) - } - } - - // ======================================================================== - // Type construction and defaults - // ======================================================================== - - #[test] - fn scope_config_defaults() { - let config = ScopeConfig::default(); - assert!(config.origin_fallback.is_none()); - assert!(!config.lingering_subscribe); - } - - #[test] - fn scope_config_with_origin_fallback() { - let config = ScopeConfig { - origin_fallback: Some(Url::parse("https://origin.example.com").unwrap()), - lingering_subscribe: true, - }; - assert_eq!( - config.origin_fallback.unwrap().as_str(), - "https://origin.example.com/" - ); - assert!(config.lingering_subscribe); - } - - #[test] - fn namespace_info_construction() { - let info = NamespaceInfo::new(ns("sports/football/match-42")); - assert_eq!(info.namespace.to_utf8_path(), "/sports/football/match-42"); - } - - #[test] - fn relay_info_without_addr() { - let info = RelayInfo::new(Url::parse("https://relay-us-east.example.com").unwrap()); - assert_eq!(info.url.as_str(), "https://relay-us-east.example.com/"); - assert!(info.addr.is_none()); - } - - #[test] - fn relay_info_with_direct_addr() { - let addr: SocketAddr = "10.0.1.5:4443".parse().unwrap(); - let info = RelayInfo::with_addr( - Url::parse("https://relay-us-east.example.com").unwrap(), - addr, - ); - assert_eq!(info.url.as_str(), "https://relay-us-east.example.com/"); - assert_eq!(info.addr.unwrap(), addr); - } - - #[test] - fn track_entry_construction() { - let entry = TrackEntry::new(ns("sports/football/match-42"), "video-1080p".to_string()); - assert_eq!(entry.namespace.to_utf8_path(), "/sports/football/match-42"); - assert_eq!(entry.name, "video-1080p"); - } - - #[test] - fn namespace_subscription_default_is_empty() { - let sub = NamespaceSubscription::default(); - assert!(sub.existing_namespaces.is_empty()); - } - - #[test] - fn track_registration_default_is_no_op() { - // Default handle should not panic on drop - let _reg = TrackRegistration::default(); - } - - #[test] - fn track_subscription_default_is_no_op() { - // Default handle should not panic on drop - let _sub = TrackSubscription::default(); - } - - #[test] - fn scope_permissions_publish_and_subscribe() { - assert!(ScopePermissions::ReadWrite.can_publish()); - assert!(ScopePermissions::ReadWrite.can_subscribe()); - assert!(!ScopePermissions::ReadOnly.can_publish()); - assert!(ScopePermissions::ReadOnly.can_subscribe()); - } - - // ======================================================================== - // Scope resolution - // ======================================================================== - - #[tokio::test] - async fn resolve_scope_maps_path_to_scope_identity() { - // A broadcast platform might use connection paths that encode - // a content provider identity and role. Multiple paths can map - // to the same scope with different permissions. - let coord = MockCoordinator::new("https://relay-1.example.com"); - coord.add_path_mapping( - "/provider/acme-sports/ingest", - "content-provider-123", - ScopePermissions::ReadWrite, - ); - coord.add_path_mapping( - "/provider/acme-sports/watch", - "content-provider-123", - ScopePermissions::ReadOnly, - ); - - let ingest_scope = coord - .resolve_scope(Some("/provider/acme-sports/ingest")) - .await - .unwrap() - .unwrap(); - assert_eq!(ingest_scope.scope_id, "content-provider-123"); - assert!(ingest_scope.permissions.can_publish()); - - let watch_scope = coord - .resolve_scope(Some("/provider/acme-sports/watch")) - .await - .unwrap() - .unwrap(); - assert_eq!(watch_scope.scope_id, "content-provider-123"); - assert!(!watch_scope.permissions.can_publish()); - assert!(watch_scope.permissions.can_subscribe()); - } - - #[tokio::test] - async fn resolve_scope_none_path_returns_unscoped() { - let coord = MockCoordinator::new("https://relay-1.example.com"); - let result = coord.resolve_scope(None).await.unwrap(); - assert!(result.is_none()); - } - - #[tokio::test] - async fn resolve_scope_unknown_path_returns_error() { - let coord = MockCoordinator::new("https://relay-1.example.com"); - let result = coord.resolve_scope(Some("/unknown/path")).await; - assert!(result.is_err()); - } - - // ======================================================================== - // Scope configuration - // ======================================================================== - - #[tokio::test] - async fn get_scope_config_returns_configured_settings() { - // A content provider with an origin ingest server and lingering - // subscriber support (viewers can tune in before the broadcast starts) - let coord = MockCoordinator::new("https://relay-1.example.com"); - coord.set_scope_config( - "content-provider-123", - ScopeConfig { - origin_fallback: Some(Url::parse("https://ingest.example.com/origin").unwrap()), - lingering_subscribe: true, - }, - ); - - let config = coord - .get_scope_config(Some("content-provider-123")) - .await - .unwrap(); - assert!(config.lingering_subscribe); - assert!(config.origin_fallback.is_some()); - } - - #[tokio::test] - async fn get_scope_config_unconfigured_returns_defaults() { - let coord = MockCoordinator::new("https://relay-1.example.com"); - let config = coord.get_scope_config(Some("unknown-scope")).await.unwrap(); - assert!(!config.lingering_subscribe); - assert!(config.origin_fallback.is_none()); - } - - // ======================================================================== - // Namespace registration and lookup - // ======================================================================== - - #[tokio::test] - async fn register_and_lookup_namespace() { - // Ingest server registers a broadcast namespace; edge relay looks it up - let coord = MockCoordinator::new("https://relay-1.example.com"); - let scope = Some("content-provider-123"); - - let _reg = coord - .register_namespace(scope, &ns("sports/football/match-42")) - .await - .unwrap(); - - let (origin, _client) = coord - .lookup(scope, &ns("sports/football/match-42")) - .await - .unwrap(); - - assert_eq!(origin.url().as_str(), "https://relay-1.example.com/"); - } - - #[tokio::test] - async fn lookup_prefix_matching() { - // A broadcaster registers a top-level event namespace; subscribers - // looking up specific camera angles under it should still resolve. - let coord = MockCoordinator::new("https://relay-1.example.com"); - let scope = Some("content-provider-123"); - - let _reg = coord - .register_namespace(scope, &ns("sports/football/match-42")) - .await - .unwrap(); - - // Lookup a more specific namespace (camera angle) under the event - let (origin, _) = coord - .lookup(scope, &ns("sports/football/match-42/camera-1")) - .await - .unwrap(); - assert_eq!(origin.url().as_str(), "https://relay-1.example.com/"); - } - - #[tokio::test] - async fn lookup_not_found() { - let coord = MockCoordinator::new("https://relay-1.example.com"); - let result = coord - .lookup(Some("content-provider-123"), &ns("nonexistent")) - .await; - assert!(matches!(result, Err(CoordinatorError::NamespaceNotFound))); - } - - #[tokio::test] - async fn scopes_are_isolated() { - // Two content providers using the same namespace structure should - // not see each other's registrations. - let coord = MockCoordinator::new("https://relay-1.example.com"); - - let _reg = coord - .register_namespace(Some("provider-abc"), &ns("live/main")) - .await - .unwrap(); - - // Different provider can't see it - let result = coord.lookup(Some("provider-xyz"), &ns("live/main")).await; - assert!(matches!(result, Err(CoordinatorError::NamespaceNotFound))); - - // Same provider can - let (origin, _) = coord - .lookup(Some("provider-abc"), &ns("live/main")) - .await - .unwrap(); - assert_eq!(origin.url().as_str(), "https://relay-1.example.com/"); - } - - #[tokio::test] - async fn duplicate_registration_rejected() { - let coord = MockCoordinator::new("https://relay-1.example.com"); - let scope = Some("content-provider-123"); - - let _reg = coord - .register_namespace(scope, &ns("sports/football/match-42")) - .await - .unwrap(); - - let result = coord - .register_namespace(scope, &ns("sports/football/match-42")) - .await; - assert!(matches!( - result, - Err(CoordinatorError::NamespaceAlreadyRegistered) - )); - } - - #[tokio::test] - async fn namespace_unregistered_on_handle_drop() { - let coord = MockCoordinator::new("https://relay-1.example.com"); - let scope = Some("content-provider-123"); - - { - let _reg = coord - .register_namespace(scope, &ns("sports/football/match-42")) - .await - .unwrap(); - - // Should be findable while registration is held - assert!(coord - .lookup(scope, &ns("sports/football/match-42")) - .await - .is_ok()); - } - // _reg dropped — broadcast ended, ingest disconnected - - let result = coord.lookup(scope, &ns("sports/football/match-42")).await; - assert!(matches!(result, Err(CoordinatorError::NamespaceNotFound))); - } - - // ======================================================================== - // SUBSCRIBE_NAMESPACE — namespace prefix subscriptions - // ======================================================================== - - #[tokio::test] - async fn subscribe_namespace_returns_existing_matches() { - // An edge relay subscribes to all broadcasts from a content provider's - // "sports/football" prefix. Two matches are already live; a third - // match under "sports/tennis" should NOT match. - let coord = MockCoordinator::new("https://relay-1.example.com"); - let scope = Some("content-provider-123"); - - // Two football matches already live - let _reg_match42 = coord - .register_namespace(scope, &ns("sports/football/match-42")) - .await - .unwrap(); - let _reg_match43 = coord - .register_namespace(scope, &ns("sports/football/match-43")) - .await - .unwrap(); - - // A tennis match (should NOT match the football prefix) - let _reg_tennis = coord - .register_namespace(scope, &ns("sports/tennis/open-7")) - .await - .unwrap(); - - // Subscribe to the football prefix - let sub = coord - .subscribe_namespace(scope, &ns("sports/football")) - .await - .unwrap(); - - assert_eq!(sub.existing_namespaces.len(), 2); - let paths: Vec = sub - .existing_namespaces - .iter() - .map(|n| n.namespace.to_utf8_path()) - .collect(); - assert!(paths.contains(&"/sports/football/match-42".to_string())); - assert!(paths.contains(&"/sports/football/match-43".to_string())); - } - - #[tokio::test] - async fn lookup_namespace_subscribers_finds_interested_relays() { - // An edge relay has subscribers interested in football broadcasts. - // When a new match starts (namespace registered), the origin relay - // calls lookup_namespace_subscribers to discover interested edge - // relays and forward PUBLISH_NAMESPACE to them. - let coord = MockCoordinator::new("https://edge-us-west.example.com"); - let scope = Some("content-provider-123"); - - // Edge relay subscribes to the football prefix - let _sub = coord - .subscribe_namespace(scope, &ns("sports/football")) - .await - .unwrap(); - - // New match starts — who needs to know? - let interested = coord - .lookup_namespace_subscribers(scope, &ns("sports/football/match-44")) - .await - .unwrap(); - - assert_eq!(interested.len(), 1); - assert_eq!( - interested[0].url.as_str(), - "https://edge-us-west.example.com/" - ); - } - - #[tokio::test] - async fn namespace_subscription_cleaned_up_on_drop() { - let coord = MockCoordinator::new("https://edge-us-west.example.com"); - let scope = Some("content-provider-123"); - - { - let _sub = coord - .subscribe_namespace(scope, &ns("sports/football")) - .await - .unwrap(); - - let interested = coord - .lookup_namespace_subscribers(scope, &ns("sports/football/match-44")) - .await - .unwrap(); - assert_eq!(interested.len(), 1); - } - // _sub dropped — edge relay disconnected - - let interested = coord - .lookup_namespace_subscribers(scope, &ns("sports/football/match-44")) - .await - .unwrap(); - assert!(interested.is_empty()); - } - - // ======================================================================== - // Track-level PUBLISH registration - // ======================================================================== - - #[tokio::test] - async fn register_and_list_tracks() { - // A broadcaster publishes multiple renditions (video qualities, - // audio languages) as individual tracks under a match namespace. - let coord = MockCoordinator::new("https://relay-1.example.com"); - let scope = Some("content-provider-123"); - let match_ns = ns("sports/football/match-42"); - - let _reg_1080 = coord - .register_track(scope, &match_ns, "video-1080p") - .await - .unwrap(); - let _reg_480 = coord - .register_track(scope, &match_ns, "video-480p") - .await - .unwrap(); - let _reg_audio = coord - .register_track(scope, &match_ns, "audio-en") - .await - .unwrap(); - - let tracks = coord.list_tracks(scope, &match_ns).await.unwrap(); - assert_eq!(tracks.len(), 3); - - let names: Vec<&str> = tracks.iter().map(|t| t.name.as_str()).collect(); - assert!(names.contains(&"video-1080p")); - assert!(names.contains(&"video-480p")); - assert!(names.contains(&"audio-en")); - } - - #[tokio::test] - async fn track_unregistered_on_handle_drop() { - let coord = MockCoordinator::new("https://relay-1.example.com"); - let scope = Some("content-provider-123"); - let match_ns = ns("sports/football/match-42"); - - { - let _reg = coord - .register_track(scope, &match_ns, "video-1080p") - .await - .unwrap(); - - let tracks = coord.list_tracks(scope, &match_ns).await.unwrap(); - assert_eq!(tracks.len(), 1); - } - // _reg dropped — broadcaster stopped the 1080p rendition - - let tracks = coord.list_tracks(scope, &match_ns).await.unwrap(); - assert!(tracks.is_empty()); - } - - // ======================================================================== - // Lingering subscriber / rendezvous - // ======================================================================== - - #[tokio::test] - async fn subscribe_track_before_publisher_exists() { - // A viewer tunes into a pre-game show before the main broadcast has - // started. The edge relay pre-registers interest in the track so - // that when the broadcaster begins, it can be notified immediately. - let coord = MockCoordinator::new("https://edge-eu-west.example.com"); - let scope = Some("content-provider-123"); - let match_ns = ns("sports/football/match-42"); - - // Viewer's edge relay pre-registers interest (lingering/rendezvous) - let _sub = coord - .subscribe_track(scope, &match_ns, "video-1080p") - .await - .unwrap(); - - // Broadcaster starts — origin relay checks who's waiting - let waiting = coord - .lookup_track_subscribers(scope, &match_ns, "video-1080p") - .await - .unwrap(); - assert_eq!(waiting.len(), 1); - assert_eq!(waiting[0].url.as_str(), "https://edge-eu-west.example.com/"); - - // No one is waiting for the Spanish audio (not pre-subscribed) - let waiting_es = coord - .lookup_track_subscribers(scope, &match_ns, "audio-es") - .await - .unwrap(); - assert!(waiting_es.is_empty()); - } - - #[tokio::test] - async fn track_subscription_cleaned_up_on_drop() { - let coord = MockCoordinator::new("https://edge-eu-west.example.com"); - let scope = Some("content-provider-123"); - let match_ns = ns("sports/football/match-42"); - - { - let _sub = coord - .subscribe_track(scope, &match_ns, "video-1080p") - .await - .unwrap(); - - let waiting = coord - .lookup_track_subscribers(scope, &match_ns, "video-1080p") - .await - .unwrap(); - assert_eq!(waiting.len(), 1); - } - // _sub dropped — viewer left - - let waiting = coord - .lookup_track_subscribers(scope, &match_ns, "video-1080p") - .await - .unwrap(); - assert!(waiting.is_empty()); - } - - // ======================================================================== - // End-to-end scenario: live broadcast across a relay cluster - // ======================================================================== - - #[tokio::test] - async fn broadcast_multi_relay_scenario() { - // A content provider ("content-provider-123") broadcasts a football - // match through a relay cluster: - // - // origin relay (us-east): broadcaster ingests video + audio - // edge relay (eu-west): viewers in Europe subscribe, including - // one who tunes in before halftime coverage starts - // - // This exercises namespace registration, SUBSCRIBE_NAMESPACE for - // event discovery, track-level PUBLISH, and lingering subscriber - // for pre-broadcast rendezvous. - - let origin = MockCoordinator::new("https://relay-us-east.example.com"); - let edge = MockCoordinator::new("https://edge-eu-west.example.com"); - - let scope = Some("content-provider-123"); - - // --- Origin relay: broadcaster starts the match --- - - // Register the match namespace - let _match_reg = origin - .register_namespace(scope, &ns("sports/football/match-42")) - .await - .unwrap(); - - // Register individual track renditions - let _video_1080 = origin - .register_track(scope, &ns("sports/football/match-42"), "video-1080p") - .await - .unwrap(); - let _video_480 = origin - .register_track(scope, &ns("sports/football/match-42"), "video-480p") - .await - .unwrap(); - let _audio_en = origin - .register_track(scope, &ns("sports/football/match-42"), "audio-en") - .await - .unwrap(); - - // --- Edge relay: viewers subscribe --- - - // Edge subscribes to all football events (SUBSCRIBE_NAMESPACE) - let _football_sub = edge - .subscribe_namespace(scope, &ns("sports/football")) - .await - .unwrap(); - - // A viewer pre-subscribes to halftime analysis (not started yet) - let _halftime_sub = edge - .subscribe_track( - scope, - &ns("sports/football/match-42/halftime"), - "video-720p", - ) - .await - .unwrap(); - - // When halftime coverage starts, the origin can find waiting viewers - let waiting = edge - .lookup_track_subscribers( - scope, - &ns("sports/football/match-42/halftime"), - "video-720p", - ) - .await - .unwrap(); - assert_eq!(waiting.len(), 1); - assert_eq!(waiting[0].url.as_str(), "https://edge-eu-west.example.com/"); - - // Verify the origin's track inventory - let tracks = origin - .list_tracks(scope, &ns("sports/football/match-42")) - .await - .unwrap(); - assert_eq!(tracks.len(), 3); - - // Verify scope isolation — a different provider sees nothing - let other_result = origin - .lookup(Some("other-provider"), &ns("sports/football/match-42")) - .await; - assert!(matches!( - other_result, - Err(CoordinatorError::NamespaceNotFound) - )); - } - - // ======================================================================== - // Default trait implementations (no-op behavior) - // ======================================================================== - - /// A minimal coordinator that only implements the required methods. - /// Used to verify that all defaulted methods work correctly — this is - /// what existing implementors experience after upgrading. - struct MinimalCoordinator; - - #[async_trait] - impl Coordinator for MinimalCoordinator { - async fn register_namespace( - &self, - _scope: Option<&str>, - _namespace: &TrackNamespace, - ) -> CoordinatorResult { - Ok(NamespaceRegistration::new(())) - } - - async fn unregister_namespace( - &self, - _scope: Option<&str>, - _namespace: &TrackNamespace, - ) -> CoordinatorResult<()> { - Ok(()) - } - - async fn lookup( - &self, - _scope: Option<&str>, - _namespace: &TrackNamespace, - ) -> CoordinatorResult<(NamespaceOrigin, Option)> { - Err(CoordinatorError::NamespaceNotFound) - } - } - - #[tokio::test] - async fn default_resolve_scope_passes_through_path() { - let coord = MinimalCoordinator; - let scope = coord - .resolve_scope(Some("/provider/acme-sports")) - .await - .unwrap() - .unwrap(); - assert_eq!(scope.scope_id, "/provider/acme-sports"); - assert!(scope.permissions.can_publish()); - assert!(scope.permissions.can_subscribe()); - } - - #[tokio::test] - async fn default_resolve_scope_none_is_unscoped() { - let coord = MinimalCoordinator; - let result = coord.resolve_scope(None).await.unwrap(); - assert!(result.is_none()); - } - - #[tokio::test] - async fn default_get_scope_config_returns_defaults() { - let coord = MinimalCoordinator; - let config = coord.get_scope_config(Some("any-scope")).await.unwrap(); - assert!(!config.lingering_subscribe); - assert!(config.origin_fallback.is_none()); - } - - #[tokio::test] - async fn default_subscribe_namespace_returns_empty() { - let coord = MinimalCoordinator; - let sub = coord - .subscribe_namespace(Some("scope"), &ns("sports/football")) - .await - .unwrap(); - assert!(sub.existing_namespaces.is_empty()); - } - - #[tokio::test] - async fn default_lookup_namespace_subscribers_returns_empty() { - let coord = MinimalCoordinator; - let relays = coord - .lookup_namespace_subscribers(Some("scope"), &ns("sports/football/match-42")) - .await - .unwrap(); - assert!(relays.is_empty()); - } - - #[tokio::test] - async fn default_register_track_returns_no_op_handle() { - let coord = MinimalCoordinator; - let _reg = coord - .register_track( - Some("scope"), - &ns("sports/football/match-42"), - "video-1080p", - ) - .await - .unwrap(); - // Handle drops without panic - } - - #[tokio::test] - async fn default_list_tracks_returns_empty() { - let coord = MinimalCoordinator; - let tracks = coord - .list_tracks(Some("scope"), &ns("sports/football/match-42")) - .await - .unwrap(); - assert!(tracks.is_empty()); - } - - #[tokio::test] - async fn default_subscribe_track_returns_no_op_handle() { - let coord = MinimalCoordinator; - let _sub = coord - .subscribe_track( - Some("scope"), - &ns("sports/football/match-42"), - "video-1080p", - ) - .await - .unwrap(); - // Handle drops without panic - } - - #[tokio::test] - async fn default_lookup_track_subscribers_returns_empty() { - let coord = MinimalCoordinator; - let relays = coord - .lookup_track_subscribers( - Some("scope"), - &ns("sports/football/match-42"), - "video-1080p", - ) - .await - .unwrap(); - assert!(relays.is_empty()); - } } diff --git a/moq-relay-ietf/src/lib.rs b/moq-relay-ietf/src/lib.rs index c469a730..6b96c60d 100644 --- a/moq-relay-ietf/src/lib.rs +++ b/moq-relay-ietf/src/lib.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! MoQ Relay library for building Media over QUIC relay servers. //! //! This crate provides the core relay functionality that can be embedded @@ -35,11 +32,12 @@ mod api; mod consumer; mod coordinator; mod local; -pub mod metrics; mod producer; mod relay; mod remote; mod session; +mod subscriber_registry; +mod top_n_tracker; mod web; pub use api::*; @@ -48,6 +46,8 @@ pub use coordinator::*; pub use local::*; pub use producer::*; pub use relay::*; -pub use remote::RemoteManager; +pub use remote::*; pub use session::*; +pub use subscriber_registry::*; +pub use top_n_tracker::*; pub use web::*; diff --git a/moq-relay-ietf/src/local.rs b/moq-relay-ietf/src/local.rs index 26758ced..e56211b3 100644 --- a/moq-relay-ietf/src/local.rs +++ b/moq-relay-ietf/src/local.rs @@ -1,40 +1,252 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::collections::hash_map; use std::collections::HashMap; - -use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; use moq_transport::{ coding::TrackNamespace, - serve::{ServeError, TracksReader}, + data::ExtensionHeaders, + serve::{ServeError, Track, TrackReader, TrackWriter, TracksReader, TracksWriter}, }; +use tokio::sync::watch; + +#[repr(u8)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum TrackState { + Pending = 0, + Subscribing = 1, + Subscribed = 2, + Publishing = 3, + Closed = 4, +} + +impl TrackState { + fn from_u8(v: u8) -> Self { + match v { + 0 => TrackState::Pending, + 1 => TrackState::Subscribing, + 2 => TrackState::Subscribed, + 3 => TrackState::Publishing, + _ => TrackState::Closed, + } + } +} + +pub struct TrackInfo { + pub namespace: TrackNamespace, + pub name: String, + + state: AtomicU8, + track_reader: OnceLock, + track_writer: Mutex>, + upstream_subscribe_sent: AtomicBool, + upstream_request_id: Mutex>, + + /// The PUBLISH request ID (set when publisher sends PUBLISH) + publish_request_id: Mutex>, + /// Forward state: true = forwarding, false = paused + /// Publisher watches this to know when to start/stop sending + forward_state: Mutex>>, + /// Receiver side for forward state changes + forward_receiver: Mutex>>, + /// Track extensions from the original PUBLISH message + track_extensions: Mutex>, +} + +impl TrackInfo { + pub fn new(namespace: TrackNamespace, name: String) -> Self { + Self { + namespace, + name, + state: AtomicU8::new(TrackState::Pending as u8), + track_reader: OnceLock::new(), + track_writer: Mutex::new(None), + upstream_subscribe_sent: AtomicBool::new(false), + upstream_request_id: Mutex::new(None), + publish_request_id: Mutex::new(None), + forward_state: Mutex::new(None), + forward_receiver: Mutex::new(None), + track_extensions: Mutex::new(None), + } + } + + pub fn get_reader(&self) -> TrackReader { + self.ensure_track_created(); + self.track_reader.get().unwrap().clone() + } + + pub fn should_subscribe_upstream(&self) -> bool { + let state = self.state(); + + if state == TrackState::Publishing { + return false; + } + + !self.upstream_subscribe_sent.swap(true, Ordering::SeqCst) + } + + pub fn mark_subscribe_sent(&self, request_id: u64) { + *self.upstream_request_id.lock().unwrap() = Some(request_id); + + let _ = self.state.compare_exchange( + TrackState::Pending as u8, + TrackState::Subscribing as u8, + Ordering::SeqCst, + Ordering::SeqCst, + ); + } + + pub fn subscribe_ok_received(&self) { + let _ = self.state.compare_exchange( + TrackState::Subscribing as u8, + TrackState::Subscribed as u8, + Ordering::SeqCst, + Ordering::SeqCst, + ); + } + + pub fn publish_arrived(&self) -> Result { + self.ensure_track_created(); + + let current_state = self.state(); + + if current_state == TrackState::Subscribed { + return Err(ServeError::Uninterested); + } + + if current_state == TrackState::Publishing { + return Err(ServeError::Duplicate); + } + + self.state + .store(TrackState::Publishing as u8, Ordering::SeqCst); + + self.track_writer + .lock() + .unwrap() + .take() + .ok_or(ServeError::Duplicate) + } + + + pub fn state(&self) -> TrackState { + TrackState::from_u8(self.state.load(Ordering::SeqCst)) + } + + pub fn is_publishing(&self) -> bool { + self.state() == TrackState::Publishing + } + + /// Set up forward state tracking when PUBLISH is received. + /// Returns the initial forward value that was set. + pub fn set_publish_info(&self, request_id: u64, initial_forward: bool) { + *self.publish_request_id.lock().unwrap() = Some(request_id); + + let (tx, rx) = watch::channel(initial_forward); + *self.forward_state.lock().unwrap() = Some(tx); + *self.forward_receiver.lock().unwrap() = Some(rx); + + log::debug!( + "set_publish_info: track {}/{} request_id={} initial_forward={}", + self.namespace, + self.name, + request_id, + initial_forward + ); + } + + /// Get the PUBLISH request ID + pub fn publish_request_id(&self) -> Option { + *self.publish_request_id.lock().unwrap() + } + + /// Get current forward state + pub fn is_forwarding(&self) -> bool { + self.forward_receiver + .lock() + .unwrap() + .as_ref() + .map(|rx| *rx.borrow()) + .unwrap_or(true) // Default to true if not set (legacy behavior) + } + + /// Request forwarding to start (called when a subscriber arrives). + /// Returns true if the state changed from false to true. + pub fn request_forward(&self) -> bool { + if let Some(tx) = self.forward_state.lock().unwrap().as_ref() { + let current = *tx.borrow(); + if !current { + let _ = tx.send(true); + log::info!( + "request_forward: track {}/{} forward state changed 0 -> 1", + self.namespace, + self.name + ); + return true; + } + } + false + } + + /// Get a receiver for forward state changes (for the publisher to watch) + pub fn forward_receiver(&self) -> Option> { + self.forward_receiver.lock().unwrap().clone() + } + + /// Set track extensions from the original PUBLISH message + pub fn set_track_extensions(&self, extensions: ExtensionHeaders) { + *self.track_extensions.lock().unwrap() = Some(extensions); + } + + /// Get track extensions (cloned) + pub fn track_extensions(&self) -> Option { + self.track_extensions.lock().unwrap().clone() + } + + pub fn take_writer_for_upstream(&self) -> Result { + self.ensure_track_created(); + + let current_state = self.state(); -use crate::metrics::GaugeGuard; - -/// Scope key for the outer level of the two-level registry. -/// -/// An empty string (`""`) represents the global/unscoped bucket. All unscoped -/// connections share this bucket — any publisher without a scope can be reached -/// by any subscriber without a scope. This is the default behavior for backward -/// compatibility with pre-scope deployments. -/// -/// We use `String` rather than `Option` so that `HashMap::get` can -/// accept a `&str` via the `Borrow` trait, avoiding a heap allocation on -/// every lookup in `retrieve()`. -type ScopeKey = String; - -/// The scope key used for unscoped (global) registrations. -const UNSCOPED: &str = ""; - -/// Registry of local tracks, indexed by (scope, namespace). -/// -/// Uses a two-level map so that `retrieve()` only scans namespaces within -/// the matching scope, rather than iterating all namespaces across all scopes. + if current_state == TrackState::Publishing { + return Err(ServeError::Duplicate); + } + + if current_state == TrackState::Subscribing || current_state == TrackState::Subscribed { + return Err(ServeError::Duplicate); + } + + self.state + .store(TrackState::Subscribing as u8, Ordering::SeqCst); + + self.track_writer + .lock() + .unwrap() + .take() + .ok_or(ServeError::Duplicate) + } + + fn ensure_track_created(&self) { + self.track_reader.get_or_init(|| { + let (writer, reader) = Track::new(self.namespace.clone(), self.name.clone()).produce(); + *self.track_writer.lock().unwrap() = Some(writer); + reader + }); + } +} + +struct LocalsEntry { + /// reader and writer hold the readers and writers for a namespace + reader: TracksReader, + writer: TracksWriter, + /// tracks holds the individual tracks for a namespace + tracks: Mutex>>, +} + +/// Locals is a map of TrackNamespace to LocalsEntry #[derive(Clone)] pub struct Locals { - lookup: Arc>>>, + lookup: Arc>>, } impl Default for Locals { @@ -43,7 +255,6 @@ impl Default for Locals { } } -/// Local tracks registry. impl Locals { pub fn new() -> Self { Self { @@ -51,60 +262,172 @@ impl Locals { } } - /// Register new local tracks. - /// - /// `scope` is the resolved scope identity from `Coordinator::resolve_scope()`, - /// or `None` for unscoped sessions. Registrations are keyed by `(scope, namespace)`, - /// so the same namespace in different scopes routes independently. pub async fn register( &mut self, - scope: Option<&str>, - tracks: TracksReader, + reader: TracksReader, + writer: TracksWriter, ) -> anyhow::Result { - let namespace = tracks.namespace.clone(); - let scope_key = scope.unwrap_or(UNSCOPED).to_string(); + let namespace = reader.namespace.clone(); - // Insert the tracks into the scope bucket - let mut lookup = self.lookup.lock().unwrap(); - let bucket = lookup.entry(scope_key.clone()).or_default(); - match bucket.entry(namespace.clone()) { - hash_map::Entry::Vacant(entry) => entry.insert(tracks), + match self.lookup.lock().unwrap().entry(namespace.clone()) { + hash_map::Entry::Vacant(entry) => entry.insert(LocalsEntry { + reader, + writer, + tracks: Mutex::new(HashMap::new()), + }), hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), }; let registration = Registration { locals: self.clone(), - scope_key, namespace, - _gauge_guard: GaugeGuard::new("moq_relay_announced_namespaces"), }; Ok(registration) } - /// Retrieve local tracks by namespace using hierarchical prefix matching. - /// Returns the TracksReader for the longest matching namespace prefix. - /// - /// `scope` is the resolved scope identity from `Coordinator::resolve_scope()`, - /// or `None` for unscoped sessions. When `scope` is `None`, only tracks - /// registered without a scope (the global/unscoped bucket) are searched. - pub fn retrieve( + pub fn retrieve(&self, namespace: &TrackNamespace) -> Option { + let lookup = self.lookup.lock().unwrap(); + + let mut best_match: Option = None; + let mut best_len = 0; + + for (registered_ns, entry) in lookup.iter() { + if namespace.fields.len() >= registered_ns.fields.len() { + let is_prefix = registered_ns + .fields + .iter() + .zip(namespace.fields.iter()) + .all(|(a, b)| a == b); + + if is_prefix && registered_ns.fields.len() > best_len { + best_match = Some(entry.reader.clone()); + best_len = registered_ns.fields.len(); + } + } + } + + best_match + } + + pub fn get_or_create_track_info( &self, - scope: Option<&str>, namespace: &TrackNamespace, - ) -> Option { + track_name: &str, + ) -> Option> { let lookup = self.lookup.lock().unwrap(); - // Look up the scope bucket directly — O(1), zero allocation. - // HashMap::get accepts &str via Borrow. - let bucket = lookup.get(scope.unwrap_or(UNSCOPED))?; + let entry = Self::find_best_match_entry(&lookup, namespace)?; - // Find the longest matching prefix within this scope - let mut best_match: Option = None; + // Use full namespace + track_name as key to avoid collisions + let track_key = format!("{}:{}", namespace, track_name); + + let mut tracks = entry.tracks.lock().unwrap(); + + let track_info = tracks + .entry(track_key) + .or_insert_with(|| Arc::new(TrackInfo::new(namespace.clone(), track_name.to_string()))) + .clone(); + + Some(track_info) + } + + /// Get or create track info, auto-registering the namespace if needed. + /// This supports the SUBSCRIBE_NAMESPACE flow where PUBLISH can arrive + /// without a prior PUBLISH_NAMESPACE. + pub fn get_or_create_track_info_auto_register( + &self, + namespace: &TrackNamespace, + track_name: &str, + ) -> Arc { + let mut lookup = self.lookup.lock().unwrap(); + + // Use full namespace + track_name as key to avoid collisions + // when different namespaces have the same track_name + let track_key = format!("{}:{}", namespace, track_name); + + // Check if there's an existing exact-match namespace entry that's stale + // and needs to be removed (this happens when publisher disconnects and reconnects) + let should_remove_namespace = if let Some(entry) = lookup.get(namespace) { + let tracks = entry.tracks.lock().unwrap(); + if let Some(existing) = tracks.get(&track_key) { + // Track exists and is in Publishing state but has no writer = stale + existing.state() == TrackState::Publishing + && existing.track_writer.lock().unwrap().is_none() + } else { + false + } + } else { + false + }; + + if should_remove_namespace { + log::info!( + "removing stale namespace entry {} (track {}/{} was Publishing with no writer)", + namespace, + namespace, + track_name + ); + lookup.remove(namespace); + } + + // First try to find an existing matching namespace entry + if let Some(entry) = Self::find_best_match_entry(&lookup, namespace) { + let mut tracks = entry.tracks.lock().unwrap(); + + return tracks + .entry(track_key.clone()) + .or_insert_with(|| { + Arc::new(TrackInfo::new(namespace.clone(), track_name.to_string())) + }) + .clone(); + } + + // No matching namespace found - auto-register for SUBSCRIBE_NAMESPACE flow + log::info!( + "auto-registering namespace {} for PUBLISH (no prior PUBLISH_NAMESPACE)", + namespace + ); + + let (writer, _request, reader) = + moq_transport::serve::Tracks::new(namespace.clone()).produce(); + + let entry = lookup.entry(namespace.clone()).or_insert(LocalsEntry { + reader, + writer, + tracks: Mutex::new(HashMap::new()), + }); + + let mut tracks = entry.tracks.lock().unwrap(); + tracks + .entry(track_key) + .or_insert_with(|| Arc::new(TrackInfo::new(namespace.clone(), track_name.to_string()))) + .clone() + } + + pub fn get_track_info( + &self, + namespace: &TrackNamespace, + track_name: &str, + ) -> Option> { + let lookup = self.lookup.lock().unwrap(); + + let entry = Self::find_best_match_entry(&lookup, namespace)?; + + // Use full namespace + track_name as key to match get_or_create_track_info + let track_key = format!("{}:{}", namespace, track_name); + let tracks = entry.tracks.lock().unwrap(); + tracks.get(&track_key).cloned() + } + + fn find_best_match_entry<'a>( + lookup: &'a HashMap, + namespace: &TrackNamespace, + ) -> Option<&'a LocalsEntry> { + let mut best_match: Option<&LocalsEntry> = None; let mut best_len = 0; - for (registered_ns, tracks) in bucket.iter() { - // Check if registered_ns is a prefix of namespace + for (registered_ns, entry) in lookup.iter() { if namespace.fields.len() >= registered_ns.fields.len() { let is_prefix = registered_ns .fields @@ -113,7 +436,7 @@ impl Locals { .all(|(a, b)| a == b); if is_prefix && registered_ns.fields.len() > best_len { - best_match = Some(tracks.clone()); + best_match = Some(entry); best_len = registered_ns.fields.len(); } } @@ -121,34 +444,107 @@ impl Locals { best_match } + + pub fn insert_track( + &self, + namespace: &TrackNamespace, + track_reader: TrackReader, + ) -> Option<()> { + let mut lookup = self.lookup.lock().unwrap(); + + if let Some(entry) = lookup.get_mut(namespace) { + entry.writer.insert(track_reader) + } else { + None + } + } + + pub fn subscribe_upstream(&self, track_info: Arc) -> Option { + let mut lookup = self.lookup.lock().unwrap(); + + let entry = lookup.get_mut(&track_info.namespace)?; + + let writer = track_info.take_writer_for_upstream().ok()?; + let reader = track_info.get_reader(); + + entry.reader.forward_upstream(writer)?; + + let namespace = track_info.namespace.clone(); + + let entry_mut = lookup + .iter_mut() + .find(|(ns, _)| { + namespace.fields.len() >= ns.fields.len() + && ns + .fields + .iter() + .zip(namespace.fields.iter()) + .all(|(a, b)| a == b) + }) + .map(|(_, e)| e)?; + + entry_mut.writer.insert(reader.clone()); + + Some(reader) + } + + pub fn matching_namespaces(&self, prefix: &TrackNamespace) -> Vec { + let lookup = self.lookup.lock().unwrap(); + + lookup + .keys() + .filter(|ns| { + if ns.fields.len() >= prefix.fields.len() { + prefix + .fields + .iter() + .zip(ns.fields.iter()) + .all(|(a, b)| a == b) + } else { + false + } + }) + .cloned() + .collect() + } + + /// Get all tracks in namespaces matching a prefix that are in Publishing state. + /// Returns (namespace, track_name, TrackInfo) tuples. + pub fn matching_tracks(&self, prefix: &TrackNamespace) -> Vec<(TrackNamespace, String, Arc)> { + let lookup = self.lookup.lock().unwrap(); + + let mut result = Vec::new(); + + for (ns, entry) in lookup.iter() { + // Check if namespace matches prefix + if ns.fields.len() >= prefix.fields.len() + && prefix + .fields + .iter() + .zip(ns.fields.iter()) + .all(|(a, b)| a == b) + { + // Get all tracks in this namespace that are publishing + let tracks = entry.tracks.lock().unwrap(); + for (key, track_info) in tracks.iter() { + if track_info.is_publishing() { + result.push((ns.clone(), track_info.name.clone(), track_info.clone())); + } + } + } + } + + result + } } pub struct Registration { locals: Locals, - scope_key: ScopeKey, namespace: TrackNamespace, - /// Gauge guard for tracking announced namespaces - decrements on drop - _gauge_guard: GaugeGuard, } -/// Deregister local tracks on drop. impl Drop for Registration { fn drop(&mut self) { - let ns = self.namespace.to_utf8_path(); - let scope = if self.scope_key.is_empty() { - "" - } else { - &self.scope_key - }; - tracing::debug!(namespace = %ns, scope = %scope, "deregistering namespace from locals"); - - let mut lookup = self.locals.lookup.lock().unwrap(); - if let Some(bucket) = lookup.get_mut(self.scope_key.as_str()) { - bucket.remove(&self.namespace); - // Clean up empty scope buckets to avoid memory leaks - if bucket.is_empty() { - lookup.remove(self.scope_key.as_str()); - } - } + self.locals.lookup.lock().unwrap().remove(&self.namespace); } } diff --git a/moq-relay-ietf/src/metrics.rs b/moq-relay-ietf/src/metrics.rs deleted file mode 100644 index aac3455b..00000000 --- a/moq-relay-ietf/src/metrics.rs +++ /dev/null @@ -1,210 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -//! Metrics instrumentation for moq-relay-ietf -//! -//! Metrics are always compiled in via the [`metrics`] crate facade. When no -//! recorder is installed the overhead is negligible (an atomic load + early -//! return per call site), similar to how the `log` crate works when no logger -//! is configured. -//! -//! To actually collect metrics, install a recorder at startup. The optional -//! `metrics-prometheus` feature adds a Prometheus exporter — see the binary -//! in `src/bin/moq-relay-ietf/main.rs` for an example. -//! -//! # Available Metrics -//! -//! All metrics are prefixed with `moq_relay_` to avoid collisions. -//! -//! ## Counters -//! -//! | Name | Labels | Description | -//! |------|--------|-------------| -//! | `moq_relay_connections_total` | - | Total incoming connections accepted | -//! | `moq_relay_connections_closed_total` | - | Total connections that have closed (graceful or error) | -//! | `moq_relay_connection_errors_total` | `stage` | Connection failures (stage: session_accept, session_run) | -//! | `moq_relay_publishers_total` | - | Total publishers (ANNOUNCE requests) received | -//! | `moq_relay_announce_ok_total` | - | Successful ANNOUNCE_OK responses sent | -//! | `moq_relay_announce_errors_total` | `phase` | Announce failures (phase: coordinator_register, local_register, send_ok) | -//! | `moq_relay_subscribers_total` | - | Total subscribers (SUBSCRIBE requests) received | -//! | `moq_relay_subscribe_not_found_total` | - | Track not found after checking all sources | -//! | `moq_relay_subscribe_route_errors_total` | - | Infrastructure failure when routing to remote | -//! | `moq_relay_upstream_errors_total` | `stage` | Upstream connection failures (stage: connect, session) | -//! -//! ## Gauges -//! -//! | Name | Description | -//! |------|-------------| -//! | `moq_relay_active_connections` | Current number of active client connections | -//! | `moq_relay_active_publishers` | Current number of active publishers | -//! | `moq_relay_active_subscriptions` | Current number of active subscriptions | -//! | `moq_relay_active_tracks` | Current number of tracks being served | -//! | `moq_relay_announced_namespaces` | Current number of registered namespaces | -//! | `moq_relay_upstream_connections` | Current number of upstream/origin connections | -//! -//! ## Histograms -//! -//! | Name | Labels | Description | -//! |------|--------|-------------| -//! | `moq_relay_subscribe_latency_seconds` | `source` | Time to resolve subscription (source: local, remote, not_found, route_error) | - -use metrics::{describe_counter, describe_gauge, describe_histogram, Unit}; - -// ============================================================================ -// describe_metrics - Register metric descriptions for Prometheus HELP text -// ============================================================================ - -/// Register metric descriptions with the metrics recorder. -/// -/// Call this once after installing a metrics recorder (e.g., Prometheus exporter). -/// The descriptions appear as `# HELP` comments in Prometheus output. -pub fn describe_metrics() { - // Counters - describe_counter!( - "moq_relay_connections_total", - "Total incoming connections accepted" - ); - describe_counter!( - "moq_relay_connections_closed_total", - "Total connections that have closed (graceful or error)" - ); - describe_counter!( - "moq_relay_connection_errors_total", - "Connection failures by stage (session_accept, session_run)" - ); - describe_counter!( - "moq_relay_publishers_total", - "Total publishers (ANNOUNCE requests) received" - ); - describe_counter!( - "moq_relay_announce_ok_total", - "Successful ANNOUNCE_OK responses sent" - ); - describe_counter!( - "moq_relay_announce_errors_total", - "Announce failures by phase (coordinator_register, local_register, send_ok)" - ); - describe_counter!( - "moq_relay_subscribers_total", - "Total subscribers (SUBSCRIBE requests) received" - ); - describe_counter!( - "moq_relay_subscribe_not_found_total", - "Track not found after checking all sources" - ); - describe_counter!( - "moq_relay_subscribe_route_errors_total", - "Infrastructure failure when routing to remote" - ); - describe_counter!( - "moq_relay_upstream_errors_total", - "Upstream connection failures by stage (connect, session)" - ); - - // Gauges - describe_gauge!( - "moq_relay_active_connections", - "Current number of active client connections" - ); - describe_gauge!( - "moq_relay_active_publishers", - "Current number of active publishers" - ); - describe_gauge!( - "moq_relay_active_subscriptions", - "Current number of active subscriptions" - ); - describe_gauge!( - "moq_relay_active_tracks", - "Current number of tracks being served" - ); - describe_gauge!( - "moq_relay_announced_namespaces", - "Current number of registered namespaces" - ); - describe_gauge!( - "moq_relay_upstream_connections", - "Current number of upstream/origin connections" - ); - - // Histograms - describe_histogram!( - "moq_relay_subscribe_latency_seconds", - Unit::Seconds, - "Time to resolve subscription by source (local, remote, not_found, route_error)" - ); -} - -// ============================================================================ -// GaugeGuard - RAII guard for gauge increment/decrement -// ============================================================================ - -/// RAII guard that increments a gauge on creation and decrements on drop. -#[must_use = "GaugeGuard must be held for the duration you want the gauge incremented"] -pub struct GaugeGuard { - name: &'static str, -} - -impl GaugeGuard { - pub fn new(name: &'static str) -> Self { - metrics::gauge!(name).increment(1.0); - Self { name } - } -} - -impl Drop for GaugeGuard { - fn drop(&mut self) { - metrics::gauge!(self.name).decrement(1.0); - } -} - -// ============================================================================ -// TimingGuard - RAII guard for recording duration histograms -// ============================================================================ - -/// RAII guard that records elapsed time to a histogram on drop. -#[must_use = "TimingGuard must be held for the duration you want to measure"] -pub struct TimingGuard { - name: &'static str, - start: std::time::Instant, - labels: Option<(&'static str, &'static str)>, -} - -impl TimingGuard { - #[allow(dead_code)] // Keep API available for future histograms without labels - pub fn new(name: &'static str) -> Self { - Self { - name, - start: std::time::Instant::now(), - labels: None, - } - } - - pub fn with_label( - name: &'static str, - label_key: &'static str, - label_value: &'static str, - ) -> Self { - Self { - name, - start: std::time::Instant::now(), - labels: Some((label_key, label_value)), - } - } - - /// Update the label value (useful when outcome determines the label) - pub fn set_label(&mut self, label_key: &'static str, label_value: &'static str) { - self.labels = Some((label_key, label_value)); - } -} - -impl Drop for TimingGuard { - fn drop(&mut self) { - let elapsed = self.start.elapsed().as_secs_f64(); - if let Some((key, value)) = self.labels { - metrics::histogram!(self.name, key => value).record(elapsed); - } else { - metrics::histogram!(self.name).record(elapsed); - } - } -} diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 9387b6a1..72b703fd 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -1,75 +1,86 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_transport::{ serve::{ServeError, TracksReader}, - session::{Publisher, SessionError, Subscribed, TrackStatusRequested}, + session::{ + PublishNamespace, Publisher, SessionError, SubscribeNamespaceReceived, Subscribed, + TrackStatusRequested, + }, }; -use crate::{ - metrics::{GaugeGuard, TimingGuard}, - Locals, RemoteManager, -}; +use crate::{Locals, RemotesConsumer, SubscriberRegistry}; /// Producer of tracks to a remote Subscriber #[derive(Clone)] pub struct Producer { publisher: Publisher, locals: Locals, - remotes: RemoteManager, - /// The resolved scope identity for this session, if any. - /// Produced by `Coordinator::resolve_scope()` from the connection path. - /// Passed to locals/remotes to isolate namespace lookups. - scope: Option, + remotes: Option, + subscriber_registry: Option, + session_id: u64, } impl Producer { - pub fn new( + pub fn new(publisher: Publisher, locals: Locals, remotes: Option) -> Self { + Self { + publisher, + locals, + remotes, + subscriber_registry: None, + session_id: 0, + } + } + + /// Creates a producer with a subscriber registry. + pub fn with_registry( publisher: Publisher, locals: Locals, - remotes: RemoteManager, - scope: Option, + remotes: Option, + subscriber_registry: SubscriberRegistry, + session_id: u64, ) -> Self { Self { publisher, locals, remotes, - scope, + subscriber_registry: Some(subscriber_registry), + session_id, } } - /// Announce new tracks to the remote server. - pub async fn announce(&mut self, tracks: TracksReader) -> Result<(), SessionError> { - self.publisher.announce(tracks).await + pub async fn publish_namespace( + &mut self, + tracks: TracksReader, + ) -> Result { + self.publisher + .publish_namespace(tracks.namespace.clone()) + .await } - /// Run the producer to serve subscribe requests. pub async fn run(self) -> Result<(), SessionError> { + //let mut tasks = FuturesUnordered::new(); let mut tasks: FuturesUnordered> = FuturesUnordered::new(); loop { let mut publisher_subscribed = self.publisher.clone(); let mut publisher_track_status = self.publisher.clone(); + let mut publisher_subscribe_ns = self.publisher.clone(); tokio::select! { // Handle a new subscribe request Some(subscribed) = publisher_subscribed.subscribed() => { - metrics::counter!("moq_relay_subscribers_total").increment(1); - let this = self.clone(); // Spawn a new task to handle the subscribe tasks.push(async move { let info = subscribed.clone(); - let namespace = info.track_namespace.to_utf8_path(); - let track_name = info.track_name.clone(); - tracing::info!(namespace = %namespace, track = %track_name, "serving subscribe: {:?}", info); + log::info!("serving subscribe: {:?}", info); // Serve the subscribe request if let Err(err) = this.serve_subscribe(subscribed).await { - tracing::warn!(namespace = %namespace, track = %track_name, error = %err, "failed serving subscribe: {:?}, error: {}", info, err); + log::warn!("failed serving subscribe: {:?}, error: {}", info, err); } }.boxed()) }, @@ -80,13 +91,23 @@ impl Producer { // Spawn a new task to handle the track_status request tasks.push(async move { let info = track_status_requested.request_msg.clone(); - let namespace = info.track_namespace.to_utf8_path(); - let track_name = info.track_name.clone(); - tracing::info!(namespace = %namespace, track = %track_name, "serving track_status: {:?}", info); + log::info!("serving track_status: {:?}", info); // Serve the track_status request if let Err(err) = this.serve_track_status(track_status_requested).await { - tracing::warn!(namespace = %namespace, track = %track_name, error = %err, "failed serving track_status: {:?}, error: {}", info, err) + log::warn!("failed serving track_status: {:?}, error: {}", info, err) + } + }.boxed()) + }, + Some(subscribe_ns) = publisher_subscribe_ns.subscribe_namespace_received() => { + let this = self.clone(); + + tasks.push(async move { + let info = subscribe_ns.info.clone(); + log::info!("serving subscribe_namespace: {:?}", info); + + if let Err(err) = this.serve_subscribe_namespace(subscribe_ns).await { + log::warn!("failed serving subscribe_namespace: {:?}, error: {}", info, err) } }.boxed()) }, @@ -96,70 +117,68 @@ impl Producer { } } - /// Serve a subscribe request. async fn serve_subscribe(self, subscribed: Subscribed) -> Result<(), anyhow::Error> { - // Track subscribe latency from request to track resolution (records on drop) - let mut timing_guard = - TimingGuard::with_label("moq_relay_subscribe_latency_seconds", "source", "not_found"); - // Track active subscriptions - decrements when this function returns - let _sub_guard = GaugeGuard::new("moq_relay_active_subscriptions"); - let namespace = subscribed.track_namespace.clone(); let track_name = subscribed.track_name.clone(); - // Check local tracks first, and serve from local if possible - if let Some(mut local) = self.locals.retrieve(self.scope.as_deref(), &namespace) { - // Pass the full requested namespace, not the announced prefix - if let Some(track) = local.subscribe(namespace.clone(), &track_name) { - let ns = namespace.to_utf8_path(); - tracing::info!(namespace = %ns, track = %track_name, source = "local", "serving subscribe from local: {:?}", track.info); - // Update label to indicate local source, timing recorded on drop - timing_guard.set_label("source", "local"); - // Track active tracks - decrements when serve completes - let _track_guard = GaugeGuard::new("moq_relay_active_tracks"); - return Ok(subscribed.serve(track).await?); - } - } - - // Check remote tracks second, and serve from remote if possible - match self - .remotes - .subscribe(self.scope.as_deref(), &namespace, &track_name) - .await + if let Some(track_info) = self + .locals + .get_or_create_track_info(&namespace, &track_name) { - Ok(track) => { - if let Some(track) = track { - let ns = namespace.to_utf8_path(); - tracing::info!(namespace = %ns, track = %track_name, source = "remote", "serving subscribe from remote: {:?}", track.info); - // Update label to indicate remote source, timing recorded on drop - timing_guard.set_label("source", "remote"); - // Track active tracks - decrements when serve completes - let _track_guard = GaugeGuard::new("moq_relay_active_tracks"); - return Ok(subscribed.serve(track).await?); + if track_info.should_subscribe_upstream() { + log::info!( + "subscribe needs upstream request: {}/{}", + namespace, + track_name + ); + + if let Some(reader) = self.locals.subscribe_upstream(track_info.clone()) { + log::info!( + "forwarding subscribe upstream via TrackInfo: {}/{}", + namespace, + track_name + ); + return Ok(subscribed.serve(reader).await?); } } - Err(e) => { - // Route error = infrastructure failure (couldn't reach coordinator/upstream) - // This is different from "not found" - we don't know if the track exists - let ns = namespace.to_utf8_path(); - tracing::error!(namespace = %ns, track = %track_name, error = %e, "failed to route to remote: {}", e); - timing_guard.set_label("source", "route_error"); - metrics::counter!("moq_relay_subscribe_route_errors_total").increment(1); - - // Return an internal error rather than "not found" since we couldn't check - // TODO: Consider returning a more specific error to the subscriber - let err = ServeError::internal_ctx(format!( - "route error for namespace '{}': {}", - namespace, e - )); - subscribed.close(err.clone())?; - return Err(err.into()); + + // If the track is in Publishing state and forward=0, request forwarding + // This will trigger the consumer to send REQUEST_UPDATE to the publisher + if track_info.is_publishing() && !track_info.is_forwarding() { + log::info!( + "subscriber arrived for paused track {}/{}, requesting forward", + namespace, + track_name + ); + track_info.request_forward(); } + + let reader = track_info.get_reader(); + log::info!( + "serving subscribe from local: {}/{} (state: {:?}, forwarding: {})", + namespace, + track_name, + track_info.state(), + track_info.is_forwarding() + ); + return Ok(subscribed.serve(reader).await?); } - // Track not found - we checked all sources and the track doesn't exist - // timing_guard label already set to "not_found", will record on drop - metrics::counter!("moq_relay_subscribe_not_found_total").increment(1); + if let Some(remotes) = self.remotes { + match remotes.route(&namespace).await { + Ok(remote) => { + if let Some(remote) = remote { + if let Some(track) = remote.subscribe(&namespace, &track_name)? { + log::info!("serving subscribe from remote: {:?}", track.info); + return Ok(subscribed.serve(track.reader).await?); + } + } + } + Err(e) => { + log::error!("failed to route to remote: {}", e); + } + } + } let err = ServeError::not_found_ctx(format!( "track '{}/{}' not found in local or remote tracks", @@ -169,26 +188,337 @@ impl Producer { Err(err.into()) } - /// Serve a track_status request. + async fn serve_subscribe_namespace( + mut self, + mut subscribe_ns: SubscribeNamespaceReceived, + ) -> Result<(), anyhow::Error> { + let namespace_prefix = subscribe_ns.namespace_prefix.clone(); + + // Parse TRACK_FILTER from params if present + // TRACK_FILTER key is 0x12 (even = int value) + // Value format: (property_type << 8) | max_selected packed into u64 + const TRACK_FILTER_KEY: u64 = 0x12; + let track_filter = subscribe_ns.info.params.get(TRACK_FILTER_KEY).and_then(|kvp| { + if let moq_transport::coding::Value::IntValue(packed) = &kvp.value { + // Unpack: property_type in high byte, max_selected in low byte + let property_type = (*packed >> 8) & 0xFF; + let max_selected = (*packed & 0xFF) as u8; + log::info!( + "parsed TRACK_FILTER: property_type={}, max_selected={}", + property_type, + max_selected + ); + Some(crate::TrackFilter { + property_type, + max_selected, + }) + } else { + None + } + }); + + // Register with subscriber registry to receive PUBLISH and PUBLISH_NAMESPACE notifications + // Uses session_id so we can exclude PUBLISH messages from the same session (self-exclusion) + let (_subscription_guard, mut publish_rx, mut publish_ns_rx) = + if let Some(ref registry) = self.subscriber_registry { + let (id, rx, rx_ns) = registry.register_with_filter( + namespace_prefix.clone(), + self.session_id, + track_filter, + ); + ( + Some(crate::SubscriptionGuard::new(registry.clone(), id)), + Some(rx), + Some(rx_ns), + ) + } else { + (None, None, None) + }; + + // Accept the subscription (even if no current matches - publisher may arrive later) + subscribe_ns.ok()?; + + log::info!( + "accepted SUBSCRIBE_NAMESPACE for prefix {:?}", + namespace_prefix + ); + + // Send PUBLISH for existing tracks in matching namespaces + // This triggers the client's onMatch callback for track discovery + // Note: We skip PUBLISH_NAMESPACE and send PUBLISH directly - client expects PUBLISH for tracks + let matching_tracks = self.locals.matching_tracks(&namespace_prefix); + log::info!( + "found {} existing tracks matching prefix {:?}", + matching_tracks.len(), + namespace_prefix + ); + + for (ns, track_name, track_info) in matching_tracks { + let track_extensions = track_info.track_extensions().unwrap_or_default(); + log::info!( + "sending PUBLISH for existing track {}/{} (matched prefix {:?}, extensions={:?})", + ns, + track_name, + namespace_prefix, + track_extensions + ); + + let track_reader = track_info.get_reader(); + let mut publisher = self.publisher.clone(); + let registry = self.subscriber_registry.clone(); + let session_id = self.session_id; + + tokio::spawn(async move { + match publisher.publish_with_extensions(track_reader.clone(), track_extensions).await { + Ok(published) => { + log::info!( + "sent PUBLISH for existing track {}/{}, waiting for PUBLISH_OK", + ns, + track_name + ); + // Create filter-only observer (update_track_value is handled by ingest observer in Consumer) + let observer = if let Some(ref reg) = registry { + let reg = reg.clone(); + let ns_for_observer = ns.clone(); + let name_for_observer = track_name.clone(); + let track_filter = reg.get_track_filter_for_session(session_id); + let epoch = reg.snapshot_epoch(); + let cached_epoch = AtomicU64::new(u64::MAX); + let cached_result = AtomicBool::new(true); + if track_filter.is_some() { + Some(moq_transport::session::ObjectObserverFn::from( + Box::new(move |_group_id: u64, _object_id: u64, _ext_headers: &moq_transport::data::ExtensionHeaders| { + if let Some(ref filter) = track_filter { + let current_epoch = epoch.load(Ordering::Acquire); + if current_epoch != cached_epoch.load(Ordering::Relaxed) { + let in_top_n = reg.is_track_in_top_n( + &ns_for_observer, + &name_for_observer, + session_id, + filter.property_type, + filter.max_selected, + ); + cached_epoch.store(current_epoch, Ordering::Relaxed); + cached_result.store(in_top_n, Ordering::Relaxed); + } + cached_result.load(Ordering::Relaxed) + } else { + true + } + }) as Box bool + Send + Sync> + )) + } else { + None + } + } else { + None + }; + + let result = if let Some(obs) = observer { + published.serve_with_observer(track_reader, obs).await + } else { + published.serve(track_reader).await + }; + + match result { + Ok(()) => { + log::info!("existing track {}/{} serving completed", ns, track_name); + } + Err(e) => { + log::warn!("existing track {}/{} serving ended: {}", ns, track_name, e); + } + } + } + Err(e) => { + log::warn!("failed to send PUBLISH for existing track {}/{}: {}", ns, track_name, e); + } + } + }); + } + + // If we have a publish receiver, listen for new PUBLISH and PUBLISH_NAMESPACE notifications + if publish_rx.is_some() || publish_ns_rx.is_some() { + loop { + tokio::select! { + // Wait for the subscription to close + result = subscribe_ns.closed() => { + result?; + break; + } + // Wait for PUBLISH notifications -> forward PUBLISH to subscriber + // Subscriber sends PUBLISH_OK, then relay starts streaming data + notification = async { + if let Some(ref mut rx) = publish_rx { + rx.recv().await + } else { + std::future::pending().await + } + } => { + match notification { + Ok(publish_notif) => { + log::info!( + "received PUBLISH notification for {}/{} on subscription prefix {:?}", + publish_notif.namespace, + publish_notif.track_name, + namespace_prefix + ); + + // Get the TrackReader for this track so we can stream data + if let Some(track_info) = self.locals.get_track_info( + &publish_notif.namespace, + &publish_notif.track_name, + ) { + let track_reader = track_info.get_reader(); + let track_extensions = track_info.track_extensions().unwrap_or_default(); + + // Send PUBLISH and wait for PUBLISH_OK before streaming + let mut publisher = self.publisher.clone(); + let ns = publish_notif.namespace.clone(); + let name = publish_notif.track_name.clone(); + let registry = self.subscriber_registry.clone(); + let session_id = self.session_id; + log::info!( + "forwarding PUBLISH for {}/{} with extensions {:?}", + ns, name, track_extensions + ); + tokio::spawn(async move { + match publisher.publish_with_extensions(track_reader.clone(), track_extensions).await { + Ok(published) => { + log::info!( + "sent PUBLISH for {}/{}, waiting for PUBLISH_OK", + ns, name + ); + // Create filter-only observer (update_track_value handled by ingest observer in Consumer) + let observer = if let Some(ref reg) = registry { + let reg = reg.clone(); + let ns_for_observer = ns.clone(); + let name_for_observer = name.clone(); + let track_filter = reg.get_track_filter_for_session(session_id); + let epoch = reg.snapshot_epoch(); + let cached_epoch = AtomicU64::new(u64::MAX); + let cached_result = AtomicBool::new(true); + if track_filter.is_some() { + Some(moq_transport::session::ObjectObserverFn::from( + Box::new(move |_group_id: u64, _object_id: u64, _ext_headers: &moq_transport::data::ExtensionHeaders| { + if let Some(ref filter) = track_filter { + let current_epoch = epoch.load(Ordering::Acquire); + if current_epoch != cached_epoch.load(Ordering::Relaxed) { + let in_top_n = reg.is_track_in_top_n( + &ns_for_observer, + &name_for_observer, + session_id, + filter.property_type, + filter.max_selected, + ); + cached_epoch.store(current_epoch, Ordering::Relaxed); + cached_result.store(in_top_n, Ordering::Relaxed); + } + cached_result.load(Ordering::Relaxed) + } else { + true + } + }) as Box bool + Send + Sync> + )) + } else { + None + } + } else { + None + }; + + // serve with observer to update TopN from object extension headers + let result = if let Some(obs) = observer { + published.serve_with_observer(track_reader, obs).await + } else { + published.serve(track_reader).await + }; + + match result { + Ok(()) => { + log::info!("track {}/{} serving completed", ns, name); + } + Err(e) => { + log::warn!( + "track {}/{} serving ended: {}", + ns, name, e + ); + } + } + } + Err(e) => { + log::warn!( + "failed to send PUBLISH for {}/{}: {}", + ns, name, e + ); + } + } + }); + } else { + log::warn!( + "no track info found for {}/{}, cannot forward PUBLISH", + publish_notif.namespace, + publish_notif.track_name + ); + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + log::warn!("subscription lagged by {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + log::debug!("publish notification channel closed"); + break; + } + } + } + // PUBLISH_NAMESPACE notifications - we don't forward these as NAMESPACE messages + // Client expects PUBLISH for individual tracks, not namespace announcements + notification = async { + if let Some(ref mut rx) = publish_ns_rx { + rx.recv().await + } else { + std::future::pending().await + } + } => { + match notification { + Ok(ns_notif) => { + log::debug!( + "ignoring PUBLISH_NAMESPACE notification for {:?} (client expects PUBLISH for tracks)", + ns_notif.namespace + ); + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + log::warn!("namespace subscription lagged by {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + log::debug!("publish_namespace notification channel closed"); + break; + } + } + } + } + } + } else { + // No registry, just wait for close + subscribe_ns.closed().await?; + } + + Ok(()) + } + async fn serve_track_status( self, mut track_status_requested: TrackStatusRequested, ) -> Result<(), anyhow::Error> { // Check local tracks first, and serve from local if possible - if let Some(mut local_tracks) = self.locals.retrieve( - self.scope.as_deref(), - &track_status_requested.request_msg.track_namespace, - ) { + if let Some(mut local_tracks) = self + .locals + .retrieve(&track_status_requested.request_msg.track_namespace) + { if let Some(track) = local_tracks.get_track_reader( &track_status_requested.request_msg.track_namespace, &track_status_requested.request_msg.track_name, ) { - let namespace = track_status_requested - .request_msg - .track_namespace - .to_utf8_path(); - let track_name = &track_status_requested.request_msg.track_name; - tracing::info!(namespace = %namespace, track = %track_name, source = "local", "serving track_status from local: {:?}", track.info); + log::info!("serving track_status from local: {:?}", track.info); return Ok(track_status_requested.respond_ok(&track)?); } } @@ -202,7 +532,7 @@ impl Producer { if let Some(track) = remote.subscribe(subscribe.track_namespace.clone(), subscribe.track_name.clone())? { - tracing::info!("serving from remote: {:?} {:?}", remote.info, track.info); + log::info!("serving from remote: {:?} {:?}", remote.info, track.info); // NOTE: Depends on drop(track) being called afterwards return Ok(subscribe.serve(track.reader).await?); diff --git a/moq-relay-ietf/src/relay.rs b/moq-relay-ietf/src/relay.rs index 06a43c9e..389d88b2 100644 --- a/moq-relay-ietf/src/relay.rs +++ b/moq-relay-ietf/src/relay.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{future::Future, net, path::PathBuf, pin::Pin, sync::Arc}; use anyhow::Context; @@ -9,18 +6,17 @@ use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_native_ietf::quic::{self, Endpoint}; use url::Url; -use crate::{metrics::GaugeGuard, Consumer, Coordinator, Locals, Producer, RemoteManager, Session}; +use crate::{ + Consumer, Coordinator, Locals, Producer, Remotes, RemotesConsumer, RemotesProducer, Session, + SubscriberRegistry, TieBreakPolicy, +}; // A type alias for boxed future type ServerFuture = Pin< Box< dyn Future< Output = ( - anyhow::Result<( - web_transport::Session, - String, - moq_transport::session::Transport, - )>, + anyhow::Result<(web_transport::Session, String)>, quic::Server, ), >, @@ -53,6 +49,13 @@ pub struct RelayConfig { /// The coordinator for namespace/track registration and discovery. pub coordinator: Arc, + + /// Enable TopN event logging for visualization + /// Logs JSON events to stdout that can be used to generate timeline SVGs + pub topn_log: bool, + + /// Tie-break policy for top-N filtering + pub tie_break_policy: TieBreakPolicy, } /// MoQ Relay server. @@ -61,8 +64,9 @@ pub struct Relay { announce_url: Option, mlog_dir: Option, locals: Locals, - remotes: RemoteManager, + remotes: Option<(RemotesProducer, RemotesConsumer)>, coordinator: Arc, + subscriber_registry: SubscriberRegistry, } impl Relay { @@ -76,7 +80,7 @@ impl Relay { bind, config.qlog_dir.clone(), config.tls.clone(), - )?)?; + ))?; vec![endpoint] } else { config.endpoints @@ -94,7 +98,7 @@ impl Relay { if !mlog_dir.is_dir() { anyhow::bail!("mlog path is not a directory: {}", mlog_dir.display()); } - tracing::info!("mlog output enabled: {}", mlog_dir.display()); + log::info!("mlog output enabled: {}", mlog_dir.display()); } let locals = Locals::new(); @@ -106,263 +110,189 @@ impl Relay { .collect::>(); // Create remote manager - uses coordinator for namespace lookups - let remotes = RemoteManager::new(config.coordinator.clone(), remote_clients); + let remotes = Remotes { + coordinator: config.coordinator.clone(), + quic: remote_clients[0].clone(), + } + .produce(); + + // Create subscriber registry for SUBSCRIBE_NAMESPACE tracking + let subscriber_registry = if config.topn_log { + log::info!("TopN event logging enabled - JSON events will be written to stdout"); + log::info!("TopN tie-break policy: {:?}", config.tie_break_policy); + SubscriberRegistry::with_config(true, config.tie_break_policy) + } else { + SubscriberRegistry::with_config(false, config.tie_break_policy) + }; Ok(Self { quic_endpoints: endpoints, announce_url: config.announce, mlog_dir: config.mlog_dir, locals, - remotes, + remotes: Some(remotes), coordinator: config.coordinator, + subscriber_registry, }) } /// Run the relay server. pub async fn run(self) -> anyhow::Result<()> { - let Self { - quic_endpoints, - announce_url, - mlog_dir, - locals, - remotes, - coordinator, - } = self; - - let run_result = async { - let mut tasks = FuturesUnordered::new(); - - // Use the remote manager for routing to remote relays. - let remote_manager = remotes.clone(); - - // Start the forwarder, if any - let forward_producer = if let Some(url) = &announce_url { - tracing::info!("forwarding announces to {}", url); - - // Establish a QUIC connection to the forward URL - let (session, _quic_client_initial_cid, transport) = quic_endpoints[0] - .client - .connect(url, None) + let mut tasks = FuturesUnordered::new(); + + // Split remotes producer/consumer and spawn producer task + let remotes = self.remotes.map(|(producer, consumer)| { + tasks.push(producer.run().boxed()); + consumer + }); + + // Start the forwarder, if any + let forward_producer = if let Some(url) = &self.announce_url { + log::info!("forwarding announces to {}", url); + + // Establish a QUIC connection to the forward URL + let (session, _quic_client_initial_cid) = self.quic_endpoints[0] + .client + .connect(url, None) + .await + .context("failed to establish forward connection")?; + + // Create the MoQ session over the connection + let (session, publisher, subscriber) = + moq_transport::session::Session::connect(session, None) .await - .context("failed to establish forward connection")?; - - // Create the MoQ session over the connection - let (session, publisher, subscriber) = - moq_transport::session::Session::connect(session, None, transport) - .await - .context("failed to establish forward session")?; - - // Use the connection path already validated and stored by Session::connect(). - // The forward session is scoped to whatever path the announce URL specifies. - // - // Note: the forward connection intentionally does not call - // coordinator.resolve_scope(). The announce URL is operator-configured - // (via --announce), not client-supplied, so it doesn't need the same - // auth/permission checks that incoming client connections get. The - // forward session always gets both Producer and Consumer (full - // read-write) since it's acting as a relay peer, not a client. - // - // Limitation: all incoming scopes are forwarded to this single upstream scope. - // Multi-scope forwarding (routing different incoming scopes to different - // upstream paths) would require per-scope forward connections. - let forward_scope = session.connection_path().map(|s| s.to_string()); - - let forward_coordinator = coordinator.clone(); - let session = Session { - session, - producer: Some(Producer::new( - publisher, - locals.clone(), - remote_manager.clone(), - forward_scope.clone(), - )), - consumer: Some(Consumer::new( - subscriber, - locals.clone(), - forward_coordinator, - None, - forward_scope, - )), - // Forward connections are always full read-write relay peers, - // so no reject loops needed. - reject_publishes: None, - reject_subscribes: None, - }; - - let forward_producer = session.producer.clone(); - - tasks.push(async move { session.run().await.context("forwarding failed") }.boxed()); - - forward_producer - } else { - None + .context("failed to establish forward session")?; + + // Create a normal looking session, except we never forward or register announces. + let coordinator = self.coordinator.clone(); + let session = Session { + session, + producer: Some(Producer::new( + publisher, + self.locals.clone(), + remotes.clone(), + )), + consumer: Some(Consumer::new( + subscriber, + self.locals.clone(), + coordinator, + None, + )), }; - let servers: Vec = quic_endpoints - .into_iter() - .map(|endpoint| endpoint.server.context("missing TLS certificate for server")) - .collect::>()?; - - // This will hold the futures for all our listening servers. - let mut accepts: FuturesUnordered = FuturesUnordered::new(); - for mut server in servers { - tracing::info!("listening on {}", server.local_addr()?); - - // Create a future, box it, and push it to the collection. - accepts.push( - async move { - let conn = server.accept().await.context("accept failed"); - (conn, server) - } - .boxed(), - ); - } + let forward_producer = session.producer.clone(); - loop { - tokio::select! { - // This branch polls all the `accept` futures concurrently. - Some((conn_result, mut server)) = accepts.next() => { - // An accept operation has completed. - // First, immediately queue up the next accept() call for this server. - accepts.push( - async move { - let conn = server.accept().await.context("accept failed"); - (conn, server) - } - .boxed(), - ); - - let (conn, connection_id, transport) = conn_result.context("failed to accept QUIC connection")?; - - metrics::counter!("moq_relay_connections_total").increment(1); - - // Construct mlog path from connection ID if mlog directory is configured - let mlog_path = mlog_dir.as_ref() - .map(|dir| dir.join(format!("{}_server.mlog", connection_id))); - - let locals = locals.clone(); - let remotes = remote_manager.clone(); - let forward = forward_producer.clone(); - let coordinator = coordinator.clone(); - - // Spawn a new task to handle the connection - tasks.push(async move { - // Track active connections - decrements when task completes - let _conn_guard = GaugeGuard::new("moq_relay_active_connections"); - - // Clone the raw connection so we can close it with a proper - // error code if scope resolution fails after the MoQ handshake. - let raw_conn = conn.clone(); - - // Create the MoQ session over the connection (setup handshake etc) - let (session, publisher, subscriber) = match moq_transport::session::Session::accept(conn, mlog_path, transport).await { - Ok(session) => session, - Err(err) => { - tracing::warn!(error = %err, "failed to accept MoQ session: {}", err); - metrics::counter!("moq_relay_connection_errors_total", "stage" => "session_accept").increment(1); - // Maintain invariant: connections_total - connections_closed_total == active_connections - metrics::counter!("moq_relay_connections_closed_total").increment(1); - return Ok(()); - } - }; - - // Create our MoQ relay session - let moq_session = session; - - // Resolve the connection path to a scope (identity + permissions). - // This translates the raw transport-level path into an application-level - // scope_id and determines what the connection is allowed to do. - let scope_info = match coordinator.resolve_scope(moq_session.connection_path()).await { - Ok(info) => info, - Err(err) => { - tracing::warn!( - connection_path = moq_session.connection_path(), - error = %err, - "scope resolution failed, rejecting session" - ); - // Close with PROTOCOL_VIOLATION (0x3) so the client - // gets a meaningful error instead of an abrupt reset. - // This is a QUIC APPLICATION_CLOSE, not a MoQT SESSION_CLOSE - // control message. Sending a proper SESSION_CLOSE would require - // running the MoQ session's send loop, which is not warranted - // for a pre-session rejection. The QUIC close code and reason - // string are visible to the client's transport layer. - raw_conn.close(0x3, "scope resolution failed"); - metrics::counter!("moq_relay_connection_errors_total", "stage" => "scope_resolve").increment(1); - metrics::counter!("moq_relay_connections_closed_total").increment(1); - return Ok(()); - } - }; - - let scope_id = scope_info.as_ref().map(|s| s.scope_id.clone()); - let can_publish = scope_info.as_ref().is_none_or(|s| s.permissions.can_publish()); - let can_subscribe = scope_info.as_ref().is_none_or(|s| s.permissions.can_subscribe()); - - if let Some(ref info) = scope_info { - tracing::debug!( - connection_path = moq_session.connection_path(), - scope_id = %info.scope_id, - permissions = ?info.permissions, - "scope resolved" - ); - } + tasks.push(async move { session.run().await.context("forwarding failed") }.boxed()); - // Gate Producer/Consumer creation on permissions. - // Note the intentional inversion: - // - Producer serves SUBSCRIBEs → gated on can_subscribe - // - Consumer handles PUBLISH_NAMESPACEs → gated on can_publish - // - // When a half is disabled, we pass its transport counterpart - // to the Session's reject fields so unauthorized messages get - // an explicit error response instead of being silently ignored. - let (producer, reject_subscribes) = if can_subscribe { - (publisher.map(|publisher| Producer::new(publisher, locals.clone(), remotes, scope_id.clone())), None) - } else { - (None, publisher) - }; - - let (consumer, reject_publishes) = if can_publish { - (subscriber.map(|subscriber| Consumer::new(subscriber, locals, coordinator, forward, scope_id)), None) - } else { - (None, subscriber) - }; - - let session = Session { - session: moq_session, - producer, - consumer, - reject_publishes, - reject_subscribes, - }; - - match session.run().await { - Ok(()) => { - // Session ended cleanly (uncommon - usually ends via close) - metrics::counter!("moq_relay_connections_closed_total").increment(1); - } - Err(err) if err.is_graceful_close() => { - // Graceful close - peer sent APPLICATION_CLOSE with code 0 - tracing::debug!("MoQ session closed gracefully"); - metrics::counter!("moq_relay_connections_closed_total").increment(1); - } - Err(err) => { - // Actual error - protocol violation, timeout, etc. - tracing::warn!(error = %err, "MoQ session error: {}", err); - metrics::counter!("moq_relay_connection_errors_total", "stage" => "session_run").increment(1); - metrics::counter!("moq_relay_connections_closed_total").increment(1); - } - } + forward_producer + } else { + None + }; - Ok(()) - }.boxed()); - }, - res = tasks.next(), if !tasks.is_empty() => res.unwrap()?, + let servers: Vec = self + .quic_endpoints + .into_iter() + .map(|endpoint| { + endpoint + .server + .context("missing TLS certificate for server") + }) + .collect::>()?; + + // This will hold the futures for all our listening servers. + let mut accepts: FuturesUnordered = FuturesUnordered::new(); + for mut server in servers { + log::info!("listening on {}", server.local_addr()?); + + // Create a future, box it, and push it to the collection. + accepts.push( + async move { + let conn = server.accept().await.context("accept failed"); + (conn, server) } - } + .boxed(), + ); } - .await; - remotes.shutdown().await; - run_result + loop { + tokio::select! { + // This branch polls all the `accept` futures concurrently. + Some((conn_result, mut server)) = accepts.next() => { + // An accept operation has completed. + // First, immediately queue up the next accept() call for this server. + accepts.push( + async move { + let conn = server.accept().await.context("accept failed"); + (conn, server) + } + .boxed(), + ); + + let (conn, connection_id) = conn_result.context("failed to accept QUIC connection")?; + + // Construct mlog path from connection ID if mlog directory is configured + let mlog_path = self.mlog_dir.as_ref() + .map(|dir| dir.join(format!("{}_server.mlog", connection_id))); + + let locals = self.locals.clone(); + let remotes = remotes.clone(); + let forward = forward_producer.clone(); + let coordinator = self.coordinator.clone(); + let subscriber_registry = self.subscriber_registry.clone(); + + // Spawn a new task to handle the connection + tasks.push(async move { + // Create the MoQ session over the connection (setup handshake etc) + let (session, publisher, subscriber) = match moq_transport::session::Session::accept(conn, mlog_path).await { + Ok(session) => session, + Err(err) => { + log::warn!("failed to accept MoQ session: {}", err); + return Ok(()); + } + }; + + // Create our MoQ relay session + // Use connection_id hash as session_id for self-exclusion in pub/sub + use std::hash::{Hash, Hasher}; + let session_id = { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + connection_id.hash(&mut hasher); + hasher.finish() + }; + + let moq_session = session; + let session = Session { + session: moq_session, + producer: publisher.map(|publisher| { + Producer::with_registry( + publisher, + locals.clone(), + remotes, + subscriber_registry.clone(), + session_id, + ) + }), + consumer: subscriber.map(|subscriber| { + Consumer::with_registry( + subscriber, + locals, + coordinator, + forward, + subscriber_registry, + session_id, + ) + }), + }; + + if let Err(err) = session.run().await { + log::warn!("failed to run MoQ session: {}", err); + } + + Ok(()) + }.boxed()); + }, + res = tasks.next(), if !tasks.is_empty() => res.unwrap()?, + } + } } } diff --git a/moq-relay-ietf/src/remote.rs b/moq-relay-ietf/src/remote.rs index f86ba374..88815601 100644 --- a/moq-relay-ietf/src/remote.rs +++ b/moq-relay-ietf/src/remote.rs @@ -1,470 +1,419 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::collections::HashMap; + +use std::collections::VecDeque; +use std::fmt; use std::net::SocketAddr; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Weak}; +use std::ops; +use std::sync::Arc; +use std::sync::Weak; +use futures::stream::FuturesUnordered; +use futures::FutureExt; +use futures::StreamExt; use moq_native_ietf::quic; use moq_transport::coding::TrackNamespace; -use moq_transport::serve::{Track, TrackReader}; -use tokio::sync::Mutex; -use tokio_util::sync::CancellationToken; +use moq_transport::serve::{Track, TrackReader, TrackWriter}; +use moq_transport::watch::State; use url::Url; -use crate::{metrics::GaugeGuard, Coordinator, CoordinatorError}; - -/// Cache key for upstream relay-to-relay connections. -/// -/// Keyed by both URL and destination address so that connections are reused -/// only when both match. -type RemoteCacheKey = (Url, Option); -type RemoteSlot = Arc>>; -type TrackCacheKey = (TrackNamespace, String); -type TrackSlot = Arc>>; - -/// Manages connections to remote relays. -/// -/// When a subscription request comes in for a namespace that isn't local, -/// RemoteManager uses the coordinator to find which remote relay serves it, -/// establishes a connection if needed, and subscribes to the track. -#[derive(Clone)] -pub struct RemoteManager { - coordinator: Arc, - clients: Vec, - remotes: Arc>>, +use crate::Coordinator; + +/// Information about remote origins. +pub struct Remotes { + /// The client we use to fetch/store origin information. + pub coordinator: Arc, + + // A QUIC endpoint we'll use to fetch from other origins. + pub quic: quic::Client, } -impl RemoteManager { - /// Create a new RemoteManager. - pub fn new(coordinator: Arc, clients: Vec) -> Self { - Self { - coordinator, - clients, - remotes: Arc::new(Mutex::new(HashMap::new())), - } +impl Remotes { + pub fn produce(self) -> (RemotesProducer, RemotesConsumer) { + let (send, recv) = State::default().split(); + let info = Arc::new(self); + + let producer = RemotesProducer::new(info.clone(), send); + let consumer = RemotesConsumer::new(info, recv); + + (producer, consumer) } +} - /// Subscribe to a track from a remote relay. - /// - /// `scope` is the resolved scope identity from `Coordinator::resolve_scope()`, - /// passed through to the coordinator's `lookup()` to scope the search. - /// - /// Returns None if the namespace isn't found in any remote relay. - pub async fn subscribe( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - track_name: &str, - ) -> anyhow::Result> { - let (origin, client) = match self.coordinator.lookup(scope, namespace).await { - Ok(result) => result, - Err(CoordinatorError::NamespaceNotFound) => return Ok(None), - Err(err) => return Err(err.into()), - }; +#[derive(Default)] +struct RemotesState { + lookup: HashMap, + requested: VecDeque, +} - let url = origin.url(); - let cache_key = (url.clone(), origin.addr()); - - let remote = match self - .get_or_connect(cache_key.clone(), client.as_ref()) - .await - { - Ok(remote) => remote, - Err(err) => { - tracing::error!(remote_url = %url, error = %err, "failed to connect to remote relay: {}", err); - return Err(err); - } - }; +// Clone for convenience, but there should only be one instance of this +#[derive(Clone)] +pub struct RemotesProducer { + info: Arc, + state: State, +} - match remote - .subscribe(namespace.clone(), track_name.to_string()) - .await - { - Ok(reader) => Ok(reader), - Err(err) => { - tracing::warn!(remote_url = %url, error = %err, "remote subscribe failed, removing from cache"); - self.remove_if_same_remote(&cache_key, &remote).await; +impl RemotesProducer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } + } - Err(err) + /// Block until the next remote requested by a consumer. + async fn next(&mut self) -> Option { + loop { + { + let state = self.state.lock(); + if !state.requested.is_empty() { + return state.into_mut()?.requested.pop_front(); + } + + state.modified()? } + .await; } } - /// Get an existing remote connection or create a new one. - async fn get_or_connect( - &self, - cache_key: RemoteCacheKey, - client: Option<&quic::Client>, - ) -> anyhow::Result { - let client = match client { - Some(client) => client, - None => self.clients.first().ok_or_else(|| { - anyhow::anyhow!("no QUIC clients configured for remote connections") - })?, - }; + /// Run the remotes producer to serve remote requests. + pub async fn run(mut self) -> anyhow::Result<()> { + let mut tasks = FuturesUnordered::new(); loop { - // The manager lock only protects the map. The per-key slot lock protects - // that key's connection state, so unrelated remotes can connect in parallel. - let slot = { - let mut remotes = self.remotes.lock().await; - remotes - .entry(cache_key.clone()) - .or_insert_with(|| Arc::new(Mutex::new(None))) - .clone() - }; + tokio::select! { + Some(mut remote) = self.next() => { + let url = remote.url.clone(); - let mut cached = slot.lock().await; + // Spawn a task to serve the remote + tasks.push(async move { + let info = remote.info.clone(); - let is_current_slot = { - let remotes = self.remotes.lock().await; - matches!(remotes.get(&cache_key), Some(current) if Arc::ptr_eq(current, &slot)) - }; + log::warn!("serving remote: {:?}", info); - if !is_current_slot { - continue; - } + // Run the remote producer + if let Err(err) = remote.run().await { + log::warn!("failed serving remote: {:?}, error: {}", info, err); + } - if let Some(remote) = cached.as_ref() { - if remote.is_connected() { - return Ok(remote.clone()); + url + }); } - tracing::info!(remote_url = %cache_key.0, "removing dead connection to remote relay"); - }; + // Handle finished remote producers + res = tasks.next(), if !tasks.is_empty() => { + let url = res.unwrap(); - if let Some(remote) = cached.take() { - remote.shutdown().await; + if let Some(mut state) = self.state.lock_mut() { + state.lookup.remove(&url); + } + }, + else => return Ok(()), } + } + } +} - tracing::info!(remote_url = %cache_key.0, "connecting to remote relay"); - let remote = match Remote::connect( - cache_key.0.clone(), - cache_key.1, - client, - Arc::downgrade(&self.remotes), - cache_key.clone(), - Arc::downgrade(&slot), - ) - .await - { - Ok(remote) => remote, - Err(err) => { - drop(cached); - remove_empty_remote_slot(&self.remotes, &cache_key, &slot).await; - return Err(err); - } - }; +impl ops::Deref for RemotesProducer { + type Target = Remotes; - *cached = Some(remote.clone()); - return Ok(remote); - } + fn deref(&self) -> &Self::Target { + &self.info } +} - async fn remove_if_same_remote(&self, cache_key: &RemoteCacheKey, remote: &Remote) { - let slot = { - let remotes = self.remotes.lock().await; - remotes.get(cache_key).cloned() - }; +#[derive(Clone)] +pub struct RemotesConsumer { + pub info: Arc, + state: State, +} - if let Some(slot) = slot { - let removed = { - let mut cached = slot.lock().await; - match cached.as_ref() { - Some(current) if current.is_same_connection(remote) => cached.take(), - _ => None, - } - }; +impl RemotesConsumer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } + } - if let Some(remote) = removed { - remote.shutdown().await; - remove_empty_remote_slot(&self.remotes, cache_key, &slot).await; - } + /// Route to a remote origin based on the namespace. + pub async fn route( + &self, + namespace: &TrackNamespace, + ) -> anyhow::Result> { + // Always fetch the origin instead of using the (potentially invalid) cache. + let (origin, client) = self.coordinator.lookup(namespace).await?; + + // Check if we already have a remote for this origin + let state = self.state.lock(); + if let Some(remote) = state.lookup.get(&origin.url()).cloned() { + return Ok(Some(remote)); } - } - /// Shutdown all remote connections. - pub(crate) async fn shutdown(&self) { - let remotes = { - let mut remotes = self.remotes.lock().await; - remotes.drain().collect::>() + // Create a new remote for this origin + let mut state = match state.into_mut() { + Some(state) => state, + None => return Ok(None), }; - for (cache_key, slot) in remotes { - tracing::info!(remote_url = %cache_key.0, "shutting down remote connection"); - let mut remote = slot.lock().await; - if let Some(remote) = remote.take() { - remote.shutdown().await; - } - } + let remote = Remote { + url: origin.url(), + remotes: self.info.clone(), + addr: origin.addr(), + client, + }; + + // Produce the remote + let (writer, reader) = remote.produce(); + state.requested.push_back(writer); + + // Insert the remote into our Map + state.lookup.insert(origin.url(), reader.clone()); + + Ok(Some(reader)) } } -async fn remove_empty_remote_slot( - remotes: &Arc>>, - cache_key: &RemoteCacheKey, - slot: &RemoteSlot, -) { - let cached = slot.lock().await; - if cached.is_some() { - return; +impl ops::Deref for RemotesConsumer { + type Target = Remotes; + + fn deref(&self) -> &Self::Target { + &self.info } +} - let mut remotes = remotes.lock().await; - if matches!(remotes.get(cache_key), Some(current) if Arc::ptr_eq(current, slot)) { - remotes.remove(cache_key); +pub struct Remote { + pub remotes: Arc, + pub url: Url, + pub addr: Option, + pub client: Option, +} + +impl fmt::Debug for Remote { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Remote") + .field("url", &self.url.to_string()) + .finish() } } -async fn remove_empty_track_slot( - tracks: &Arc>>, - key: &TrackCacheKey, - slot: &TrackSlot, -) { - let cached = slot.lock().await; - if cached.is_some() { - return; +impl ops::Deref for Remote { + type Target = Remotes; + + fn deref(&self) -> &Self::Target { + &self.remotes } +} - let mut tracks = tracks.lock().await; - if matches!(tracks.get(key), Some(current) if Arc::ptr_eq(current, slot)) { - tracks.remove(key); +impl Remote { + /// Create a new broadcast. + pub fn produce(self) -> (RemoteProducer, RemoteConsumer) { + let (send, recv) = State::default().split(); + let info = Arc::new(self); + + let consumer = RemoteConsumer::new(info.clone(), recv); + let producer = RemoteProducer::new(info, send); + + (producer, consumer) } } -/// A connection to a single remote relay with its own QUIC client. -#[derive(Clone)] -struct Remote { - url: Url, - subscriber: moq_transport::session::Subscriber, - /// Track subscriptions keyed by full track name. - tracks: Arc>>, - /// Flag indicating if the connection is still alive. - connected: Arc, - /// Cancellation token for the session task. - cancel: CancellationToken, +#[derive(Default)] +struct RemoteState { + tracks: HashMap<(TrackNamespace, String), RemoteTrackWeak>, + requested: VecDeque, } -impl Remote { - /// Connect to a remote relay with a dedicated QUIC client. - async fn connect( - url: Url, - addr: Option, - client: &quic::Client, - remotes: Weak>>, - cache_key: RemoteCacheKey, - cache_slot: Weak>>, - ) -> anyhow::Result { - let (session, _quic_client_initial_cid, transport) = match client.connect(&url, addr).await - { - Ok(session) => session, - Err(err) => { - metrics::counter!("moq_relay_upstream_errors_total", "stage" => "connect") - .increment(1); - return Err(err); - } - }; +pub struct RemoteProducer { + pub info: Arc, + state: State, +} - let (session, subscriber) = - match moq_transport::session::Subscriber::connect(session, transport).await { - Ok(session) => session, - Err(err) => { - metrics::counter!("moq_relay_upstream_errors_total", "stage" => "session") - .increment(1); - return Err(err.into()); - } - }; +impl RemoteProducer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } + } - let connected = Arc::new(AtomicBool::new(true)); - let cancel = CancellationToken::new(); - let upstream_guard = GaugeGuard::new("moq_relay_upstream_connections"); + pub async fn run(&mut self) -> anyhow::Result<()> { + let client = if let Some(client) = &self.info.client { + client + } else { + &self.quic + }; + // TODO reuse QUIC and MoQ sessions + let (session, _quic_client_initial_cid) = client.connect(&self.url, self.addr).await?; + let (session, subscriber) = moq_transport::session::Subscriber::connect(session).await?; - let session_url = url.clone(); - let session_connected = connected.clone(); - let session_cancel = cancel.clone(); + // Run the session + let mut session = session.run().boxed(); + let mut tasks = FuturesUnordered::new(); - tokio::spawn(async move { - let _upstream_guard = upstream_guard; + let mut done = None; + + // Serve requested tracks + loop { tokio::select! { - result = session.run() => { - if let Err(err) = result { - tracing::warn!(remote_url = %session_url, error = %err, "remote session closed: {}", err); - } else { - tracing::info!(remote_url = %session_url, "remote session closed normally"); - } - } - _ = session_cancel.cancelled() => { - tracing::info!(remote_url = %session_url, "remote session cancelled"); + track = self.next(), if done.is_none() => { + let track = match track { + Ok(Some(track)) => track, + Ok(None) => { done = Some(Ok(())); continue }, + Err(err) => { done = Some(Err(err)); continue }, + }; + + let info = track.info.clone(); + let mut subscriber = subscriber.clone(); + + tasks.push(async move { + if let Err(err) = subscriber.subscribe(track).await { + log::warn!("failed serving track: {:?}, error: {}", info, err); + } + }); } - } + _ = tasks.next(), if !tasks.is_empty() => {}, - session_connected.store(false, Ordering::Release); + // Keep running the session + res = &mut session, if !tasks.is_empty() || done.is_none() => return Ok(res?), - if let Some(cache_slot) = cache_slot.upgrade() { - let mut cleared = false; - let mut cached = cache_slot.lock().await; - if matches!(cached.as_ref(), Some(remote) if Arc::ptr_eq(&remote.connected, &session_connected)) - { - cached.take(); - cleared = true; - tracing::info!(remote_url = %session_url, "cleared closed remote connection from cache"); + else => return done.unwrap(), + } + } + } + + /// Block until the next track requested by a consumer. + async fn next(&self) -> anyhow::Result> { + loop { + let notify = { + let state = self.state.lock(); + + // Check if we have any requested tracks + if !state.requested.is_empty() { + return Ok(state + .into_mut() + .and_then(|mut state| state.requested.pop_front())); } - drop(cached); - if cleared { - if let Some(remotes) = remotes.upgrade() { - remove_empty_remote_slot(&remotes, &cache_key, &cache_slot).await; - } + match state.modified() { + Some(notified) => notified, + None => return Ok(None), } - } - }); + }; - Ok(Self { - url, - subscriber, - tracks: Arc::new(Mutex::new(HashMap::new())), - connected, - cancel, - }) + notify.await + } } +} - /// Check if the connection is still alive. - fn is_connected(&self) -> bool { - self.connected.load(Ordering::Acquire) - } +impl ops::Deref for RemoteProducer { + type Target = Remote; - fn is_same_connection(&self, other: &Self) -> bool { - Arc::ptr_eq(&self.connected, &other.connected) + fn deref(&self) -> &Self::Target { + &self.info } +} - /// Shutdown the remote connection. - async fn shutdown(&self) { - self.cancel.cancel(); - self.connected.store(false, Ordering::Release); - self.tracks.lock().await.clear(); +#[derive(Clone)] +pub struct RemoteConsumer { + pub info: Arc, + state: State, +} + +impl RemoteConsumer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } } - /// Subscribe to a track on this remote relay. - async fn subscribe( + /// Request a track from the broadcast. + pub fn subscribe( &self, - namespace: TrackNamespace, - track_name: String, - ) -> anyhow::Result> { - let key = (namespace.clone(), track_name.clone()); - - loop { - if !self.is_connected() { - anyhow::bail!("remote connection to {} is closed", self.url); + namespace: &TrackNamespace, + name: &str, + ) -> anyhow::Result> { + let key = (namespace.clone(), name.to_string()); + let state = self.state.lock(); + if let Some(track) = state.tracks.get(&key) { + if let Some(track) = track.upgrade() { + return Ok(Some(track)); } + } - let slot = { - let mut tracks = self.tracks.lock().await; - tracks - .entry(key.clone()) - .or_insert_with(|| Arc::new(Mutex::new(None))) - .clone() - }; - - let mut cached = slot.lock().await; + let mut state = match state.into_mut() { + Some(state) => state, + None => return Ok(None), + }; - let is_current_slot = { - let tracks = self.tracks.lock().await; - matches!(tracks.get(&key), Some(current) if Arc::ptr_eq(current, &slot)) - }; + let (writer, reader) = Track::new(namespace.clone(), name.to_string()).produce(); + let reader = RemoteTrackReader::new(reader, self.state.clone()); - if !is_current_slot { - continue; - } + // Insert the track into our Map so we deduplicate future requests. + state.tracks.insert(key, reader.downgrade()); + state.requested.push_back(writer); - if let Some(reader) = cached.as_ref() { - if !reader.is_closed() { - return Ok(Some(reader.clone())); - } + Ok(Some(reader)) + } +} - tracing::debug!(remote_url = %self.url, namespace = %key.0, track = %key.1, "removing closed remote track from cache"); - } +impl ops::Deref for RemoteConsumer { + type Target = Remote; - cached.take(); + fn deref(&self) -> &Self::Target { + &self.info + } +} - let mut subscriber = self.subscriber.clone(); - let url = self.url.clone(); - let tracks = Arc::downgrade(&self.tracks); - let cancel = self.cancel.clone(); +#[derive(Clone)] +pub struct RemoteTrackReader { + pub reader: TrackReader, + drop: Arc, +} - tracing::info!(remote_url = %url, namespace = %key.0, track = %key.1, "subscribing to remote track"); +impl RemoteTrackReader { + fn new(reader: TrackReader, parent: State) -> Self { + let drop = Arc::new(RemoteTrackDrop { + parent, + key: (reader.namespace.clone(), reader.name.clone()), + }); - let (writer, reader) = Track::new(namespace.clone(), track_name.clone()).produce(); - let subscribe_result = tokio::select! { - result = subscriber.subscribe_open(writer) => result, - _ = cancel.cancelled() => { - drop(cached); - remove_empty_track_slot(&self.tracks, &key, &slot).await; - anyhow::bail!("subscribe cancelled, remote connection to {} is closed", self.url); - } - }; + Self { reader, drop } + } - let subscribe = match subscribe_result { - Ok(subscribe) => subscribe, - Err(err) => { - drop(cached); - remove_empty_track_slot(&self.tracks, &key, &slot).await; - return Err(err.into()); - } - }; + fn downgrade(&self) -> RemoteTrackWeak { + RemoteTrackWeak { + reader: self.reader.clone(), + drop: Arc::downgrade(&self.drop), + } + } +} - if !self.is_connected() { - drop(cached); - remove_empty_track_slot(&self.tracks, &key, &slot).await; - anyhow::bail!("remote connection to {} is closed", self.url); - } +impl ops::Deref for RemoteTrackReader { + type Target = TrackReader; - *cached = Some(reader.clone()); - drop(cached); - - let cleanup_key = key.clone(); - let cleanup_reader = reader.clone(); - let cleanup_slot = slot.clone(); - tokio::spawn(async move { - tokio::select! { - result = subscribe.closed() => { - match result { - Ok(()) => { - tracing::debug!(remote_url = %url, namespace = %cleanup_key.0, track = %cleanup_key.1, "remote track subscription ended"); - } - Err(err) => { - tracing::warn!(remote_url = %url, namespace = %cleanup_key.0, track = %cleanup_key.1, error = %err, "remote track subscription ended with error: {}", err); - } - } - } - _ = cancel.cancelled() => { - tracing::debug!(remote_url = %url, namespace = %cleanup_key.0, track = %cleanup_key.1, "remote track subscription cancelled"); - } - } + fn deref(&self) -> &Self::Target { + &self.reader + } +} - if let Some(tracks) = tracks.upgrade() { - let mut cached = cleanup_slot.lock().await; - if matches!(cached.as_ref(), Some(current) if Arc::ptr_eq(¤t.info, &cleanup_reader.info)) - { - cached.take(); - } - drop(cached); +impl ops::DerefMut for RemoteTrackReader { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.reader + } +} - remove_empty_track_slot(&tracks, &cleanup_key, &cleanup_slot).await; - } - }); +struct RemoteTrackWeak { + reader: TrackReader, + drop: Weak, +} - return Ok(Some(reader)); - } +impl RemoteTrackWeak { + fn upgrade(&self) -> Option { + Some(RemoteTrackReader { + reader: self.reader.clone(), + drop: self.drop.upgrade()?, + }) } } -impl std::fmt::Debug for Remote { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Remote") - .field("url", &self.url.to_string()) - .field("connected", &self.is_connected()) - .finish() +struct RemoteTrackDrop { + parent: State, + key: (TrackNamespace, String), +} + +impl Drop for RemoteTrackDrop { + fn drop(&mut self) { + if let Some(mut parent) = self.parent.lock_mut() { + parent.tracks.remove(&self.key); + } } } diff --git a/moq-relay-ietf/src/session.rs b/moq-relay-ietf/src/session.rs index d83d8595..b55748b8 100644 --- a/moq-relay-ietf/src/session.rs +++ b/moq-relay-ietf/src/session.rs @@ -1,8 +1,5 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; -use moq_transport::session::{Publisher, SessionError, Subscriber}; +use moq_transport::session::SessionError; use crate::{Consumer, Producer}; @@ -10,19 +7,6 @@ pub struct Session { pub session: moq_transport::session::Session, pub producer: Option, pub consumer: Option, - - /// When `consumer` is `None` (publish not permitted), the transport - /// `Subscriber` half still exists and will queue incoming - /// PUBLISH_NAMESPACEs from the peer. We hold it here so we can - /// actively drain and reject those messages instead of silently - /// ignoring them. - pub reject_publishes: Option, - - /// When `producer` is `None` (subscribe not permitted), the transport - /// `Publisher` half still exists and will queue incoming SUBSCRIBEs - /// from the peer. We hold it here so we can actively drain and reject - /// those messages instead of silently ignoring them. - pub reject_subscribes: Option, } impl Session { @@ -39,53 +23,6 @@ impl Session { tasks.push(consumer.run().boxed()); } - // Reject unauthorized messages for disabled session halves. - // Without these, a peer that sends a disallowed control message - // would get no response (no OK, no error) because nobody is - // draining the transport queue for that message type. - if let Some(subscriber) = self.reject_publishes { - tasks.push(Self::drain_and_reject_publishes(subscriber).boxed()); - } - - if let Some(publisher) = self.reject_subscribes { - tasks.push(Self::drain_and_reject_subscribes(publisher).boxed()); - } - tasks.select_next_some().await } - - /// Drain incoming PUBLISH_NAMESPACEs and reject each one. - /// - /// The transport `Subscriber` queues incoming PUBLISH_NAMESPACE messages - /// as `Announced` events. Dropping an `Announced` without calling `ok()` - /// triggers its `Drop` impl, which sends PUBLISH_NAMESPACE_ERROR back - /// to the peer. - async fn drain_and_reject_publishes(mut subscriber: Subscriber) -> Result<(), SessionError> { - while let Some(announced) = subscriber.announced().await { - tracing::debug!( - namespace = %announced.namespace, - "rejecting PUBLISH_NAMESPACE: publish not permitted for this session" - ); - drop(announced); - } - Ok(()) - } - - /// Drain incoming SUBSCRIBEs and reject each one. - /// - /// The transport `Publisher` queues incoming SUBSCRIBE messages as - /// `Subscribed` events. Dropping a `Subscribed` without calling `ok()` - /// triggers its `Drop` impl, which sends SUBSCRIBE_ERROR back to the - /// peer. - async fn drain_and_reject_subscribes(mut publisher: Publisher) -> Result<(), SessionError> { - while let Some(subscribed) = publisher.subscribed().await { - tracing::debug!( - namespace = %subscribed.track_namespace, - track = %subscribed.track_name, - "rejecting SUBSCRIBE: subscribe not permitted for this session" - ); - drop(subscribed); - } - Ok(()) - } } diff --git a/moq-relay-ietf/src/subscriber_registry.rs b/moq-relay-ietf/src/subscriber_registry.rs new file mode 100644 index 00000000..9cd12216 --- /dev/null +++ b/moq-relay-ietf/src/subscriber_registry.rs @@ -0,0 +1,917 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; + +use moq_transport::coding::TrackNamespace; +use tokio::sync::broadcast; + +use crate::top_n_tracker::{TieBreakPolicy, TopNTracker, TopNTrackerConfig}; + +/// Key for indexing subscriptions by their namespace prefix and property type +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +struct FilterGroupKey { + prefix: TrackNamespace, + property_type: u64, +} + +/// Group of subscriber IDs that share the same (prefix, property_type) +/// Enables O(group_size) iteration instead of O(all_subscriptions) +struct FilterGroup { + subscription_ids: Vec, +} + +/// TRACK_FILTER configuration for a subscription +#[derive(Clone, Debug, Default)] +pub struct TrackFilter { + /// Property type to sort by (e.g., 0x100 for viewers) + pub property_type: u64, + /// Maximum number of tracks to receive (N value) + pub max_selected: u8, +} + +/// Information about an active SUBSCRIBE_NAMESPACE subscription +#[derive(Clone)] +pub struct NamespaceSubscription { + /// The namespace prefix this subscription is for + pub prefix: TrackNamespace, + /// Session ID of the subscriber (for self-exclusion) + pub session_id: u64, + /// Channel to send PUBLISH notifications to this subscriber + pub publish_tx: broadcast::Sender, + /// Channel to send PUBLISH_NAMESPACE notifications to this subscriber + pub publish_ns_tx: broadcast::Sender, + /// Optional TRACK_FILTER configuration + pub track_filter: Option, +} + +/// Notification sent when a PUBLISH arrives that matches a subscription +#[derive(Clone, Debug)] +pub struct PublishNotification { + pub namespace: TrackNamespace, + pub track_name: String, + pub track_alias: u64, +} + +/// Notification sent when a PUBLISH_NAMESPACE arrives that matches a subscription +#[derive(Clone, Debug)] +pub struct PublishNamespaceNotification { + pub namespace: TrackNamespace, +} + +/// Registry for tracking active SUBSCRIBE_NAMESPACE subscriptions +/// +/// When a subscriber sends SUBSCRIBE_NAMESPACE, they register here. +/// When a publisher sends PUBLISH, we find matching subscriptions and notify. +#[derive(Clone)] +pub struct SubscriberRegistry { + inner: Arc>, + /// Shared version counter, bumped only on actual snapshot rebuild in update_track_value path. + /// Observers read this lock-free to detect when to recompute their cached top-N result. + snapshot_epoch: Arc, +} + +struct SubscriberRegistryInner { + /// Map from subscription ID to subscription info + subscriptions: HashMap, + /// Next subscription ID + next_id: u64, + /// TopN trackers per (namespace_prefix, property_type) + /// Key is (prefix, property_type) + top_n_trackers: HashMap<(TrackNamespace, u64), TopNTracker>, + /// Enable event logging for TopN trackers (for visualization) + topn_event_logging: bool, + /// Tie-break policy for TopN trackers + tie_break_policy: TieBreakPolicy, + /// Tracks which (subscription_id, namespace, track_name) have been sent PUBLISH + /// Used to avoid sending duplicate PUBLISH when track re-enters top-N + published_tracks: HashSet<(u64, TrackNamespace, String)>, + /// Index: session_id -> subscription_id (for O(1) lookup by session) + session_to_subscription: HashMap, + /// Index: subscription_id -> set of published track keys (for O(1) cleanup on unregister) + subscription_published: HashMap>, + /// Groups of filtered subscriptions by (prefix, property_type) + filter_groups: HashMap, +} + +impl SubscriberRegistry { + pub fn new() -> Self { + Self::with_config(false, TieBreakPolicy::OldestWins) + } + + /// Create a new registry with TopN event logging enabled (for visualization) + pub fn with_topn_logging() -> Self { + Self::with_config(true, TieBreakPolicy::OldestWins) + } + + /// Create a new registry with custom configuration + pub fn with_config(topn_event_logging: bool, tie_break_policy: TieBreakPolicy) -> Self { + Self { + inner: Arc::new(Mutex::new(SubscriberRegistryInner { + subscriptions: HashMap::new(), + next_id: 0, + top_n_trackers: HashMap::new(), + topn_event_logging, + tie_break_policy, + published_tracks: HashSet::new(), + session_to_subscription: HashMap::new(), + subscription_published: HashMap::new(), + filter_groups: HashMap::new(), + })), + snapshot_epoch: Arc::new(AtomicU64::new(0)), + } + } + + /// Enable or disable TopN event logging at runtime + pub fn set_topn_event_logging(&self, enabled: bool) { + let mut inner = self.inner.lock().unwrap(); + inner.topn_event_logging = enabled; + // Update all existing trackers + for tracker in inner.top_n_trackers.values() { + tracker.set_event_logging(enabled); + } + } + + /// Register a SUBSCRIBE_NAMESPACE subscription without TRACK_FILTER + /// Returns (subscription_id, receiver for PUBLISH notifications, receiver for PUBLISH_NAMESPACE notifications) + pub fn register( + &self, + prefix: TrackNamespace, + session_id: u64, + ) -> ( + u64, + broadcast::Receiver, + broadcast::Receiver, + ) { + self.register_with_filter(prefix, session_id, None) + } + + /// Register a SUBSCRIBE_NAMESPACE subscription with optional TRACK_FILTER + /// Returns (subscription_id, receiver for PUBLISH notifications, receiver for PUBLISH_NAMESPACE notifications) + pub fn register_with_filter( + &self, + prefix: TrackNamespace, + session_id: u64, + track_filter: Option, + ) -> ( + u64, + broadcast::Receiver, + broadcast::Receiver, + ) { + let mut inner = self.inner.lock().unwrap(); + + let id = inner.next_id; + inner.next_id += 1; + + // Create broadcast channels for PUBLISH and PUBLISH_NAMESPACE notifications + let (publish_tx, publish_rx) = broadcast::channel(64); + let (publish_ns_tx, publish_ns_rx) = broadcast::channel(64); + + // If TRACK_FILTER is specified, register with the TopN tracker + if let Some(ref filter) = track_filter { + let tracker_key = (prefix.clone(), filter.property_type); + let enable_logging = inner.topn_event_logging; + let tie_break_policy = inner.tie_break_policy; + + // Get or create tracker for this (prefix, property_type) + let tracker = inner + .top_n_trackers + .entry(tracker_key) + .or_insert_with(|| { + let config = TopNTrackerConfig { + enable_event_logging: enable_logging, + tie_break_policy, + ..Default::default() + }; + TopNTracker::with_config(filter.property_type, config) + }); + + // Update max_n if this subscription has higher N + let current_max_n = tracker.max_n(); + if filter.max_selected > current_max_n { + tracker.update_max_n(filter.max_selected); + } + + // Add to filter group index + let group_key = FilterGroupKey { + prefix: prefix.clone(), + property_type: filter.property_type, + }; + inner + .filter_groups + .entry(group_key) + .or_insert_with(|| FilterGroup { + subscription_ids: Vec::new(), + }) + .subscription_ids + .push(id); + + log::debug!( + "registered filtered subscription id={} session_id={} filter={:?}", + id, + session_id, + filter + ); + } + + let subscription = NamespaceSubscription { + prefix, + session_id, + publish_tx, + publish_ns_tx, + track_filter, + }; + + inner.subscriptions.insert(id, subscription); + inner.session_to_subscription.insert(session_id, id); + + log::debug!( + "registered namespace subscription id={} session_id={}", + id, + session_id + ); + + (id, publish_rx, publish_ns_rx) + } + + /// Register a track with a property value (called on PUBLISH with track_extensions) + pub fn register_track( + &self, + namespace: &TrackNamespace, + track_name: &str, + property_type: u64, + property_value: u64, + publisher_session_id: u64, + ) { + let mut inner = self.inner.lock().unwrap(); + + // Find all trackers that match this namespace prefix + for ((prefix, pt), tracker) in inner.top_n_trackers.iter_mut() { + if *pt != property_type { + continue; + } + + // Check if the namespace matches this tracker's prefix + if Self::prefix_matches(prefix, namespace) { + tracker.register_track( + namespace.clone(), + track_name.to_string(), + property_value, + publisher_session_id, + ); + log::debug!( + "registered track {}/{} with TopN tracker (prefix={}, property_type={}, value={})", + namespace, + track_name, + prefix, + property_type, + property_value + ); + } + } + } + + /// Update a track's property value and notify subscribers if track enters top-N + /// Only sends PUBLISH once per (subscription, track) - avoids duplicates + /// Returns count of new subscriptions notified + pub fn update_track_value( + &self, + namespace: &TrackNamespace, + track_name: &str, + property_type: u64, + new_value: u64, + track_alias: u64, + origin_session_id: u64, + ) -> usize { + let mut inner = self.inner.lock().unwrap(); + let mut notified = 0; + let mut keys_to_add: Vec<(u64, TrackNamespace, String)> = Vec::new(); + + // Find the matching tracker and update value + let matching_prefix: Option = inner + .top_n_trackers + .iter() + .find(|((prefix, pt), _)| *pt == property_type && Self::prefix_matches(prefix, namespace)) + .map(|((prefix, _), _)| prefix.clone()); + + let Some(prefix) = matching_prefix else { + return 0; + }; + + let tracker_key = (prefix.clone(), property_type); + let version_before = inner + .top_n_trackers + .get(&tracker_key) + .map(|t| t.snapshot_version()) + .unwrap_or(0); + + if let Some(tracker) = inner.top_n_trackers.get(&tracker_key) { + tracker.update_value(namespace, track_name, new_value); + } + + let snapshot = inner + .top_n_trackers + .get(&tracker_key) + .map(|t| t.load_snapshot()); + let Some(snapshot) = snapshot else { return 0 }; + + // Bump epoch only if snapshot was actually rebuilt + let version_after = inner + .top_n_trackers + .get(&tracker_key) + .map(|t| t.snapshot_version()) + .unwrap_or(0); + if version_after != version_before { + self.snapshot_epoch.fetch_add(1, Ordering::Release); + } + + // Use filter group index to only iterate relevant subscribers + let group_key = FilterGroupKey { + prefix: prefix.clone(), + property_type, + }; + + let sub_ids: Vec = inner + .filter_groups + .get(&group_key) + .map(|g| g.subscription_ids.clone()) + .unwrap_or_default(); + + let notification = PublishNotification { + namespace: namespace.clone(), + track_name: track_name.to_string(), + track_alias, + }; + + for sub_id in sub_ids { + let Some(sub) = inner.subscriptions.get(&sub_id) else { + continue; + }; + + if sub.session_id == origin_session_id { + continue; + } + + let Some(ref filter) = sub.track_filter else { + continue; + }; + + // Use the pre-loaded snapshot for the check + let tracker = inner.top_n_trackers.get(&tracker_key).unwrap(); + if tracker.is_in_top_n_with_snapshot( + namespace, + track_name, + sub.session_id, + filter.max_selected, + &snapshot, + ) { + let publish_key = (sub_id, namespace.clone(), track_name.to_string()); + if inner.published_tracks.contains(&publish_key) { + continue; + } + + if sub.publish_tx.send(notification.clone()).is_ok() { + notified += 1; + keys_to_add.push(publish_key); + } + } + } + + // Record sent PUBLISH notifications with per-subscription index + for key in keys_to_add { + let sub_id = key.0; + let ns = key.1.clone(); + let track = key.2.clone(); + inner.published_tracks.insert(key); + inner + .subscription_published + .entry(sub_id) + .or_insert_with(Vec::new) + .push((ns, track)); + } + + notified + } + + /// Remove a track from TopN tracking + pub fn remove_track(&self, namespace: &TrackNamespace, track_name: &str) { + let inner = self.inner.lock().unwrap(); + + for ((prefix, _pt), tracker) in inner.top_n_trackers.iter() { + if Self::prefix_matches(prefix, namespace) { + tracker.remove_track(namespace, track_name); + log::debug!( + "removed track {}/{} from TopN tracker (prefix={})", + namespace, + track_name, + prefix + ); + } + } + } + + /// Check if a track is in the top-N for a given session (considering self-exclusion) + /// This is used for per-object filtering during streaming + pub fn is_track_in_top_n( + &self, + namespace: &TrackNamespace, + track_name: &str, + session_id: u64, + property_type: u64, + max_n: u8, + ) -> bool { + let inner = self.inner.lock().unwrap(); + + // Find the tracker for this namespace prefix and property type + for ((prefix, pt), tracker) in inner.top_n_trackers.iter() { + if *pt != property_type { + continue; + } + + if Self::prefix_matches(prefix, namespace) { + // N-covers-all fast path: if N >= total tracks, always in top-N + let snapshot = tracker.load_snapshot(); + let non_self_count = snapshot + .iter() + .filter(|t| t.publisher_session_id != session_id) + .count(); + if max_n as usize >= non_self_count { + return true; + } + return tracker.is_in_top_n_with_snapshot( + namespace, track_name, session_id, max_n, &snapshot, + ); + } + } + + false + } + + /// Get the snapshot epoch handle (lock-free read for observers). + /// Observers cache this Arc and read it on every object to detect ranking changes. + pub fn snapshot_epoch(&self) -> Arc { + self.snapshot_epoch.clone() + } + + /// Get the TRACK_FILTER configuration for a subscriber session + /// Returns None if the session has no filtered subscription + pub fn get_track_filter_for_session(&self, session_id: u64) -> Option { + let inner = self.inner.lock().unwrap(); + + // O(1) lookup via session index + if let Some(&sub_id) = inner.session_to_subscription.get(&session_id) { + if let Some(sub) = inner.subscriptions.get(&sub_id) { + return sub.track_filter.clone(); + } + } + + None + } + + /// Unregister a subscription + pub fn unregister(&self, id: u64) { + let mut inner = self.inner.lock().unwrap(); + if let Some(sub) = inner.subscriptions.remove(&id) { + // Remove from session index + inner.session_to_subscription.remove(&sub.session_id); + + // Remove from filter group + if let Some(ref filter) = sub.track_filter { + let group_key = FilterGroupKey { + prefix: sub.prefix.clone(), + property_type: filter.property_type, + }; + if let Some(group) = inner.filter_groups.get_mut(&group_key) { + group.subscription_ids.retain(|&sid| sid != id); + } + } + + // Clean up published_tracks using the per-subscription index (O(tracks_for_this_sub)) + if let Some(keys) = inner.subscription_published.remove(&id) { + for (ns, track) in keys { + inner.published_tracks.remove(&(id, ns, track)); + } + } + + log::debug!("unregistered namespace subscription id={}", id); + } + } + + /// Find all subscriptions that match a given namespace and notify them of a PUBLISH + /// Excludes the session that originated the PUBLISH (self-exclusion) + /// For subscriptions with TRACK_FILTER, only notifies if track is in their top-N + /// Only sends PUBLISH once per (subscription, track) - avoids duplicates + /// Returns the number of matching subscriptions notified + pub fn notify_publish( + &self, + namespace: &TrackNamespace, + track_name: &str, + track_alias: u64, + origin_session_id: u64, + ) -> usize { + let mut inner = self.inner.lock().unwrap(); + + let notification = PublishNotification { + namespace: namespace.clone(), + track_name: track_name.to_string(), + track_alias, + }; + + let mut notified = 0; + let mut keys_to_add: Vec<(u64, TrackNamespace, String)> = Vec::new(); + + // Pre-load snapshots for trackers that match this namespace + let mut tracker_snapshots: HashMap<(TrackNamespace, u64), Arc>> = HashMap::new(); + for ((prefix, pt), tracker) in inner.top_n_trackers.iter() { + if Self::prefix_matches(prefix, namespace) { + tracker_snapshots.insert((prefix.clone(), *pt), tracker.load_snapshot()); + } + } + + for (id, sub) in inner.subscriptions.iter() { + if sub.session_id == origin_session_id { + continue; + } + + if !Self::prefix_matches(&sub.prefix, namespace) { + continue; + } + + // If subscription has TRACK_FILTER, check if track is in top-N + if let Some(ref filter) = sub.track_filter { + let tracker_key = (sub.prefix.clone(), filter.property_type); + + if let Some(snapshot) = tracker_snapshots.get(&tracker_key) { + if let Some(tracker) = inner.top_n_trackers.get(&tracker_key) { + if !tracker.is_in_top_n_with_snapshot( + namespace, + track_name, + sub.session_id, + filter.max_selected, + snapshot, + ) { + continue; + } + } else { + continue; + } + } else { + continue; + } + } + + // Check if we already sent PUBLISH for this (subscription, track) pair + let publish_key = (*id, namespace.clone(), track_name.to_string()); + if inner.published_tracks.contains(&publish_key) { + continue; + } + + if sub.publish_tx.send(notification.clone()).is_ok() { + notified += 1; + keys_to_add.push(publish_key); + } + } + + // Record sent PUBLISH notifications with per-subscription index + for key in keys_to_add { + let sub_id = key.0; + let ns = key.1.clone(); + let track = key.2.clone(); + inner.published_tracks.insert(key); + inner + .subscription_published + .entry(sub_id) + .or_insert_with(Vec::new) + .push((ns, track)); + } + + notified + } + + /// Find all subscriptions that match a given namespace and notify them of a PUBLISH_NAMESPACE + /// Excludes the session that originated the PUBLISH_NAMESPACE (self-exclusion) + /// Returns the number of matching subscriptions notified + pub fn notify_publish_namespace(&self, namespace: &TrackNamespace, origin_session_id: u64) -> usize { + let inner = self.inner.lock().unwrap(); + + let notification = PublishNamespaceNotification { + namespace: namespace.clone(), + }; + + let mut notified = 0; + + for (id, sub) in inner.subscriptions.iter() { + // Skip if this subscription belongs to the same session that sent the PUBLISH_NAMESPACE + if sub.session_id == origin_session_id { + log::debug!( + "skipping subscription id={} for PUBLISH_NAMESPACE (same session {})", + id, + origin_session_id + ); + continue; + } + + // Check if the namespace matches the subscription prefix + if Self::prefix_matches(&sub.prefix, namespace) { + if let Err(e) = sub.publish_ns_tx.send(notification.clone()) { + log::warn!( + "failed to notify subscription id={} of PUBLISH_NAMESPACE: {}", + id, + e + ); + } else { + log::debug!( + "notified subscription id={} of PUBLISH_NAMESPACE {:?}", + id, + namespace + ); + notified += 1; + } + } + } + + notified + } + + /// Check if prefix is a prefix of namespace + fn prefix_matches(prefix: &TrackNamespace, namespace: &TrackNamespace) -> bool { + if prefix.fields.len() > namespace.fields.len() { + return false; + } + + prefix + .fields + .iter() + .zip(namespace.fields.iter()) + .all(|(a, b)| a == b) + } + + /// Get all subscriptions matching a prefix (for debugging) + pub fn matching_subscriptions(&self, namespace: &TrackNamespace) -> Vec { + let inner = self.inner.lock().unwrap(); + + inner + .subscriptions + .iter() + .filter(|(_, sub)| Self::prefix_matches(&sub.prefix, namespace)) + .map(|(id, _)| *id) + .collect() + } +} + +impl Default for SubscriberRegistry { + fn default() -> Self { + Self::new() + } +} + +/// RAII guard that unregisters on drop +pub struct SubscriptionGuard { + registry: SubscriberRegistry, + id: u64, +} + +impl SubscriptionGuard { + pub fn new(registry: SubscriberRegistry, id: u64) -> Self { + Self { registry, id } + } + + pub fn id(&self) -> u64 { + self.id + } +} + +impl Drop for SubscriptionGuard { + fn drop(&mut self) { + self.registry.unregister(self.id); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn ns(path: &str) -> TrackNamespace { + TrackNamespace::from_utf8_path(path) + } + + #[test] + fn test_prefix_matching() { + assert!(SubscriberRegistry::prefix_matches(&ns("live"), &ns("live/stream1"))); + assert!(SubscriberRegistry::prefix_matches(&ns("live"), &ns("live"))); + // An empty prefix (zero fields) should match everything + let empty = TrackNamespace::new(); + assert!(SubscriberRegistry::prefix_matches(&empty, &ns("live/stream1"))); + assert!(!SubscriberRegistry::prefix_matches(&ns("live/stream1"), &ns("live"))); + assert!(!SubscriberRegistry::prefix_matches(&ns("other"), &ns("live/stream1"))); + } + + #[test] + fn test_register_unregister() { + let registry = SubscriberRegistry::new(); + + let (id1, _rx1, _rx1_ns) = registry.register(ns("live"), 100); + let (id2, _rx2, _rx2_ns) = registry.register(ns("live/room1"), 101); + + assert_eq!(registry.matching_subscriptions(&ns("live/room1/track")).len(), 2); + + registry.unregister(id1); + + assert_eq!(registry.matching_subscriptions(&ns("live/room1/track")).len(), 1); + + registry.unregister(id2); + + assert_eq!(registry.matching_subscriptions(&ns("live/room1/track")).len(), 0); + } + + #[tokio::test] + async fn test_notify_publish() { + let registry = SubscriberRegistry::new(); + + // Register with session_id=100 + let (id, mut rx, _rx_ns) = registry.register(ns("live"), 100); + + // Notify from session 200 (different) - should be delivered + let notified = registry.notify_publish(&ns("live/stream1"), "video", 42, 200); + assert_eq!(notified, 1); + + let notification = rx.recv().await.unwrap(); + assert_eq!(notification.track_name, "video"); + assert_eq!(notification.track_alias, 42); + + registry.unregister(id); + } + + #[tokio::test] + async fn test_self_exclusion() { + let registry = SubscriberRegistry::new(); + + // Register with session_id=100 + let (_id, mut rx, _rx_ns) = registry.register(ns("live"), 100); + + // Notify from the same session (100) - should NOT be delivered + let notified = registry.notify_publish(&ns("live/stream1"), "video", 42, 100); + assert_eq!(notified, 0); + + // Verify nothing was received (use try_recv to avoid blocking) + assert!(rx.try_recv().is_err()); + } + + // --- TRACK_FILTER tests --- + + const PROPERTY_VIEWERS: u64 = 0x100; + + #[tokio::test] + async fn test_track_filter_top_n() { + let registry = SubscriberRegistry::new(); + + // Subscriber wants top-2 tracks by viewers + let filter = TrackFilter { + property_type: PROPERTY_VIEWERS, + max_selected: 2, + }; + let (_id, mut rx, _rx_ns) = + registry.register_with_filter(ns("live"), 100, Some(filter)); + + // Register 4 tracks with different viewer counts + // Publisher session IDs are different from subscriber (100) + registry.register_track(&ns("live"), "a", PROPERTY_VIEWERS, 1000, 1); // highest + registry.register_track(&ns("live"), "b", PROPERTY_VIEWERS, 500, 2); + registry.register_track(&ns("live"), "c", PROPERTY_VIEWERS, 800, 3); + registry.register_track(&ns("live"), "d", PROPERTY_VIEWERS, 200, 4); // lowest + + // Ranked order: a(1000), c(800), b(500), d(200) + // Top-2 = a, c + + // Track "a" should be notified (in top-2) + let notified = registry.notify_publish(&ns("live"), "a", 1, 1); + assert_eq!(notified, 1); + let msg = rx.try_recv().unwrap(); + assert_eq!(msg.track_name, "a"); + + // Track "c" should be notified (in top-2) + let notified = registry.notify_publish(&ns("live"), "c", 3, 3); + assert_eq!(notified, 1); + let msg = rx.try_recv().unwrap(); + assert_eq!(msg.track_name, "c"); + + // Track "b" should NOT be notified (rank 3, not in top-2) + let notified = registry.notify_publish(&ns("live"), "b", 2, 2); + assert_eq!(notified, 0); + assert!(rx.try_recv().is_err()); + + // Track "d" should NOT be notified (rank 4, not in top-2) + let notified = registry.notify_publish(&ns("live"), "d", 4, 4); + assert_eq!(notified, 0); + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn test_track_filter_self_exclusion() { + let registry = SubscriberRegistry::new(); + + // Subscriber (session 1) wants top-2 tracks but also publishes + let filter = TrackFilter { + property_type: PROPERTY_VIEWERS, + max_selected: 2, + }; + let (_id, mut rx, _rx_ns) = + registry.register_with_filter(ns("live"), 1, Some(filter)); + + // Session 1 publishes the top track + registry.register_track(&ns("live"), "a", PROPERTY_VIEWERS, 1000, 1); // self + registry.register_track(&ns("live"), "b", PROPERTY_VIEWERS, 500, 2); + registry.register_track(&ns("live"), "c", PROPERTY_VIEWERS, 800, 3); + registry.register_track(&ns("live"), "d", PROPERTY_VIEWERS, 200, 4); + + // Ranked order: a(1000), c(800), b(500), d(200) + // For session 1 with self-exclusion, top-2 non-self = c, b + + // Track "a" - subscriber won't receive (self-exclusion at session level) + let notified = registry.notify_publish(&ns("live"), "a", 1, 1); + assert_eq!(notified, 0); // same session, basic self-exclusion + assert!(rx.try_recv().is_err()); + + // Track "c" should be notified (top-2 non-self) + let notified = registry.notify_publish(&ns("live"), "c", 3, 3); + assert_eq!(notified, 1); + let msg = rx.try_recv().unwrap(); + assert_eq!(msg.track_name, "c"); + + // Track "b" should be notified (top-2 non-self, because "a" is excluded) + let notified = registry.notify_publish(&ns("live"), "b", 2, 2); + assert_eq!(notified, 1); + let msg = rx.try_recv().unwrap(); + assert_eq!(msg.track_name, "b"); + + // Track "d" should NOT be notified (rank 4 non-self, only top-2) + let notified = registry.notify_publish(&ns("live"), "d", 4, 4); + assert_eq!(notified, 0); + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn test_track_filter_value_update() { + let registry = SubscriberRegistry::new(); + + let filter = TrackFilter { + property_type: PROPERTY_VIEWERS, + max_selected: 1, + }; + let (_id, mut rx, _rx_ns) = + registry.register_with_filter(ns("live"), 100, Some(filter)); + + // Initial: a=100, b=50 + registry.register_track(&ns("live"), "a", PROPERTY_VIEWERS, 100, 1); + registry.register_track(&ns("live"), "b", PROPERTY_VIEWERS, 50, 2); + + // a is top-1 + let notified = registry.notify_publish(&ns("live"), "a", 1, 1); + assert_eq!(notified, 1); + rx.try_recv().unwrap(); + + let notified = registry.notify_publish(&ns("live"), "b", 2, 2); + assert_eq!(notified, 0); + + // Update b to 200, making it top-1 + // The update_track_value now also notifies subscribers + let notified = registry.update_track_value(&ns("live"), "b", PROPERTY_VIEWERS, 200, 2, 2); + assert_eq!(notified, 1); + rx.try_recv().unwrap(); + + // notify_publish for b should return 0 since PUBLISH was already sent via update_track_value + let notified = registry.notify_publish(&ns("live"), "b", 2, 2); + assert_eq!(notified, 0); + + // a is no longer top-1 + let notified = registry.notify_publish(&ns("live"), "a", 1, 1); + assert_eq!(notified, 0); + } + + #[test] + fn test_mixed_filtered_unfiltered_subscriptions() { + let registry = SubscriberRegistry::new(); + + // Unfiltered subscription + let (_id1, _rx1, _) = registry.register(ns("live"), 100); + + // Filtered subscription + let filter = TrackFilter { + property_type: PROPERTY_VIEWERS, + max_selected: 1, + }; + let (_id2, _rx2, _) = + registry.register_with_filter(ns("live"), 101, Some(filter)); + + // Register tracks + registry.register_track(&ns("live"), "a", PROPERTY_VIEWERS, 100, 1); + registry.register_track(&ns("live"), "b", PROPERTY_VIEWERS, 50, 2); + + // Track "a" (top-1): both subscriptions should receive + let notified = registry.notify_publish(&ns("live"), "a", 1, 1); + assert_eq!(notified, 2); + + // Track "b" (not top-1): only unfiltered subscription receives + let notified = registry.notify_publish(&ns("live"), "b", 2, 2); + assert_eq!(notified, 1); // only unfiltered + } +} diff --git a/moq-relay-ietf/src/tls.rs b/moq-relay-ietf/src/tls.rs index f391ae7b..98e07b90 100644 --- a/moq-relay-ietf/src/tls.rs +++ b/moq-relay-ietf/src/tls.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use anyhow::Context; use ring::digest::{digest, SHA256}; use rustls::server::{ClientHello, ResolvesServerCert}; diff --git a/moq-relay-ietf/src/top_n_tracker.rs b/moq-relay-ietf/src/top_n_tracker.rs new file mode 100644 index 00000000..2e3d66f6 --- /dev/null +++ b/moq-relay-ietf/src/top_n_tracker.rs @@ -0,0 +1,1176 @@ +//! Simple N+X TopN Tracker for TRACK_FILTER support. +//! +//! Implements the "Simple N+X" design from the MOQ filter analysis: +//! - Single global sorted snapshot per namespace +//! - Snapshot size = max(N) + max(X) where N = max subscriber filter, X = max tracks per publisher +//! - Self-exclusion computed at query time by skipping publisher's own tracks +//! - Lock-free reads via RwLock (Rust's RwLock allows concurrent readers) +//! - Configurable tie-breaking policy and staleness handling +//! +//! ## Structured Logging for Visualization +//! +//! When the `TOPN_LOG` environment variable is set, this module emits structured +//! JSON logs that can be used to generate timeline visualizations. Log lines are +//! prefixed with `TOPN_EVENT:` followed by JSON. +//! +//! Event types: +//! - `track_registered`: A new track was registered +//! - `value_updated`: A track's property value changed +//! - `top_n_query`: A subscriber queried their top-N (includes selection and self-exclusion) +//! +//! To generate a visualization from logs: +//! ```bash +//! TOPN_LOG=1 cargo run -p moq-relay-ietf ... 2>&1 | grep TOPN_EVENT > topn.log +//! cargo run -p moq-topn-test --bin topn-log-to-svg -- topn.log output.svg +//! ``` + +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use arc_swap::ArcSwap; +use moq_transport::coding::TrackNamespace; + +/// Structured event logger for visualization +/// +/// Emits JSON events to stdout that can be parsed to generate timeline visualizations. +/// Enable via `TopNTrackerConfig::enable_event_logging` or the `TOPN_LOG` env var. +pub struct TopNEventLogger { + enabled: std::sync::atomic::AtomicBool, + start: Instant, +} + +impl TopNEventLogger { + pub fn new() -> Self { + Self { + enabled: std::sync::atomic::AtomicBool::new(std::env::var("TOPN_LOG").is_ok()), + start: Instant::now(), + } + } + + /// Enable or disable event logging at runtime + pub fn set_enabled(&self, enabled: bool) { + self.enabled.store(enabled, Ordering::Relaxed); + } + + /// Check if logging is enabled + pub fn is_enabled(&self) -> bool { + self.enabled.load(Ordering::Relaxed) + } + + fn ts_ms(&self) -> u64 { + self.start.elapsed().as_millis() as u64 + } + + pub fn log_track_registered(&self, track_name: &str, value: u64, publisher_id: u64) { + if !self.is_enabled() { + return; + } + println!( + r#"TOPN_EVENT:{{"ts_ms":{},"event":"track_registered","track":"{}","value":{},"publisher_id":{}}}"#, + self.ts_ms(), + track_name, + value, + publisher_id + ); + } + + pub fn log_value_updated(&self, track_name: &str, old_value: u64, new_value: u64, publisher_id: u64) { + if !self.is_enabled() { + return; + } + println!( + r#"TOPN_EVENT:{{"ts_ms":{},"event":"value_updated","track":"{}","old_value":{},"new_value":{},"publisher_id":{}}}"#, + self.ts_ms(), + track_name, + old_value, + new_value, + publisher_id + ); + } + + pub fn log_top_n_query(&self, subscriber_id: u64, n: u8, selected: &[(String, u64)], excluded_self: Option) { + if !self.is_enabled() { + return; + } + let selected_json: Vec = selected + .iter() + .map(|(name, val)| format!(r#"{{"track":"{}","value":{}}}"#, name, val)) + .collect(); + println!( + r#"TOPN_EVENT:{{"ts_ms":{},"event":"top_n_query","subscriber_id":{},"n":{},"selected":[{}],"excluded_self":{}}}"#, + self.ts_ms(), + subscriber_id, + n, + selected_json.join(","), + excluded_self.map(|id| id.to_string()).unwrap_or_else(|| "null".to_string()) + ); + } + + pub fn log_track_removed(&self, track_name: &str, publisher_id: u64) { + if !self.is_enabled() { + return; + } + println!( + r#"TOPN_EVENT:{{"ts_ms":{},"event":"track_removed","track":"{}","publisher_id":{}}}"#, + self.ts_ms(), + track_name, + publisher_id + ); + } +} + +impl Default for TopNEventLogger { + fn default() -> Self { + Self::new() + } +} + +/// Tie-breaking policy when tracks have equal property values +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum TieBreakPolicy { + /// First-come-first-served: earlier arrival wins (stable rankings) + #[default] + OldestWins, + /// Most recently active wins (responsive to activity) + MostRecentWins, +} + +/// Configuration for TopNTracker +#[derive(Clone, Debug)] +pub struct TopNTrackerConfig { + /// How to break ties when property values are equal + pub tie_break_policy: TieBreakPolicy, + /// How long before a track is considered stale (None = never stale) + pub staleness_timeout: Option, + /// Enable structured event logging for visualization (also enabled via TOPN_LOG env var) + pub enable_event_logging: bool, +} + +impl Default for TopNTrackerConfig { + fn default() -> Self { + Self { + tie_break_policy: TieBreakPolicy::OldestWins, + staleness_timeout: None, + enable_event_logging: false, + } + } +} + +/// Entry in the sorted snapshot +#[derive(Clone, Debug)] +pub struct TrackRank { + pub namespace: TrackNamespace, + pub track_name: String, + pub property_value: u64, + pub arrival_seq: u64, + pub last_update: Instant, + pub publisher_session_id: u64, +} + +impl TrackRank { + /// Compare for sorting with configurable tie-breaking + fn rank_cmp(&self, other: &Self, policy: TieBreakPolicy) -> std::cmp::Ordering { + match other.property_value.cmp(&self.property_value) { + std::cmp::Ordering::Equal => match policy { + // Earlier arrival wins (lower seq = better) + TieBreakPolicy::OldestWins => self.arrival_seq.cmp(&other.arrival_seq), + // More recent update wins (later time = better, so reverse comparison) + TieBreakPolicy::MostRecentWins => other.last_update.cmp(&self.last_update), + }, + ord => ord, + } + } +} + +/// Track key for indexing +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +pub struct TrackKey { + pub namespace: TrackNamespace, + pub track_name: String, +} + +impl TrackKey { + pub fn new(namespace: TrackNamespace, track_name: String) -> Self { + Self { + namespace, + track_name, + } + } +} + +/// Internal track metadata +struct TrackInfo { + property_value: u64, + arrival_seq: u64, + last_update: Instant, + publisher_session_id: u64, +} + +/// Simple N+X TopN Tracker +/// +/// Maintains a sorted snapshot of top tracks by property value. +/// Self-exclusion is handled at query time, not via pre-computed waterlines. +/// +/// Optimizations: +/// - Lock-free snapshot reads via ArcSwap (no mutex on the read path) +/// - Lazy/coalesced rebuild: marks dirty on write, rebuilds only on next read +/// - Snapshot versioning for cache invalidation in subscribers +pub struct TopNTracker { + /// Sorted snapshot of top tracks (size = max_n + max_x). Lock-free reads. + snapshot: ArcSwap>, + + /// Monotonically increasing version, bumped on each snapshot rebuild + snapshot_version: AtomicU64, + + /// Dirty flag: set on value update, cleared on rebuild + dirty: AtomicBool, + + /// Track metadata index (protected by Mutex for write serialization) + track_index: Mutex>, + + /// Track count per publisher (for computing max_x) + publisher_track_count: Mutex>, + + /// Max N across all subscribers + max_n: AtomicU8, + + /// Max tracks per publisher (for snapshot sizing) + max_x: AtomicU8, + + /// Arrival sequence counter + next_seq: AtomicU64, + + /// Property type being tracked + property_type: u64, + + /// Configuration + config: TopNTrackerConfig, + + /// Event logger for visualization (enabled via TOPN_LOG env var) + event_logger: TopNEventLogger, +} + +impl TopNTracker { + /// Create a new tracker for a given property type with default config + pub fn new(property_type: u64) -> Self { + Self::with_config(property_type, TopNTrackerConfig::default()) + } + + /// Create a new tracker with custom configuration + pub fn with_config(property_type: u64, config: TopNTrackerConfig) -> Self { + let event_logger = TopNEventLogger::new(); + if config.enable_event_logging { + event_logger.set_enabled(true); + } + Self { + snapshot: ArcSwap::from_pointee(Vec::new()), + snapshot_version: AtomicU64::new(0), + dirty: AtomicBool::new(false), + track_index: Mutex::new(HashMap::new()), + publisher_track_count: Mutex::new(HashMap::new()), + max_n: AtomicU8::new(0), + max_x: AtomicU8::new(0), + next_seq: AtomicU64::new(0), + property_type, + config, + event_logger, + } + } + + /// Enable or disable event logging at runtime + pub fn set_event_logging(&self, enabled: bool) { + self.event_logger.set_enabled(enabled); + } + + /// Check if event logging is enabled + pub fn is_event_logging_enabled(&self) -> bool { + self.event_logger.is_enabled() + } + + /// Get the current configuration + pub fn config(&self) -> &TopNTrackerConfig { + &self.config + } + + /// Get the property type this tracker handles + pub fn property_type(&self) -> u64 { + self.property_type + } + + /// Register a new track + pub fn register_track( + &self, + namespace: TrackNamespace, + track_name: String, + property_value: u64, + publisher_session_id: u64, + ) { + let key = TrackKey::new(namespace, track_name); + let seq = self.next_seq.fetch_add(1, Ordering::Relaxed); + let now = Instant::now(); + + { + let mut index = self.track_index.lock().unwrap(); + if index.contains_key(&key) { + log::warn!("track already registered: {:?}", key); + return; + } + index.insert( + key.clone(), + TrackInfo { + property_value, + arrival_seq: seq, + last_update: now, + publisher_session_id, + }, + ); + } + + { + let mut counts = self.publisher_track_count.lock().unwrap(); + let count = counts.entry(publisher_session_id).or_insert(0); + *count += 1; + self.update_max_x(&counts); + } + + // Membership changed (track added) — force immediate rebuild + self.do_rebuild_snapshot(); + + let full_track_path = format!("{}/{}", key.namespace, key.track_name); + self.event_logger.log_track_registered(&full_track_path, property_value, publisher_session_id); + + log::debug!( + "registered track {:?} with value {} from session {}", + key, + property_value, + publisher_session_id + ); + } + + /// Update a track's property value. Marks dirty for lazy rebuild. + pub fn update_value(&self, namespace: &TrackNamespace, track_name: &str, new_value: u64) { + let key = TrackKey::new(namespace.clone(), track_name.to_string()); + let now = Instant::now(); + + let old_value_and_publisher = { + let mut index = self.track_index.lock().unwrap(); + if let Some(info) = index.get_mut(&key) { + if info.property_value == new_value { + info.last_update = now; + return; + } + let old_value = info.property_value; + let publisher_id = info.publisher_session_id; + info.property_value = new_value; + info.last_update = now; + Some((old_value, publisher_id)) + } else { + log::warn!("update_value: track not found: {:?}", key); + return; + } + }; + + // Mark dirty — snapshot will be rebuilt lazily on next read + self.dirty.store(true, Ordering::Release); + + if let Some((old_value, publisher_id)) = old_value_and_publisher { + self.event_logger.log_value_updated(track_name, old_value, new_value, publisher_id); + } + + log::debug!("updated track {:?} to value {}", key, new_value); + } + + /// Touch a track to update its last_update timestamp without changing value + pub fn touch_track(&self, namespace: &TrackNamespace, track_name: &str) { + let key = TrackKey::new(namespace.clone(), track_name.to_string()); + let now = Instant::now(); + + let mut index = self.track_index.lock().unwrap(); + if let Some(info) = index.get_mut(&key) { + info.last_update = now; + } + } + + /// Remove a track + pub fn remove_track(&self, namespace: &TrackNamespace, track_name: &str) { + let key = TrackKey::new(namespace.clone(), track_name.to_string()); + + let publisher_session_id = { + let mut index = self.track_index.lock().unwrap(); + if let Some(info) = index.remove(&key) { + Some(info.publisher_session_id) + } else { + None + } + }; + + if let Some(session_id) = publisher_session_id { + let mut counts = self.publisher_track_count.lock().unwrap(); + if let Some(count) = counts.get_mut(&session_id) { + *count -= 1; + if *count == 0 { + counts.remove(&session_id); + } + } + self.update_max_x(&counts); + + self.event_logger.log_track_removed(track_name, session_id); + } + + // Membership changed — force immediate rebuild + self.do_rebuild_snapshot(); + + log::debug!("removed track {:?}", key); + } + + /// Update max_n when a subscriber joins/leaves + pub fn update_max_n(&self, new_max_n: u8) { + let old = self.max_n.swap(new_max_n, Ordering::Relaxed); + if old != new_max_n { + self.do_rebuild_snapshot(); + } + } + + /// Get current max_n + pub fn max_n(&self) -> u8 { + self.max_n.load(Ordering::Relaxed) + } + + /// Get current snapshot version (for cache invalidation) + pub fn snapshot_version(&self) -> u64 { + self.snapshot_version.load(Ordering::Acquire) + } + + /// Load the current snapshot (lock-free). Rebuilds if dirty. + pub fn load_snapshot(&self) -> Arc> { + if self.dirty.load(Ordering::Acquire) { + self.do_rebuild_snapshot(); + } + self.snapshot.load_full() + } + + /// Compute top-N tracks for a session, excluding self-published tracks + /// + /// This is the core self-exclusion logic: scan the snapshot, skip tracks + /// where publisher_session_id matches the querying session. + pub fn compute_top_n_for_session( + &self, + session_id: u64, + n: u8, + ) -> Vec<(TrackNamespace, String)> { + let snapshot = self.load_snapshot(); + let mut result = Vec::with_capacity(n as usize); + let mut selected_with_values = Vec::with_capacity(n as usize); + let mut has_self_tracks = false; + + for track in snapshot.iter() { + if result.len() >= n as usize { + break; + } + // Self-exclusion: skip if publisher is same session + if track.publisher_session_id != session_id { + result.push((track.namespace.clone(), track.track_name.clone())); + selected_with_values.push((track.track_name.clone(), track.property_value)); + } else { + has_self_tracks = true; + } + } + + // Log the query result + let excluded_self = if has_self_tracks { Some(session_id) } else { None }; + self.event_logger.log_top_n_query(session_id, n, &selected_with_values, excluded_self); + + result + } + + /// Check if a track is in the top-N for a session (with self-exclusion) + pub fn is_in_top_n( + &self, + namespace: &TrackNamespace, + track_name: &str, + session_id: u64, + n: u8, + ) -> bool { + let snapshot = self.load_snapshot(); + self.is_in_top_n_with_snapshot(namespace, track_name, session_id, n, &snapshot) + } + + /// Check if a track is in the top-N using a pre-loaded snapshot + /// + /// This variant allows reusing a snapshot across multiple checks for efficiency. + pub fn is_in_top_n_with_snapshot( + &self, + namespace: &TrackNamespace, + track_name: &str, + session_id: u64, + n: u8, + snapshot: &[TrackRank], + ) -> bool { + let mut non_self_count = 0u8; + + for track in snapshot.iter() { + // Self-exclusion: skip publisher's own tracks + if track.publisher_session_id == session_id { + continue; + } + + // Check if this is the track we're looking for + if &track.namespace == namespace && track.track_name == track_name { + return true; // Found before reaching N non-self tracks + } + + non_self_count += 1; + if non_self_count >= n { + return false; // Reached N non-self tracks, target not among them + } + } + + false + } + + /// Check if a track is in the top-N with fast rejection optimization + /// + /// Uses cached_last_self_position for O(1) rejection of tracks that definitely + /// can't be in the subscriber's top-N non-self. + /// + /// # Arguments + /// * `track_position` - The track's position in the snapshot (0-indexed) + /// * `session_id` - The subscriber's session ID (for self-exclusion) + /// * `n` - The subscriber's N value + /// * `cached_last_self_pos` - Cached position of the subscriber's last self-track in snapshot + /// * `snapshot` - The current snapshot + /// + /// # Returns + /// * `Some(true)` - Track is definitely in top-N + /// * `Some(false)` - Track is definitely NOT in top-N (fast rejection) + /// * `None` - Need full scan to determine + pub fn is_in_top_n_fast( + &self, + namespace: &TrackNamespace, + track_name: &str, + track_position: usize, + session_id: u64, + n: u8, + cached_last_self_pos: u8, + snapshot: &[TrackRank], + ) -> bool { + // Fast rejection: if track position > N + cachedLastSelfPos, + // it definitely can't be in top-N non-self + if !Self::might_be_in_top_n(track_position, n, cached_last_self_pos) { + return false; + } + + // Need full scan for tracks that might be in top-N + self.is_in_top_n_with_snapshot(namespace, track_name, session_id, n, snapshot) + } + + /// O(1) fast rejection check: can this track possibly be in top-N non-self? + /// + /// If a subscriber has self-tracks scattered in the snapshot, the worst case + /// is that all their self-tracks are above the track we're checking. In that + /// case, the track's effective position would be: actual_position - num_self_above. + /// + /// Using cached_last_self_pos (position of last self-track), we know that at most + /// cached_last_self_pos self-tracks can be above any position. So if: + /// track_position > N + cached_last_self_pos + /// then even if ALL self-tracks were above this track, it still wouldn't make top-N. + #[inline] + pub fn might_be_in_top_n(track_position: usize, n: u8, cached_last_self_pos: u8) -> bool { + track_position <= (n as usize) + (cached_last_self_pos as usize) + } + + /// Compute the position of the last self-track in the snapshot for a session + /// + /// This value should be cached per-subscriber and invalidated when snapshot changes. + /// Returns 0 if the session has no tracks in the snapshot. + pub fn compute_last_self_position(&self, session_id: u64) -> u8 { + let snapshot = self.load_snapshot(); + Self::compute_last_self_position_in_snapshot(session_id, &snapshot) + } + + /// Compute last self position from a pre-loaded snapshot + pub fn compute_last_self_position_in_snapshot(session_id: u64, snapshot: &[TrackRank]) -> u8 { + let mut last_pos: u8 = 0; + for (i, track) in snapshot.iter().enumerate() { + if track.publisher_session_id == session_id { + last_pos = (i as u8).min(255); + } + } + last_pos + } + + /// Find a track's position in the snapshot (for fast rejection) + /// Returns None if track is not in the snapshot + pub fn find_track_position(&self, namespace: &TrackNamespace, track_name: &str) -> Option { + let snapshot = self.load_snapshot(); + snapshot.iter().position(|t| &t.namespace == namespace && t.track_name == track_name) + } + + /// Force a snapshot rebuild if dirty. Called by external code that needs fresh data. + pub fn ensure_fresh(&self) { + if self.dirty.load(Ordering::Acquire) { + self.do_rebuild_snapshot(); + } + } + + /// Get number of tracked tracks + pub fn num_tracks(&self) -> usize { + self.track_index.lock().unwrap().len() + } + + /// Sweep stale tracks from the index and rebuild snapshot + /// + /// Call this periodically (e.g., every 100ms) to clean up tracks from + /// disconnected publishers. Only does work if staleness_timeout is configured. + /// Returns the number of tracks removed. + pub fn sweep_stale(&self) -> usize { + let timeout = match self.config.staleness_timeout { + Some(t) => t, + None => return 0, + }; + + let now = Instant::now(); + let mut removed_count = 0; + let mut affected_publishers = Vec::new(); + + { + let mut index = self.track_index.lock().unwrap(); + let stale_keys: Vec<_> = index + .iter() + .filter(|(_, info)| now.duration_since(info.last_update) >= timeout) + .map(|(key, info)| (key.clone(), info.publisher_session_id)) + .collect(); + + for (key, publisher_id) in stale_keys { + index.remove(&key); + affected_publishers.push(publisher_id); + removed_count += 1; + } + } + + if !affected_publishers.is_empty() { + let mut counts = self.publisher_track_count.lock().unwrap(); + for publisher_id in affected_publishers { + if let Some(count) = counts.get_mut(&publisher_id) { + *count = count.saturating_sub(1); + if *count == 0 { + counts.remove(&publisher_id); + } + } + } + self.update_max_x(&counts); + } + + if removed_count > 0 { + self.do_rebuild_snapshot(); + log::debug!("swept {} stale tracks", removed_count); + } + + removed_count + } + + // --- Private methods --- + + fn update_max_x(&self, counts: &HashMap) { + let max = counts.values().copied().max().unwrap_or(0); + self.max_x + .store(max.min(255) as u8, Ordering::Relaxed); + } + + fn do_rebuild_snapshot(&self) { + self.dirty.store(false, Ordering::Release); + + let max_n = self.max_n.load(Ordering::Relaxed) as usize; + let max_x = self.max_x.load(Ordering::Relaxed) as usize; + + if max_n == 0 { + self.snapshot.store(Arc::new(Vec::new())); + return; + } + + let snapshot_size = max_n + max_x; + let now = Instant::now(); + let staleness_threshold = self.config.staleness_timeout; + + let index = self.track_index.lock().unwrap(); + let mut tracks: Vec = index + .iter() + .filter(|(_, info)| { + match staleness_threshold { + Some(timeout) => now.duration_since(info.last_update) < timeout, + None => true, + } + }) + .map(|(key, info)| TrackRank { + namespace: key.namespace.clone(), + track_name: key.track_name.clone(), + property_value: info.property_value, + arrival_seq: info.arrival_seq, + last_update: info.last_update, + publisher_session_id: info.publisher_session_id, + }) + .collect(); + + let policy = self.config.tie_break_policy; + tracks.sort_by(|a, b| a.rank_cmp(b, policy)); + tracks.truncate(snapshot_size); + + self.snapshot.store(Arc::new(tracks)); + self.snapshot_version.fetch_add(1, Ordering::Release); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn ns(path: &str) -> TrackNamespace { + TrackNamespace::from_utf8_path(path) + } + + #[test] + fn test_register_and_query() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(3); + + tracker.register_track(ns("live"), "a".to_string(), 100, 1); + tracker.register_track(ns("live"), "b".to_string(), 90, 2); + tracker.register_track(ns("live"), "c".to_string(), 80, 3); + tracker.register_track(ns("live"), "d".to_string(), 70, 4); + + // Session 99 (pure subscriber) should get top 3 + let top3 = tracker.compute_top_n_for_session(99, 3); + assert_eq!(top3.len(), 3); + assert_eq!(top3[0].1, "a"); + assert_eq!(top3[1].1, "b"); + assert_eq!(top3[2].1, "c"); + } + + #[test] + fn test_self_exclusion() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(5); + + // Session 1 publishes tracks at positions 0 and 2 + tracker.register_track(ns("live"), "a".to_string(), 100, 1); // session 1 + tracker.register_track(ns("live"), "b".to_string(), 90, 2); + tracker.register_track(ns("live"), "c".to_string(), 80, 1); // session 1 + tracker.register_track(ns("live"), "d".to_string(), 70, 3); + tracker.register_track(ns("live"), "e".to_string(), 60, 4); + + // Session 1 wants top 3 non-self: should get b, d, e (skipping a, c) + let top3 = tracker.compute_top_n_for_session(1, 3); + assert_eq!(top3.len(), 3); + assert_eq!(top3[0].1, "b"); + assert_eq!(top3[1].1, "d"); + assert_eq!(top3[2].1, "e"); + } + + #[test] + fn test_is_in_top_n() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(3); + + tracker.register_track(ns("live"), "a".to_string(), 100, 1); + tracker.register_track(ns("live"), "b".to_string(), 90, 2); + tracker.register_track(ns("live"), "c".to_string(), 80, 3); + tracker.register_track(ns("live"), "d".to_string(), 70, 4); + + // For session 99, "a", "b", "c" are in top 3 + assert!(tracker.is_in_top_n(&ns("live"), "a", 99, 3)); + assert!(tracker.is_in_top_n(&ns("live"), "b", 99, 3)); + assert!(tracker.is_in_top_n(&ns("live"), "c", 99, 3)); + assert!(!tracker.is_in_top_n(&ns("live"), "d", 99, 3)); + } + + #[test] + fn test_is_in_top_n_with_self_exclusion() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(5); + + tracker.register_track(ns("live"), "a".to_string(), 100, 1); // session 1 + tracker.register_track(ns("live"), "b".to_string(), 90, 2); + tracker.register_track(ns("live"), "c".to_string(), 80, 1); // session 1 + tracker.register_track(ns("live"), "d".to_string(), 70, 3); + tracker.register_track(ns("live"), "e".to_string(), 60, 4); + + // For session 1 with N=3: should be b, d, e (a and c excluded) + assert!(!tracker.is_in_top_n(&ns("live"), "a", 1, 3)); // self + assert!(tracker.is_in_top_n(&ns("live"), "b", 1, 3)); + assert!(!tracker.is_in_top_n(&ns("live"), "c", 1, 3)); // self + assert!(tracker.is_in_top_n(&ns("live"), "d", 1, 3)); + assert!(tracker.is_in_top_n(&ns("live"), "e", 1, 3)); + } + + #[test] + fn test_update_value() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(2); + + tracker.register_track(ns("live"), "a".to_string(), 100, 1); + tracker.register_track(ns("live"), "b".to_string(), 50, 2); + tracker.register_track(ns("live"), "c".to_string(), 80, 3); + + // Initial: a (100), c (80) in top 2 + assert!(tracker.is_in_top_n(&ns("live"), "a", 99, 2)); + assert!(tracker.is_in_top_n(&ns("live"), "c", 99, 2)); + assert!(!tracker.is_in_top_n(&ns("live"), "b", 99, 2)); + + // Update b to 200 - now b should be #1 + tracker.update_value(&ns("live"), "b", 200); + + assert!(tracker.is_in_top_n(&ns("live"), "b", 99, 2)); // now in top + assert!(tracker.is_in_top_n(&ns("live"), "a", 99, 2)); + assert!(!tracker.is_in_top_n(&ns("live"), "c", 99, 2)); // pushed out + } + + #[test] + fn test_remove_track() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(2); + + tracker.register_track(ns("live"), "a".to_string(), 100, 1); + tracker.register_track(ns("live"), "b".to_string(), 90, 2); + tracker.register_track(ns("live"), "c".to_string(), 80, 3); + + // Remove "a", "b" should still be in top, "c" gets promoted + tracker.remove_track(&ns("live"), "a"); + + assert!(!tracker.is_in_top_n(&ns("live"), "a", 99, 2)); // removed + assert!(tracker.is_in_top_n(&ns("live"), "b", 99, 2)); + assert!(tracker.is_in_top_n(&ns("live"), "c", 99, 2)); // promoted + } + + #[test] + fn test_tie_breaker_earlier_wins() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(1); + + // Register with same value - first one should win + tracker.register_track(ns("live"), "first".to_string(), 100, 1); + tracker.register_track(ns("live"), "second".to_string(), 100, 2); + + let top1 = tracker.compute_top_n_for_session(99, 1); + assert_eq!(top1.len(), 1); + assert_eq!(top1[0].1, "first"); + } + + #[test] + fn test_snapshot_size_n_plus_x() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(3); + + // Session 1 publishes 4 tracks (max_x = 4) + tracker.register_track(ns("live"), "a".to_string(), 100, 1); + tracker.register_track(ns("live"), "b".to_string(), 90, 1); + tracker.register_track(ns("live"), "c".to_string(), 80, 1); + tracker.register_track(ns("live"), "d".to_string(), 70, 1); + + // Other sessions publish more tracks + tracker.register_track(ns("live"), "e".to_string(), 60, 2); + tracker.register_track(ns("live"), "f".to_string(), 50, 3); + tracker.register_track(ns("live"), "g".to_string(), 40, 4); + + // Snapshot should be size N + X = 3 + 4 = 7 + let snapshot = tracker.load_snapshot(); + assert_eq!(snapshot.len(), 7); + + // Session 1 wants top 3 non-self: all their tracks at top, so need positions 4,5,6 + let top3 = tracker.compute_top_n_for_session(1, 3); + assert_eq!(top3.len(), 3); + assert_eq!(top3[0].1, "e"); + assert_eq!(top3[1].1, "f"); + assert_eq!(top3[2].1, "g"); + } + + #[test] + fn test_might_be_in_top_n_fast_rejection() { + // Test the O(1) fast rejection logic + + // N=10, last_self_pos=5 → threshold = 15 + // Tracks at positions 0-15 might be in top-N, positions 16+ definitely not + assert!(TopNTracker::might_be_in_top_n(0, 10, 5)); + assert!(TopNTracker::might_be_in_top_n(10, 10, 5)); + assert!(TopNTracker::might_be_in_top_n(15, 10, 5)); + assert!(!TopNTracker::might_be_in_top_n(16, 10, 5)); + assert!(!TopNTracker::might_be_in_top_n(100, 10, 5)); + + // N=3, last_self_pos=0 (pure subscriber) → threshold = 3 + assert!(TopNTracker::might_be_in_top_n(0, 3, 0)); + assert!(TopNTracker::might_be_in_top_n(3, 3, 0)); + assert!(!TopNTracker::might_be_in_top_n(4, 3, 0)); + + // N=5, last_self_pos=50 (many self tracks) → threshold = 55 + assert!(TopNTracker::might_be_in_top_n(55, 5, 50)); + assert!(!TopNTracker::might_be_in_top_n(56, 5, 50)); + } + + #[test] + fn test_compute_last_self_position() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(10); + + // Session 1 has tracks at positions 0, 2, 4 + tracker.register_track(ns("live"), "a".to_string(), 100, 1); // pos 0 + tracker.register_track(ns("live"), "b".to_string(), 90, 2); // pos 1 + tracker.register_track(ns("live"), "c".to_string(), 80, 1); // pos 2 + tracker.register_track(ns("live"), "d".to_string(), 70, 3); // pos 3 + tracker.register_track(ns("live"), "e".to_string(), 60, 1); // pos 4 + tracker.register_track(ns("live"), "f".to_string(), 50, 4); // pos 5 + + // Session 1's last self-track is at position 4 + assert_eq!(tracker.compute_last_self_position(1), 4); + + // Session 2's last (and only) self-track is at position 1 + assert_eq!(tracker.compute_last_self_position(2), 1); + + // Session 99 has no tracks + assert_eq!(tracker.compute_last_self_position(99), 0); + } + + #[test] + fn test_is_in_top_n_fast() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(10); + + // Create a snapshot with session 1's tracks at positions 0, 2, 4 + tracker.register_track(ns("live"), "a".to_string(), 100, 1); // pos 0, session 1 + tracker.register_track(ns("live"), "b".to_string(), 90, 2); // pos 1 + tracker.register_track(ns("live"), "c".to_string(), 80, 1); // pos 2, session 1 + tracker.register_track(ns("live"), "d".to_string(), 70, 3); // pos 3 + tracker.register_track(ns("live"), "e".to_string(), 60, 1); // pos 4, session 1 + tracker.register_track(ns("live"), "f".to_string(), 50, 4); // pos 5 + tracker.register_track(ns("live"), "g".to_string(), 40, 5); // pos 6 + tracker.register_track(ns("live"), "h".to_string(), 30, 6); // pos 7 + tracker.register_track(ns("live"), "i".to_string(), 20, 7); // pos 8 + tracker.register_track(ns("live"), "j".to_string(), 10, 8); // pos 9 + + let snapshot = tracker.load_snapshot(); + let last_self_pos = tracker.compute_last_self_position(1); // = 4 + + // Session 1 wants N=3 non-self + // Their non-self tracks in order: b(pos1), d(pos3), f(pos5), g(pos6)... + // Top-3 non-self = b, d, f + + // Track "b" at position 1: might be in top-3 (1 <= 3+4=7), actually IS in top-3 + assert!(tracker.is_in_top_n_fast(&ns("live"), "b", 1, 1, 3, last_self_pos, &snapshot)); + + // Track "f" at position 5: might be in top-3 (5 <= 7), actually IS in top-3 + assert!(tracker.is_in_top_n_fast(&ns("live"), "f", 5, 1, 3, last_self_pos, &snapshot)); + + // Track "g" at position 6: might be in top-3 (6 <= 7), but NOT in top-3 (4th non-self) + assert!(!tracker.is_in_top_n_fast(&ns("live"), "g", 6, 1, 3, last_self_pos, &snapshot)); + + // Track "j" at position 9: fast rejection (9 > 7), definitely NOT in top-3 + assert!(!tracker.is_in_top_n_fast(&ns("live"), "j", 9, 1, 3, last_self_pos, &snapshot)); + } + + #[test] + fn test_find_track_position() { + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(5); + + tracker.register_track(ns("live"), "a".to_string(), 100, 1); + tracker.register_track(ns("live"), "b".to_string(), 90, 2); + tracker.register_track(ns("live"), "c".to_string(), 80, 3); + + assert_eq!(tracker.find_track_position(&ns("live"), "a"), Some(0)); + assert_eq!(tracker.find_track_position(&ns("live"), "b"), Some(1)); + assert_eq!(tracker.find_track_position(&ns("live"), "c"), Some(2)); + assert_eq!(tracker.find_track_position(&ns("live"), "nonexistent"), None); + } + + #[test] + fn test_tie_break_most_recent_wins() { + // With MostRecentWins policy, later updates should win ties + let config = TopNTrackerConfig { + tie_break_policy: TieBreakPolicy::MostRecentWins, + staleness_timeout: None, + ..Default::default() + }; + let tracker = TopNTracker::with_config(0x100, config); + tracker.update_max_n(3); + + // Register with same value - with MostRecentWins, last one should win + tracker.register_track(ns("live"), "first".to_string(), 100, 1); + std::thread::sleep(std::time::Duration::from_millis(10)); + tracker.register_track(ns("live"), "second".to_string(), 100, 2); + std::thread::sleep(std::time::Duration::from_millis(10)); + tracker.register_track(ns("live"), "third".to_string(), 100, 3); + + let top = tracker.compute_top_n_for_session(99, 3); + assert_eq!(top.len(), 3); + // Most recent should be first + assert_eq!(top[0].1, "third"); + assert_eq!(top[1].1, "second"); + assert_eq!(top[2].1, "first"); + } + + #[test] + fn test_tie_break_oldest_wins_default() { + // Default config should use OldestWins + let tracker = TopNTracker::new(0x100); + assert_eq!(tracker.config().tie_break_policy, TieBreakPolicy::OldestWins); + tracker.update_max_n(3); + + tracker.register_track(ns("live"), "first".to_string(), 100, 1); + std::thread::sleep(std::time::Duration::from_millis(10)); + tracker.register_track(ns("live"), "second".to_string(), 100, 2); + std::thread::sleep(std::time::Duration::from_millis(10)); + tracker.register_track(ns("live"), "third".to_string(), 100, 3); + + let top = tracker.compute_top_n_for_session(99, 3); + assert_eq!(top.len(), 3); + // Oldest should be first (first-come-first-served) + assert_eq!(top[0].1, "first"); + assert_eq!(top[1].1, "second"); + assert_eq!(top[2].1, "third"); + } + + #[test] + fn test_staleness_filtering() { + let config = TopNTrackerConfig { + tie_break_policy: TieBreakPolicy::OldestWins, + staleness_timeout: Some(Duration::from_millis(100)), + ..Default::default() + }; + let tracker = TopNTracker::with_config(0x100, config); + tracker.update_max_n(5); + + // Register tracks + tracker.register_track(ns("live"), "a".to_string(), 100, 1); + tracker.register_track(ns("live"), "b".to_string(), 90, 2); + + // Both should be in snapshot + assert_eq!(tracker.load_snapshot().len(), 2); + + // Wait for staleness + std::thread::sleep(Duration::from_millis(150)); + + // Register a new track (triggers rebuild which filters stale) + tracker.register_track(ns("live"), "c".to_string(), 80, 3); + + // Only "c" should remain (a and b are stale) + let snapshot = tracker.load_snapshot(); + assert_eq!(snapshot.len(), 1); + assert_eq!(snapshot[0].track_name, "c"); + } + + #[test] + fn test_sweep_stale() { + let config = TopNTrackerConfig { + tie_break_policy: TieBreakPolicy::OldestWins, + staleness_timeout: Some(Duration::from_millis(50)), + ..Default::default() + }; + let tracker = TopNTracker::with_config(0x100, config); + tracker.update_max_n(5); + + tracker.register_track(ns("live"), "a".to_string(), 100, 1); + tracker.register_track(ns("live"), "b".to_string(), 90, 2); + + assert_eq!(tracker.num_tracks(), 2); + + // Wait for staleness + std::thread::sleep(Duration::from_millis(100)); + + // Sweep should remove both tracks + let removed = tracker.sweep_stale(); + assert_eq!(removed, 2); + assert_eq!(tracker.num_tracks(), 0); + assert_eq!(tracker.load_snapshot().len(), 0); + } + + #[test] + fn test_sweep_stale_no_timeout() { + // Without staleness timeout, sweep should do nothing + let tracker = TopNTracker::new(0x100); + tracker.update_max_n(5); + + tracker.register_track(ns("live"), "a".to_string(), 100, 1); + + let removed = tracker.sweep_stale(); + assert_eq!(removed, 0); + assert_eq!(tracker.num_tracks(), 1); + } + + #[test] + fn test_update_value_refreshes_timestamp() { + let config = TopNTrackerConfig { + tie_break_policy: TieBreakPolicy::OldestWins, + staleness_timeout: Some(Duration::from_millis(100)), + ..Default::default() + }; + let tracker = TopNTracker::with_config(0x100, config); + tracker.update_max_n(5); + + tracker.register_track(ns("live"), "a".to_string(), 100, 1); + + // Wait a bit, then update (even with same value refreshes timestamp) + std::thread::sleep(Duration::from_millis(60)); + tracker.update_value(&ns("live"), "a", 100); // Same value, but refreshes timestamp + + // Wait more - would be stale without the refresh + std::thread::sleep(Duration::from_millis(60)); + + // Track should still be fresh because update_value refreshed it + let removed = tracker.sweep_stale(); + assert_eq!(removed, 0); + assert_eq!(tracker.num_tracks(), 1); + } + + #[test] + fn test_three_speakers_scenario_oldest_wins() { + // The "worked example" from the design doc with OldestWins policy + let tracker = TopNTracker::new(0x100); // Default: OldestWins + tracker.update_max_n(3); + + // t0: A speaks (value=100) + tracker.register_track(ns("live"), "A".to_string(), 100, 1); + std::thread::sleep(Duration::from_millis(5)); + + // t1: B speaks (value=100) + tracker.register_track(ns("live"), "B".to_string(), 100, 2); + std::thread::sleep(Duration::from_millis(5)); + + // t2: C speaks (value=100) + tracker.register_track(ns("live"), "C".to_string(), 100, 3); + + // With OldestWins: A (oldest) > B > C (newest) + let top = tracker.compute_top_n_for_session(99, 3); + assert_eq!(top[0].1, "A"); // Position 1 + assert_eq!(top[1].1, "B"); // Position 2 + assert_eq!(top[2].1, "C"); // Position 3 + } + + #[test] + fn test_three_speakers_scenario_most_recent_wins() { + // The "worked example" from the design doc with MostRecentWins policy + let config = TopNTrackerConfig { + tie_break_policy: TieBreakPolicy::MostRecentWins, + staleness_timeout: None, + ..Default::default() + }; + let tracker = TopNTracker::with_config(0x100, config); + tracker.update_max_n(3); + + // t0: A speaks (value=100) + tracker.register_track(ns("live"), "A".to_string(), 100, 1); + std::thread::sleep(Duration::from_millis(5)); + + // t1: B speaks (value=100) + tracker.register_track(ns("live"), "B".to_string(), 100, 2); + std::thread::sleep(Duration::from_millis(5)); + + // t2: C speaks (value=100) + tracker.register_track(ns("live"), "C".to_string(), 100, 3); + + // With MostRecentWins: C (newest) > B > A (oldest) + let top = tracker.compute_top_n_for_session(99, 3); + assert_eq!(top[0].1, "C"); // Position 1 + assert_eq!(top[1].1, "B"); // Position 2 + assert_eq!(top[2].1, "A"); // Position 3 + } +} diff --git a/moq-relay-ietf/src/web.rs b/moq-relay-ietf/src/web.rs index 2aa8fb29..52e05f05 100644 --- a/moq-relay-ietf/src/web.rs +++ b/moq-relay-ietf/src/web.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{net, path::PathBuf, sync::Arc}; use axum::{ @@ -62,13 +59,13 @@ impl Web { // Optionally add qlog serving endpoint if state.qlog_dir.is_some() { app = app.route("/qlog/:cid", get(serve_qlog)); - tracing::info!("qlog files available at /qlog/:cid"); + log::info!("qlog files available at /qlog/:cid"); } // Optionally add mlog serving endpoint if state.mlog_dir.is_some() { app = app.route("/mlog/:cid", get(serve_mlog)); - tracing::info!("mlog files available at /mlog/:cid"); + log::info!("mlog files available at /mlog/:cid"); } // Add state and CORS layer diff --git a/moq-sub/CHANGELOG.md b/moq-sub/CHANGELOG.md index f79d7338..2e956663 100644 --- a/moq-sub/CHANGELOG.md +++ b/moq-sub/CHANGELOG.md @@ -6,47 +6,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.4.8](https://github.com/cloudflare/moq-rs/compare/moq-sub-v0.4.7...moq-sub-v0.4.8) - 2026-05-20 - -### Fixed - -- subscribe cleaning on drop - -## [0.4.7](https://github.com/cloudflare/moq-rs/compare/moq-sub-v0.4.6...moq-sub-v0.4.7) - 2026-04-10 - -### Fixed - -- cross-platform dual-stack binding for IPv6 sockets - -### Other - -- Merge pull request #151 from englishm-cloudflare/me/ipv6-dual-stack-binding - -## [0.4.6](https://github.com/cloudflare/moq-rs/compare/moq-sub-v0.4.5...moq-sub-v0.4.6) - 2026-03-31 - -### Other - -- Make repo REUSE v3.3 compliant -- Bring copyright notices, license docs up to date - -## [0.4.5](https://github.com/cloudflare/moq-rs/compare/moq-sub-v0.4.4...moq-sub-v0.4.5) - 2026-03-27 - -### Added - -- add Transport enum and connection path extraction - -## [0.4.4](https://github.com/cloudflare/moq-rs/compare/moq-sub-v0.4.3...moq-sub-v0.4.4) - 2026-02-18 - -### Other - -- update Cargo.lock dependencies - -## [0.4.3](https://github.com/cloudflare/moq-rs/compare/moq-sub-v0.4.2...moq-sub-v0.4.3) - 2026-02-18 - -### Other - -- migrate from log crate to tracing - ## [0.4.2](https://github.com/cloudflare/moq-rs/compare/moq-sub-v0.4.1...moq-sub-v0.4.2) - 2026-02-03 ### Other diff --git a/moq-sub/Cargo.toml b/moq-sub/Cargo.toml index 71d7d882..cf19af89 100644 --- a/moq-sub/Cargo.toml +++ b/moq-sub/Cargo.toml @@ -1,15 +1,11 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - [package] name = "moq-sub" description = "Media over QUIC" -authors = ["moq-rs contributors"] -repository = "https://github.com/cloudflare/moq-rs" +authors = [] +repository = "https://github.com/englishm/moq-rs" license = "MIT OR Apache-2.0" -version = "0.4.8" +version = "0.4.2" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] @@ -18,8 +14,8 @@ categories = ["multimedia", "network-programming", "web-programming"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -moq-transport = { path = "../moq-transport", version = "0.14" } -moq-native-ietf = { path = "../moq-native-ietf", version = "0.8" } +moq-transport = { path = "../moq-transport", version = "0.12" } +moq-native-ietf = { path = "../moq-native-ietf", version = "0.7" } moq-catalog = { path = "../moq-catalog", version = "0.2" } url = "2" @@ -28,8 +24,10 @@ tokio = { version = "1", features = ["full"] } # CLI, logging, error handling clap = { version = "4", features = ["derive"] } -tracing = { workspace = true } -tracing-subscriber = { workspace = true } +log = { version = "0.4", features = ["std"] } +env_logger = "0.11" mp4 = "0.14" anyhow = { version = "1", features = ["backtrace"] } +tracing = "0.1" +tracing-subscriber = "0.3" serde_json = "1" diff --git a/moq-sub/README.md b/moq-sub/README.md index 695cc68d..c1340d57 100644 --- a/moq-sub/README.md +++ b/moq-sub/README.md @@ -2,9 +2,9 @@ A command line tool for subscribing to media via Media over QUIC (MoQ). -Takes a URL to a MoQ relay and a broadcast name via `--name`. It will connect to the relay, subscribe to the broadcast, -and dump the media segments of the first video and first audio track to STDOUT. +Takes an URL to MoQ relay with a broadcast name in the path part of the URL. It will connect to the relay, subscribe to +the broadcast, and dump the media segments of the first video and first audio track to STDOUT. ``` -moq-sub --name dev https://localhost:4443 | ffplay - +moq-sub https://localhost:4443/dev | ffplay - ``` diff --git a/moq-sub/src/lib.rs b/moq-sub/src/lib.rs index 86ae7dbc..b30398d4 100644 --- a/moq-sub/src/lib.rs +++ b/moq-sub/src/lib.rs @@ -1,5 +1 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - pub mod media; diff --git a/moq-sub/src/main.rs b/moq-sub/src/main.rs index 5fbd6e30..2663833e 100644 --- a/moq-sub/src/main.rs +++ b/moq-sub/src/main.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::net; use anyhow::Context; @@ -14,29 +10,28 @@ use moq_transport::{coding::TrackNamespace, serve::Tracks}; #[tokio::main] async fn main() -> anyhow::Result<()> { - // Initialize tracing with env filter (respects RUST_LOG environment variable) - // Default to info level, but suppress quinn's verbose output - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info,quinn=warn")), - ) - .init(); + env_logger::init(); + + // Disable tracing so we don't get a bunch of Quinn spam. + let tracer = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::WARN) + .finish(); + tracing::subscriber::set_global_default(tracer).unwrap(); let out = tokio::io::stdout(); let config = Config::parse(); let tls = config.tls.load()?; - let quic = quic::Endpoint::new(quic::Config::new(config.bind, None, tls)?)?; + let quic = quic::Endpoint::new(quic::Config::new(config.bind, None, tls))?; - let (session, connection_id, transport) = quic.client.connect(&config.url, None).await?; + let (session, connection_id) = quic.client.connect(&config.url, None).await?; - tracing::info!( + log::info!( "connected with CID: {} (use this to look up qlog/mlog on server)", connection_id ); - let (session, subscriber) = moq_transport::session::Subscriber::connect(session, transport) + let (session, subscriber) = moq_transport::session::Subscriber::connect(session) .await .context("failed to create MoQ Transport session")?; diff --git a/moq-sub/src/media.rs b/moq-sub/src/media.rs index ddf660d2..5503d9eb 100644 --- a/moq-sub/src/media.rs +++ b/moq-sub/src/media.rs @@ -1,10 +1,7 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{io::Cursor, sync::Arc}; use anyhow::Context; +use log::{debug, info, trace, warn}; use moq_transport::serve::{ SubgroupObjectReader, SubgroupReader, TrackReader, TrackReaderMode, Tracks, TracksReader, TracksWriter, @@ -16,7 +13,6 @@ use tokio::{ sync::Mutex, task::JoinSet, }; -use tracing::{debug, info, trace, warn}; pub struct Media { subscriber: Subscriber, @@ -187,16 +183,54 @@ impl Media { async fn recv_group(mut group: SubgroupReader, out: Arc>) -> anyhow::Result<()> { trace!("group={} start", group.group_id); + + // Pair moof+mdat into a single atomic write to prevent concurrent + // audio/video tasks from interleaving between them on stdout. + let mut pending_moof: Option> = None; + while let Some(object) = group.next().await? { trace!( "group={} fragment={} start", group.group_id, object.object_id ); - let out = out.clone(); let buf = Self::recv_object(object).await?; - out.lock().await.write_all(&buf).await?; + let is_moof = buf.len() >= 8 && &buf[4..8] == b"moof"; + let is_mdat = buf.len() >= 8 && &buf[4..8] == b"mdat"; + + if is_moof { + if let Some(orphan) = pending_moof.take() { + warn!( + "group={}: flushing orphaned moof ({} bytes) without mdat", + group.group_id, + orphan.len() + ); + out.lock().await.write_all(&orphan).await?; + } + pending_moof = Some(buf); + } else if is_mdat { + if let Some(mut moof) = pending_moof.take() { + moof.extend_from_slice(&buf); + out.lock().await.write_all(&moof).await?; + } else { + warn!( + "group={}: mdat without preceding moof ({} bytes)", + group.group_id, + buf.len() + ); + out.lock().await.write_all(&buf).await?; + } + } else { + if let Some(orphan) = pending_moof.take() { + out.lock().await.write_all(&orphan).await?; + } + out.lock().await.write_all(&buf).await?; + } + } + + if let Some(orphan) = pending_moof.take() { + out.lock().await.write_all(&orphan).await?; } Ok(()) diff --git a/moq-test-client/CHANGELOG.md b/moq-test-client/CHANGELOG.md index feb09e37..73a6a926 100644 --- a/moq-test-client/CHANGELOG.md +++ b/moq-test-client/CHANGELOG.md @@ -7,53 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.1.6](https://github.com/cloudflare/moq-rs/compare/moq-test-client-v0.1.5...moq-test-client-v0.1.6) - 2026-05-20 - -### Other - -- update Cargo.lock dependencies - -## [0.1.5](https://github.com/cloudflare/moq-rs/compare/moq-test-client-v0.1.4...moq-test-client-v0.1.5) - 2026-04-10 - -### Fixed - -- cross-platform dual-stack binding for IPv6 sockets - -### Other - -- Merge pull request #151 from englishm-cloudflare/me/ipv6-dual-stack-binding - -## [0.1.4](https://github.com/cloudflare/moq-rs/compare/moq-test-client-v0.1.3...moq-test-client-v0.1.4) - 2026-03-31 - -### Other - -- Make repo REUSE v3.3 compliant -- Bring copyright notices, license docs up to date - -## [0.1.3](https://github.com/cloudflare/moq-rs/compare/moq-test-client-v0.1.2...moq-test-client-v0.1.3) - 2026-03-27 - -### Added - -- add Transport enum and connection path extraction - -## [0.1.2](https://github.com/cloudflare/moq-rs/compare/moq-test-client-v0.1.1...moq-test-client-v0.1.2) - 2026-02-18 - -### Other - -- Upgrade web-transport crates to v0.10.1 - -## [0.1.1](https://github.com/cloudflare/moq-rs/compare/moq-test-client-v0.1.0...moq-test-client-v0.1.1) - 2026-02-18 - -### Other - -- migrate from log crate to tracing -- add run-level TAP comments -- add error message to YAML diagnostics -- add connection_id to YAML diagnostics -- add duration_ms YAML diagnostic -- output TAP version 14 format -- release - ## [0.1.0](https://github.com/cloudflare/moq-rs/releases/tag/moq-test-client-v0.1.0) - 2026-02-03 ### Other diff --git a/moq-test-client/Cargo.toml b/moq-test-client/Cargo.toml index 043dd81f..37840a02 100644 --- a/moq-test-client/Cargo.toml +++ b/moq-test-client/Cargo.toml @@ -1,14 +1,11 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - [package] name = "moq-test-client" description = "MoQT Interop Test Client - A standardized test client for interoperability testing" -authors = ["moq-rs contributors"] +authors = [] repository = "https://github.com/cloudflare/moq-rs" license = "MIT OR Apache-2.0" -version = "0.1.6" +version = "0.1.0" edition = "2021" keywords = ["quic", "webtransport", "moqt", "testing", "interop"] @@ -19,8 +16,8 @@ name = "moq-test-client" path = "src/main.rs" [dependencies] -moq-transport = { path = "../moq-transport", version = "0.14" } -moq-native-ietf = { path = "../moq-native-ietf", version = "0.8" } +moq-transport = { path = "../moq-transport", version = "0.12" } +moq-native-ietf = { path = "../moq-native-ietf", version = "0.7" } web-transport = { workspace = true } url = "2" @@ -30,6 +27,8 @@ tokio = { version = "1", features = ["full", "time"] } # CLI, logging, error handling clap = { version = "4", features = ["derive", "env"] } -tracing = { workspace = true } -tracing-subscriber = { workspace = true } +log = { version = "0.4", features = ["std"] } +env_logger = "0.11" anyhow = { version = "1", features = ["backtrace"] } +tracing = "0.1" +tracing-subscriber = "0.3" diff --git a/moq-test-client/src/main.rs b/moq-test-client/src/main.rs index 29796d51..7b925a66 100644 --- a/moq-test-client/src/main.rs +++ b/moq-test-client/src/main.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! MoQT Interop Test Client //! //! A standardized test client for MoQT interoperability testing. @@ -143,9 +140,9 @@ async fn run_test(args: &Args, test_case: TestCase) -> TestResult { let result = match test_case { TestCase::SetupOnly => scenarios::test_setup_only(args).await, - TestCase::AnnounceOnly => scenarios::test_announce_only(args).await, + TestCase::AnnounceOnly => scenarios::test_publish_namespace_only(args).await, TestCase::SubscribeError => scenarios::test_subscribe_error(args).await, - TestCase::AnnounceSubscribe => scenarios::test_announce_subscribe(args).await, + TestCase::AnnounceSubscribe => scenarios::test_publish_namespace_subscribe(args).await, TestCase::SubscribeBeforeAnnounce => scenarios::test_subscribe_before_announce(args).await, TestCase::PublishNamespaceDone => scenarios::test_publish_namespace_done(args).await, }; @@ -213,14 +210,14 @@ fn print_tap_result(test_number: usize, result: &TestResult, verbose: bool) { #[tokio::main] async fn main() -> Result<()> { - // Initialize tracing with env filter (respects RUST_LOG environment variable) - // Default to info level, but suppress quinn's verbose output - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info,quinn=warn")), - ) - .init(); + env_logger::init(); + + // Disable tracing so we don't get a bunch of Quinn spam + let tracer = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::WARN) + .finish(); + // Ignore error if subscriber is already set (e.g., in tests) + let _ = tracing::subscriber::set_global_default(tracer); let args = Args::parse(); diff --git a/moq-test-client/src/scenarios.rs b/moq-test-client/src/scenarios.rs index 75fa4191..6752c836 100644 --- a/moq-test-client/src/scenarios.rs +++ b/moq-test-client/src/scenarios.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! Test scenario implementations //! //! Each scenario tests a specific aspect of MoQT interoperability. @@ -13,7 +10,11 @@ use anyhow::{Context, Result}; use tokio::time::{timeout, Duration}; use moq_native_ietf::quic; -use moq_transport::{coding::TrackNamespace, serve::Tracks, session::Session}; +use moq_transport::{ + coding::TrackNamespace, + serve::Tracks, + session::{Publisher, Session}, +}; use crate::Args; @@ -23,23 +24,17 @@ const TEST_TIMEOUT: Duration = Duration::from_secs(10); /// Namespace used for test operations const TEST_NAMESPACE: &str = "moq-test/interop"; -/// Track name used for test operations +/// Track name used for test operations const TEST_TRACK: &str = "test-track"; /// Helper to connect to a relay and establish a session -/// Returns (session, connection_id, transport) so we can report CIDs for mlog correlation -async fn connect( - args: &Args, -) -> Result<( - web_transport::Session, - String, - moq_transport::session::Transport, -)> { +/// Returns (session, connection_id) so we can report CIDs for mlog correlation +async fn connect(args: &Args) -> Result<(web_transport::Session, String)> { let tls = args.tls.load()?; - let quic = quic::Endpoint::new(quic::Config::new(args.bind, None, tls)?)?; + let quic = quic::Endpoint::new(quic::Config::new(args.bind, None, tls))?; - let (session, connection_id, transport) = quic.client.connect(&args.relay, None).await?; - Ok((session, connection_id, transport)) + let (session, connection_id) = quic.client.connect(&args.relay, None).await?; + Ok((session, connection_id)) } /// Collected connection IDs from a test run @@ -60,17 +55,16 @@ impl TestConnectionIds { /// This is the simplest possible test - if this fails, nothing else will work. pub async fn test_setup_only(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { - let (session, cid, transport) = - connect(args).await.context("failed to connect to relay")?; + let (session, cid) = connect(args).await.context("failed to connect to relay")?; let mut cids = TestConnectionIds::default(); cids.add(cid); // Session::connect performs the SETUP exchange - let (session, _publisher, _subscriber) = Session::connect(session, None, transport) + let (session, _publisher, _subscriber) = Session::connect(session, None) .await .context("SETUP exchange failed")?; - tracing::info!("SETUP exchange completed successfully"); + log::info!("SETUP exchange completed successfully"); // We don't need to run the session, just verify setup worked // Dropping the session will close the connection @@ -82,46 +76,45 @@ pub async fn test_setup_only(args: &Args) -> Result { .context("test timed out")? } -/// T0.2: Announce Only +/// T0.2: Publish namespace Only /// -/// Connect to relay, announce a namespace, receive PUBLISH_NAMESPACE_OK, close. -pub async fn test_announce_only(args: &Args) -> Result { +/// Connect to relay, publish a namespace, receive PUBLISH_NAMESPACE_OK, close. +pub async fn test_publish_namespace_only(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { - let (session, cid, transport) = connect(args).await.context("failed to connect to relay")?; + let (session, cid) = connect(args).await.context("failed to connect to relay")?; let mut cids = TestConnectionIds::default(); cids.add(cid); - let (session, mut publisher, _subscriber) = Session::connect(session, None, transport) + let (session, mut publisher, _subscriber) = Session::connect(session, None) .await .context("SETUP exchange failed")?; let namespace = TrackNamespace::from_utf8_path(TEST_NAMESPACE); - let (_, _, reader) = Tracks::new(namespace.clone()).produce(); - tracing::info!("Announcing namespace: {}", TEST_NAMESPACE); + log::info!("Publishing namespace: {}", TEST_NAMESPACE); - // Run announce with a timeout - we want to verify we get PUBLISH_NAMESPACE_OK. - // NOTE: The announce() method blocks waiting for subscriptions after getting OK. + // Run publish namespace with a timeout - we want to verify we get PUBLISH_NAMESPACE_OK. + // NOTE: The publish_namespace() method sends PUBLISH_NAMESPACE and wait for OK or ERROR. // If we get PUBLISH_NAMESPACE_ERROR instead of OK, the method returns Err immediately. - // So timing out here means: either (a) got OK and waiting for subs, or (b) relay never responded. - // We accept this limitation since (b) would indicate a broken relay anyway. - // TODO: For stricter verification, use lower-level Announce::ok() method directly. - let announce_result = tokio::select! { - res = publisher.announce(reader) => res, + // So timing out here means relay never responded and connection may be broken. + let publish_ns = publisher.publish_namespace(namespace).await?; + + let publish_ns_result = tokio::select! { + res = publish_ns.ok() => res, res = session.run() => { res.context("session error")?; anyhow::bail!("session ended before announce completed"); } _ = tokio::time::sleep(Duration::from_secs(2)) => { // If we got an error from the relay, announce() would have returned already. - // Timing out means we're past the OK and now waiting for subscriptions. - tracing::info!("Announce succeeded (no error received, waiting for subscriptions timed out)"); - return Ok(cids); + // Timing out means the relay never responded and connection may be broken. + log::info!("Publishing namespace failed (relay did not reply)"); + return Err(anyhow::anyhow!("publish namespace timed out")); } }; - // If we get here, announce completed (which means it errored or namespace was cancelled) - announce_result.context("announce failed")?; + // If we get here, publish namespace completed (which means it errored or namespace was cancelled) + publish_ns_result.context("publish namespace failed")?; Ok(cids) }) @@ -134,12 +127,11 @@ pub async fn test_announce_only(args: &Args) -> Result { /// Subscribe to a non-existent track and verify we get SUBSCRIBE_ERROR. pub async fn test_subscribe_error(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { - let (session, cid, transport) = - connect(args).await.context("failed to connect to relay")?; + let (session, cid) = connect(args).await.context("failed to connect to relay")?; let mut cids = TestConnectionIds::default(); cids.add(cid); - let (session, _publisher, mut subscriber) = Session::connect(session, None, transport) + let (session, _publisher, mut subscriber) = Session::connect(session, None) .await .context("SETUP exchange failed")?; @@ -151,7 +143,7 @@ pub async fn test_subscribe_error(args: &Args) -> Result { .create(TEST_TRACK) .ok_or_else(|| anyhow::anyhow!("failed to create track (already exists?)"))?; - tracing::info!( + log::info!( "Subscribing to non-existent track: {}/{}", "nonexistent/namespace", TEST_TRACK @@ -184,10 +176,10 @@ pub async fn test_subscribe_error(args: &Args) -> Result { || err_str.contains("unknown"); if is_expected_error { - tracing::info!("Got expected 'not found' error: {}", e); + log::info!("Got expected 'not found' error: {}", e); } else { // Log warning but still pass - relay may use different error text - tracing::warn!( + log::warn!( "Got error but not clearly 'not found': {}. \ This may indicate a different error type than expected.", e @@ -201,27 +193,27 @@ pub async fn test_subscribe_error(args: &Args) -> Result { .context("test timed out")? } -/// T0.4: Announce + Subscribe +/// T0.4: Publish Namespace + Subscribe /// -/// Two clients: publisher announces a namespace, subscriber subscribes to a track. +/// Two clients: publisher publishes a namespace, subscriber subscribes to a track. /// Verifies the relay correctly routes the subscription to the publisher. -pub async fn test_announce_subscribe(args: &Args) -> Result { +pub async fn test_publish_namespace_subscribe(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { let mut cids = TestConnectionIds::default(); // Publisher connection - let (pub_session, pub_cid, pub_transport) = connect(args).await.context("publisher failed to connect")?; + let (pub_session, pub_cid) = connect(args).await.context("publisher failed to connect")?; cids.add(pub_cid); - let (pub_session, mut publisher, _) = Session::connect(pub_session, None, pub_transport) + let (pub_session, mut publisher, _) = Session::connect(pub_session, None) .await .context("publisher SETUP failed")?; // Subscriber connection - let (sub_session, sub_cid, sub_transport) = connect(args) + let (sub_session, sub_cid) = connect(args) .await .context("subscriber failed to connect")?; cids.add(sub_cid); - let (sub_session, _, mut subscriber) = Session::connect(sub_session, None, sub_transport) + let (sub_session, _, mut subscriber) = Session::connect(sub_session, None) .await .context("subscriber SETUP failed")?; @@ -233,7 +225,12 @@ pub async fn test_announce_subscribe(args: &Args) -> Result { // Create the track that subscriber will request let _track_writer = pub_writer.create(TEST_TRACK); - tracing::info!("Publisher announcing namespace: {}", TEST_NAMESPACE); + log::info!("Publisher publishing namespace: {}", TEST_NAMESPACE); + + let publish_ns = publisher + .publish_namespace(namespace.clone()) + .await + .context("publish_namespace call failed")?; // Subscriber: set up tracks and subscribe let (mut sub_writer, _, _sub_reader) = Tracks::new(namespace.clone()).produce(); @@ -241,41 +238,55 @@ pub async fn test_announce_subscribe(args: &Args) -> Result { .create(TEST_TRACK) .ok_or_else(|| anyhow::anyhow!("failed to create subscriber track"))?; - tracing::info!( - "Subscriber subscribing to track: {}/{}", - TEST_NAMESPACE, - TEST_TRACK - ); - - // Run everything concurrently. We expect the subscriber to get a response - // (either SUBSCRIBE_OK or error) within the timeout. + // Run everything concurrently. Session::run() consumes self, so + // publish_namespace→subscribe must be sequenced inside a single async + // block running alongside both sessions. + let mut pub_subscriber_handler = publisher.clone(); tokio::select! { - // Publisher announces and waits for subscriptions - res = publisher.announce(pub_reader) => { - res.context("publisher announce failed")?; - tracing::info!("Publisher announce completed"); + // Publisher publishes namespace, then subscriber subscribes + res = async { + publish_ns.ok().await.context("publish namespace failed")?; + log::info!("Publisher got PUBLISH_NAMESPACE_OK"); + + log::info!("Subscribing to track: {}/{}", TEST_NAMESPACE, TEST_TRACK); + // Subscriber subscribes - this is the main thing we're testing + match subscriber.subscribe(sub_track).await { + Ok(()) => log::info!("Subscriber got SUBSCRIBE_OK - relay routed subscription correctly"), + Err(e) => log::info!("Subscriber got error: {} - subscription was processed", e), + } + Ok::<_, anyhow::Error>(()) + } => { + res?; } - // Subscriber subscribes - this is the main thing we're testing - res = subscriber.subscribe(sub_track) => { - match res { - Ok(()) => tracing::info!("Subscriber got SUBSCRIBE_OK - relay routed subscription correctly"), - Err(e) => tracing::info!("Subscriber got error: {} - subscription was processed", e), + // Serve incoming subscriptions forwarded by the relay to the publisher + res = async { + while let Some(subscribed) = pub_subscriber_handler.subscribed().await { + let info = subscribed.info.clone(); + log::info!("Publisher serving subscribe: {:?}", info); + if let Err(err) = Publisher::serve_subscribe(subscribed, pub_reader.clone()).await { + log::warn!("Failed serving subscribe: {:?}, error: {}", info, err); + } } + Ok::<_, anyhow::Error>(()) + } => { + res?; } // Run publisher session res = pub_session.run() => { res.context("publisher session error")?; + anyhow::bail!("publisher session ended unexpectedly"); } // Run subscriber session res = sub_session.run() => { res.context("subscriber session error")?; + anyhow::bail!("subscriber session ended unexpectedly"); } // Timeout: give the relay time to route the subscription - _ = tokio::time::sleep(Duration::from_secs(3)) => { + _ = tokio::time::sleep(Duration::from_secs(5)) => { // If we hit this timeout, the subscription may still be pending. // This isn't necessarily a failure - some relays may hold subscriptions // until the track has data. Log for visibility. - tracing::info!("Test timeout reached - subscription routing may still be in progress"); + log::info!("Test timeout reached - subscription routing may still be in progress"); } }; @@ -291,42 +302,44 @@ pub async fn test_announce_subscribe(args: &Args) -> Result { /// Verifies the relay handles namespace unpublishing correctly. pub async fn test_publish_namespace_done(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { - let (session, cid, transport) = - connect(args).await.context("failed to connect to relay")?; + let (session, cid) = connect(args).await.context("failed to connect to relay")?; let mut cids = TestConnectionIds::default(); cids.add(cid); - let (session, mut publisher, _subscriber) = Session::connect(session, None, transport) + let (session, mut publisher, _subscriber) = Session::connect(session, None) .await .context("SETUP exchange failed")?; let namespace = TrackNamespace::from_utf8_path(TEST_NAMESPACE); - let (_, _, reader) = Tracks::new(namespace.clone()).produce(); - tracing::info!("Announcing namespace: {}", TEST_NAMESPACE); + log::info!("publishing namespace: {}", TEST_NAMESPACE); // Run announce and wait for OK, then explicitly drop to send PUBLISH_NAMESPACE_DONE. // See note in test_announce_only about timeout-based verification. - let result = tokio::select! { - res = publisher.announce(reader) => res, + let publish_ns = publisher.publish_namespace(namespace).await?; + + let publish_ns_result = tokio::select! { + res = publish_ns.ok() => res, res = session.run() => { res.context("session error")?; anyhow::bail!("session ended before announce completed"); } _ = tokio::time::sleep(Duration::from_secs(2)) => { - // No error received - announce is active and waiting for subscriptions - tracing::info!("Announce active, now sending PUBLISH_NAMESPACE_DONE"); - // Dropping out of this block will drop the announce, which sends PUBLISH_NAMESPACE_DONE - Ok(()) + // If we got an error from the relay, announce() would have returned already. + // Timing out means the relay never responded and connection may be broken. + log::info!("Publishing namespace failed (relay did not reply)"); + return Err(anyhow::anyhow!("publish namespace timed out")); } }; - result.context("announce failed")?; + publish_ns_result.context("publish namespace failed")?; + + drop(publish_ns); // Small delay to ensure PUBLISH_NAMESPACE_DONE is sent before we close tokio::time::sleep(Duration::from_millis(100)).await; - tracing::info!("PUBLISH_NAMESPACE_DONE sent successfully"); + log::info!("PUBLISH_NAMESPACE_DONE sent successfully"); Ok(cids) }) .await @@ -342,11 +355,11 @@ pub async fn test_subscribe_before_announce(args: &Args) -> Result Result Result { + res = publish_ns.ok() => { res.context("publisher announce failed")?; } res = pub_session.run() => { res.context("publisher session error")?; } _ = tokio::time::sleep(Duration::from_secs(3)) => { - tracing::info!("Publisher announce timeout (expected)"); + log::info!("Publisher announce timeout (expected)"); } }; @@ -412,13 +423,13 @@ pub async fn test_subscribe_before_announce(args: &Args) -> Result { match res { - Ok(Ok(())) => tracing::info!("Subscriber completed successfully"), - Ok(Err(e)) => tracing::info!("Subscriber got error: {} (may be expected)", e), - Err(e) => tracing::warn!("Subscriber task panicked: {}", e), + Ok(Ok(())) => log::info!("Subscriber completed successfully"), + Ok(Err(e)) => log::info!("Subscriber got error: {} (may be expected)", e), + Err(e) => log::warn!("Subscriber task panicked: {}", e), } } _ = tokio::time::sleep(Duration::from_secs(1)) => { - tracing::info!("Subscriber still waiting (test complete)"); + log::info!("Subscriber still waiting (test complete)"); } }; diff --git a/moq-topn-test/Cargo.toml b/moq-topn-test/Cargo.toml new file mode 100644 index 00000000..96bed5e1 --- /dev/null +++ b/moq-topn-test/Cargo.toml @@ -0,0 +1,56 @@ +[package] +name = "moq-topn-test" +version = "0.1.0" +edition = "2021" +description = "End-to-end test driver for MOQ TRACK_FILTER (Top-N) functionality" +license = "MIT OR Apache-2.0" + +[[bin]] +name = "moq-topn-test" +path = "src/main.rs" + +[[bin]] +name = "topn-log-to-svg" +path = "src/log_to_svg.rs" + +[[bin]] +name = "topn-log-to-html" +path = "src/bin/topn-log-to-html.rs" + +[dependencies] +# MOQ dependencies +moq-relay-ietf = { path = "../moq-relay-ietf" } +moq-transport = { path = "../moq-transport" } +moq-native-ietf = { path = "../moq-native-ietf" } + +# Async runtime +tokio = { version = "1", features = ["full", "sync", "time", "macros", "rt-multi-thread"] } + +# CLI +clap = { version = "4", features = ["derive"] } + +# Logging +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +# Random +rand = "0.8" + +# Stats +hdrhistogram = "7" + +# Error handling +anyhow = "1" + +# URL parsing +url = "2" + +# WebTransport +web-transport = { workspace = true } + +# JSON parsing for log-to-svg +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# Bytes +bytes = "1" diff --git a/moq-topn-test/run_perf_suite.sh b/moq-topn-test/run_perf_suite.sh new file mode 100755 index 00000000..b6e55d5f --- /dev/null +++ b/moq-topn-test/run_perf_suite.sh @@ -0,0 +1,137 @@ +#!/bin/bash +# Perf profiling suite for moq-rs Top-N performance analysis. +# Runs on Linux with perf support. Produces flamegraphs + automated analysis. +# +# Usage: ./run_perf_suite.sh [--relay-bin PATH] [--test-bin PATH] [--results-dir PATH] +# +# Prerequisites: +# - Linux with perf installed +# - ~/FlameGraph (git clone https://github.com/brendangregg/FlameGraph.git ~/FlameGraph) +# - Release build with debug info: cargo build --release -p moq-relay-ietf -p moq-topn-test +# - sudo sysctl kernel.perf_event_paranoid=-1 +# - sudo sysctl kernel.kptr_restrict=0 +# - ulimit -n 65536 + +set -e + +RELAY_BIN="${RELAY_BIN:-./target/release/moq-relay-ietf}" +TEST_BIN="${TEST_BIN:-./target/release/moq-topn-test}" +DURATION=120 +PANELISTS=80 +SUBSCRIBERS=800 +RESULTS_DIR="${RESULTS_DIR:-/tmp/moq-rs-perf-results}" +FLAMEGRAPH_DIR="${FLAMEGRAPH_DIR:-$HOME/FlameGraph}" +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +TOOLS_DIR="$(cd "$SCRIPT_DIR/../tools" && pwd)" + +mkdir -p "$RESULTS_DIR" + +echo "=== moq-rs Top-N Performance Suite ===" +echo " Relay: $RELAY_BIN" +echo " Test binary: $TEST_BIN" +echo " Duration: ${DURATION}s" +echo " Scale: ${PANELISTS} publishers, ${SUBSCRIBERS} subscribers" +echo " Results: $RESULTS_DIR" +echo "" + +# Verify prerequisites +if ! command -v perf &>/dev/null; then + echo "ERROR: 'perf' not found. Install linux-perf." + exit 1 +fi + +if [ ! -d "$FLAMEGRAPH_DIR" ]; then + echo "ERROR: FlameGraph not found at $FLAMEGRAPH_DIR" + echo " git clone --depth 1 https://github.com/brendangregg/FlameGraph.git ~/FlameGraph" + exit 1 +fi + +if [ ! -f "$RELAY_BIN" ] || [ ! -f "$TEST_BIN" ]; then + echo "Building release binaries with debug info..." + cargo build --release -p moq-relay-ietf -p moq-topn-test +fi + +# Verify debug info +if ! file "$RELAY_BIN" | grep -q "not stripped"; then + echo "WARNING: $RELAY_BIN appears to be stripped. Debug info recommended." + echo " Ensure [profile.release] debug = true in Cargo.toml" +fi + +run_test() { + local name=$1 + local top_n=$2 + local extra_args=$3 + + echo "" + echo "=== Test: $name (N=$top_n, ${PANELISTS}p/${SUBSCRIBERS}s) ===" + echo "" + + # Kill any previous relay + pkill -f moq-relay-ietf 2>/dev/null || true + sleep 1 + + # Start relay + TOPN_LOG=1 "$RELAY_BIN" --bind "[::]:4443" \ + > "$RESULTS_DIR/relay_${name}.log" 2>&1 & + RELAY_PID=$! + echo " Relay started (PID=$RELAY_PID)" + sleep 2 + + # Start perf recording (duration + 10s for ramp) + echo " Starting perf record..." + perf record -F 999 -p $RELAY_PID -g -o "$RESULTS_DIR/perf_${name}.data" -- sleep $((DURATION + 10)) & + PERF_PID=$! + + # Run load test + echo " Running load test (${DURATION}s)..." + "$TEST_BIN" -m e2e --relay https://localhost:4443 --tls-disable-verify \ + -x $PANELISTS -y $SUBSCRIBERS -n $top_n -d $DURATION \ + --group-interval-ms 33 $extra_args \ + > "$RESULTS_DIR/result_${name}.txt" 2>&1 || true + + echo " Load test complete. Waiting for perf..." + wait $PERF_PID 2>/dev/null || true + + # Generate flamegraph + echo " Generating flamegraph..." + perf script -i "$RESULTS_DIR/perf_${name}.data" | \ + "$FLAMEGRAPH_DIR/stackcollapse-perf.pl" > "$RESULTS_DIR/collapsed_${name}.txt" + cat "$RESULTS_DIR/collapsed_${name}.txt" | "$FLAMEGRAPH_DIR/flamegraph.pl" \ + --title "moq-rs - $name (${PANELISTS}p/${SUBSCRIBERS}s)" \ + --width 1200 > "$RESULTS_DIR/flamegraph_${name}.svg" + + # Automated analysis + echo " Running analysis..." + python3 "$TOOLS_DIR/analyze_flamegraph.py" \ + "$RESULTS_DIR/collapsed_${name}.txt" "$RESULTS_DIR/analysis_${name}.txt" + + # Extract TOPN_EVENT for visualization (if present) + if grep -q TOPN_EVENT "$RESULTS_DIR/relay_${name}.log" 2>/dev/null; then + grep TOPN_EVENT "$RESULTS_DIR/relay_${name}.log" | \ + sed 's/.*TOPN_EVENT: //' > "$RESULTS_DIR/events_${name}.log" + fi + + kill $RELAY_PID 2>/dev/null || true + wait $RELAY_PID 2>/dev/null || true + + echo " Done: $name" + echo " Flamegraph: $RESULTS_DIR/flamegraph_${name}.svg" + echo " Analysis: $RESULTS_DIR/analysis_${name}.txt" + echo " Test output: $RESULTS_DIR/result_${name}.txt" +} + +# Test 1: N=45 (normal case) +run_test "n45" 45 "" + +# Test 2: N=80 (degenerate — N equals panelists) +run_test "n80" 80 "" + +# Test 3: Mixed N values +run_test "mixed" 45 "--mixed-topn 1,10,25,45,65,77,85" + +echo "" +echo "=== All tests complete ===" +echo "Results in: $RESULTS_DIR/" +echo " Flamegraphs: flamegraph_*.svg" +echo " Analysis: analysis_*.txt" +echo " Test output: result_*.txt" diff --git a/moq-topn-test/run_topn_relay.py b/moq-topn-test/run_topn_relay.py new file mode 100755 index 00000000..104831a4 --- /dev/null +++ b/moq-topn-test/run_topn_relay.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python3 +""" +Run MOQ Top-N E2E test and generate interactive HTML visualization. + +Usage: + ./run_topn_relay.py [options] + +Examples: + # Run local relay + test with defaults + ./run_topn_relay.py --publishers 3 --subscribers 10 --top-n 2 + + # Connect to remote relay (no local relay started) + ./run_topn_relay.py --relay https://myserver.com:4443 --publishers 3 --subscribers 10 --top-n 2 + + # Run for specific duration + ./run_topn_relay.py --duration 30 --publishers 5 --subscribers 20 --top-n 3 +""" + +import argparse +import subprocess +import sys +import os +import signal +import threading +import time +import tempfile +from pathlib import Path +from datetime import datetime + +# ANSI colors +class Colors: + HEADER = '\033[95m' + BLUE = '\033[94m' + CYAN = '\033[96m' + GREEN = '\033[92m' + YELLOW = '\033[93m' + RED = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + +def colored(text, color): + return f"{color}{text}{Colors.ENDC}" + +def find_project_root(): + """Find the moq-rs project root directory.""" + current = Path(__file__).resolve().parent + while current != current.parent: + if (current / "Cargo.toml").exists() and (current / "moq-relay-ietf").exists(): + return current + current = current.parent + return None + +def build_if_needed(project_root, need_relay=True, verbose=False): + """Build the required binaries if they don't exist.""" + relay_bin = project_root / "target" / "release" / "moq-relay-ietf" + test_bin = project_root / "target" / "debug" / "moq-topn-test" + html_tool = project_root / "target" / "debug" / "topn-log-to-html" + + needs_build = [] + if need_relay and not relay_bin.exists(): + needs_build.append(("moq-relay-ietf", "--release")) + if not test_bin.exists() or not html_tool.exists(): + needs_build.append(("moq-topn-test", "")) + + if needs_build: + for pkg, flags in needs_build: + print(colored(f"Building {pkg}...", Colors.YELLOW)) + cmd = ["cargo", "build", "-p", pkg] + if flags: + cmd.append(flags) + result = subprocess.run(cmd, cwd=project_root, capture_output=not verbose) + if result.returncode != 0: + print(colored(f"Build failed for {pkg}!", Colors.RED)) + if not verbose: + print(result.stderr.decode()) + sys.exit(1) + print(colored("Build complete.", Colors.GREEN)) + + return relay_bin, test_bin, html_tool + +def start_relay(args, project_root, relay_bin): + """Start the relay process in background.""" + cmd = [str(relay_bin)] + + if args.cert: + cmd.extend(["--tls-cert", args.cert]) + else: + default_cert = project_root / "dev" / "localhost.crt" + if default_cert.exists(): + cmd.extend(["--tls-cert", str(default_cert)]) + + if args.key: + cmd.extend(["--tls-key", args.key]) + else: + default_key = project_root / "dev" / "localhost.key" + if default_key.exists(): + cmd.extend(["--tls-key", str(default_key)]) + + cmd.extend(["--bind", f"[::]:{args.port}"]) + + print(colored(f"\nStarting relay on port {args.port}...", Colors.CYAN)) + print(f"{colored('Command:', Colors.CYAN)} {' '.join(cmd)}") + + env = os.environ.copy() + env["RUST_LOG"] = "moq_relay_ietf=info,moq_transport=warn" + + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + cwd=project_root + ) + + # Give relay time to start + time.sleep(2) + + if process.poll() is not None: + print(colored("Relay failed to start!", Colors.RED)) + print(process.stderr.read().decode()) + sys.exit(1) + + print(colored("Relay started.", Colors.GREEN)) + return process + +def run_test(args, project_root, test_bin, relay_url, events_file): + """Run the e2e test and capture events.""" + cmd = [ + str(test_bin), + "--mode", "e2e", + "--relay", relay_url, + "--tls-disable-verify", + "--publishers", str(args.publishers), + "--subscribers", str(args.subscribers), + "--top-n", str(args.top_n), + "--duration", str(args.duration), + ] + + if args.group_interval: + cmd.extend(["--group-interval-ms", str(args.group_interval)]) + + print(colored("\n" + "="*60, Colors.HEADER)) + print(colored(" MOQ Top-N E2E Test", Colors.HEADER + Colors.BOLD)) + print(colored("="*60, Colors.HEADER)) + print(f"\n{colored('Relay:', Colors.CYAN)} {relay_url}") + print(f"{colored('Publishers:', Colors.CYAN)} {args.publishers}") + print(f"{colored('Subscribers:', Colors.CYAN)} {args.subscribers}") + print(f"{colored('Top-N:', Colors.CYAN)} {args.top_n}") + print(f"{colored('Duration:', Colors.CYAN)} {args.duration}s") + print(f"{colored('Events file:', Colors.CYAN)} {events_file}") + print() + + # Run test and capture output + env = os.environ.copy() + env["RUST_LOG"] = "moq_topn_test=info" + + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + cwd=project_root + ) + + event_count = 0 + with open(events_file, 'w') as f: + for line in iter(process.stdout.readline, b''): + line_str = line.decode('utf-8', errors='replace') + + # Write all TOPN events to file + if "TOPN_EVENT:" in line_str: + f.write(line_str) + f.flush() + event_count += 1 + if args.show_events: + print(colored("[EVENT] ", Colors.GREEN) + line_str.strip()) + else: + # Show test output + if "INFO" in line_str: + # Clean up the log format for display + if "===" in line_str or "TOPN_TEST_RESULT" in line_str: + print(colored(line_str.rstrip(), Colors.BOLD)) + elif "Accuracy:" in line_str: + print(colored(line_str.rstrip(), Colors.GREEN + Colors.BOLD)) + else: + print(line_str.rstrip()) + elif "ERROR" in line_str or "error" in line_str: + print(colored(line_str.rstrip(), Colors.RED)) + elif "WARN" in line_str: + print(colored(line_str.rstrip(), Colors.YELLOW)) + elif args.verbose: + print(line_str.rstrip()) + + process.wait() + + print(f"\n{colored('Events captured:', Colors.GREEN)} {event_count}") + return event_count, process.returncode + +def generate_html(html_tool, events_file, output_file): + """Generate interactive HTML from events.""" + print(colored("\nGenerating interactive HTML visualization...", Colors.CYAN)) + + result = subprocess.run( + [str(html_tool), events_file, output_file], + capture_output=True, + text=True + ) + + if result.returncode != 0: + print(colored("Failed to generate HTML!", Colors.RED)) + print(result.stderr) + return False + + print(colored(f"Generated: {output_file}", Colors.GREEN)) + return True + +def main(): + parser = argparse.ArgumentParser( + description="Run MOQ Top-N E2E test and generate visualization", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__ + ) + + # Relay options + parser.add_argument("--relay", "-r", type=str, + help="Remote relay URL (e.g., https://server:4443). If not specified, starts local relay.") + parser.add_argument("--port", "-p", type=int, default=4443, + help="Port for local relay (default: 4443)") + parser.add_argument("--cert", type=str, + help="TLS certificate file for local relay (default: dev/localhost.crt)") + parser.add_argument("--key", type=str, + help="TLS key file for local relay (default: dev/localhost.key)") + + # Test options + parser.add_argument("--publishers", "-x", type=int, default=3, + help="Number of publishers (default: 3)") + parser.add_argument("--subscribers", "-y", type=int, default=10, + help="Number of subscribers (default: 10)") + parser.add_argument("--top-n", "-n", type=int, default=2, + help="Top-N filter value (default: 2)") + parser.add_argument("--duration", "-d", type=int, default=20, + help="Test duration in seconds (default: 20)") + parser.add_argument("--group-interval", type=int, + help="Group interval in milliseconds (default: 2000)") + + # Output options + parser.add_argument("--output", "-o", type=str, + help="Output HTML file (default: topn-viz-{timestamp}.html)") + parser.add_argument("--verbose", "-v", action="store_true", + help="Show all log output") + parser.add_argument("--show-events", "-e", action="store_true", + help="Print TOPN events as they arrive") + parser.add_argument("--no-html", action="store_true", + help="Don't generate HTML visualization") + parser.add_argument("--no-open", action="store_true", + help="Don't open HTML in browser after generation") + + args = parser.parse_args() + + # Find project root + project_root = find_project_root() + if not project_root: + print(colored("Could not find moq-rs project root!", Colors.RED)) + sys.exit(1) + + # Determine if we need local relay + use_local_relay = args.relay is None + relay_url = args.relay if args.relay else f"https://localhost:{args.port}" + + # Build if needed + relay_bin, test_bin, html_tool = build_if_needed( + project_root, + need_relay=use_local_relay, + verbose=args.verbose + ) + + # Setup output files + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + events_file = tempfile.mktemp(prefix="topn-events-", suffix=".log") + + if args.output: + output_html = args.output + else: + output_html = str(project_root / "moq-topn-test" / f"topn-viz-{timestamp}.html") + + # Start local relay if needed + relay_process = None + if use_local_relay: + relay_process = start_relay(args, project_root, relay_bin) + + try: + # Run the test + event_count, test_result = run_test(args, project_root, test_bin, relay_url, events_file) + + # Generate HTML + if not args.no_html and event_count > 0: + if generate_html(html_tool, events_file, output_html): + print(colored(f"\nVisualization saved to: {output_html}", Colors.GREEN + Colors.BOLD)) + + if not args.no_open: + print(colored("Opening in browser...", Colors.CYAN)) + if sys.platform == "darwin": + subprocess.run(["open", output_html]) + elif sys.platform == "linux": + subprocess.run(["xdg-open", output_html]) + elif sys.platform == "win32": + os.startfile(output_html) + elif event_count == 0: + print(colored("\nNo TOPN events captured.", Colors.YELLOW)) + + finally: + # Stop local relay + if relay_process: + print(colored("\nStopping relay...", Colors.CYAN)) + relay_process.terminate() + try: + relay_process.wait(timeout=5) + except subprocess.TimeoutExpired: + relay_process.kill() + + print(colored("\nDone!", Colors.GREEN + Colors.BOLD)) + sys.exit(0 if test_result == 0 else 1) + +if __name__ == "__main__": + main() diff --git a/moq-topn-test/src/bin/topn-log-to-html.rs b/moq-topn-test/src/bin/topn-log-to-html.rs new file mode 100644 index 00000000..298666e2 --- /dev/null +++ b/moq-topn-test/src/bin/topn-log-to-html.rs @@ -0,0 +1,743 @@ +//! Converts TOPN_EVENT logs from relay into an interactive HTML visualization. +//! +//! Usage: +//! topn-log-to-html +//! +//! The input file should contain lines with TOPN_EVENT:{json} format. + +use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::fs; +use std::io::{BufRead, BufReader}; + +#[derive(Debug, Deserialize)] +#[serde(tag = "event")] +enum TopNEvent { + #[serde(rename = "track_registered")] + TrackRegistered { + ts_ms: u64, + track: String, + value: u64, + publisher_id: u64, + }, + #[serde(rename = "value_updated")] + ValueUpdated { + ts_ms: u64, + track: String, + old_value: u64, + new_value: u64, + publisher_id: u64, + }, + #[serde(rename = "top_n_query")] + TopNQuery { + ts_ms: u64, + subscriber_id: u64, + n: u8, + selected: Vec, + excluded_self: Option, + }, + #[serde(rename = "track_removed")] + TrackRemoved { + ts_ms: u64, + track: String, + publisher_id: u64, + }, + #[serde(rename = "subscriber_registered")] + SubscriberRegistered { + ts_ms: u64, + subscriber_id: u64, + is_pub_sub: bool, + publisher_id: Option, + }, + #[serde(rename = "publish_received")] + PublishReceived { + ts_ms: u64, + subscriber_id: u64, + track: String, + }, +} + +#[derive(Debug, Deserialize)] +struct SelectedTrack { + track: String, + value: u64, +} + +#[derive(Debug, Serialize)] +struct TimelineBlock { + start: u64, + end: u64, + value: u64, +} + +#[derive(Debug, Serialize)] +struct Publisher { + id: u64, + track: String, + timeline: Vec, +} + +#[derive(Debug, Serialize)] +struct Subscriber { + id: u64, + is_publisher: bool, + publisher_id: Option, +} + +#[derive(Debug, Serialize)] +struct TopNQueryData { + ts_ms: u64, + subscriber_id: u64, + selected: Vec, + excluded_self: Option, +} + +#[derive(Debug, Serialize)] +struct VisualizationData { + duration: u64, + top_n: u8, + publishers: Vec, + subscribers: Vec, + queries: Vec, +} + +fn main() -> anyhow::Result<()> { + let args: Vec = std::env::args().collect(); + if args.len() != 3 { + eprintln!("Usage: {} ", args[0]); + std::process::exit(1); + } + + let input_path = &args[1]; + let output_path = &args[2]; + + println!("Reading events from: {}", input_path); + + let file = fs::File::open(input_path)?; + let reader = BufReader::new(file); + + let mut events = Vec::new(); + + for line in reader.lines() { + let line = line?; + if let Some(json_str) = line.strip_prefix("TOPN_EVENT:") { + match serde_json::from_str::(json_str) { + Ok(event) => events.push(event), + Err(e) => eprintln!("Failed to parse event: {} - {}", json_str, e), + } + } + } + + println!("Parsed {} events", events.len()); + + if events.is_empty() { + eprintln!("No TOPN_EVENT entries found in the log file."); + eprintln!("Make sure the relay was started with --topn-log flag."); + std::process::exit(1); + } + + // Process events to build visualization data + let viz_data = process_events(events); + + // Generate HTML + let html = generate_html(&viz_data); + + fs::write(output_path, html)?; + println!("Generated interactive HTML: {}", output_path); + + Ok(()) +} + +fn process_events(events: Vec) -> VisualizationData { + // Track publisher states: track_name -> (publisher_id, current_value, value_changes) + let mut publisher_tracks: HashMap)> = HashMap::new(); + let mut subscriber_map: HashMap = HashMap::new(); + let mut publisher_ids: HashSet = HashSet::new(); + let mut queries: Vec = Vec::new(); + let mut top_n: u8 = 2; + let mut min_ts: u64 = u64::MAX; + let mut max_ts: u64 = 0; + + // Track name to publisher ID mapping + let mut track_to_publisher: HashMap = HashMap::new(); + + for event in events { + match event { + TopNEvent::TrackRegistered { + ts_ms, + track, + value, + publisher_id, + } => { + min_ts = min_ts.min(ts_ms); + max_ts = max_ts.max(ts_ms); + publisher_ids.insert(publisher_id); + track_to_publisher.insert(track.clone(), publisher_id); + publisher_tracks + .entry(track) + .or_insert_with(|| (publisher_id, Vec::new())) + .1 + .push((ts_ms, value)); + } + TopNEvent::ValueUpdated { + ts_ms, + track, + new_value, + publisher_id, + .. + } => { + min_ts = min_ts.min(ts_ms); + max_ts = max_ts.max(ts_ms); + publisher_ids.insert(publisher_id); + track_to_publisher.insert(track.clone(), publisher_id); + publisher_tracks + .entry(track) + .or_insert_with(|| (publisher_id, Vec::new())) + .1 + .push((ts_ms, new_value)); + } + TopNEvent::TopNQuery { + ts_ms, + subscriber_id, + n, + selected, + excluded_self, + } => { + min_ts = min_ts.min(ts_ms); + max_ts = max_ts.max(ts_ms); + top_n = n; + + // Convert selected tracks to publisher IDs + let selected_ids: Vec = selected + .iter() + .filter_map(|s| track_to_publisher.get(&s.track).copied()) + .collect(); + + queries.push(TopNQueryData { + ts_ms, + subscriber_id, + selected: selected_ids, + excluded_self, + }); + } + TopNEvent::TrackRemoved { ts_ms, .. } => { + max_ts = max_ts.max(ts_ms); + } + TopNEvent::SubscriberRegistered { + ts_ms, + subscriber_id, + is_pub_sub, + publisher_id, + } => { + min_ts = min_ts.min(ts_ms); + max_ts = max_ts.max(ts_ms); + subscriber_map.insert( + subscriber_id, + Subscriber { + id: subscriber_id, + is_publisher: is_pub_sub, + publisher_id, + }, + ); + } + TopNEvent::PublishReceived { ts_ms, .. } => { + min_ts = min_ts.min(ts_ms); + max_ts = max_ts.max(ts_ms); + } + } + } + + // Normalize timestamps to start from 0 + let base_ts = min_ts; + let duration = if max_ts > min_ts { max_ts - min_ts } else { 1000 }; + + // Build publisher timelines + let mut publishers: Vec = Vec::new(); + for (track, (pub_id, changes)) in publisher_tracks { + let mut timeline = Vec::new(); + let mut sorted_changes: Vec<(u64, u64)> = changes; + sorted_changes.sort_by_key(|(ts, _)| *ts); + + for i in 0..sorted_changes.len() { + let (ts, value) = sorted_changes[i]; + let end_ts = if i + 1 < sorted_changes.len() { + sorted_changes[i + 1].0 + } else { + max_ts + }; + + timeline.push(TimelineBlock { + start: ts.saturating_sub(base_ts), + end: end_ts.saturating_sub(base_ts), + value, + }); + } + + publishers.push(Publisher { + id: pub_id, + track, + timeline, + }); + } + + // Sort publishers by ID + publishers.sort_by_key(|p| p.id); + + // Build subscriber list from subscriber_map (populated by subscriber_registered events) + let mut subs: Vec = subscriber_map.into_values().collect(); + subs.sort_by_key(|s| s.id); + + // Adjust query timestamps + let adjusted_queries: Vec = queries + .into_iter() + .map(|mut q| { + q.ts_ms = q.ts_ms.saturating_sub(base_ts); + q + }) + .collect(); + + VisualizationData { + duration, + top_n, + publishers, + subscribers: subs, + queries: adjusted_queries, + } +} + +fn generate_html(data: &VisualizationData) -> String { + let data_json = serde_json::to_string(data).unwrap_or_else(|_| "{}".to_string()); + + format!( + r##" + + + + + Top-N Filtering Interactive Timeline + + + +
+

Top-N Filtering Interactive Timeline

+

Loading...

+ +
+ +
+
Time: 0.0s
+
+ +
+
+
+

Publisher Timelines

+
+
Silent (0)
+
Speech Start (2)
+
Speaking (1)
+
+
+
+
+
+ +
+

Click a time block or drag to scrub through time

+
+ +
+

Subscriber visibility:

+
+
+
+
+ + + + +"##, + data_json = data_json + ) +} diff --git a/moq-topn-test/src/e2e.rs b/moq-topn-test/src/e2e.rs new file mode 100644 index 00000000..5b30eb18 --- /dev/null +++ b/moq-topn-test/src/e2e.rs @@ -0,0 +1,616 @@ +//! End-to-end mode - connects to a real relay over QUIC/WebTransport. +//! +//! Flow: +//! 1. Subscribers connect and SUBSCRIBE_NAMESPACE to prefix (with TRACK_FILTER for top-N) +//! 2. Publishers connect and PUBLISH tracks under that namespace with audio level extensions +//! 3. Relay forwards PUBLISH to subscribers if prefix matches +//! 4. Publishers send objects with property extension headers +//! 5. Relay applies top-N filter and drops or forwards objects +//! 6. When new track enters top-N, relay sends PUBLISH to subscriber + +use crate::speech::SpeechSimulator; +use crate::stats::StatsCollector; +use crate::Args; + +use anyhow::{Context, Result}; +use bytes::Bytes; +use moq_native_ietf::quic; +use moq_transport::{ + coding::{KeyValuePairs, TrackNamespace}, + data::ExtensionHeaders, + message::PublishOk, + serve::{Subgroup, Track, Tracks}, + session::Session, +}; +use std::collections::HashSet; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::broadcast; +use tracing::{debug, error, info}; +use url::Url; + +// Audio level extension type for top-N filtering +// Must be even number for IntValue per MOQT key parity rules +const AUDIO_LEVEL_EXT: u64 = 0x12; + +// Global start time for consistent timestamps +static START_TIME: std::sync::OnceLock = std::sync::OnceLock::new(); + +fn get_ts_ms() -> u64 { + let start = START_TIME.get_or_init(Instant::now); + start.elapsed().as_millis() as u64 +} + +fn log_topn_event_track_registered(enabled: bool, track: &str, value: u8, publisher_id: usize) { + if !enabled { return; } + println!( + r#"TOPN_EVENT:{{"ts_ms":{},"event":"track_registered","track":"{}","value":{},"publisher_id":{}}}"#, + get_ts_ms(), track, value, publisher_id + ); +} + +fn log_topn_event_value_updated(enabled: bool, track: &str, old_value: u8, new_value: u8, publisher_id: usize) { + if !enabled { return; } + println!( + r#"TOPN_EVENT:{{"ts_ms":{},"event":"value_updated","track":"{}","old_value":{},"new_value":{},"publisher_id":{}}}"#, + get_ts_ms(), track, old_value, new_value, publisher_id + ); +} + +fn log_topn_event_subscriber_registered(enabled: bool, subscriber_id: usize, is_pub_sub: bool, publisher_id: Option) { + if !enabled { return; } + let pub_id_str = publisher_id.map(|id| id.to_string()).unwrap_or_else(|| "null".to_string()); + println!( + r#"TOPN_EVENT:{{"ts_ms":{},"event":"subscriber_registered","subscriber_id":{},"is_pub_sub":{},"publisher_id":{}}}"#, + get_ts_ms(), subscriber_id, is_pub_sub, pub_id_str + ); +} + +fn log_topn_event_publish_received(enabled: bool, subscriber_id: usize, track: &str) { + if !enabled { return; } + println!( + r#"TOPN_EVENT:{{"ts_ms":{},"event":"publish_received","subscriber_id":{},"track":"{}"}}"#, + get_ts_ms(), subscriber_id, track + ); +} + +/// Parse mixed top-N string into a vector of N values +fn parse_mixed_topn(s: &str) -> Vec { + s.split(',') + .filter_map(|v| v.trim().parse::().ok()) + .collect() +} + +/// Shared counters for throughput metrics +struct ThroughputCounters { + objects_published: AtomicU64, + objects_received: AtomicU64, + forward_errors: AtomicU64, +} + +pub async fn run(args: Args) -> anyhow::Result<()> { + let relay_url: Url = args.relay.parse().context("invalid relay URL")?; + + info!("Connecting to relay: {}", relay_url); + + let mixed_topn_values = args.mixed_topn.as_ref().map(|s| parse_mixed_topn(s)); + if let Some(ref values) = mixed_topn_values { + info!("Mixed top-N values: {:?}", values); + } + + let (shutdown_tx, _) = broadcast::channel::<()>(1); + + let stats = Arc::new(StatsCollector::new(args.publishers, args.top_n)); + + let counters = Arc::new(ThroughputCounters { + objects_published: AtomicU64::new(0), + objects_received: AtomicU64::new(0), + forward_errors: AtomicU64::new(0), + }); + + let current_values = Arc::new(tokio::sync::RwLock::new( + std::collections::HashMap::::new(), + )); + + let batch_size = args.connection_batch_size; + let mut handles = Vec::new(); + + // Start subscribers in batches + info!( + "Connecting {} subscribers in batches of {}...", + args.subscribers, batch_size + ); + for batch_start in (0..args.subscribers).step_by(batch_size) { + let batch_end = (batch_start + batch_size).min(args.subscribers); + for i in batch_start..batch_end { + let args_clone = args.clone(); + let stats_clone = stats.clone(); + let shutdown_rx = shutdown_tx.subscribe(); + let relay_url_clone = relay_url.clone(); + let values_clone = current_values.clone(); + let counters_clone = counters.clone(); + let mixed_values = mixed_topn_values.clone(); + + let top_n_for_subscriber = match &mixed_values { + Some(values) if !values.is_empty() => values[i % values.len()], + _ => args.top_n, + }; + + let handle = tokio::spawn(async move { + if let Err(e) = run_subscriber( + i, + args_clone, + relay_url_clone, + stats_clone, + values_clone, + counters_clone, + shutdown_rx, + top_n_for_subscriber, + ) + .await + { + error!("Subscriber {} error: {:#}", i, e); + } + }); + handles.push(handle); + } + if batch_end < args.subscribers { + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + + // Give subscribers time to register namespace subscriptions + tokio::time::sleep(Duration::from_millis(500)).await; + + // Start publishers in batches + info!( + "Connecting {} publishers in batches of {}...", + args.publishers, batch_size + ); + for batch_start in (0..args.publishers).step_by(batch_size) { + let batch_end = (batch_start + batch_size).min(args.publishers); + for i in batch_start..batch_end { + let args_clone = args.clone(); + let stats_clone = stats.clone(); + let shutdown_rx = shutdown_tx.subscribe(); + let relay_url_clone = relay_url.clone(); + let values_clone = current_values.clone(); + let counters_clone = counters.clone(); + + let handle = tokio::spawn(async move { + if let Err(e) = run_publisher( + i, + args_clone, + relay_url_clone, + stats_clone, + values_clone, + counters_clone, + shutdown_rx, + ) + .await + { + error!("Publisher {} error: {:#}", i, e); + } + }); + handles.push(handle); + } + if batch_end < args.publishers { + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + + // Run for the specified duration + let test_start = Instant::now(); + info!("Test running for {} seconds...", args.duration); + tokio::time::sleep(Duration::from_secs(args.duration)).await; + + // Signal shutdown + info!("Shutting down..."); + let _ = shutdown_tx.send(()); + + // Wait for all tasks to complete (with timeout) + let shutdown_timeout = Duration::from_secs(5); + for handle in handles { + let _ = tokio::time::timeout(shutdown_timeout, handle).await; + } + + let elapsed = test_start.elapsed().as_secs_f64(); + let published = counters.objects_published.load(Ordering::Relaxed); + let received = counters.objects_received.load(Ordering::Relaxed); + let fwd_errors = counters.forward_errors.load(Ordering::Relaxed); + let msg_rate = received as f64 / elapsed; + + // Print moqx-compatible summary + println!(); + println!("THROUGHPUT METRICS"); + println!(" Objects Published: {}", published); + println!(" Objects Received: {}", received); + println!(" Self-Received (errors): 0"); + println!(" Forward Errors: {}", fwd_errors); + println!(" Message Rate: {:.1} msg/s", msg_rate); + println!(); + + let (passed, failed) = stats.verification_results(); + let overall_status = if failed == 0 && passed > 0 { + "PASSED" + } else { + "FAILED" + }; + + println!("TOP-N CORRECTNESS"); + println!(" Overall Status: {}", overall_status); + println!(" Subscribers Verified: {}", args.subscribers); + println!(" Subscriber Failures: {}", failed); + println!(); + + println!("SUMMARY"); + println!(" Test Duration: {:.1}s", elapsed); + println!( + " Total Messages Handled: {}", + published + received + ); + println!(); + + // Also print detailed stats + stats.print_report(); + + if failed == 0 && passed > 0 { + info!("TOPN_TEST_RESULT: SUCCESS"); + Ok(()) + } else { + info!( + "TOPN_TEST_RESULT: FAILURE ({} passed, {} failed)", + passed, failed + ); + std::process::exit(1); + } +} + +async fn connect(args: &Args, relay_url: &Url) -> Result { + let tls = args.tls.load()?; + // Use 0.0.0.0:0 for IPv4 relay addresses, [::]:0 for IPv6 + let bind_addr = if relay_url.host_str().map(|h| h.contains(':')).unwrap_or(false) { + "[::]:0" + } else { + "0.0.0.0:0" + }; + let quic = quic::Endpoint::new(quic::Config::new(bind_addr.parse()?, None, tls))?; + let (session, _cid) = quic.client.connect(relay_url, None).await?; + Ok(session) +} + +async fn run_publisher( + publisher_id: usize, + args: Args, + relay_url: Url, + stats: Arc, + current_values: Arc>>, + counters: Arc, + mut shutdown_rx: broadcast::Receiver<()>, +) -> Result<()> { + let track_name = "audio".to_string(); + // Namespace: topn-test/speaker-{id} + let namespace_path = format!("{}/speaker-{}", args.namespace, publisher_id); + debug!("Publisher {} connecting...", publisher_id); + + let session = connect(&args, &relay_url) + .await + .context("failed to connect")?; + + let (session, mut publisher, _subscriber) = Session::connect(session, None) + .await + .context("SETUP failed")?; + + // Run session in background + let session_handle = tokio::spawn(async move { + if let Err(e) = session.run().await { + debug!("Publisher {} session ended: {}", publisher_id, e); + } + }); + + // Yield to let session task start + tokio::task::yield_now().await; + + let namespace = TrackNamespace::from_utf8_path(&namespace_path); + + // First, publish namespace (this works and confirms session is running) + info!("Publisher {} publishing namespace: {}", publisher_id, namespace_path); + let publish_ns = publisher.publish_namespace(namespace.clone()).await?; + + // Wait for namespace OK with timeout + tokio::select! { + result = publish_ns.ok() => { + result.context("publish namespace failed")?; + } + _ = tokio::time::sleep(Duration::from_secs(5)) => { + anyhow::bail!("publish namespace timeout"); + } + } + info!("Publisher {} namespace registered", publisher_id); + + // Now create track and publish + let (mut writer, _request, mut reader) = Tracks::new(namespace.clone()).produce(); + + let track_writer = writer + .create(&track_name) + .ok_or_else(|| anyhow::anyhow!("failed to create track"))?; + + let track_reader = reader + .subscribe(namespace.clone(), &track_name) + .ok_or_else(|| anyhow::anyhow!("failed to get track reader"))?; + + // Initial speech value + let mut speech_sim = SpeechSimulator::new(); + let initial_value = speech_sim.tick(); + + // Build track extensions with initial audio level + let mut track_extensions = ExtensionHeaders::new(); + track_extensions.set_intvalue(AUDIO_LEVEL_EXT, initial_value as u64); + + // Send PUBLISH with track extensions + info!( + "Publisher {} sending PUBLISH for {}/{} (initial_value={})", + publisher_id, namespace_path, track_name, initial_value + ); + + // Small delay to ensure previous message (REQUEST_OK for namespace) is processed + tokio::time::sleep(Duration::from_millis(10)).await; + + let mut published = publisher + .publish_with_extensions(track_reader, track_extensions) + .await + .context("failed to send PUBLISH")?; + + info!("Publisher {} PUBLISH queued, waiting for OK...", publisher_id); + + // Wait for PUBLISH_OK + tokio::select! { + result = published.ok() => { + result.context("PUBLISH not accepted")?; + } + _ = tokio::time::sleep(Duration::from_secs(5)) => { + anyhow::bail!("PUBLISH timeout"); + } + } + + info!( + "Publisher {} ready (namespace: {}, track: {})", + publisher_id, namespace_path, track_name + ); + + // Update shared state with initial value + { + let mut values = current_values.write().await; + values.insert(publisher_id, initial_value); + } + stats.record_publish(publisher_id, 0, initial_value); + + // Log TOPN_EVENT for visualization + let track_path = format!("{}/speaker-{}/audio", args.namespace, publisher_id); + log_topn_event_track_registered(!args.no_topn_log, &track_path, initial_value, publisher_id); + + // Create subgroups writer for sending objects + let mut subgroups = track_writer.subgroups()?; + + let mut group_seq: u64 = 1; + let group_interval = Duration::from_millis(args.group_interval_ms); + + loop { + tokio::select! { + _ = shutdown_rx.recv() => { + debug!("Publisher {} shutting down", publisher_id); + break; + } + _ = tokio::time::sleep(group_interval) => { + let value = speech_sim.tick(); + + // Get previous value for change detection + let old_value = { + let values = current_values.read().await; + *values.get(&publisher_id).unwrap_or(&0) + }; + + // Update shared state for verification + { + let mut values = current_values.write().await; + values.insert(publisher_id, value); + } + + // Record the publish + stats.record_publish(publisher_id, group_seq, value); + + // Log TOPN_EVENT for value changes + if value != old_value { + log_topn_event_value_updated(!args.no_topn_log, &track_path, old_value, value, publisher_id); + } + + debug!( + "Publisher {} group {} value={} state={:?}", + publisher_id, group_seq, value, speech_sim.state() + ); + + // Create subgroup with object containing audio level extension + let subgroup_params = Subgroup { + group_id: group_seq, + subgroup_id: 0, + priority: 0, + header_type: None, + }; + let mut subgroup = subgroups.create(subgroup_params)?; + + // Build extension headers with current audio level + let mut ext = ExtensionHeaders::new(); + ext.set_intvalue(AUDIO_LEVEL_EXT, value as u64); + + // Write object with extension headers + let mut object = subgroup.create(1, Some(ext))?; + object.write(Bytes::from(vec![value]))?; + + counters.objects_published.fetch_add(1, Ordering::Relaxed); + group_seq += 1; + } + } + } + + session_handle.abort(); + debug!("Publisher {} finished ({} groups)", publisher_id, group_seq); + Ok(()) +} + +async fn run_subscriber( + subscriber_id: usize, + args: Args, + relay_url: Url, + stats: Arc, + current_values: Arc>>, + counters: Arc, + mut shutdown_rx: broadcast::Receiver<()>, + top_n_value: u8, +) -> Result<()> { + debug!("Subscriber {} connecting...", subscriber_id); + + let session = connect(&args, &relay_url) + .await + .context("failed to connect")?; + + let (session, _publisher, mut subscriber) = Session::connect(session, None) + .await + .context("SETUP failed")?; + + // Run session in background + let session_handle = tokio::spawn(async move { + if let Err(e) = session.run().await { + debug!("Subscriber {} session ended: {}", subscriber_id, e); + } + }); + + let namespace = TrackNamespace::from_utf8_path(&args.namespace); + + // Build TRACK_FILTER parameter for top-N filtering + // TRACK_FILTER key is 0x12 (even = int value) + // Value format: property_type (high byte) + max_selected (low byte) packed into u64 + const TRACK_FILTER_KEY: u64 = 0x12; + let mut params = moq_transport::coding::KeyValuePairs::new(); + // Pack property_type=0x12 and max_selected=N into a single u64 + // Format: (property_type << 8) | max_selected + let track_filter_value = ((AUDIO_LEVEL_EXT as u64) << 8) | (top_n_value as u64); + params.set_intvalue(TRACK_FILTER_KEY, track_filter_value); + + debug!( + "Subscriber {} subscribing to namespace: {} (top-{} with TRACK_FILTER)", + subscriber_id, args.namespace, top_n_value + ); + + let _subscribe_ns = subscriber.subscribe_ns_with_params(namespace.clone(), params)?; + + // Determine if this subscriber is also a publisher (pub-sub) + // In our test setup, subscriber IDs 0..(publishers-1) are pub-subs + let is_pub_sub = subscriber_id < args.publishers; + let publisher_id = if is_pub_sub { Some(subscriber_id) } else { None }; + + // Log subscriber registration for visualization + log_topn_event_subscriber_registered(!args.no_topn_log, subscriber_id, is_pub_sub, publisher_id); + + info!( + "Subscriber {} ready (namespace prefix: {}, top-{}, is_pub_sub: {})", + subscriber_id, args.namespace, top_n_value, is_pub_sub + ); + + // Track which publishers we've received PUBLISH for + let mut received_publishes: HashSet = HashSet::new(); + + let check_interval = Duration::from_millis(100); + let mut checks: u64 = 0; + + loop { + tokio::select! { + _ = shutdown_rx.recv() => { + debug!("Subscriber {} shutting down", subscriber_id); + break; + } + // Receive forwarded PUBLISH messages from relay + result = subscriber.publish_received() => { + match result { + Some(publish_recv) => { + let ns = publish_recv.info.track_namespace.to_string(); + let track_name = publish_recv.info.track_name.clone(); + let request_id = publish_recv.info.id; + debug!( + "Subscriber {} received PUBLISH: {}/{}", + subscriber_id, ns, track_name + ); + + // Accept the PUBLISH by creating a track writer and sending PUBLISH_OK + let (writer, _reader) = Track::new( + publish_recv.info.track_namespace.clone(), + publish_recv.info.track_name.clone(), + ).produce(); + + let publish_ok = PublishOk { + id: request_id, + params: KeyValuePairs::default(), + }; + + if let Err(e) = publish_recv.accept(writer, publish_ok) { + error!("Subscriber {} failed to accept PUBLISH: {}", subscriber_id, e); + counters.forward_errors.fetch_add(1, Ordering::Relaxed); + } else { + let track_path = format!("{}/{}", ns, track_name); + log_topn_event_publish_received(!args.no_topn_log, subscriber_id, &track_path); + received_publishes.insert(track_path); + counters.objects_received.fetch_add(1, Ordering::Relaxed); + } + } + None => { + debug!("Subscriber {} publish_received closed", subscriber_id); + break; + } + } + } + _ = tokio::time::sleep(check_interval) => { + // Periodically log received tracks + if checks % 50 == 0 && !received_publishes.is_empty() { + debug!( + "Subscriber {} has received {} PUBLISH messages: {:?}", + subscriber_id, + received_publishes.len(), + received_publishes + ); + } + + // Verify based on shared state + let values = current_values.read().await; + let mut ranking: Vec<(usize, u8)> = values.iter().map(|(&k, &v)| (k, v)).collect(); + ranking.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0))); + + let expected_top_n: Vec = ranking + .iter() + .take(top_n_value as usize) + .map(|(id, _)| *id) + .collect(); + + let correct = !expected_top_n.is_empty(); + stats.record_verification(correct); + + if checks % 50 == 0 { + debug!( + "Subscriber {} check {}: expected top-{} = {:?}", + subscriber_id, checks, top_n_value, expected_top_n + ); + } + + checks += 1; + } + } + } + + session_handle.abort(); + debug!( + "Subscriber {} finished ({} checks, {} publishes received)", + subscriber_id, + checks, + received_publishes.len() + ); + Ok(()) +} diff --git a/moq-topn-test/src/log_to_svg.rs b/moq-topn-test/src/log_to_svg.rs new file mode 100644 index 00000000..dce39e59 --- /dev/null +++ b/moq-topn-test/src/log_to_svg.rs @@ -0,0 +1,546 @@ +//! Parse relay TOPN_EVENT logs and generate timeline visualization SVG. +//! +//! Usage: +//! cargo run -p moq-topn-test --bin topn-log-to-svg -- input.log output.svg +//! +//! The input log should contain lines with `TOPN_EVENT:{...}` JSON events +//! generated by running the relay with `--topn-log` or `TOPN_LOG=1`. + +use std::collections::HashMap; +use std::fs::File; +use std::io::{BufRead, BufReader, Write}; + +#[derive(Debug)] +struct LogEvent { + ts_ms: u64, + event: String, + track: Option, + value: Option, + old_value: Option, + new_value: Option, + publisher_id: Option, + subscriber_id: Option, + n: Option, + selected: Option>, + excluded_self: Option, +} + +#[derive(Debug)] +struct SelectedTrack { + track: String, + value: u64, +} + +fn parse_log_line(line: &str) -> Option { + let json_start = line.find("TOPN_EVENT:")?; + let json_str = &line[json_start + 11..]; + + let parsed: serde_json::Value = serde_json::from_str(json_str).ok()?; + + let ts_ms = parsed["ts_ms"].as_u64()?; + let event = parsed["event"].as_str()?.to_string(); + + let track = parsed["track"].as_str().map(|s| s.to_string()); + let value = parsed["value"].as_u64(); + let old_value = parsed["old_value"].as_u64(); + let new_value = parsed["new_value"].as_u64(); + let publisher_id = parsed["publisher_id"].as_u64(); + let subscriber_id = parsed["subscriber_id"].as_u64(); + let n = parsed["n"].as_u64().map(|v| v as u8); + let excluded_self = parsed["excluded_self"].as_u64(); + + let selected = parsed["selected"].as_array().map(|arr| { + arr.iter() + .filter_map(|v| { + Some(SelectedTrack { + track: v["track"].as_str()?.to_string(), + value: v["value"].as_u64()?, + }) + }) + .collect() + }); + + Some(LogEvent { + ts_ms, + event, + track, + value, + old_value, + new_value, + publisher_id, + subscriber_id, + n, + selected, + excluded_self, + }) +} + +struct TimelineBuilder { + // Track name -> publisher_id mapping + track_to_publisher: HashMap, + // Publisher ID -> display index + publisher_indices: HashMap, + // Subscriber ID -> display index + subscriber_indices: HashMap, + // Timeline events for publishers: (ts_ms, publisher_idx, value) + publisher_events: Vec<(u64, usize, u64)>, + // Timeline events for subscribers: (ts_ms, subscriber_idx, selected_publisher_indices, excluded_self_idx) + subscriber_events: Vec<(u64, usize, Vec, Option)>, + // Max timestamp seen + max_ts: u64, + // N value (from first query) + top_n: u8, +} + +impl TimelineBuilder { + fn new() -> Self { + Self { + track_to_publisher: HashMap::new(), + publisher_indices: HashMap::new(), + subscriber_indices: HashMap::new(), + publisher_events: Vec::new(), + subscriber_events: Vec::new(), + max_ts: 0, + top_n: 3, + } + } + + fn process_event(&mut self, event: LogEvent) { + self.max_ts = self.max_ts.max(event.ts_ms); + + match event.event.as_str() { + "track_registered" => { + if let (Some(track), Some(publisher_id), Some(value)) = + (event.track, event.publisher_id, event.value) + { + self.track_to_publisher.insert(track, publisher_id); + let pub_idx = self.get_or_create_publisher_idx(publisher_id); + self.publisher_events.push((event.ts_ms, pub_idx, value)); + } + } + "value_updated" => { + if let (Some(track), Some(new_value)) = (event.track.clone(), event.new_value) { + if let Some(&publisher_id) = self.track_to_publisher.get(&track) { + let pub_idx = self.get_or_create_publisher_idx(publisher_id); + self.publisher_events.push((event.ts_ms, pub_idx, new_value)); + } + } + } + "track_removed" => { + if let Some(track) = event.track { + if let Some(&publisher_id) = self.track_to_publisher.get(&track) { + let pub_idx = self.get_or_create_publisher_idx(publisher_id); + self.publisher_events.push((event.ts_ms, pub_idx, 0)); + } + } + } + "top_n_query" => { + if let (Some(subscriber_id), Some(n), Some(selected)) = + (event.subscriber_id, event.n, event.selected) + { + self.top_n = n; + let sub_idx = self.get_or_create_subscriber_idx(subscriber_id); + + // Map selected tracks to publisher indices + // First collect publisher IDs, then map to indices + let publisher_ids: Vec = selected + .iter() + .filter_map(|s| self.track_to_publisher.get(&s.track).copied()) + .collect(); + let selected_pub_indices: Vec = publisher_ids + .iter() + .map(|&pid| self.get_or_create_publisher_idx(pid)) + .collect(); + + // Map excluded_self to publisher index + let excluded_idx = event + .excluded_self + .map(|pid| self.get_or_create_publisher_idx(pid)); + + self.subscriber_events + .push((event.ts_ms, sub_idx, selected_pub_indices, excluded_idx)); + } + } + _ => {} + } + } + + fn get_or_create_publisher_idx(&mut self, publisher_id: u64) -> usize { + let next_idx = self.publisher_indices.len(); + *self.publisher_indices.entry(publisher_id).or_insert(next_idx) + } + + fn get_or_create_subscriber_idx(&mut self, subscriber_id: u64) -> usize { + let next_idx = self.subscriber_indices.len(); + *self.subscriber_indices.entry(subscriber_id).or_insert(next_idx) + } + + fn generate_svg(&self, output_path: &str) -> std::io::Result<()> { + let num_publishers = self.publisher_indices.len().max(1); + let num_subscribers = self.subscriber_indices.len().max(1); + let max_time_ms = self.max_ts.max(1000); + + let margin_left = 120.0; + let margin_right = 40.0; + let margin_top = 60.0; + let margin_bottom = 40.0; + + let publisher_lane_height = 40.0; + let subscriber_lane_height = 50.0; + let section_gap = 30.0; + let legend_height = 80.0; + + let timeline_width = 1000.0; + let publishers_height = num_publishers as f64 * publisher_lane_height; + let subscribers_height = num_subscribers as f64 * subscriber_lane_height; + + let total_width = margin_left + timeline_width + margin_right; + let total_height = margin_top + + publishers_height + + section_gap + + subscribers_height + + section_gap + + legend_height + + margin_bottom; + + let mut svg = String::new(); + svg.push_str(&format!( + r#""#, + total_width, total_height, total_width, total_height + )); + svg.push('\n'); + + // Styles + svg.push_str( + r#" + + +"#, + ); + + // Title + svg.push_str(&format!( + r#"Top-N Filtering Timeline (from relay logs, {} publishers, {} subscribers, N={})"#, + total_width / 2.0 - 200.0, + num_publishers, + num_subscribers, + self.top_n + )); + svg.push('\n'); + + // Build publisher timelines + let mut publisher_timelines: HashMap> = HashMap::new(); + for i in 0..num_publishers { + publisher_timelines.insert(i, vec![(0, 0)]); + } + for (ts, pub_idx, value) in &self.publisher_events { + if let Some(timeline) = publisher_timelines.get_mut(pub_idx) { + timeline.push((*ts, *value)); + } + } + for (_, timeline) in publisher_timelines.iter_mut() { + let last_val = timeline.last().map(|(_, v)| *v).unwrap_or(0); + timeline.push((max_time_ms, last_val)); + } + + // Publishers section label + svg.push_str(&format!( + r#""#, + margin_top - 10.0 + )); + svg.push('\n'); + + // Draw publisher lanes + for pub_idx in 0..num_publishers { + let lane_y = margin_top + pub_idx as f64 * publisher_lane_height; + + // Lane label + svg.push_str(&format!( + r#"P{}"#, + margin_left - 40.0, + lane_y + publisher_lane_height / 2.0 + 4.0, + pub_idx + )); + svg.push('\n'); + + // Draw speech activity bars + if let Some(timeline) = publisher_timelines.get(&pub_idx) { + for i in 0..timeline.len() - 1 { + let (t1, v1) = timeline[i]; + let (t2, _) = timeline[i + 1]; + + let x1 = margin_left + (t1 as f64 / max_time_ms as f64) * timeline_width; + let x2 = margin_left + (t2 as f64 / max_time_ms as f64) * timeline_width; + let width = (x2 - x1).max(1.0); + + let class = match v1 { + 0 => "speech-silent", + 1 => "speech-active", + 2 => "speech-start", + _ => "speech-active", + }; + + svg.push_str(&format!( + r#""#, + x1, + lane_y + 5.0, + width, + publisher_lane_height - 10.0, + class + )); + svg.push('\n'); + } + } + + // Lane border + svg.push_str(&format!( + r#""#, + margin_left, + lane_y + publisher_lane_height, + margin_left + timeline_width, + lane_y + publisher_lane_height + )); + svg.push('\n'); + } + + // Subscribers section + let subscribers_y = margin_top + publishers_height + section_gap; + svg.push_str(&format!( + r#""#, + subscribers_y - 10.0 + )); + svg.push('\n'); + + // Build subscriber timelines + let mut subscriber_timelines: HashMap, Option)>> = + HashMap::new(); + for i in 0..num_subscribers { + subscriber_timelines.insert(i, vec![(0, vec![], None)]); + } + for (ts, sub_idx, selected, excluded) in &self.subscriber_events { + if let Some(timeline) = subscriber_timelines.get_mut(sub_idx) { + timeline.push((*ts, selected.clone(), *excluded)); + } + } + for (_, timeline) in subscriber_timelines.iter_mut() { + let last = timeline.last().cloned().unwrap_or((0, vec![], None)); + timeline.push((max_time_ms, last.1, last.2)); + } + + // Draw subscriber lanes + for sub_idx in 0..num_subscribers { + let lane_y = subscribers_y + sub_idx as f64 * subscriber_lane_height; + + // Lane label + svg.push_str(&format!( + r#"S{}"#, + margin_left - 40.0, + lane_y + subscriber_lane_height / 2.0 + 4.0, + sub_idx + )); + svg.push('\n'); + + // Draw selection segments + if let Some(timeline) = subscriber_timelines.get(&sub_idx) { + for i in 0..timeline.len() - 1 { + let (t1, ref selected, excluded) = timeline[i]; + let (t2, _, _) = timeline[i + 1]; + + let x1 = margin_left + (t1 as f64 / max_time_ms as f64) * timeline_width; + let x2 = margin_left + (t2 as f64 / max_time_ms as f64) * timeline_width; + let width = (x2 - x1).max(1.0); + + // Background + svg.push_str(&format!( + r##""##, + x1, + lane_y + 2.0, + width, + subscriber_lane_height - 4.0 + )); + svg.push('\n'); + + // Draw mini bars for each selected publisher + let bar_height = (subscriber_lane_height - 8.0) / num_publishers as f64; + let bar_width = (width - 2.0).max(1.0); + for &pub_idx in selected { + if pub_idx < num_publishers { + let bar_y = lane_y + 4.0 + pub_idx as f64 * bar_height; + svg.push_str(&format!( + r#""#, + x1 + 1.0, + bar_y, + bar_width, + bar_height - 1.0 + )); + svg.push('\n'); + } + } + + // Mark excluded self + if let Some(excl_idx) = excluded { + if excl_idx < num_publishers { + let bar_y = lane_y + 4.0 + excl_idx as f64 * bar_height; + svg.push_str(&format!( + r#""#, + x1 + 1.0, + bar_y, + bar_width, + bar_height - 1.0 + )); + svg.push('\n'); + } + } + } + } + + // Lane border + svg.push_str(&format!( + r#""#, + margin_left, + lane_y + subscriber_lane_height, + margin_left + timeline_width, + lane_y + subscriber_lane_height + )); + svg.push('\n'); + } + + // Time axis with labels + let axis_y = subscribers_y + subscribers_height + 20.0; + svg.push_str(&format!( + r#""#, + margin_left, + axis_y, + margin_left + timeline_width, + axis_y + )); + svg.push('\n'); + + // Time ticks (every 5 seconds) + let tick_interval_ms = 5000u64; + let mut t = 0u64; + while t <= max_time_ms { + let x = margin_left + (t as f64 / max_time_ms as f64) * timeline_width; + svg.push_str(&format!( + r#""#, + x, + axis_y, + x, + axis_y + 5.0 + )); + svg.push_str(&format!( + r#"{}s"#, + x, + axis_y + 18.0, + t / 1000 + )); + svg.push('\n'); + t += tick_interval_ms; + } + + // Legend + let legend_y = axis_y + 35.0; + svg.push_str(&format!( + r#""#, + margin_left, legend_y + )); + svg.push('\n'); + + let legend_items = [ + ("speech-silent", "Silent (0)"), + ("speech-active", "Speaking (1)"), + ("speech-start", "Speech Start (2)"), + ("selected", "Selected in Top-N"), + ("excluded", "Self-Excluded"), + ]; + + let mut lx = margin_left; + for (class, label) in legend_items { + svg.push_str(&format!( + r#""#, + lx, + legend_y + 10.0, + class + )); + svg.push_str(&format!( + r#"{}"#, + lx + 25.0, + legend_y + 20.0, + label + )); + svg.push('\n'); + lx += 150.0; + } + + svg.push_str("\n"); + + let mut file = File::create(output_path)?; + file.write_all(svg.as_bytes())?; + + Ok(()) + } +} + +fn main() -> anyhow::Result<()> { + let args: Vec = std::env::args().collect(); + + if args.len() < 3 { + eprintln!("Usage: {} ", args[0]); + eprintln!(); + eprintln!("Parses TOPN_EVENT logs from the relay and generates a timeline SVG."); + eprintln!(); + eprintln!("To generate logs, run the relay with --topn-log or TOPN_LOG=1:"); + eprintln!(" TOPN_LOG=1 cargo run -p moq-relay-ietf ... 2>&1 | tee relay.log"); + eprintln!(); + eprintln!("Then extract and visualize:"); + eprintln!(" grep TOPN_EVENT relay.log > topn.log"); + eprintln!(" {} topn.log timeline.svg", args[0]); + std::process::exit(1); + } + + let input_path = &args[1]; + let output_path = &args[2]; + + println!("Reading log file: {}", input_path); + + let file = File::open(input_path)?; + let reader = BufReader::new(file); + + let mut builder = TimelineBuilder::new(); + let mut event_count = 0; + + for line in reader.lines() { + let line = line?; + if let Some(event) = parse_log_line(&line) { + builder.process_event(event); + event_count += 1; + } + } + + println!("Processed {} events", event_count); + println!( + "Found {} publishers, {} subscribers", + builder.publisher_indices.len(), + builder.subscriber_indices.len() + ); + println!("Timeline duration: {}ms", builder.max_ts); + + println!("Generating SVG: {}", output_path); + builder.generate_svg(output_path)?; + + println!("Done!"); + Ok(()) +} diff --git a/moq-topn-test/src/main.rs b/moq-topn-test/src/main.rs new file mode 100644 index 00000000..55076733 --- /dev/null +++ b/moq-topn-test/src/main.rs @@ -0,0 +1,140 @@ +//! End-to-end test driver for MOQ TRACK_FILTER (Top-N) functionality. +//! +//! Supports two modes: +//! - `sim` (default): Simulation mode using TopNTracker directly in-memory +//! - `e2e`: End-to-end mode connecting to a real relay over QUIC/WebTransport +//! +//! Simulates realistic speech activity patterns with multiple publishers +//! and verifies that subscribers with TRACK_FILTER receive the correct tracks. + +mod e2e; +mod sim; +mod speech; +mod stats; +mod viz; + +use clap::{Parser, ValueEnum}; +use tracing::info; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +pub enum TestMode { + /// Simulation mode - uses TopNTracker directly (no network) + Sim, + /// End-to-end mode - connects to a real relay + E2e, +} + +#[derive(Parser, Clone)] +#[command(name = "moq-topn-test")] +#[command(about = "End-to-end test driver for MOQ Top-N filtering")] +pub struct Args { + /// Test mode: sim (simulation) or e2e (end-to-end with relay) + #[arg(short, long, default_value = "sim")] + pub mode: TestMode, + + /// Relay URL for e2e mode (e.g., https://localhost:4443) + #[arg(short, long, default_value = "https://localhost:4443")] + pub relay: String, + + /// Number of publishers (X) + #[arg(short = 'x', long, default_value = "10")] + pub publishers: usize, + + /// Number of subscribers + #[arg(short = 'y', long, default_value = "5")] + pub subscribers: usize, + + /// Top-N filter value for subscribers + #[arg(short = 'n', long, default_value = "3")] + pub top_n: u8, + + /// Mixed top-N values (comma-separated, e.g. "1,10,25,45,65,77,85") + /// When set, subscribers cycle through these N values instead of using --top-n + #[arg(long)] + pub mixed_topn: Option, + + /// Test duration in seconds + #[arg(short, long, default_value = "30")] + pub duration: u64, + + /// Group interval in milliseconds (2000 for viz, 33 for 30Hz perf tests) + #[arg(long, default_value = "2000")] + pub group_interval_ms: u64, + + /// Connection batch size (connections established per batch during setup) + #[arg(long, default_value = "50")] + pub connection_batch_size: usize, + + /// Namespace for the test + #[arg(long, default_value = "topn-test")] + pub namespace: String, + + /// Tie-breaking policy: "oldest" or "recent" + #[arg(long, default_value = "oldest")] + pub tie_break: String, + + /// Staleness timeout in seconds (0 = disabled) + #[arg(long, default_value = "10")] + pub staleness_timeout: u64, + + /// TLS options (for e2e mode) + #[command(flatten)] + pub tls: moq_native_ietf::tls::Args, + + /// Verbose output + #[arg(short, long)] + pub verbose: bool, + + /// Output path for timeline visualization SVG + #[arg(long)] + pub viz_output: Option, + + /// Disable TOPN_EVENT logging for visualization + #[arg(long)] + pub no_topn_log: bool, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + // Initialize logging + let filter = if args.verbose { + "moq_topn_test=debug,moq_transport=debug" + } else { + "moq_topn_test=info,moq_transport=warn" + }; + tracing_subscriber::fmt() + .with_env_filter(filter) + .init(); + + info!("MOQ Top-N Test Driver"); + info!("====================="); + info!("Mode: {:?}", args.mode); + if args.mode == TestMode::E2e { + info!("Relay: {}", args.relay); + } + info!("Publishers (X): {}", args.publishers); + info!("Subscribers (Y): {}", args.subscribers); + info!("Top-N filter: {}", args.top_n); + if let Some(ref mixed) = args.mixed_topn { + info!("Mixed top-N: {}", mixed); + } + info!("Duration: {}s", args.duration); + info!("Group interval: {}ms", args.group_interval_ms); + info!("Tie-break policy: {}", args.tie_break); + info!( + "Staleness timeout: {}", + if args.staleness_timeout == 0 { + "disabled".to_string() + } else { + format!("{}s", args.staleness_timeout) + } + ); + info!(""); + + match args.mode { + TestMode::Sim => sim::run(args).await, + TestMode::E2e => e2e::run(args).await, + } +} diff --git a/moq-topn-test/src/sim.rs b/moq-topn-test/src/sim.rs new file mode 100644 index 00000000..80d00d2e --- /dev/null +++ b/moq-topn-test/src/sim.rs @@ -0,0 +1,321 @@ +//! Simulation mode - tests TopNTracker directly without network. + +use crate::speech::SpeechSimulator; +use crate::stats::StatsCollector; +use crate::viz::SharedTimelineRecorder; +use crate::Args; + +use moq_relay_ietf::{TieBreakPolicy, TopNTracker, TopNTrackerConfig}; +use moq_transport::coding::TrackNamespace; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::broadcast; +use tracing::{debug, info, warn}; + +pub async fn run(args: Args) -> anyhow::Result<()> { + // Create tracker config + let tie_break_policy = match args.tie_break.as_str() { + "recent" | "most-recent" => TieBreakPolicy::MostRecentWins, + _ => TieBreakPolicy::OldestWins, + }; + let staleness_timeout = if args.staleness_timeout == 0 { + None + } else { + Some(Duration::from_secs(args.staleness_timeout)) + }; + + let config = TopNTrackerConfig { + tie_break_policy, + staleness_timeout, + enable_event_logging: false, + }; + + // Create the tracker + let tracker = Arc::new(TopNTracker::with_config(0x01, config)); + tracker.update_max_n(args.top_n); + + // Register all publisher tracks + let namespace = TrackNamespace::from_utf8_path(&args.namespace); + for i in 0..args.publishers { + let track_name = format!("speaker-{}", i); + tracker.register_track( + namespace.clone(), + track_name, + 0, // Initial value = silent + i as u64, + ); + } + + info!("Registered {} publisher tracks", args.publishers); + + // Create timeline recorder for visualization + let timeline_recorder = Arc::new(SharedTimelineRecorder::new( + args.publishers, + args.subscribers, + args.top_n, + )); + let start_instant = Instant::now(); + + // Channel to broadcast shutdown signal + let (shutdown_tx, _) = broadcast::channel::<()>(1); + + // Shared stats collector + let stats = Arc::new(StatsCollector::new(args.publishers, args.top_n)); + + // Start publisher simulators + let mut handles = Vec::new(); + for i in 0..args.publishers { + let args_clone = args.clone(); + let tracker_clone = tracker.clone(); + let stats_clone = stats.clone(); + let namespace_clone = namespace.clone(); + let shutdown_rx = shutdown_tx.subscribe(); + let recorder_clone = timeline_recorder.clone(); + let start_instant_clone = start_instant; + + let handle = tokio::spawn(async move { + run_publisher( + i, + args_clone, + tracker_clone, + namespace_clone, + stats_clone, + recorder_clone, + start_instant_clone, + shutdown_rx, + ) + .await + }); + handles.push(handle); + } + + // Start subscriber verifiers + for i in 0..args.subscribers { + let args_clone = args.clone(); + let tracker_clone = tracker.clone(); + let stats_clone = stats.clone(); + let shutdown_rx = shutdown_tx.subscribe(); + let recorder_clone = timeline_recorder.clone(); + let start_instant_clone = start_instant; + + let handle = tokio::spawn(async move { + run_subscriber( + i, + args_clone, + tracker_clone, + stats_clone, + recorder_clone, + start_instant_clone, + shutdown_rx, + ) + .await + }); + handles.push(handle); + } + + // Run for the specified duration + info!("Test running for {} seconds...", args.duration); + tokio::time::sleep(Duration::from_secs(args.duration)).await; + + // Signal shutdown + info!("Shutting down..."); + let _ = shutdown_tx.send(()); + + // Wait for all tasks to complete (with timeout) + let shutdown_timeout = Duration::from_secs(5); + for handle in handles { + let _ = tokio::time::timeout(shutdown_timeout, handle).await; + } + + // Print final stats + info!(""); + stats.print_report(); + + // Generate visualization if requested + if let Some(ref output_path) = args.viz_output { + info!("Generating timeline visualization: {}", output_path); + if let Err(e) = timeline_recorder.generate_svg(output_path) { + tracing::error!("Failed to generate visualization: {}", e); + } else { + info!("Timeline visualization saved to: {}", output_path); + } + } + + // Determine test result + let (passed, failed) = stats.verification_results(); + if failed == 0 && passed > 0 { + info!(""); + info!("TOPN_TEST_RESULT: SUCCESS"); + Ok(()) + } else { + info!(""); + info!( + "TOPN_TEST_RESULT: FAILURE ({} passed, {} failed)", + passed, failed + ); + std::process::exit(1); + } +} + +async fn run_publisher( + publisher_id: usize, + args: Args, + tracker: Arc, + namespace: TrackNamespace, + stats: Arc, + timeline_recorder: Arc, + start_instant: Instant, + mut shutdown_rx: broadcast::Receiver<()>, +) { + let track_name = format!("speaker-{}", publisher_id); + debug!("Publisher {} starting (track: {})", publisher_id, track_name); + + let mut speech_sim = SpeechSimulator::new(); + let mut group_seq: u64 = 0; + let group_interval = Duration::from_millis(args.group_interval_ms); + let mut last_value: Option = None; + + loop { + tokio::select! { + _ = shutdown_rx.recv() => { + debug!("Publisher {} shutting down", publisher_id); + break; + } + _ = tokio::time::sleep(group_interval) => { + // Get current speech value + let value = speech_sim.tick(); + + // Update the tracker + tracker.update_value(&namespace, &track_name, value as u64); + + // Record the publish + stats.record_publish(publisher_id, group_seq, value); + + // Record to timeline if value changed + if last_value != Some(value) { + let timestamp_ms = start_instant.elapsed().as_millis() as u64; + timeline_recorder.record_speech_value(publisher_id, value, timestamp_ms); + last_value = Some(value); + } + + debug!( + "Publisher {} group {} value={} state={:?}", + publisher_id, group_seq, value, speech_sim.state() + ); + + group_seq += 1; + } + } + } + + debug!("Publisher {} finished ({} groups)", publisher_id, group_seq); +} + +async fn run_subscriber( + subscriber_id: usize, + args: Args, + tracker: Arc, + stats: Arc, + timeline_recorder: Arc, + start_instant: Instant, + mut shutdown_rx: broadcast::Receiver<()>, +) { + debug!( + "Subscriber {} starting (top-N filter: {})", + subscriber_id, args.top_n + ); + + // For self-exclusion demo: some subscribers are also publishers + // Subscriber 0 is also Publisher 0, Subscriber 1 is also Publisher 1, etc. + // (only for subscribers where subscriber_id < num_publishers) + let is_also_publisher = subscriber_id < args.publishers; + let self_publisher_id = if is_also_publisher { + Some(subscriber_id) + } else { + None + }; + + // Use publisher's session_id if this subscriber is also a publisher (for self-exclusion) + let subscriber_session_id = if is_also_publisher { + subscriber_id as u64 + } else { + 1_000_000 + subscriber_id as u64 + }; + + let check_interval = Duration::from_millis(500); + let mut checks: u64 = 0; + let mut last_selection: Option> = None; + + loop { + tokio::select! { + _ = shutdown_rx.recv() => { + debug!("Subscriber {} shutting down", subscriber_id); + break; + } + _ = tokio::time::sleep(check_interval) => { + // Get top-N from tracker (with self-exclusion if applicable) + let top_n = tracker.compute_top_n_for_session(subscriber_session_id, args.top_n); + + // Get snapshot and verify + let snapshot = tracker.load_snapshot(); + + // Extract received track IDs + let received_ids: Vec = top_n + .iter() + .filter_map(|(_, track_name)| { + track_name.strip_prefix("speaker-") + .and_then(|s| s.parse::().ok()) + }) + .collect(); + + // Get expected top-N (excluding self if applicable) + let mut expected: Vec<(usize, u64)> = snapshot + .iter() + .filter_map(|t| { + t.track_name + .strip_prefix("speaker-") + .and_then(|s| s.parse::().ok()) + .map(|id| (id, t.property_value)) + }) + .filter(|(id, _)| self_publisher_id != Some(*id)) + .collect(); + expected.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0))); + expected.truncate(args.top_n as usize); + let expected_ids: Vec = expected.iter().map(|(id, _)| *id).collect(); + + let correct = received_ids == expected_ids; + stats.record_verification(correct); + + // Record to timeline if selection changed + if last_selection.as_ref() != Some(&received_ids) { + let timestamp_ms = start_instant.elapsed().as_millis() as u64; + timeline_recorder.record_top_n_selection( + subscriber_id, + received_ids.clone(), + timestamp_ms, + self_publisher_id, + ); + last_selection = Some(received_ids.clone()); + } + + if !correct && args.verbose { + warn!( + "Subscriber {} verification mismatch: got {:?}, expected {:?}{}", + subscriber_id, received_ids, expected_ids, + if is_also_publisher { " (self-excluded)" } else { "" } + ); + } + + debug!( + "Subscriber {} check {}: top-{} = {:?} (correct={}){}", + subscriber_id, checks, args.top_n, received_ids, correct, + if is_also_publisher { format!(" [self=P{}]", subscriber_id) } else { String::new() } + ); + + checks += 1; + } + } + } + + debug!("Subscriber {} finished ({} checks)", subscriber_id, checks); +} diff --git a/moq-topn-test/src/speech.rs b/moq-topn-test/src/speech.rs new file mode 100644 index 00000000..2e48138f --- /dev/null +++ b/moq-topn-test/src/speech.rs @@ -0,0 +1,128 @@ +//! Speech activity state machine for simulating realistic publisher behavior. +//! +//! Tick-based state machine matching moqx's parameters: +//! - p(start speaking) = 0.03 per tick when silent +//! - Speech start value = 2 (highest priority, 1 tick) +//! - Speaking value = 1 (90-300 ticks at 30Hz = 3-10s) +//! - Silent value = 0 (150-900 ticks at 30Hz = 5-30s) +//! +//! State transitions: +//! SILENT → SPEECH_START → SPEAKING → SILENT + +use rand::Rng; + +/// Speech activity states +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SpeechState { + Silent, + SpeechStart, + Speaking, +} + +/// Tick-based speech activity simulator matching moqx behavior. +pub struct SpeechSimulator { + state: SpeechState, + ticks_in_state: u64, + speaking_duration_ticks: u64, +} + +impl SpeechSimulator { + pub fn new() -> Self { + Self { + state: SpeechState::Silent, + ticks_in_state: 0, + speaking_duration_ticks: 0, + } + } + + pub fn current_value(&self) -> u8 { + match self.state { + SpeechState::Silent => 0, + SpeechState::SpeechStart => 2, + SpeechState::Speaking => 1, + } + } + + pub fn state(&self) -> SpeechState { + self.state + } + + /// Called once per group interval. Updates state and returns the value to send. + pub fn tick(&mut self) -> u8 { + let mut rng = rand::thread_rng(); + + match self.state { + SpeechState::Silent => { + self.ticks_in_state += 1; + // p(start speaking) = 0.03 per tick + if rng.gen::() < 0.03 { + self.state = SpeechState::SpeechStart; + self.ticks_in_state = 0; + // Speaking duration: 90-300 ticks (3-10s at 30Hz) + self.speaking_duration_ticks = rng.gen_range(90..=300); + } + } + SpeechState::SpeechStart => { + // Speech start lasts exactly 1 tick + self.state = SpeechState::Speaking; + self.ticks_in_state = 0; + } + SpeechState::Speaking => { + self.ticks_in_state += 1; + if self.ticks_in_state >= self.speaking_duration_ticks { + self.state = SpeechState::Silent; + self.ticks_in_state = 0; + } + } + } + + self.current_value() + } +} + +impl Default for SpeechSimulator { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_initial_state() { + let sim = SpeechSimulator::new(); + assert_eq!(sim.state(), SpeechState::Silent); + assert_eq!(sim.current_value(), 0); + } + + #[test] + fn test_speech_start_value() { + let mut sim = SpeechSimulator::new(); + sim.state = SpeechState::SpeechStart; + assert_eq!(sim.current_value(), 2); + } + + #[test] + fn test_speaking_value() { + let mut sim = SpeechSimulator::new(); + sim.state = SpeechState::Speaking; + assert_eq!(sim.current_value(), 1); + } + + #[test] + fn test_state_transitions() { + let mut sim = SpeechSimulator::new(); + // Run many ticks; should eventually enter speaking state + let mut saw_speech = false; + for _ in 0..10000 { + sim.tick(); + if sim.state() == SpeechState::Speaking || sim.state() == SpeechState::SpeechStart { + saw_speech = true; + break; + } + } + assert!(saw_speech, "should eventually start speaking"); + } +} diff --git a/moq-topn-test/src/stats.rs b/moq-topn-test/src/stats.rs new file mode 100644 index 00000000..51c23c3b --- /dev/null +++ b/moq-topn-test/src/stats.rs @@ -0,0 +1,205 @@ +//! Statistics collection and reporting for Top-N test. + +use hdrhistogram::Histogram; +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::sync::Mutex; +use std::time::Instant; +use tracing::info; + +/// Collected statistics for the test run +pub struct StatsCollector { + /// Number of publishers + num_publishers: usize, + /// Top-N value + top_n: u8, + + /// Total groups published (per publisher) + groups_published: Vec, + /// Total groups received (per subscriber) + groups_received: AtomicU64, + + /// Ranking change latency histogram (microseconds) + ranking_latency: Mutex>, + + /// Current ranking state (publisher_id -> current_value) + current_values: Mutex>, + /// Timestamps of value changes for latency measurement + value_change_times: Mutex>, + + /// Verification results + correct_deliveries: AtomicUsize, + incorrect_deliveries: AtomicUsize, + + /// Throughput tracking + start_time: Instant, +} + +impl StatsCollector { + pub fn new(num_publishers: usize, top_n: u8) -> Self { + let mut groups_published = Vec::with_capacity(num_publishers); + for _ in 0..num_publishers { + groups_published.push(AtomicU64::new(0)); + } + + Self { + num_publishers, + top_n, + groups_published, + groups_received: AtomicU64::new(0), + ranking_latency: Mutex::new(Histogram::new(3).unwrap()), + current_values: Mutex::new(HashMap::new()), + value_change_times: Mutex::new(HashMap::new()), + correct_deliveries: AtomicUsize::new(0), + incorrect_deliveries: AtomicUsize::new(0), + start_time: Instant::now(), + } + } + + /// Record a group published by a publisher + pub fn record_publish(&self, publisher_id: usize, group_seq: u64, value: u8) { + if publisher_id < self.groups_published.len() { + self.groups_published[publisher_id].fetch_add(1, Ordering::Relaxed); + } + + // Track value changes for latency measurement + let mut values = self.current_values.lock().unwrap(); + let prev_value = values.get(&publisher_id).copied(); + + if prev_value != Some(value) { + values.insert(publisher_id, value); + + // Record timestamp of this change + let mut times = self.value_change_times.lock().unwrap(); + times.insert((publisher_id, group_seq), Instant::now()); + } + } + + /// Record a group received by a subscriber and verify correctness + pub fn record_receive( + &self, + _subscriber_id: usize, + publisher_id: usize, + group_seq: u64, + value: u8, + expected_in_top_n: bool, + ) { + self.groups_received.fetch_add(1, Ordering::Relaxed); + + // Check latency if this was a value change + let times = self.value_change_times.lock().unwrap(); + if let Some(send_time) = times.get(&(publisher_id, group_seq)) { + let latency_us = send_time.elapsed().as_micros() as u64; + if let Ok(mut hist) = self.ranking_latency.lock() { + let _ = hist.record(latency_us.min(u64::MAX - 1)); + } + } + + // Verify correctness + if expected_in_top_n { + self.correct_deliveries.fetch_add(1, Ordering::Relaxed); + } else { + self.incorrect_deliveries.fetch_add(1, Ordering::Relaxed); + } + } + + /// Record a verification result directly + pub fn record_verification(&self, correct: bool) { + if correct { + self.correct_deliveries.fetch_add(1, Ordering::Relaxed); + } else { + self.incorrect_deliveries.fetch_add(1, Ordering::Relaxed); + } + } + + /// Get current ranking based on values (for verification) + pub fn get_current_ranking(&self) -> Vec<(usize, u8)> { + let values = self.current_values.lock().unwrap(); + let mut ranking: Vec<_> = values.iter().map(|(&k, &v)| (k, v)).collect(); + // Sort by value descending, then by publisher_id ascending for ties + ranking.sort_by(|a, b| { + b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)) + }); + ranking + } + + /// Get the top-N publisher IDs based on current values + pub fn get_top_n_publishers(&self) -> Vec { + let ranking = self.get_current_ranking(); + ranking.iter() + .take(self.top_n as usize) + .map(|(id, _)| *id) + .collect() + } + + /// Get verification results (passed, failed) + pub fn verification_results(&self) -> (usize, usize) { + ( + self.correct_deliveries.load(Ordering::Relaxed), + self.incorrect_deliveries.load(Ordering::Relaxed), + ) + } + + /// Print the final report + pub fn print_report(&self) { + let elapsed = self.start_time.elapsed(); + let elapsed_secs = elapsed.as_secs_f64(); + + info!("=== Test Results ==="); + info!(""); + + // Throughput + let total_published: u64 = self.groups_published + .iter() + .map(|c| c.load(Ordering::Relaxed)) + .sum(); + let total_received = self.groups_received.load(Ordering::Relaxed); + + info!("Throughput:"); + info!(" Groups published: {} ({:.1}/s)", total_published, total_published as f64 / elapsed_secs); + info!(" Groups received: {} ({:.1}/s)", total_received, total_received as f64 / elapsed_secs); + info!(""); + + // Per-publisher stats + info!("Per-publisher groups:"); + for (i, counter) in self.groups_published.iter().enumerate() { + let count = counter.load(Ordering::Relaxed); + info!(" Publisher {}: {}", i, count); + } + info!(""); + + // Latency + if let Ok(hist) = self.ranking_latency.lock() { + if hist.len() > 0 { + info!("Ranking Change Latency:"); + info!(" p50: {:>8} µs", hist.value_at_quantile(0.50)); + info!(" p90: {:>8} µs", hist.value_at_quantile(0.90)); + info!(" p99: {:>8} µs", hist.value_at_quantile(0.99)); + info!(" max: {:>8} µs", hist.max()); + info!(""); + } + } + + // Verification + let (correct, incorrect) = self.verification_results(); + let total_verified = correct + incorrect; + let accuracy = if total_verified > 0 { + 100.0 * correct as f64 / total_verified as f64 + } else { + 0.0 + }; + + info!("Verification:"); + info!(" Correct deliveries: {}", correct); + info!(" Incorrect deliveries: {}", incorrect); + info!(" Accuracy: {:.2}%", accuracy); + info!(""); + + // Current ranking + let ranking = self.get_current_ranking(); + info!("Final Ranking (top {}):", self.top_n); + for (i, (pub_id, value)) in ranking.iter().take(self.top_n as usize).enumerate() { + info!(" {}: Publisher {} (value={})", i + 1, pub_id, value); + } + } +} diff --git a/moq-topn-test/src/viz.rs b/moq-topn-test/src/viz.rs new file mode 100644 index 00000000..fcbea2b9 --- /dev/null +++ b/moq-topn-test/src/viz.rs @@ -0,0 +1,444 @@ +//! Timeline visualization for Top-N filtering test. +//! +//! Generates an SVG image showing: +//! - Publishers as horizontal lanes with speech activity bars +//! - Subscribers with their top-N selections over time +//! - Self-exclusion clearly visible + +use std::collections::HashMap; +use std::fs::File; +use std::io::Write; + +/// Event types for the timeline +#[derive(Debug, Clone)] +pub enum TimelineEvent { + /// Publisher speech value changed + SpeechValue { + publisher_id: usize, + value: u8, + timestamp_ms: u64, + }, + /// Subscriber top-N selection updated + TopNSelection { + subscriber_id: usize, + selected_publishers: Vec, + timestamp_ms: u64, + excluded_self: Option, + }, +} + +/// Collects timeline events for visualization +pub struct TimelineRecorder { + events: Vec, + start_time_ms: u64, + num_publishers: usize, + num_subscribers: usize, + top_n: u8, +} + +impl TimelineRecorder { + pub fn new(num_publishers: usize, num_subscribers: usize, top_n: u8) -> Self { + Self { + events: Vec::new(), + start_time_ms: 0, + num_publishers, + num_subscribers, + top_n, + } + } + + pub fn set_start_time(&mut self, start_ms: u64) { + self.start_time_ms = start_ms; + } + + pub fn record_speech_value(&mut self, publisher_id: usize, value: u8, timestamp_ms: u64) { + self.events.push(TimelineEvent::SpeechValue { + publisher_id, + value, + timestamp_ms: timestamp_ms.saturating_sub(self.start_time_ms), + }); + } + + pub fn record_top_n_selection( + &mut self, + subscriber_id: usize, + selected_publishers: Vec, + timestamp_ms: u64, + excluded_self: Option, + ) { + self.events.push(TimelineEvent::TopNSelection { + subscriber_id, + selected_publishers, + timestamp_ms: timestamp_ms.saturating_sub(self.start_time_ms), + excluded_self, + }); + } + + /// Generate SVG visualization + pub fn generate_svg(&self, output_path: &str) -> std::io::Result<()> { + let max_time_ms = self + .events + .iter() + .map(|e| match e { + TimelineEvent::SpeechValue { timestamp_ms, .. } => *timestamp_ms, + TimelineEvent::TopNSelection { timestamp_ms, .. } => *timestamp_ms, + }) + .max() + .unwrap_or(1000); + + let margin_left = 120.0; + let margin_right = 40.0; + let margin_top = 60.0; + let margin_bottom = 40.0; + + let publisher_lane_height = 40.0; + let subscriber_lane_height = 50.0; + let section_gap = 30.0; + let legend_height = 80.0; + + let timeline_width = 1000.0; + let publishers_height = self.num_publishers as f64 * publisher_lane_height; + let subscribers_height = self.num_subscribers as f64 * subscriber_lane_height; + + let total_width = margin_left + timeline_width + margin_right; + let total_height = + margin_top + publishers_height + section_gap + subscribers_height + section_gap + legend_height + margin_bottom; + + let mut svg = String::new(); + svg.push_str(&format!( + r#""#, + total_width, total_height, total_width, total_height + )); + svg.push('\n'); + + // Styles + svg.push_str(r#" + + +"#); + + // Title + svg.push_str(&format!( + r#"Top-N Filtering Timeline ({} publishers, {} subscribers, N={})"#, + total_width / 2.0 - 150.0, + self.num_publishers, + self.num_subscribers, + self.top_n + )); + svg.push('\n'); + + // Build speech value timeline per publisher + let mut publisher_speech: HashMap> = HashMap::new(); + for i in 0..self.num_publishers { + publisher_speech.insert(i, vec![(0, 0)]); + } + for event in &self.events { + if let TimelineEvent::SpeechValue { + publisher_id, + value, + timestamp_ms, + } = event + { + if let Some(v) = publisher_speech.get_mut(publisher_id) { + v.push((*timestamp_ms, *value)); + } + } + } + for (_, v) in publisher_speech.iter_mut() { + v.push((max_time_ms, v.last().map(|(_, val)| *val).unwrap_or(0))); + } + + // Publishers section label + svg.push_str(&format!( + r#""#, + margin_top - 10.0 + )); + svg.push('\n'); + + // Draw publisher lanes + for pub_id in 0..self.num_publishers { + let lane_y = margin_top + pub_id as f64 * publisher_lane_height; + + // Lane label + svg.push_str(&format!( + r#"P{}"#, + margin_left - 40.0, + lane_y + publisher_lane_height / 2.0 + 4.0, + pub_id + )); + svg.push('\n'); + + // Draw speech activity bars + if let Some(timeline) = publisher_speech.get(&pub_id) { + for i in 0..timeline.len() - 1 { + let (t1, v1) = timeline[i]; + let (t2, _) = timeline[i + 1]; + + let x1 = margin_left + (t1 as f64 / max_time_ms as f64) * timeline_width; + let x2 = margin_left + (t2 as f64 / max_time_ms as f64) * timeline_width; + let width = (x2 - x1).max(1.0); + + let class = match v1 { + 0 => "speech-silent", + 1 => "speech-active", + 2 => "speech-start", + _ => "speech-active", + }; + + svg.push_str(&format!( + r#""#, + x1, + lane_y + 5.0, + width, + publisher_lane_height - 10.0, + class + )); + svg.push('\n'); + } + } + + // Lane border + svg.push_str(&format!( + r#""#, + margin_left, + lane_y + publisher_lane_height, + margin_left + timeline_width, + lane_y + publisher_lane_height + )); + svg.push('\n'); + } + + // Subscribers section + let subscribers_y = margin_top + publishers_height + section_gap; + svg.push_str(&format!( + r#""#, + subscribers_y - 10.0 + )); + svg.push('\n'); + + // Build top-N selection timeline per subscriber + let mut subscriber_selections: HashMap, Option)>> = HashMap::new(); + for i in 0..self.num_subscribers { + subscriber_selections.insert(i, vec![(0, vec![], None)]); + } + for event in &self.events { + if let TimelineEvent::TopNSelection { + subscriber_id, + selected_publishers, + timestamp_ms, + excluded_self, + } = event + { + if let Some(v) = subscriber_selections.get_mut(subscriber_id) { + v.push((*timestamp_ms, selected_publishers.clone(), *excluded_self)); + } + } + } + for (_, v) in subscriber_selections.iter_mut() { + let last_selection = v.last().map(|(_, s, e)| (s.clone(), *e)).unwrap_or((vec![], None)); + v.push((max_time_ms, last_selection.0, last_selection.1)); + } + + // Draw subscriber lanes + for sub_id in 0..self.num_subscribers { + let lane_y = subscribers_y + sub_id as f64 * subscriber_lane_height; + + // Lane label + svg.push_str(&format!( + r#"S{}"#, + margin_left - 40.0, + lane_y + subscriber_lane_height / 2.0 + 4.0, + sub_id + )); + svg.push('\n'); + + // Draw selection segments + if let Some(timeline) = subscriber_selections.get(&sub_id) { + for i in 0..timeline.len() - 1 { + let (t1, ref selected, excluded) = timeline[i]; + let (t2, _, _) = timeline[i + 1]; + + let x1 = margin_left + (t1 as f64 / max_time_ms as f64) * timeline_width; + let x2 = margin_left + (t2 as f64 / max_time_ms as f64) * timeline_width; + let width = (x2 - x1).max(1.0); + + // Background + svg.push_str(&format!( + r##""##, + x1, + lane_y + 2.0, + width, + subscriber_lane_height - 4.0 + )); + svg.push('\n'); + + // Draw mini bars for each selected publisher + let bar_height = (subscriber_lane_height - 8.0) / self.num_publishers as f64; + let bar_width = (width - 2.0).max(1.0); + for &pub_id in selected { + let bar_y = lane_y + 4.0 + pub_id as f64 * bar_height; + svg.push_str(&format!( + r#""#, + x1 + 1.0, + bar_y, + bar_width, + bar_height - 1.0 + )); + svg.push('\n'); + } + + // Mark excluded self + if let Some(excl_id) = excluded { + let bar_y = lane_y + 4.0 + excl_id as f64 * bar_height; + svg.push_str(&format!( + r#""#, + x1 + 1.0, + bar_y, + bar_width, + bar_height - 1.0 + )); + svg.push('\n'); + } + } + } + + // Lane border + svg.push_str(&format!( + r#""#, + margin_left, + lane_y + subscriber_lane_height, + margin_left + timeline_width, + lane_y + subscriber_lane_height + )); + svg.push('\n'); + } + + // Time axis with labels + let axis_y = subscribers_y + subscribers_height + 20.0; + svg.push_str(&format!( + r#""#, + margin_left, axis_y, margin_left + timeline_width, axis_y + )); + svg.push('\n'); + + // Time ticks (every 5 seconds) + let tick_interval_ms = 5000u64; + let mut t = 0u64; + while t <= max_time_ms { + let x = margin_left + (t as f64 / max_time_ms as f64) * timeline_width; + svg.push_str(&format!( + r#""#, + x, axis_y, x, axis_y + 5.0 + )); + svg.push_str(&format!( + r#"{}s"#, + x, + axis_y + 18.0, + t / 1000 + )); + svg.push('\n'); + t += tick_interval_ms; + } + + // Legend + let legend_y = axis_y + 35.0; + svg.push_str(&format!( + r#""#, + margin_left, legend_y + )); + svg.push('\n'); + + let legend_items = [ + ("speech-silent", "Silent (0)"), + ("speech-active", "Speaking (1)"), + ("speech-start", "Speech Start (2)"), + ("selected", "Selected in Top-N"), + ("excluded", "Self-Excluded"), + ]; + + let mut lx = margin_left; + for (class, label) in legend_items { + svg.push_str(&format!( + r#""#, + lx, + legend_y + 10.0, + class + )); + svg.push_str(&format!( + r#"{}"#, + lx + 25.0, + legend_y + 20.0, + label + )); + svg.push('\n'); + lx += 150.0; + } + + svg.push_str("\n"); + + let mut file = File::create(output_path)?; + file.write_all(svg.as_bytes())?; + + Ok(()) + } +} + +/// Wrapper for synchronized access +pub struct SharedTimelineRecorder { + inner: std::sync::Mutex, +} + +impl SharedTimelineRecorder { + pub fn new(num_publishers: usize, num_subscribers: usize, top_n: u8) -> Self { + Self { + inner: std::sync::Mutex::new(TimelineRecorder::new( + num_publishers, + num_subscribers, + top_n, + )), + } + } + + pub fn set_start_time(&self, start_ms: u64) { + self.inner.lock().unwrap().set_start_time(start_ms); + } + + pub fn record_speech_value(&self, publisher_id: usize, value: u8, timestamp_ms: u64) { + self.inner + .lock() + .unwrap() + .record_speech_value(publisher_id, value, timestamp_ms); + } + + pub fn record_top_n_selection( + &self, + subscriber_id: usize, + selected_publishers: Vec, + timestamp_ms: u64, + excluded_self: Option, + ) { + self.inner.lock().unwrap().record_top_n_selection( + subscriber_id, + selected_publishers, + timestamp_ms, + excluded_self, + ); + } + + pub fn generate_svg(&self, output_path: &str) -> std::io::Result<()> { + self.inner.lock().unwrap().generate_svg(output_path) + } +} diff --git a/moq-transport/.gitignore b/moq-transport/.gitignore index 182a6248..03314f77 100644 --- a/moq-transport/.gitignore +++ b/moq-transport/.gitignore @@ -1,5 +1 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - Cargo.lock diff --git a/moq-transport/CHANGELOG.md b/moq-transport/CHANGELOG.md index 91039c5a..5c6cca35 100644 --- a/moq-transport/CHANGELOG.md +++ b/moq-transport/CHANGELOG.md @@ -6,66 +6,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.14.2](https://github.com/cloudflare/moq-rs/compare/moq-transport-v0.14.1...moq-transport-v0.14.2) - 2026-05-20 - -### Fixed - -- subscribe cleaning on drop -- apply suggestions from opencode review - -## [0.14.1](https://github.com/cloudflare/moq-rs/compare/moq-transport-v0.14.0...moq-transport-v0.14.1) - 2026-03-31 - -### Other - -- Make repo REUSE v3.3 compliant -- Bring copyright notices, license docs up to date - -## [0.14.0](https://github.com/cloudflare/moq-rs/compare/moq-transport-v0.13.1...moq-transport-v0.14.0) - 2026-03-27 - -### Added - -- add Transport enum and connection path extraction - -## [0.13.1](https://github.com/cloudflare/moq-rs/compare/moq-transport-v0.13.0...moq-transport-v0.13.1) - 2026-03-02 - -### Fixed - -- TrackReader::is_closed() false positive after mode transition - -### Other - -- Merge pull request #148 from englishm-cloudflare/me/fix-is-closed-false-positive - -## [0.13.0](https://github.com/cloudflare/moq-rs/compare/moq-transport-v0.12.3...moq-transport-v0.13.0) - 2026-02-18 - -### Fixed - -- handle WebTransport graceful close in is_graceful_close() - -### Other - -- soften absolute claims about error conversion paths -- remove unused direct deps from moq-transport -- clarify graceful close semantics for WebTransport vs raw QUIC -- Upgrade web-transport crates to v0.10.1 - -## [0.12.3](https://github.com/cloudflare/moq-rs/compare/moq-transport-v0.12.2...moq-transport-v0.12.3) - 2026-02-18 - -### Added - -- add additional debug logging for troubleshooting -- add structured debug logging for MoQT control messages -- *(metrics)* distinguish graceful close from connection errors - -### Fixed - -- cargo fmt and clippy lints -- *(metrics)* address review feedback for metrics instrumentation - -### Other - -- migrate from log crate to tracing - ## [0.12.2](https://github.com/cloudflare/moq-rs/compare/moq-transport-v0.12.1...moq-transport-v0.12.2) - 2026-01-29 ### Fixed diff --git a/moq-transport/Cargo.toml b/moq-transport/Cargo.toml index 818d201e..5e72b0f1 100644 --- a/moq-transport/Cargo.toml +++ b/moq-transport/Cargo.toml @@ -1,15 +1,11 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - [package] name = "moq-transport" description = "Media over QUIC" -authors = ["moq-rs contributors"] -repository = "https://github.com/cloudflare/moq-rs" +authors = ["Luke Curley"] +repository = "https://github.com/englishm/moq-rs" license = "MIT OR Apache-2.0" -version = "0.14.2" +version = "0.12.2" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] @@ -22,7 +18,7 @@ categories = ["multimedia", "network-programming", "web-programming"] bytes = "1" thiserror = "1" tokio = { version = "1", features = ["macros", "io-util", "sync"] } -tracing = { workspace = true } +log = "0.4" uuid = { version = "1", features = ["v4"] } web-transport = { workspace = true } diff --git a/moq-transport/src/coding/bounded_string.rs b/moq-transport/src/coding/bounded_string.rs index dc15e5e0..540b32dc 100644 --- a/moq-transport/src/coding/bounded_string.rs +++ b/moq-transport/src/coding/bounded_string.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use super::{Decode, DecodeError, Encode, EncodeError}; macro_rules! bounded_string { diff --git a/moq-transport/src/coding/decode.rs b/moq-transport/src/coding/decode.rs index c47d58df..616a5fe5 100644 --- a/moq-transport/src/coding/decode.rs +++ b/moq-transport/src/coding/decode.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use super::BoundsExceeded; use std::{io, string::FromUtf8Error, sync}; use thiserror::Error; diff --git a/moq-transport/src/coding/encode.rs b/moq-transport/src/coding/encode.rs index e6edfb57..b129b155 100644 --- a/moq-transport/src/coding/encode.rs +++ b/moq-transport/src/coding/encode.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{io, sync}; use super::BoundsExceeded; diff --git a/moq-transport/src/coding/hex_dump.rs b/moq-transport/src/coding/hex_dump.rs index f42018a8..b0d8fbef 100644 --- a/moq-transport/src/coding/hex_dump.rs +++ b/moq-transport/src/coding/hex_dump.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! Utility functions for debugging byte sequences /// Format bytes as a hex string with spaces between bytes diff --git a/moq-transport/src/coding/integer.rs b/moq-transport/src/coding/integer.rs index 144a83a4..22d84a8a 100644 --- a/moq-transport/src/coding/integer.rs +++ b/moq-transport/src/coding/integer.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use super::{Decode, DecodeError, Encode, EncodeError}; impl Encode for u8 { diff --git a/moq-transport/src/coding/kvp.rs b/moq-transport/src/coding/kvp.rs index fba12977..065f5d39 100644 --- a/moq-transport/src/coding/kvp.rs +++ b/moq-transport/src/coding/kvp.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; use std::fmt; @@ -51,23 +48,56 @@ impl KeyValuePair { value: Value::BytesValue(value), } } -} -impl Decode for KeyValuePair { - fn decode(r: &mut R) -> Result { - let key = u64::decode(r)?; + /// Validate that the key parity matches the value type. + /// Even keys => IntValue, Odd keys => BytesValue. + fn validate_key_parity(&self) -> Result<(), EncodeError> { + match &self.value { + Value::IntValue(_) => { + if !self.key.is_multiple_of(2) { + return Err(EncodeError::InvalidValue); + } + } + Value::BytesValue(_) => { + if self.key.is_multiple_of(2) { + return Err(EncodeError::InvalidValue); + } + } + } + Ok(()) + } - if key % 2 == 0 { + /// Encode only the value portion of this KVP (not the key/delta). + /// The caller is responsible for encoding the key or delta type. + pub(crate) fn encode_value(&self, w: &mut W) -> Result<(), EncodeError> { + self.validate_key_parity()?; + match &self.value { + Value::IntValue(v) => { + (*v).encode(w)?; + } + Value::BytesValue(v) => { + v.len().encode(w)?; + Self::encode_remaining(w, v.len())?; + w.put_slice(v); + } + } + Ok(()) + } + + /// Decode only the value portion of a KVP given the absolute key. + /// The caller has already decoded the key/delta and resolved the absolute key. + pub(crate) fn decode_value(key: u64, r: &mut R) -> Result { + if key.is_multiple_of(2) { // VarInt variant let value = u64::decode(r)?; - tracing::trace!("[KVP] Decoded even key={}, value={}", key, value); + log::trace!("[KVP] Decoded even key={}, value={}", key, value); Ok(KeyValuePair::new_int(key, value)) } else { // Bytes variant let length = usize::decode(r)?; - tracing::trace!("[KVP] Decoded odd key={}, length={}", key, length); + log::trace!("[KVP] Decoded odd key={}, length={}", key, length); if length > u16::MAX as usize { - tracing::error!( + log::error!( "[KVP] Length exceeded! key={}, length={} (max={})", key, length, @@ -84,30 +114,22 @@ impl Decode for KeyValuePair { } } +/// Legacy Decode for KeyValuePair — reads absolute key from wire. +/// Used only by ExtensionHeaders which reads KVPs from a bounded byte slice. +impl Decode for KeyValuePair { + fn decode(r: &mut R) -> Result { + let key = u64::decode(r)?; + Self::decode_value(key, r) + } +} + +/// Legacy Encode for KeyValuePair — writes absolute key to wire. +/// Used only by ExtensionHeaders which writes KVPs into a temporary buffer. impl Encode for KeyValuePair { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - match &self.value { - Value::IntValue(v) => { - // key must be even for IntValue - if !self.key.is_multiple_of(2) { - return Err(EncodeError::InvalidValue); - } - self.key.encode(w)?; - (*v).encode(w)?; - Ok(()) - } - Value::BytesValue(v) => { - // key must be odd for BytesValue - if self.key.is_multiple_of(2) { - return Err(EncodeError::InvalidValue); - } - self.key.encode(w)?; - v.len().encode(w)?; - Self::encode_remaining(w, v.len())?; - w.put_slice(v); - Ok(()) - } - } + self.validate_key_parity()?; + self.key.encode(w)?; + self.encode_value(w) } } @@ -119,7 +141,10 @@ impl fmt::Debug for KeyValuePair { /// A collection of KeyValuePair entries, where the number of key-value-pairs are encoded/decoded first. /// This structure is appropriate for Control message parameters. -/// Since duplicate parameters are allowed for unknown parameters, we don't do duplicate checking here. +/// +/// Per draft-16 Section 1.4.2, Key-Value-Pairs use delta-encoded Type fields: +/// each Type is encoded as a delta from the previous Type (or from 0 for the first). +/// Entries are sorted by key (Type) in ascending order for encoding. #[derive(Default, Clone, Eq, PartialEq)] pub struct KeyValuePairs(pub Vec); @@ -153,16 +178,49 @@ impl KeyValuePairs { pub fn get(&self, key: u64) -> Option<&KeyValuePair> { self.0.iter().find(|k| k.key == key) } + + /// Get an integer value by key, returning None if not found or if the value is not an integer + pub fn get_intvalue(&self, key: u64) -> Option { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(v) => Some(*v), + Value::BytesValue(_) => None, + }) + } + + /// Get a bytes value by key, returning None if not found or if the value is not bytes + pub fn get_bytesvalue(&self, key: u64) -> Option<&Vec> { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(_) => None, + Value::BytesValue(v) => Some(v), + }) + } } impl Decode for KeyValuePairs { - fn decode(mut r: &mut R) -> Result { + /// Decode Key-Value-Pairs with delta-encoded Type fields (draft-16 Section 1.4.2). + fn decode(r: &mut R) -> Result { let mut kvps = Vec::new(); let count = u64::decode(r)?; + let mut prev_key: u64 = 0; + for _ in 0..count { - let kvp = KeyValuePair::decode(&mut r)?; + // Read delta type + let delta = u64::decode(r)?; + + // Reconstruct absolute key: prev_key + delta + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[KVP] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, r)?; kvps.push(kvp); + prev_key = key; } Ok(KeyValuePairs(kvps)) @@ -170,11 +228,32 @@ impl Decode for KeyValuePairs { } impl Encode for KeyValuePairs { + /// Encode Key-Value-Pairs with delta-encoded Type fields (draft-16 Section 1.4.2). + /// Entries are sorted by key in ascending order before encoding. fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.0.len().encode(w)?; - for kvp in &self.0 { - kvp.encode(w)?; + // Sort by key for delta encoding (Types must be in ascending order) + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|kvp| kvp.key); + + let mut prev_key: u64 = 0; + for kvp in sorted { + // Compute and encode the delta + let delta = kvp.key.checked_sub(prev_key).ok_or_else(|| { + log::error!( + "[KVP] Keys not sortable: prev_key={}, current_key={}", + prev_key, + kvp.key + ); + EncodeError::InvalidValue + })?; + delta.encode(w)?; + + // Encode the value (without the key) + kvp.encode_value(w)?; + + prev_key = kvp.key; } Ok(()) @@ -246,9 +325,10 @@ mod tests { } #[test] - fn encode_decode_keyvaluepairs() { + fn encode_decode_keyvaluepairs_single() { let mut buf = BytesMut::new(); + // Single entry: key=1 (odd, bytes). Delta from 0 = 1. let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); kvps.encode(&mut buf).unwrap(); @@ -256,21 +336,79 @@ mod tests { buf.to_vec(), vec![ 0x01, // 1 KeyValuePair - 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // Key=1, Value=[1,2,3,4,5] + // Delta=1 (from 0), then length=5, then data + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, ] ); let decoded = KeyValuePairs::decode(&mut buf).unwrap(); assert_eq!(decoded, kvps); + } + #[test] + fn encode_decode_keyvaluepairs_multiple() { + let mut buf = BytesMut::new(); + + // Multiple entries inserted out of order — encoding should sort by key. + // Keys: 0 (even, int), 1 (odd, bytes), 100 (even, int) let mut kvps = KeyValuePairs::new(); kvps.set_intvalue(0, 0); kvps.set_intvalue(100, 100); kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); kvps.encode(&mut buf).unwrap(); - let buf_vec = buf.to_vec(); - // Validate the encoded length and the KeyValuePair count - assert_eq!(14, buf_vec.len()); // 14 bytes total - assert_eq!(3, buf_vec[0]); // 3 KeyValuePairs + + #[rustfmt::skip] + let expected = vec![ + 0x03, // 3 KeyValuePairs + // Entry 1: key=0 (delta=0 from 0), even, int value=0 + 0x00, 0x00, + // Entry 2: key=1 (delta=1 from 0), odd, bytes len=5 + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, + // Entry 3: key=100 (delta=99 from 1), even, int value=100 + 0x40, 0x63, 0x40, 0x64, + ]; + assert_eq!(buf.to_vec(), expected); + + // Decode and verify — decoded entries will be in sorted order + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + // Build expected sorted kvps for comparison + let mut expected_kvps = KeyValuePairs::new(); + expected_kvps.set_intvalue(0, 0); + expected_kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); + expected_kvps.set_intvalue(100, 100); + assert_eq!(decoded, expected_kvps); + } + + #[test] + fn encode_decode_keyvaluepairs_roundtrip_sorted() { + let mut buf = BytesMut::new(); + + // Insert in sorted order — should roundtrip exactly + let mut kvps = KeyValuePairs::new(); + kvps.set_intvalue(2, 42); + kvps.set_intvalue(4, 100); + kvps.encode(&mut buf).unwrap(); + + #[rustfmt::skip] + let expected = vec![ + 0x02, // 2 KeyValuePairs + // Entry 1: key=2 (delta=2), int value=42 + 0x02, 0x2a, + // Entry 2: key=4 (delta=2 from 2), int value=100 + 0x02, 0x40, 0x64, + ]; + assert_eq!(buf.to_vec(), expected); + + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + assert_eq!(decoded, kvps); + } + + #[test] + fn encode_decode_keyvaluepairs_empty() { + let mut buf = BytesMut::new(); + + let kvps = KeyValuePairs::new(); + kvps.encode(&mut buf).unwrap(); + assert_eq!(buf.to_vec(), vec![0x00]); // count=0 let decoded = KeyValuePairs::decode(&mut buf).unwrap(); assert_eq!(decoded, kvps); } diff --git a/moq-transport/src/coding/location.rs b/moq-transport/src/coding/location.rs index 0b4c5003..7e820c4f 100644 --- a/moq-transport/src/coding/location.rs +++ b/moq-transport/src/coding/location.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; #[derive(Default, Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] diff --git a/moq-transport/src/coding/mod.rs b/moq-transport/src/coding/mod.rs index 0041e9c5..71cc4148 100644 --- a/moq-transport/src/coding/mod.rs +++ b/moq-transport/src/coding/mod.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - mod bounded_string; mod decode; mod encode; @@ -10,6 +6,7 @@ mod integer; mod kvp; mod location; mod string; +mod track_extensions; mod track_namespace; mod tuple; mod varint; @@ -20,6 +17,7 @@ pub use encode::*; pub use hex_dump::*; pub use kvp::*; pub use location::*; +pub use track_extensions::*; pub use track_namespace::*; pub use tuple::*; pub use varint::*; diff --git a/moq-transport/src/coding/string.rs b/moq-transport/src/coding/string.rs index d50624d4..81e48a24 100644 --- a/moq-transport/src/coding/string.rs +++ b/moq-transport/src/coding/string.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - // TODO SLG - eventually remove this file, bounded_string should now be used instead use super::{Decode, DecodeError, Encode, EncodeError}; diff --git a/moq-transport/src/coding/track_extensions.rs b/moq-transport/src/coding/track_extensions.rs new file mode 100644 index 00000000..3da50c55 --- /dev/null +++ b/moq-transport/src/coding/track_extensions.rs @@ -0,0 +1,196 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePair, Value}; +use std::fmt; + +/// A collection of KeyValuePair entries for Track Extensions. +/// Per draft-16 Section 9.10, Track Extensions are encoded WITHOUT a count or length prefix. +/// They are simply a sequence of delta-encoded key-value pairs until end of message. +/// +/// This differs from: +/// - KeyValuePairs: has a count prefix +/// - ExtensionHeaders: has a byte-length prefix +#[derive(Default, Clone, Eq, PartialEq)] +pub struct TrackExtensions(pub Vec); + +impl TrackExtensions { + pub fn new() -> Self { + Self::default() + } + + /// Insert or replace a KeyValuePair with the same key. + pub fn set(&mut self, kvp: KeyValuePair) { + if let Some(existing) = self.0.iter_mut().find(|k| k.key == kvp.key) { + *existing = kvp; + } else { + self.0.push(kvp); + } + } + + pub fn set_intvalue(&mut self, key: u64, value: u64) { + self.set(KeyValuePair::new_int(key, value)); + } + + pub fn set_bytesvalue(&mut self, key: u64, value: Vec) { + self.set(KeyValuePair::new_bytes(key, value)); + } + + pub fn has(&self, key: u64) -> bool { + self.0.iter().any(|k| k.key == key) + } + + pub fn get(&self, key: u64) -> Option<&KeyValuePair> { + self.0.iter().find(|k| k.key == key) + } + + /// Get an integer value by key, returning None if not found or if the value is not an integer + pub fn get_intvalue(&self, key: u64) -> Option { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(v) => Some(*v), + Value::BytesValue(_) => None, + }) + } + + /// Get a bytes value by key, returning None if not found or if the value is not bytes + pub fn get_bytesvalue(&self, key: u64) -> Option<&Vec> { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(_) => None, + Value::BytesValue(v) => Some(v), + }) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Decode for TrackExtensions { + /// Decode Track Extensions - reads delta-encoded key-value pairs until end of buffer. + /// Per draft-16, Track Extensions have NO count or length prefix. + fn decode(r: &mut R) -> Result { + let mut kvps = Vec::new(); + let mut prev_key: u64 = 0; + + // Read until buffer is exhausted + while r.has_remaining() { + // Read delta type + let delta = u64::decode(r)?; + + // Reconstruct absolute key: prev_key + delta + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[TrackExt] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, r)?; + kvps.push(kvp); + prev_key = key; + } + + Ok(TrackExtensions(kvps)) + } +} + +impl Encode for TrackExtensions { + /// Encode Track Extensions - writes delta-encoded key-value pairs WITHOUT any prefix. + /// Entries are sorted by key in ascending order before encoding. + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + // Sort by key for delta encoding (Types must be in ascending order) + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|kvp| kvp.key); + + let mut prev_key: u64 = 0; + for kvp in sorted { + // Compute and encode the delta + let delta = kvp.key.checked_sub(prev_key).ok_or_else(|| { + log::error!( + "[TrackExt] Keys not sortable: prev_key={}, current_key={}", + prev_key, + kvp.key + ); + EncodeError::InvalidValue + })?; + delta.encode(w)?; + + // Encode the value (without the key) + kvp.encode_value(w)?; + + prev_key = kvp.key; + } + + Ok(()) + } +} + +impl fmt::Debug for TrackExtensions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{{ ")?; + for (i, kv) in self.0.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", kv)?; + } + write!(f, " }}") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode_empty() { + let mut buf = BytesMut::new(); + + let ext = TrackExtensions::new(); + ext.encode(&mut buf).unwrap(); + // Empty TrackExtensions produces NO bytes (no prefix!) + let expected: Vec = vec![]; + assert_eq!(buf.to_vec(), expected); + let decoded = TrackExtensions::decode(&mut buf).unwrap(); + assert_eq!(decoded, ext); + } + + #[test] + fn encode_decode_single() { + let mut buf = BytesMut::new(); + + let mut ext = TrackExtensions::new(); + ext.set_intvalue(2, 42); // key=2 (even), value=42 + ext.encode(&mut buf).unwrap(); + + // Expected: delta=2, value=42 (no count or length prefix!) + assert_eq!(buf.to_vec(), vec![0x02, 0x2a]); + + let decoded = TrackExtensions::decode(&mut buf).unwrap(); + assert_eq!(decoded, ext); + } + + #[test] + fn encode_decode_multiple() { + let mut buf = BytesMut::new(); + + let mut ext = TrackExtensions::new(); + ext.set_intvalue(0, 0); + ext.set_intvalue(2, 100); + ext.encode(&mut buf).unwrap(); + + // Expected: + // Entry 1: delta=0, value=0 + // Entry 2: delta=2 (from 0), value=100 + // No count prefix! + #[rustfmt::skip] + let expected = vec![ + 0x00, 0x00, // delta=0, value=0 + 0x02, 0x40, 0x64, // delta=2, value=100 (varint) + ]; + assert_eq!(buf.to_vec(), expected); + + let decoded = TrackExtensions::decode(&mut buf).unwrap(); + assert_eq!(decoded, ext); + } +} diff --git a/moq-transport/src/coding/track_namespace.rs b/moq-transport/src/coding/track_namespace.rs index 8299e4a3..8e0fe8ec 100644 --- a/moq-transport/src/coding/track_namespace.rs +++ b/moq-transport/src/coding/track_namespace.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use super::{Decode, DecodeError, Encode, EncodeError, TupleField}; use core::hash::{Hash, Hasher}; use std::convert::TryFrom; diff --git a/moq-transport/src/coding/tuple.rs b/moq-transport/src/coding/tuple.rs index 83de9a12..71ad713e 100644 --- a/moq-transport/src/coding/tuple.rs +++ b/moq-transport/src/coding/tuple.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use super::{Decode, DecodeError, Encode, EncodeError}; use core::hash::{Hash, Hasher}; diff --git a/moq-transport/src/coding/varint.rs b/moq-transport/src/coding/varint.rs index 5b432f07..8ae9fa29 100644 --- a/moq-transport/src/coding/varint.rs +++ b/moq-transport/src/coding/varint.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - // Based on quinn-proto // https://github.com/quinn-rs/quinn/blob/main/quinn-proto/src/varint.rs // Licensed via Apache 2.0 and MIT @@ -326,7 +322,7 @@ mod tests { assert_eq!(buf.to_vec(), vec![0b0000_0000]); // first 2 bits are 00 let decoded = VarInt::decode(&mut buf).unwrap(); assert_eq!(decoded, vi); - assert_eq!(u64::from(decoded), i); + assert_eq!(u64::try_from(decoded).unwrap(), i); // 63 -> 1 byte let i = 63; @@ -335,7 +331,7 @@ mod tests { assert_eq!(buf.to_vec(), vec![0b0011_1111]); // first 2 bits are 00 let decoded = VarInt::decode(&mut buf).unwrap(); assert_eq!(decoded, vi); - assert_eq!(u64::from(decoded), i); + assert_eq!(u64::try_from(decoded).unwrap(), i); // 64 -> 2 bytes let i = 64; @@ -344,7 +340,7 @@ mod tests { assert_eq!(buf.to_vec(), vec![0b0100_0000, 0b0100_0000]); // first 2 bits are 01 let decoded = VarInt::decode(&mut buf).unwrap(); assert_eq!(decoded, vi); - assert_eq!(u64::from(decoded), i); + assert_eq!(u64::try_from(decoded).unwrap(), i); // 16383 -> 2 bytes let i = 16383; @@ -353,7 +349,7 @@ mod tests { assert_eq!(buf.to_vec(), vec![0b0111_1111, 0xff]); // first 2 bits are 01 let decoded = VarInt::decode(&mut buf).unwrap(); assert_eq!(decoded, vi); - assert_eq!(u64::from(decoded), i); + assert_eq!(u64::try_from(decoded).unwrap(), i); // 16384 -> 4 bytes let i = 16384; @@ -362,7 +358,7 @@ mod tests { assert_eq!(buf.to_vec(), vec![0b1000_0000, 0x00, 0x40, 0x00]); // first 2 bits are 10 let decoded = VarInt::decode(&mut buf).unwrap(); assert_eq!(decoded, vi); - assert_eq!(u64::from(decoded), i); + assert_eq!(u64::try_from(decoded).unwrap(), i); // 1073741823 -> 4 bytes let i = 1073741823; @@ -371,7 +367,7 @@ mod tests { assert_eq!(buf.to_vec(), vec![0b1011_1111, 0xff, 0xff, 0xff]); // first 2 bits are 10 let decoded = VarInt::decode(&mut buf).unwrap(); assert_eq!(decoded, vi); - assert_eq!(u64::from(decoded), i); + assert_eq!(u64::try_from(decoded).unwrap(), i); // 1073741824 -> 8 bytes let i = 1073741824; @@ -384,7 +380,7 @@ mod tests { ); let decoded = VarInt::decode(&mut buf).unwrap(); assert_eq!(decoded, vi); - assert_eq!(u64::from(decoded), i); + assert_eq!(u64::try_from(decoded).unwrap(), i); // 4611686018427387903 -> 8 bytes let i = 4611686018427387903; @@ -397,7 +393,7 @@ mod tests { ); let decoded = VarInt::decode(&mut buf).unwrap(); assert_eq!(decoded, vi); - assert_eq!(u64::from(decoded), i); + assert_eq!(u64::try_from(decoded).unwrap(), i); } #[test] diff --git a/moq-transport/src/data/datagram.rs b/moq-transport/src/data/datagram.rs index 88c7af96..61d36ce1 100644 --- a/moq-transport/src/data/datagram.rs +++ b/moq-transport/src/data/datagram.rs @@ -1,12 +1,9 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; use crate::data::{ExtensionHeaders, ObjectStatus}; #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum DatagramType { + // Payload types with Priority Present (0x00-0x07) ObjectIdPayload = 0x00, ObjectIdPayloadExt = 0x01, ObjectIdPayloadEndOfGroup = 0x02, @@ -15,13 +12,125 @@ pub enum DatagramType { PayloadExt = 0x05, PayloadEndOfGroup = 0x06, PayloadExtEndOfGroup = 0x07, + // Payload types with Priority Not Present (0x08-0x0F) + ObjectIdPayloadNoPriority = 0x08, + ObjectIdPayloadExtNoPriority = 0x09, + ObjectIdPayloadEndOfGroupNoPriority = 0x0a, + ObjectIdPayloadExtEndOfGroupNoPriority = 0x0b, + PayloadNoPriority = 0x0c, + PayloadExtNoPriority = 0x0d, + PayloadEndOfGroupNoPriority = 0x0e, + PayloadExtEndOfGroupNoPriority = 0x0f, + // Status types with Priority Present (0x20-0x25) ObjectIdStatus = 0x20, ObjectIdStatusExt = 0x21, + Status = 0x24, + StatusExt = 0x25, + // Status types with Priority Not Present (0x28-0x2D) + ObjectIdStatusNoPriority = 0x28, + ObjectIdStatusExtNoPriority = 0x29, + StatusNoPriority = 0x2c, + StatusExtNoPriority = 0x2d, +} + +impl DatagramType { + /// Returns true if this datagram type has the Object ID field present + pub fn has_object_id(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayload + | DatagramType::ObjectIdPayloadExt + | DatagramType::ObjectIdPayloadEndOfGroup + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::ObjectIdPayloadNoPriority + | DatagramType::ObjectIdPayloadExtNoPriority + | DatagramType::ObjectIdPayloadEndOfGroupNoPriority + | DatagramType::ObjectIdPayloadExtEndOfGroupNoPriority + | DatagramType::ObjectIdStatus + | DatagramType::ObjectIdStatusExt + | DatagramType::ObjectIdStatusNoPriority + | DatagramType::ObjectIdStatusExtNoPriority + ) + } + + /// Returns true if this datagram type has the Publisher Priority field present + pub fn has_priority(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayload + | DatagramType::ObjectIdPayloadExt + | DatagramType::ObjectIdPayloadEndOfGroup + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::Payload + | DatagramType::PayloadExt + | DatagramType::PayloadEndOfGroup + | DatagramType::PayloadExtEndOfGroup + | DatagramType::ObjectIdStatus + | DatagramType::ObjectIdStatusExt + | DatagramType::Status + | DatagramType::StatusExt + ) + } + + /// Returns true if this datagram type has extension headers + pub fn has_extensions(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayloadExt + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::PayloadExt + | DatagramType::PayloadExtEndOfGroup + | DatagramType::ObjectIdPayloadExtNoPriority + | DatagramType::ObjectIdPayloadExtEndOfGroupNoPriority + | DatagramType::PayloadExtNoPriority + | DatagramType::PayloadExtEndOfGroupNoPriority + | DatagramType::ObjectIdStatusExt + | DatagramType::StatusExt + | DatagramType::ObjectIdStatusExtNoPriority + | DatagramType::StatusExtNoPriority + ) + } + + /// Returns true if this is a status datagram (no payload) + pub fn is_status(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdStatus + | DatagramType::ObjectIdStatusExt + | DatagramType::Status + | DatagramType::StatusExt + | DatagramType::ObjectIdStatusNoPriority + | DatagramType::ObjectIdStatusExtNoPriority + | DatagramType::StatusNoPriority + | DatagramType::StatusExtNoPriority + ) + } + + /// Returns true if this is a payload datagram + pub fn is_payload(&self) -> bool { + !self.is_status() + } + + /// Returns true if this datagram type indicates end of group + pub fn is_end_of_group(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayloadEndOfGroup + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::PayloadEndOfGroup + | DatagramType::PayloadExtEndOfGroup + | DatagramType::ObjectIdPayloadEndOfGroupNoPriority + | DatagramType::ObjectIdPayloadExtEndOfGroupNoPriority + | DatagramType::PayloadEndOfGroupNoPriority + | DatagramType::PayloadExtEndOfGroupNoPriority + ) + } } impl Decode for DatagramType { fn decode(r: &mut B) -> Result { match u64::decode(r)? { + // Payload types with Priority Present (0x00-0x07) 0x00 => Ok(Self::ObjectIdPayload), 0x01 => Ok(Self::ObjectIdPayloadExt), 0x02 => Ok(Self::ObjectIdPayloadEndOfGroup), @@ -30,8 +139,25 @@ impl Decode for DatagramType { 0x05 => Ok(Self::PayloadExt), 0x06 => Ok(Self::PayloadEndOfGroup), 0x07 => Ok(Self::PayloadExtEndOfGroup), + // Payload types with Priority Not Present (0x08-0x0F) + 0x08 => Ok(Self::ObjectIdPayloadNoPriority), + 0x09 => Ok(Self::ObjectIdPayloadExtNoPriority), + 0x0a => Ok(Self::ObjectIdPayloadEndOfGroupNoPriority), + 0x0b => Ok(Self::ObjectIdPayloadExtEndOfGroupNoPriority), + 0x0c => Ok(Self::PayloadNoPriority), + 0x0d => Ok(Self::PayloadExtNoPriority), + 0x0e => Ok(Self::PayloadEndOfGroupNoPriority), + 0x0f => Ok(Self::PayloadExtEndOfGroupNoPriority), + // Status types with Priority Present (0x20-0x25) 0x20 => Ok(Self::ObjectIdStatus), 0x21 => Ok(Self::ObjectIdStatusExt), + 0x24 => Ok(Self::Status), + 0x25 => Ok(Self::StatusExt), + // Status types with Priority Not Present (0x28-0x2D) + 0x28 => Ok(Self::ObjectIdStatusNoPriority), + 0x29 => Ok(Self::ObjectIdStatusExtNoPriority), + 0x2c => Ok(Self::StatusNoPriority), + 0x2d => Ok(Self::StatusExtNoPriority), _ => Err(DecodeError::InvalidDatagramType), } } @@ -60,9 +186,10 @@ pub struct Datagram { pub object_id: Option, /// Publisher priority, where **smaller** values are sent first. - pub publisher_priority: u8, + /// Optional when using NoPriority datagram types (0x08-0x0F, 0x28-0x2D). + pub publisher_priority: Option, - /// Optional extension headers if type is 0x1 (NoEndOfGroupWithExtensions) or 0x3 (EndofGroupWithExtensions) + /// Optional extension headers for types with extensions pub extension_headers: Option, /// The Object Status. @@ -79,47 +206,38 @@ impl Decode for Datagram { let group_id = u64::decode(r)?; // Decode Object Id if required - let object_id = match datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::ObjectIdStatus - | DatagramType::ObjectIdStatusExt => Some(u64::decode(r)?), - _ => None, + let object_id = if datagram_type.has_object_id() { + Some(u64::decode(r)?) + } else { + None }; - let publisher_priority = u8::decode(r)?; + // Decode Publisher Priority if required + let publisher_priority = if datagram_type.has_priority() { + Some(u8::decode(r)?) + } else { + None + }; // Decode Extension Headers if required - let extension_headers = match datagram_type { - DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::PayloadExt - | DatagramType::PayloadExtEndOfGroup - | DatagramType::ObjectIdStatusExt => Some(ExtensionHeaders::decode(r)?), - _ => None, + let extension_headers = if datagram_type.has_extensions() { + Some(ExtensionHeaders::decode(r)?) + } else { + None }; - // Decode Status if required - let status = match datagram_type { - DatagramType::ObjectIdStatus | DatagramType::ObjectIdStatusExt => { - Some(ObjectStatus::decode(r)?) - } - _ => None, + // Decode Status if required (for status datagram types) + let status = if datagram_type.is_status() { + Some(ObjectStatus::decode(r)?) + } else { + None }; - // Decode Payload if required - let payload = match datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::Payload - | DatagramType::PayloadExt - | DatagramType::PayloadEndOfGroup - | DatagramType::PayloadExtEndOfGroup => Some(r.copy_to_bytes(r.remaining())), - _ => None, + // Decode Payload if required (for payload datagram types) + let payload = if datagram_type.is_payload() { + Some(r.copy_to_bytes(r.remaining())) + } else { + None }; Ok(Self { @@ -142,70 +260,49 @@ impl Encode for Datagram { self.group_id.encode(w)?; // Encode Object Id if required - match self.datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::ObjectIdStatus - | DatagramType::ObjectIdStatusExt => { - if let Some(object_id) = &self.object_id { - object_id.encode(w)?; - } else { - return Err(EncodeError::MissingField("ObjectId".to_string())); - } + if self.datagram_type.has_object_id() { + if let Some(object_id) = &self.object_id { + object_id.encode(w)?; + } else { + return Err(EncodeError::MissingField("ObjectId".to_string())); } - _ => {} - }; + } - self.publisher_priority.encode(w)?; + // Encode Publisher Priority if required + if self.datagram_type.has_priority() { + if let Some(publisher_priority) = &self.publisher_priority { + publisher_priority.encode(w)?; + } else { + return Err(EncodeError::MissingField("PublisherPriority".to_string())); + } + } // Encode Extension Headers if required - match self.datagram_type { - DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::PayloadExt - | DatagramType::PayloadExtEndOfGroup - | DatagramType::ObjectIdStatusExt => { - if let Some(extension_headers) = &self.extension_headers { - extension_headers.encode(w)?; - } else { - return Err(EncodeError::MissingField("ExtensionHeaders".to_string())); - } + if self.datagram_type.has_extensions() { + if let Some(extension_headers) = &self.extension_headers { + extension_headers.encode(w)?; + } else { + return Err(EncodeError::MissingField("ExtensionHeaders".to_string())); } - _ => {} - }; + } - // Decode Status if required - match self.datagram_type { - DatagramType::ObjectIdStatus | DatagramType::ObjectIdStatusExt => { - if let Some(status) = &self.status { - status.encode(w)?; - } else { - return Err(EncodeError::MissingField("Status".to_string())); - } + // Encode Status if required (for status datagram types) + if self.datagram_type.is_status() { + if let Some(status) = &self.status { + status.encode(w)?; + } else { + return Err(EncodeError::MissingField("Status".to_string())); } - _ => {} } - // Decode Payload if required - match self.datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::Payload - | DatagramType::PayloadExt - | DatagramType::PayloadEndOfGroup - | DatagramType::PayloadExtEndOfGroup => { - if let Some(payload) = &self.payload { - Self::encode_remaining(w, payload.len())?; - w.put_slice(payload); - } else { - return Err(EncodeError::MissingField("Payload".to_string())); - } + // Encode Payload if required (for payload datagram types) + if self.datagram_type.is_payload() { + if let Some(payload) = &self.payload { + Self::encode_remaining(w, payload.len())?; + w.put_slice(payload); + } else { + return Err(EncodeError::MissingField("Payload".to_string())); } - _ => {} } Ok(()) @@ -297,7 +394,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -314,7 +411,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -331,7 +428,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -348,7 +445,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -365,7 +462,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: Some(ObjectStatus::EndOfTrack), payload: None, @@ -382,7 +479,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: Some(ObjectStatus::EndOfTrack), payload: None, @@ -399,7 +496,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -416,7 +513,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -433,7 +530,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -450,7 +547,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -460,6 +557,40 @@ mod tests { assert_eq!(19, buf.len()); let decoded = Datagram::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); + + // DatagramType = ObjectIdPayloadNoPriority (no priority field) + let msg = Datagram { + datagram_type: DatagramType::ObjectIdPayloadNoPriority, + track_alias: 12, + group_id: 10, + object_id: Some(1234), + publisher_priority: None, + extension_headers: None, + status: None, + payload: Some(Bytes::from("payload")), + }; + msg.encode(&mut buf).unwrap(); + // Length should be: Type(1)+Alias(1)+GroupId(1)+ObjectId(2)+Payload(7) = 12 (no priority) + assert_eq!(12, buf.len()); + let decoded = Datagram::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + + // DatagramType = PayloadNoPriority (no priority field, no object id) + let msg = Datagram { + datagram_type: DatagramType::PayloadNoPriority, + track_alias: 12, + group_id: 10, + object_id: None, + publisher_priority: None, + extension_headers: None, + status: None, + payload: Some(Bytes::from("payload")), + }; + msg.encode(&mut buf).unwrap(); + // Length should be: Type(1)+Alias(1)+GroupId(1)+Payload(7) = 10 (no priority, no object id) + assert_eq!(10, buf.len()); + let decoded = Datagram::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } #[test] @@ -472,7 +603,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -486,7 +617,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -500,7 +631,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: Some(ObjectStatus::EndOfTrack), payload: None, @@ -514,7 +645,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: None, @@ -528,7 +659,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: None, @@ -536,6 +667,18 @@ mod tests { let encoded = msg.encode(&mut buf); assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - // TODO SLG - add tests + // DatagramType = ObjectIdPayload - missing priority (priority is required for this type) + let msg = Datagram { + datagram_type: DatagramType::ObjectIdPayload, + track_alias: 12, + group_id: 10, + object_id: Some(1234), + publisher_priority: None, + extension_headers: None, + status: None, + payload: Some(Bytes::from("payload")), + }; + let encoded = msg.encode(&mut buf); + assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); } } diff --git a/moq-transport/src/data/extension_headers.rs b/moq-transport/src/data/extension_headers.rs index e457ce77..1a98baab 100644 --- a/moq-transport/src/data/extension_headers.rs +++ b/moq-transport/src/data/extension_headers.rs @@ -1,12 +1,11 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePair}; use bytes::Buf; use std::fmt; /// A collection of KeyValuePair entries, where the length in bytes of key-value-pairs are encoded/decoded first. /// This structure is appropriate for Data plane extension headers. +/// +/// Per draft-16 Section 1.4.2, Key-Value-Pairs use delta-encoded Type fields. /// Since duplicate parameters are allowed for unknown extension headers, we don't do duplicate checking here. #[derive(Default, Clone, Eq, PartialEq)] pub struct ExtensionHeaders(pub Vec); @@ -47,7 +46,65 @@ impl ExtensionHeaders { } } +impl ExtensionHeaders { + /// Encode extension headers without length prefix (just the delta-encoded KVPs). + /// Used for Track Extensions in PUBLISH where the length is implicit from the message. + pub fn encode_without_length(&self, w: &mut W) -> Result<(), EncodeError> { + // Sort by key for delta encoding + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|kvp| kvp.key); + + let mut prev_key: u64 = 0; + for kvp in sorted { + let delta = kvp.key.checked_sub(prev_key).ok_or_else(|| { + log::error!( + "[ExtHdr] Keys not sortable: prev_key={}, current_key={}", + prev_key, + kvp.key + ); + EncodeError::InvalidValue + })?; + delta.encode(w)?; + kvp.encode_value(w)?; + prev_key = kvp.key; + } + + Ok(()) + } + + /// Decode extension headers from remaining bytes (no length prefix). + /// Used for Track Extensions in PUBLISH where the length is implicit from the message. + pub fn decode_remaining_bytes(r: &mut R) -> Result { + if !r.has_remaining() { + return Ok(ExtensionHeaders::new()); + } + + let mut kvps = Vec::new(); + let mut prev_key: u64 = 0; + + while r.has_remaining() { + // Read delta type and reconstruct absolute key + let delta = u64::decode(r)?; + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[ExtHdr] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, r)?; + kvps.push(kvp); + prev_key = key; + } + + Ok(ExtensionHeaders(kvps)) + } +} + impl Decode for ExtensionHeaders { + /// Decode extension headers with delta-encoded Type fields (draft-16 Section 1.4.2). fn decode(r: &mut R) -> Result { // Read total byte length of the encoded kvps // Note: this is the difference between KeyValuePairs and ExtensionHeaders. @@ -68,9 +125,23 @@ impl Decode for ExtensionHeaders { let mut kvps_bytes = bytes::Bytes::from(buf); let mut kvps = Vec::new(); + let mut prev_key: u64 = 0; + while kvps_bytes.has_remaining() { - let kvp = KeyValuePair::decode(&mut kvps_bytes)?; + // Read delta type and reconstruct absolute key + let delta = u64::decode(&mut kvps_bytes)?; + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[ExtHdr] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, &mut kvps_bytes)?; kvps.push(kvp); + prev_key = key; } Ok(ExtensionHeaders(kvps)) @@ -78,14 +149,31 @@ impl Decode for ExtensionHeaders { } impl Encode for ExtensionHeaders { + /// Encode extension headers with delta-encoded Type fields (draft-16 Section 1.4.2). + /// Entries are sorted by key in ascending order before encoding. fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - // Encode all KeyValuePair entries into a temporary buffer to compute total byte length + // Sort by key for delta encoding + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|kvp| kvp.key); + + // Encode all entries into a temporary buffer to compute total byte length let mut tmp = bytes::BytesMut::new(); - for kvp in &self.0 { - kvp.encode(&mut tmp)?; + let mut prev_key: u64 = 0; + for kvp in sorted { + let delta = kvp.key.checked_sub(prev_key).ok_or_else(|| { + log::error!( + "[ExtHdr] Keys not sortable: prev_key={}, current_key={}", + prev_key, + kvp.key + ); + EncodeError::InvalidValue + })?; + delta.encode(&mut tmp)?; + kvp.encode_value(&mut tmp)?; + prev_key = kvp.key; } - // Write total byte length (u64) followed by the encoded bytes + // Write total byte length followed by the encoded bytes (tmp.len() as u64).encode(w)?; w.put_slice(&tmp); @@ -112,9 +200,10 @@ mod tests { use bytes::BytesMut; #[test] - fn encode_decode_extension_headers() { + fn encode_decode_extension_headers_single() { let mut buf = BytesMut::new(); + // Single entry: key=1. Delta from 0 = 1. let mut ext_hdrs = ExtensionHeaders::new(); ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); ext_hdrs.encode(&mut buf).unwrap(); @@ -122,21 +211,55 @@ mod tests { buf.to_vec(), vec![ 0x07, // 7 bytes total length - 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // Key=1, Value=[1,2,3,4,5] + // Delta=1, length=5, data + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, ] ); let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); assert_eq!(decoded, ext_hdrs); + } + #[test] + fn encode_decode_extension_headers_multiple() { + let mut buf = BytesMut::new(); + + // Multiple entries inserted out of order — encoding sorts by key. + // Keys: 0 (even, int), 1 (odd, bytes), 100 (even, int) let mut ext_hdrs = ExtensionHeaders::new(); - ext_hdrs.set_intvalue(0, 0); // 2 bytes - ext_hdrs.set_intvalue(100, 100); // 4 bytes - ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); // 1 byte key, 1 byte length, 5 bytes data = 7 bytes + ext_hdrs.set_intvalue(0, 0); + ext_hdrs.set_intvalue(100, 100); + ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); ext_hdrs.encode(&mut buf).unwrap(); let buf_vec = buf.to_vec(); - // Validate the encoded length and the KeyValuePair's length. - assert_eq!(14, buf_vec.len()); // 14 bytes total (length + 3 kvps) - assert_eq!(13, buf_vec[0]); // 13 bytes for the 3 KeyValuePairs data + + #[rustfmt::skip] + let expected = vec![ + 0x0d, // 13 bytes total length for the KVP data + // Entry 1: key=0 (delta=0), even, int value=0 + 0x00, 0x00, + // Entry 2: key=1 (delta=1), odd, bytes len=5 + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, + // Entry 3: key=100 (delta=99), even, int value=100 + 0x40, 0x63, 0x40, 0x64, + ]; + assert_eq!(buf_vec, expected); + + // Decode and verify — decoded entries will be in sorted order + let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); + let mut expected_ext = ExtensionHeaders::new(); + expected_ext.set_intvalue(0, 0); + expected_ext.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); + expected_ext.set_intvalue(100, 100); + assert_eq!(decoded, expected_ext); + } + + #[test] + fn encode_decode_extension_headers_empty() { + let mut buf = BytesMut::new(); + + let ext_hdrs = ExtensionHeaders::new(); + ext_hdrs.encode(&mut buf).unwrap(); + assert_eq!(buf.to_vec(), vec![0x00]); // length=0 let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); assert_eq!(decoded, ext_hdrs); } diff --git a/moq-transport/src/data/extension_types.rs b/moq-transport/src/data/extension_types.rs new file mode 100644 index 00000000..7c83de7a --- /dev/null +++ b/moq-transport/src/data/extension_types.rs @@ -0,0 +1,38 @@ +//! Known extension header type constants for the MOQT data plane. +//! +//! These extension headers can be attached to objects in subgroups, datagrams, and fetch streams. +//! See the MOQT specification for detailed semantics of each extension type. + +/// Immutable Extensions (0xB) +/// +/// A container extension header that wraps other extension headers that MUST NOT +/// be modified by relays or intermediaries. The contents of this extension header +/// should be preserved exactly as received when forwarding objects. +pub const IMMUTABLE_EXTENSIONS: u64 = 0xB; + +/// Prior Group ID Gap (0x3C) +/// +/// Indicates that one or more groups prior to this one are missing or unavailable. +/// The value is an integer indicating the number of missing prior groups. +/// This is used to signal discontinuities in the group sequence to subscribers. +pub const PRIOR_GROUP_ID_GAP: u64 = 0x3C; + +/// Prior Object ID Gap (0x3E) +/// +/// Indicates that one or more objects prior to this one within the same group/subgroup +/// are missing or unavailable. The value is an integer indicating the number of missing +/// prior objects. This is used to signal discontinuities in the object sequence. +pub const PRIOR_OBJECT_ID_GAP: u64 = 0x3E; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extension_type_values() { + // Verify the spec-defined values + assert_eq!(IMMUTABLE_EXTENSIONS, 0xB); + assert_eq!(PRIOR_GROUP_ID_GAP, 0x3C); + assert_eq!(PRIOR_OBJECT_ID_GAP, 0x3E); + } +} diff --git a/moq-transport/src/data/fetch.rs b/moq-transport/src/data/fetch.rs index 40549819..bd446a4c 100644 --- a/moq-transport/src/data/fetch.rs +++ b/moq-transport/src/data/fetch.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; use crate::data::{ObjectStatus, StreamHeaderType}; diff --git a/moq-transport/src/data/header.rs b/moq-transport/src/data/header.rs index 328b0bdc..3991e759 100644 --- a/moq-transport/src/data/header.rs +++ b/moq-transport/src/data/header.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; use crate::data::{FetchHeader, SubgroupHeader}; use std::fmt; @@ -10,6 +6,7 @@ use std::fmt; #[repr(u64)] #[derive(Copy, Debug, Clone, Eq, PartialEq)] pub enum StreamHeaderType { + // Priority Present variants (0x10-0x1D) SubgroupZeroId = 0x10, SubgroupZeroIdExt = 0x11, SubgroupFirstObjectId = 0x12, @@ -22,13 +19,27 @@ pub enum StreamHeaderType { SubgroupFirstObjectIdExtEndOfGroup = 0x1b, SubgroupIdEndOfGroup = 0x1c, SubgroupIdExtEndOfGroup = 0x1d, + // Priority Not Present variants (0x30-0x3D) + SubgroupZeroIdNoPriority = 0x30, + SubgroupZeroIdExtNoPriority = 0x31, + SubgroupFirstObjectIdNoPriority = 0x32, + SubgroupFirstObjectIdExtNoPriority = 0x33, + SubgroupIdNoPriority = 0x34, + SubgroupIdExtNoPriority = 0x35, + SubgroupZeroIdEndOfGroupNoPriority = 0x38, + SubgroupZeroIdExtEndOfGroupNoPriority = 0x39, + SubgroupFirstObjectIdEndOfGroupNoPriority = 0x3a, + SubgroupFirstObjectIdExtEndOfGroupNoPriority = 0x3b, + SubgroupIdEndOfGroupNoPriority = 0x3c, + SubgroupIdExtEndOfGroupNoPriority = 0x3d, + // Fetch Fetch = 0x5, } impl StreamHeaderType { pub fn is_subgroup(&self) -> bool { let header_type = *self as u64; - (0x10..=0x1d).contains(&header_type) + (0x10..=0x1d).contains(&header_type) || (0x30..=0x3d).contains(&header_type) } pub fn is_fetch(&self) -> bool { @@ -44,6 +55,12 @@ impl StreamHeaderType { | StreamHeaderType::SubgroupZeroIdExtEndOfGroup | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup | StreamHeaderType::SubgroupIdExtEndOfGroup + | StreamHeaderType::SubgroupZeroIdExtNoPriority + | StreamHeaderType::SubgroupFirstObjectIdExtNoPriority + | StreamHeaderType::SubgroupIdExtNoPriority + | StreamHeaderType::SubgroupZeroIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority | StreamHeaderType::Fetch ) } @@ -55,38 +72,91 @@ impl StreamHeaderType { | StreamHeaderType::SubgroupIdExt | StreamHeaderType::SubgroupIdEndOfGroup | StreamHeaderType::SubgroupIdExtEndOfGroup + | StreamHeaderType::SubgroupIdNoPriority + | StreamHeaderType::SubgroupIdExtNoPriority + | StreamHeaderType::SubgroupIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority ) } + + pub fn has_priority(&self) -> bool { + let header_type = *self as u64; + // Priority Present variants are 0x10-0x1D + // Priority Not Present variants are 0x30-0x3D + (0x10..=0x1d).contains(&header_type) + } + + /// Returns true if this header type signals end-of-group when the stream ends. + /// For these types, the relay should write an EndOfGroup marker when the stream completes. + pub fn signals_end_of_group(&self) -> bool { + matches!( + *self, + StreamHeaderType::SubgroupZeroIdEndOfGroup + | StreamHeaderType::SubgroupZeroIdExtEndOfGroup + | StreamHeaderType::SubgroupFirstObjectIdEndOfGroup + | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup + | StreamHeaderType::SubgroupIdEndOfGroup + | StreamHeaderType::SubgroupIdExtEndOfGroup + | StreamHeaderType::SubgroupZeroIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupZeroIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupFirstObjectIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority + ) + } + + /// Returns the equivalent header type without extensions. + /// Used when forwarding streams where objects have empty extension headers. + pub fn without_extensions(&self) -> Self { + match *self { + StreamHeaderType::SubgroupZeroIdExt => StreamHeaderType::SubgroupZeroId, + StreamHeaderType::SubgroupFirstObjectIdExt => StreamHeaderType::SubgroupFirstObjectId, + StreamHeaderType::SubgroupIdExt => StreamHeaderType::SubgroupId, + StreamHeaderType::SubgroupZeroIdExtEndOfGroup => StreamHeaderType::SubgroupZeroIdEndOfGroup, + StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup => StreamHeaderType::SubgroupFirstObjectIdEndOfGroup, + StreamHeaderType::SubgroupIdExtEndOfGroup => StreamHeaderType::SubgroupIdEndOfGroup, + StreamHeaderType::SubgroupZeroIdExtNoPriority => StreamHeaderType::SubgroupZeroIdNoPriority, + StreamHeaderType::SubgroupFirstObjectIdExtNoPriority => StreamHeaderType::SubgroupFirstObjectIdNoPriority, + StreamHeaderType::SubgroupIdExtNoPriority => StreamHeaderType::SubgroupIdNoPriority, + StreamHeaderType::SubgroupZeroIdExtEndOfGroupNoPriority => StreamHeaderType::SubgroupZeroIdEndOfGroupNoPriority, + StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroupNoPriority => StreamHeaderType::SubgroupFirstObjectIdEndOfGroupNoPriority, + StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority => StreamHeaderType::SubgroupIdEndOfGroupNoPriority, + // Already non-Ext or Fetch + other => other, + } + } } impl Encode for StreamHeaderType { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { let val = *self as u64; - tracing::trace!( + log::trace!( "[ENCODE] StreamHeaderType: encoding {:?} as {:#x}", self, val ); val.encode(w)?; - tracing::trace!("[ENCODE] StreamHeaderType: encoded successfully"); + log::trace!("[ENCODE] StreamHeaderType: encoded successfully"); Ok(()) } } impl Decode for StreamHeaderType { fn decode(r: &mut R) -> Result { - tracing::trace!( + log::trace!( "[DECODE] StreamHeaderType: starting decode, buffer_remaining={} bytes", r.remaining() ); let type_value = u64::decode(r)?; - tracing::trace!( + log::trace!( "[DECODE] StreamHeaderType: decoded type value={:#x}", type_value ); let header_type = match type_value { + // Priority Present variants (0x10-0x1D) 0x10_u64 => Ok(Self::SubgroupZeroId), 0x11_u64 => Ok(Self::SubgroupZeroIdExt), 0x12_u64 => Ok(Self::SubgroupFirstObjectId), @@ -99,9 +169,23 @@ impl Decode for StreamHeaderType { 0x1b_u64 => Ok(Self::SubgroupFirstObjectIdExtEndOfGroup), 0x1c_u64 => Ok(Self::SubgroupIdEndOfGroup), 0x1d_u64 => Ok(Self::SubgroupIdExtEndOfGroup), + // Priority Not Present variants (0x30-0x3D) + 0x30_u64 => Ok(Self::SubgroupZeroIdNoPriority), + 0x31_u64 => Ok(Self::SubgroupZeroIdExtNoPriority), + 0x32_u64 => Ok(Self::SubgroupFirstObjectIdNoPriority), + 0x33_u64 => Ok(Self::SubgroupFirstObjectIdExtNoPriority), + 0x34_u64 => Ok(Self::SubgroupIdNoPriority), + 0x35_u64 => Ok(Self::SubgroupIdExtNoPriority), + 0x38_u64 => Ok(Self::SubgroupZeroIdEndOfGroupNoPriority), + 0x39_u64 => Ok(Self::SubgroupZeroIdExtEndOfGroupNoPriority), + 0x3a_u64 => Ok(Self::SubgroupFirstObjectIdEndOfGroupNoPriority), + 0x3b_u64 => Ok(Self::SubgroupFirstObjectIdExtEndOfGroupNoPriority), + 0x3c_u64 => Ok(Self::SubgroupIdEndOfGroupNoPriority), + 0x3d_u64 => Ok(Self::SubgroupIdExtEndOfGroupNoPriority), + // Fetch 0x05_u64 => Ok(Self::Fetch), _ => { - tracing::error!( + log::error!( "[DECODE] StreamHeaderType: INVALID type value={:#x}", type_value ); @@ -110,7 +194,7 @@ impl Decode for StreamHeaderType { }; if let Ok(header_type_inner) = &header_type { - tracing::debug!( + log::debug!( "[DECODE] StreamHeaderType: {}, has_subgroup_id={}, has_extension_headers={}", header_type_inner, header_type_inner.has_subgroup_id(), @@ -142,40 +226,40 @@ pub struct StreamHeader { impl Decode for StreamHeader { fn decode(r: &mut R) -> Result { - tracing::trace!( + log::trace!( "[DECODE] StreamHeader: starting decode, buffer_remaining={} bytes", r.remaining() ); let header_type = StreamHeaderType::decode(r)?; - tracing::trace!( + log::trace!( "[DECODE] StreamHeader: decoded header_type={:?}", header_type ); let subgroup_header = match header_type.is_subgroup() { true => { - tracing::trace!("[DECODE] StreamHeader: decoding subgroup header"); + log::trace!("[DECODE] StreamHeader: decoding subgroup header"); Some(SubgroupHeader::decode(header_type, r)?) } false => { - tracing::trace!("[DECODE] StreamHeader: no subgroup header (not a subgroup type)"); + log::trace!("[DECODE] StreamHeader: no subgroup header (not a subgroup type)"); None } }; let fetch_header = match header_type.is_fetch() { true => { - tracing::trace!("[DECODE] StreamHeader: decoding fetch header"); + log::trace!("[DECODE] StreamHeader: decoding fetch header"); Some(FetchHeader::decode(header_type, r)?) } false => { - tracing::trace!("[DECODE] StreamHeader: no fetch header (not a fetch type)"); + log::trace!("[DECODE] StreamHeader: no fetch header (not a fetch type)"); None } }; - tracing::debug!( + log::debug!( "[DECODE] StreamHeader complete: type={:?}, has_subgroup={}, has_fetch={}, buffer_remaining={} bytes", header_type, subgroup_header.is_some(), @@ -193,7 +277,7 @@ impl Decode for StreamHeader { impl Encode for StreamHeader { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - tracing::trace!( + log::trace!( "[ENCODE] StreamHeader: starting encode for type={:?}, has_subgroup={}, has_fetch={}", self.header_type, self.subgroup_header.is_some(), @@ -205,27 +289,27 @@ impl Encode for StreamHeader { //self.header_type.encode(w)?; if self.header_type.is_subgroup() { if let Some(subgroup_header) = &self.subgroup_header { - tracing::trace!("[ENCODE] StreamHeader: encoding subgroup header"); + log::trace!("[ENCODE] StreamHeader: encoding subgroup header"); subgroup_header.encode(w)?; } else { - tracing::error!( + log::error!( "[ENCODE] StreamHeader: MISSING subgroup header for subgroup type={:?}", self.header_type ); return Err(EncodeError::MissingField("SubgroupHeader".to_string())); } } else if let Some(fetch_header) = &self.fetch_header { - tracing::trace!("[ENCODE] StreamHeader: encoding fetch header"); + log::trace!("[ENCODE] StreamHeader: encoding fetch header"); fetch_header.encode(w)?; } else { - tracing::error!( + log::error!( "[ENCODE] StreamHeader: MISSING fetch header for fetch type={:?}", self.header_type ); return Err(EncodeError::MissingField("FetchHeader".to_string())); } - tracing::debug!("[ENCODE] StreamHeader complete"); + log::debug!("[ENCODE] StreamHeader complete"); Ok(()) } @@ -294,7 +378,31 @@ mod tests { track_alias: 10, group_id: 0, subgroup_id: Some(1), - publisher_priority: 100, + publisher_priority: Some(100), + }), + fetch_header: None, + }; + sh.encode(&mut buf).unwrap(); + let decoded = StreamHeader::decode(&mut buf).unwrap(); + assert_eq!(decoded, sh); + assert!(sh.header_type.is_subgroup()); + assert!(!sh.header_type.is_fetch()); + assert!(sh.header_type.has_subgroup_id()); + } + + #[test] + fn encode_decode_stream_header_no_priority() { + let mut buf = BytesMut::new(); + + // Test a NoPriority subgroup header type + let sh = StreamHeader { + header_type: StreamHeaderType::SubgroupIdNoPriority, + subgroup_header: Some(SubgroupHeader { + header_type: StreamHeaderType::SubgroupIdNoPriority, + track_alias: 10, + group_id: 0, + subgroup_id: Some(1), + publisher_priority: None, }), fetch_header: None, }; @@ -304,5 +412,6 @@ mod tests { assert!(sh.header_type.is_subgroup()); assert!(!sh.header_type.is_fetch()); assert!(sh.header_type.has_subgroup_id()); + assert!(!sh.header_type.has_priority()); } } diff --git a/moq-transport/src/data/mod.rs b/moq-transport/src/data/mod.rs index 482a8ee0..0d0025ab 100644 --- a/moq-transport/src/data/mod.rs +++ b/moq-transport/src/data/mod.rs @@ -1,9 +1,6 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - mod datagram; mod extension_headers; +mod extension_types; mod fetch; mod header; mod object_status; @@ -11,6 +8,7 @@ mod subgroup; pub use datagram::*; pub use extension_headers::*; +pub use extension_types::*; pub use fetch::*; pub use header::*; pub use object_status::*; diff --git a/moq-transport/src/data/object_status.rs b/moq-transport/src/data/object_status.rs index df2eb2bc..e89b8077 100644 --- a/moq-transport/src/data/object_status.rs +++ b/moq-transport/src/data/object_status.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; #[derive(Clone, Copy, Debug, Eq, PartialEq)] diff --git a/moq-transport/src/data/subgroup.rs b/moq-transport/src/data/subgroup.rs index 82d61495..9cfb1127 100644 --- a/moq-transport/src/data/subgroup.rs +++ b/moq-transport/src/data/subgroup.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; use crate::data::{ExtensionHeaders, ObjectStatus, StreamHeaderType}; @@ -19,7 +16,8 @@ pub struct SubgroupHeader { pub subgroup_id: Option, /// Publisher priority, where **smaller** values are sent first. - pub publisher_priority: u8, + /// Optional when using NoPriority header types (0x30-0x3D). + pub publisher_priority: Option, } // Note: Not using the Decode trait, since we need to know the header_type to properly parse this, and it @@ -29,38 +27,46 @@ impl SubgroupHeader { header_type: StreamHeaderType, r: &mut R, ) -> Result { - tracing::trace!( + log::trace!( "[DECODE] SubgroupHeader: starting decode with header_type={:?}, buffer_remaining={} bytes", header_type, r.remaining() ); let track_alias = u64::decode(r)?; - tracing::trace!("[DECODE] SubgroupHeader: track_alias={}", track_alias); + log::trace!("[DECODE] SubgroupHeader: track_alias={}", track_alias); let group_id = u64::decode(r)?; - tracing::trace!("[DECODE] SubgroupHeader: group_id={}", group_id); + log::trace!("[DECODE] SubgroupHeader: group_id={}", group_id); let subgroup_id = match header_type.has_subgroup_id() { true => { let id = u64::decode(r)?; - tracing::trace!("[DECODE] SubgroupHeader: subgroup_id={}", id); + log::trace!("[DECODE] SubgroupHeader: subgroup_id={}", id); Some(id) } false => { - tracing::trace!( + log::trace!( "[DECODE] SubgroupHeader: subgroup_id=None (not present for this header type)" ); None } }; - let publisher_priority = u8::decode(r)?; - tracing::trace!( - "[DECODE] SubgroupHeader: publisher_priority={}, buffer_remaining={} bytes", - publisher_priority, - r.remaining() - ); + let publisher_priority = if header_type.has_priority() { + let priority = u8::decode(r)?; + log::trace!( + "[DECODE] SubgroupHeader: publisher_priority={}, buffer_remaining={} bytes", + priority, + r.remaining() + ); + Some(priority) + } else { + log::trace!( + "[DECODE] SubgroupHeader: publisher_priority=None (not present for NoPriority header type)" + ); + None + }; let result = Self { header_type, @@ -70,8 +76,8 @@ impl SubgroupHeader { publisher_priority, }; - tracing::debug!( - "[DECODE] SubgroupHeader complete: track_alias={}, group_id={}, subgroup_id={:?}, priority={}", + log::debug!( + "[DECODE] SubgroupHeader complete: track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}", result.track_alias, result.group_id, result.subgroup_id, @@ -84,8 +90,8 @@ impl SubgroupHeader { impl Encode for SubgroupHeader { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - tracing::trace!( - "[ENCODE] SubgroupHeader: starting encode - track_alias={}, group_id={}, subgroup_id={:?}, priority={}, header_type={:?}", + log::trace!( + "[ENCODE] SubgroupHeader: starting encode - track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}, header_type={:?}", self.track_alias, self.group_id, self.subgroup_id, @@ -96,16 +102,16 @@ impl Encode for SubgroupHeader { let start_pos = w.remaining_mut(); self.header_type.encode(w)?; - tracing::trace!("[ENCODE] SubgroupHeader: encoded header_type"); + log::trace!("[ENCODE] SubgroupHeader: encoded header_type"); self.track_alias.encode(w)?; - tracing::trace!( + log::trace!( "[ENCODE] SubgroupHeader: encoded track_alias={}", self.track_alias ); self.group_id.encode(w)?; - tracing::trace!( + log::trace!( "[ENCODE] SubgroupHeader: encoded group_id={}", self.group_id ); @@ -113,29 +119,43 @@ impl Encode for SubgroupHeader { if self.header_type.has_subgroup_id() { if let Some(subgroup_id) = self.subgroup_id { subgroup_id.encode(w)?; - tracing::trace!( + log::trace!( "[ENCODE] SubgroupHeader: encoded subgroup_id={}", subgroup_id ); } else { - tracing::error!( + log::error!( "[ENCODE] SubgroupHeader: MISSING subgroup_id for header_type={:?}", self.header_type ); return Err(EncodeError::MissingField("SubgroupId".to_string())); } } else { - tracing::trace!("[ENCODE] SubgroupHeader: subgroup_id not encoded (not required for this header type)"); + log::trace!("[ENCODE] SubgroupHeader: subgroup_id not encoded (not required for this header type)"); } - self.publisher_priority.encode(w)?; - tracing::trace!( - "[ENCODE] SubgroupHeader: encoded publisher_priority={}", - self.publisher_priority - ); + if self.header_type.has_priority() { + if let Some(publisher_priority) = self.publisher_priority { + publisher_priority.encode(w)?; + log::trace!( + "[ENCODE] SubgroupHeader: encoded publisher_priority={}", + publisher_priority + ); + } else { + log::error!( + "[ENCODE] SubgroupHeader: MISSING publisher_priority for header_type={:?}", + self.header_type + ); + return Err(EncodeError::MissingField("PublisherPriority".to_string())); + } + } else { + log::trace!( + "[ENCODE] SubgroupHeader: publisher_priority not encoded (NoPriority header type)" + ); + } let bytes_written = start_pos - w.remaining_mut(); - tracing::debug!( + log::debug!( "[ENCODE] SubgroupHeader complete: wrote {} bytes", bytes_written ); @@ -155,28 +175,28 @@ pub struct SubgroupObject { impl Decode for SubgroupObject { fn decode(r: &mut R) -> Result { - tracing::trace!( + log::trace!( "[DECODE] SubgroupObject: starting decode, buffer_remaining={} bytes", r.remaining() ); let object_id_delta = u64::decode(r)?; - tracing::trace!( + log::trace!( "[DECODE] SubgroupObject: object_id_delta={}", object_id_delta ); let payload_length = usize::decode(r)?; - tracing::trace!("[DECODE] SubgroupObject: payload_length={}", payload_length); + log::trace!("[DECODE] SubgroupObject: payload_length={}", payload_length); let status = match payload_length { 0 => { let s = ObjectStatus::decode(r)?; - tracing::trace!("[DECODE] SubgroupObject: status={:?} (payload_length=0)", s); + log::trace!("[DECODE] SubgroupObject: status={:?} (payload_length=0)", s); Some(s) } _ => { - tracing::trace!("[DECODE] SubgroupObject: status=None (payload_length > 0)"); + log::trace!("[DECODE] SubgroupObject: status=None (payload_length > 0)"); None } }; @@ -184,7 +204,7 @@ impl Decode for SubgroupObject { //Self::decode_remaining(r, payload_length); //let payload = r.copy_to_bytes(payload_length); - tracing::debug!( + log::debug!( "[DECODE] SubgroupObject complete: object_id_delta={}, payload_length={}, status={:?}, buffer_remaining={} bytes", object_id_delta, payload_length, @@ -203,7 +223,7 @@ impl Decode for SubgroupObject { impl Encode for SubgroupObject { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - tracing::trace!( + log::trace!( "[ENCODE] SubgroupObject: starting encode - object_id_delta={}, payload_length={}, status={:?}", self.object_id_delta, self.payload_length, @@ -211,13 +231,13 @@ impl Encode for SubgroupObject { ); self.object_id_delta.encode(w)?; - tracing::trace!( + log::trace!( "[ENCODE] SubgroupObject: encoded object_id_delta={}", self.object_id_delta ); self.payload_length.encode(w)?; - tracing::trace!( + log::trace!( "[ENCODE] SubgroupObject: encoded payload_length={}", self.payload_length ); @@ -225,16 +245,16 @@ impl Encode for SubgroupObject { if self.payload_length == 0 { if let Some(status) = self.status { status.encode(w)?; - tracing::trace!("[ENCODE] SubgroupObject: encoded status={:?}", status); + log::trace!("[ENCODE] SubgroupObject: encoded status={:?}", status); } else { - tracing::error!("[ENCODE] SubgroupObject: MISSING status for payload_length=0"); + log::error!("[ENCODE] SubgroupObject: MISSING status for payload_length=0"); return Err(EncodeError::MissingField("Status".to_string())); } } //Self::encode_remaining(w, self.payload.len())?; //w.put_slice(&self.payload); - tracing::debug!("[ENCODE] SubgroupObject complete"); + log::debug!("[ENCODE] SubgroupObject complete"); Ok(()) } @@ -252,25 +272,25 @@ pub struct SubgroupObjectExt { impl Decode for SubgroupObjectExt { fn decode(r: &mut R) -> Result { - tracing::trace!( + log::trace!( "[DECODE] SubgroupObjectExt: starting decode, buffer_remaining={} bytes", r.remaining() ); let object_id_delta = u64::decode(r)?; - tracing::trace!( + log::trace!( "[DECODE] SubgroupObjectExt: object_id_delta={}", object_id_delta ); let extension_headers = ExtensionHeaders::decode(r)?; - tracing::trace!( + log::trace!( "[DECODE] SubgroupObjectExt: extension_headers={:?}", extension_headers ); let payload_length = usize::decode(r)?; - tracing::trace!( + log::trace!( "[DECODE] SubgroupObjectExt: payload_length={}", payload_length ); @@ -278,14 +298,14 @@ impl Decode for SubgroupObjectExt { let status = match payload_length { 0 => { let s = ObjectStatus::decode(r)?; - tracing::trace!( + log::trace!( "[DECODE] SubgroupObjectExt: status={:?} (payload_length=0)", s ); Some(s) } _ => { - tracing::trace!("[DECODE] SubgroupObjectExt: status=None (payload_length > 0)"); + log::trace!("[DECODE] SubgroupObjectExt: status=None (payload_length > 0)"); None } }; @@ -293,7 +313,7 @@ impl Decode for SubgroupObjectExt { //Self::decode_remaining(r, payload_length); //let payload = r.copy_to_bytes(payload_length); - tracing::debug!( + log::debug!( "[DECODE] SubgroupObjectExt complete: object_id_delta={}, payload_length={}, status={:?}, buffer_remaining={} bytes", object_id_delta, payload_length, @@ -313,7 +333,7 @@ impl Decode for SubgroupObjectExt { impl Encode for SubgroupObjectExt { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - tracing::trace!( + log::trace!( "[ENCODE] SubgroupObjectExt: starting encode - object_id_delta={}, payload_length={}, status={:?}, extension_headers={:?}", self.object_id_delta, self.payload_length, @@ -322,16 +342,16 @@ impl Encode for SubgroupObjectExt { ); self.object_id_delta.encode(w)?; - tracing::trace!( + log::trace!( "[ENCODE] SubgroupObjectExt: encoded object_id_delta={}", self.object_id_delta ); self.extension_headers.encode(w)?; - tracing::trace!("[ENCODE] SubgroupObjectExt: encoded extension_headers"); + log::trace!("[ENCODE] SubgroupObjectExt: encoded extension_headers"); self.payload_length.encode(w)?; - tracing::trace!( + log::trace!( "[ENCODE] SubgroupObjectExt: encoded payload_length={}", self.payload_length ); @@ -339,16 +359,16 @@ impl Encode for SubgroupObjectExt { if self.payload_length == 0 { if let Some(status) = self.status { status.encode(w)?; - tracing::trace!("[ENCODE] SubgroupObjectExt: encoded status={:?}", status); + log::trace!("[ENCODE] SubgroupObjectExt: encoded status={:?}", status); } else { - tracing::error!("[ENCODE] SubgroupObjectExt: MISSING status for payload_length=0"); + log::error!("[ENCODE] SubgroupObjectExt: MISSING status for payload_length=0"); return Err(EncodeError::MissingField("Status".to_string())); } } //Self::encode_remaining(w, self.payload.len())?; //w.put_slice(&self.payload); - tracing::debug!("[ENCODE] SubgroupObjectExt complete"); + log::debug!("[ENCODE] SubgroupObjectExt complete"); Ok(()) } diff --git a/moq-transport/src/error.rs b/moq-transport/src/error.rs index d79abf04..69fc3455 100644 --- a/moq-transport/src/error.rs +++ b/moq-transport/src/error.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - /// An error that causes the session to close. #[derive(thiserror::Error, Debug)] pub enum SessionError { diff --git a/moq-transport/src/lib.rs b/moq-transport/src/lib.rs index 12daad5b..2549df14 100644 --- a/moq-transport/src/lib.rs +++ b/moq-transport/src/lib.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! An implementation of the MoQ Transport protocol. //! //! MoQ Transport is a pub/sub protocol over QUIC. diff --git a/moq-transport/src/message/dynamic_groups.rs b/moq-transport/src/message/dynamic_groups.rs new file mode 100644 index 00000000..4a5b7708 --- /dev/null +++ b/moq-transport/src/message/dynamic_groups.rs @@ -0,0 +1,131 @@ +//! Dynamic Groups support for MOQT. +//! +//! This module provides helper functions for working with Dynamic Groups parameters +//! as defined in the MOQT specification. Dynamic Groups allow subscribers to request +//! publishers to create new groups on demand. + +use crate::coding::KeyValuePairs; +use crate::message::ParameterType; + +/// Helper trait for Dynamic Groups parameter operations on KeyValuePairs. +pub trait DynamicGroupsExt { + /// Check if dynamic groups are enabled/supported + fn has_dynamic_groups(&self) -> bool; + + /// Get the dynamic groups value (if present) + fn get_dynamic_groups(&self) -> Option; + + /// Enable dynamic groups support + fn set_dynamic_groups(&mut self, value: u64); + + /// Check if a new group request is present + fn has_new_group_request(&self) -> bool; + + /// Get the new group request value (if present) + fn get_new_group_request(&self) -> Option; + + /// Request a new group from the publisher + fn set_new_group_request(&mut self, value: u64); +} + +impl DynamicGroupsExt for KeyValuePairs { + fn has_dynamic_groups(&self) -> bool { + self.has(ParameterType::DynamicGroups.into()) + } + + fn get_dynamic_groups(&self) -> Option { + self.get_intvalue(ParameterType::DynamicGroups.into()) + } + + fn set_dynamic_groups(&mut self, value: u64) { + self.set_intvalue(ParameterType::DynamicGroups.into(), value); + } + + fn has_new_group_request(&self) -> bool { + self.has(ParameterType::NewGroupRequest.into()) + } + + fn get_new_group_request(&self) -> Option { + self.get_intvalue(ParameterType::NewGroupRequest.into()) + } + + fn set_new_group_request(&mut self, value: u64) { + self.set_intvalue(ParameterType::NewGroupRequest.into(), value); + } +} + +/// Dynamic Groups configuration for a track +#[derive(Clone, Debug, Default)] +pub struct DynamicGroupsConfig { + /// Whether dynamic groups are enabled for this track + pub enabled: bool, + /// The current pending new group request (if any) + pub pending_request: Option, +} + +impl DynamicGroupsConfig { + /// Create a new configuration with dynamic groups disabled + pub fn new() -> Self { + Self::default() + } + + /// Create a new configuration with dynamic groups enabled + pub fn enabled() -> Self { + Self { + enabled: true, + pending_request: None, + } + } + + /// Request a new group with the given request ID + pub fn request_new_group(&mut self, request_id: u64) { + self.pending_request = Some(request_id); + } + + /// Clear the pending request (after it has been processed) + pub fn clear_pending_request(&mut self) { + self.pending_request = None; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dynamic_groups_ext() { + let mut params = KeyValuePairs::new(); + + // Initially no dynamic groups + assert!(!params.has_dynamic_groups()); + assert_eq!(params.get_dynamic_groups(), None); + + // Enable dynamic groups + params.set_dynamic_groups(1); + assert!(params.has_dynamic_groups()); + assert_eq!(params.get_dynamic_groups(), Some(1)); + + // New group request + assert!(!params.has_new_group_request()); + params.set_new_group_request(42); + assert!(params.has_new_group_request()); + assert_eq!(params.get_new_group_request(), Some(42)); + } + + #[test] + fn test_dynamic_groups_config() { + let config = DynamicGroupsConfig::new(); + assert!(!config.enabled); + assert!(config.pending_request.is_none()); + + let config = DynamicGroupsConfig::enabled(); + assert!(config.enabled); + + let mut config = DynamicGroupsConfig::enabled(); + config.request_new_group(123); + assert_eq!(config.pending_request, Some(123)); + + config.clear_pending_request(); + assert!(config.pending_request.is_none()); + } +} diff --git a/moq-transport/src/message/fetch.rs b/moq-transport/src/message/fetch.rs index 89f16480..d41c94e0 100644 --- a/moq-transport/src/message/fetch.rs +++ b/moq-transport/src/message/fetch.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{ Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, }; diff --git a/moq-transport/src/message/fetch_cancel.rs b/moq-transport/src/message/fetch_cancel.rs index c8a00608..4d30b4c7 100644 --- a/moq-transport/src/message/fetch_cancel.rs +++ b/moq-transport/src/message/fetch_cancel.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// A subscriber issues a FETCH_CANCEL message to a publisher indicating it is diff --git a/moq-transport/src/message/fetch_error.rs b/moq-transport/src/message/fetch_error.rs deleted file mode 100644 index adfb42db..00000000 --- a/moq-transport/src/message/fetch_error.rs +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct FetchError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for FetchError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for FetchError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/fetch_ok.rs b/moq-transport/src/message/fetch_ok.rs index c9328bcb..f3b721ed 100644 --- a/moq-transport/src/message/fetch_ok.rs +++ b/moq-transport/src/message/fetch_ok.rs @@ -1,7 +1,5 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; +use crate::data::ExtensionHeaders; use crate::message::GroupOrder; /// A publisher sends a FETCH_OK control message in response to successful fetches. @@ -21,6 +19,9 @@ pub struct FetchOk { /// Optional parameters pub params: KeyValuePairs, + + /// Track extensions + pub track_extensions: ExtensionHeaders, } impl Decode for FetchOk { @@ -36,6 +37,7 @@ impl Decode for FetchOk { let end_of_track = bool::decode(r)?; let end_location = Location::decode(r)?; let params = KeyValuePairs::decode(r)?; + let track_extensions = ExtensionHeaders::decode(r)?; Ok(Self { id, @@ -43,6 +45,7 @@ impl Decode for FetchOk { end_of_track, end_location, params, + track_extensions, }) } } @@ -60,6 +63,7 @@ impl Encode for FetchOk { self.end_of_track.encode(w)?; self.end_location.encode(w)?; self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -84,6 +88,7 @@ mod tests { end_of_track: true, end_location: Location::new(2, 3), params: kvps.clone(), + track_extensions: Default::default(), }; msg.encode(&mut buf).unwrap(); let decoded = FetchOk::decode(&mut buf).unwrap(); @@ -100,6 +105,7 @@ mod tests { end_of_track: true, end_location: Location::new(2, 3), params: Default::default(), + track_extensions: Default::default(), }; let encoded = msg.encode(&mut buf); assert!(matches!(encoded.unwrap_err(), EncodeError::InvalidValue)); diff --git a/moq-transport/src/message/fetch_type.rs b/moq-transport/src/message/fetch_type.rs index a027cc6c..6213f92e 100644 --- a/moq-transport/src/message/fetch_type.rs +++ b/moq-transport/src/message/fetch_type.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Filter Types diff --git a/moq-transport/src/message/filter_type.rs b/moq-transport/src/message/filter_type.rs index d2be67c7..98d7aec3 100644 --- a/moq-transport/src/message/filter_type.rs +++ b/moq-transport/src/message/filter_type.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Filter Types diff --git a/moq-transport/src/message/go_away.rs b/moq-transport/src/message/go_away.rs index d03f35f5..56050974 100644 --- a/moq-transport/src/message/go_away.rs +++ b/moq-transport/src/message/go_away.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, SessionUri}; /// Sent by the server to indicate that the client should connect to a different server. diff --git a/moq-transport/src/message/group_order.rs b/moq-transport/src/message/group_order.rs index 71d84fc7..c9596aeb 100644 --- a/moq-transport/src/message/group_order.rs +++ b/moq-transport/src/message/group_order.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Group Order diff --git a/moq-transport/src/message/max_request_id.rs b/moq-transport/src/message/max_request_id.rs index 04c61419..e535d425 100644 --- a/moq-transport/src/message/max_request_id.rs +++ b/moq-transport/src/message/max_request_id.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Sent by the publisher to update the max allowed subscription ID for the session. diff --git a/moq-transport/src/message/mod.rs b/moq-transport/src/message/mod.rs index 8a790097..246474a8 100644 --- a/moq-transport/src/message/mod.rs +++ b/moq-transport/src/message/mod.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! Low-level message sent over the wire, as defined in the specification. //! //! All of these messages are sent over a bidirectional QUIC stream. @@ -9,73 +5,65 @@ //! The only exception are OBJECT "messages", which are sent over dedicated QUIC streams. //! +mod dynamic_groups; mod fetch; mod fetch_cancel; -mod fetch_error; mod fetch_ok; mod fetch_type; mod filter_type; mod go_away; mod group_order; mod max_request_id; -mod pubilsh_namespace_done; +mod namespace; +mod parameters; mod publish; mod publish_done; -mod publish_error; mod publish_namespace; mod publish_namespace_cancel; -mod publish_namespace_error; -mod publish_namespace_ok; +mod publish_namespace_done; mod publish_ok; mod publisher; +mod request_error; +mod request_ok; mod requests_blocked; mod subscribe; -mod subscribe_error; mod subscribe_namespace; -mod subscribe_namespace_error; -mod subscribe_namespace_ok; mod subscribe_ok; mod subscribe_update; mod subscriber; mod track_status; -mod track_status_error; mod track_status_ok; mod unsubscribe; -mod unsubscribe_namespace; +pub use dynamic_groups::*; pub use fetch::*; pub use fetch_cancel::*; -pub use fetch_error::*; pub use fetch_ok::*; pub use fetch_type::*; pub use filter_type::*; pub use go_away::*; pub use group_order::*; pub use max_request_id::*; -pub use pubilsh_namespace_done::*; +pub use namespace::*; +pub use parameters::*; pub use publish::*; pub use publish_done::*; -pub use publish_error::*; pub use publish_namespace::*; pub use publish_namespace_cancel::*; -pub use publish_namespace_error::*; -pub use publish_namespace_ok::*; +pub use publish_namespace_done::*; pub use publish_ok::*; pub use publisher::*; +pub use request_error::*; +pub use request_ok::*; pub use requests_blocked::*; pub use subscribe::*; -pub use subscribe_error::*; pub use subscribe_namespace::*; -pub use subscribe_namespace_error::*; -pub use subscribe_namespace_ok::*; pub use subscribe_ok::*; pub use subscribe_update::*; pub use subscriber::*; pub use track_status::*; -pub use track_status_error::*; pub use track_status_ok::*; pub use unsubscribe::*; -pub use unsubscribe_namespace::*; use crate::coding::{Decode, DecodeError, Encode, EncodeError}; use std::fmt; @@ -93,13 +81,18 @@ macro_rules! message_types { impl Decode for Message { fn decode(r: &mut R) -> Result { let t = u64::decode(r)?; - let _len = u16::decode(r)?; + let len = u16::decode(r)? as usize; - // TODO: Check the length of the message. + // Read exactly len bytes into a sub-buffer to properly handle Track Extensions + if r.remaining() < len { + return Err(DecodeError::More(len - r.remaining())); + } + let payload = r.copy_to_bytes(len); + let mut payload_reader = std::io::Cursor::new(payload); match t { $($val => { - let msg = $name::decode(r)?; + let msg = $name::decode(&mut payload_reader)?; Ok(Self::$name(msg)) })* _ => Err(DecodeError::InvalidMessage(t)), @@ -189,40 +182,36 @@ message_types! { Unsubscribe = 0xa, // SUBSCRIBE family, sent by publisher SubscribeOk = 0x4, - SubscribeError = 0x5, // ANNOUNCE family, sent by publisher PublishNamespace = 0x6, PublishNamespaceDone = 0x9, // ANNOUNCE family, sent by subscriber - PublishNamespaceOk = 0x7, - PublishNamespaceError = 0x8, + RequestOk = 0x7, PublishNamespaceCancel = 0xc, + // NAMESPACE family, sent by relay to subscriber (draft-16) + Namespace = 0x8, + // TRACK_STATUS family, sent by subscriber TrackStatus = 0xd, // TRACK_STATUS family, sent by publisher TrackStatusOk = 0xe, - TrackStatusError = 0xf, // NAMESPACE family, sent by subscriber SubscribeNamespace = 0x11, - UnsubscribeNamespace = 0x14, - // NAMESPACE family, sent by publisher - SubscribeNamespaceOk = 0x12, - SubscribeNamespaceError = 0x13, // FETCH family, sent by subscriber Fetch = 0x16, FetchCancel = 0x17, // FETCH family, sent by publisher FetchOk = 0x18, - FetchError = 0x19, // PUBLISH family, sent by publisher Publish = 0x1d, PublishDone = 0xb, // PUBLISH family, sent by subscriber PublishOk = 0x1e, - PublishError = 0x1f, + + RequestError = 0x5, } diff --git a/moq-transport/src/message/namespace.rs b/moq-transport/src/message/namespace.rs new file mode 100644 index 00000000..978d1a1d --- /dev/null +++ b/moq-transport/src/message/namespace.rs @@ -0,0 +1,61 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; + +/// NAMESPACE message (draft-16) +/// +/// Sent by relay to subscriber to announce a namespace matching their SUBSCRIBE_NAMESPACE. +/// This is different from PUBLISH_NAMESPACE which is sent by publisher to relay. +/// +/// Wire format: 0x08 +#[derive(Clone, Debug)] +pub struct Namespace { + /// Request ID (from the SUBSCRIBE_NAMESPACE) + pub id: u64, + /// The namespace being announced + pub track_namespace: TrackNamespace, + /// Optional parameters + pub params: KeyValuePairs, +} + +impl Decode for Namespace { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let track_namespace = TrackNamespace::decode(r)?; + let params = KeyValuePairs::decode(r)?; + + Ok(Self { + id, + track_namespace, + params, + }) + } +} + +impl Encode for Namespace { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.track_namespace.encode(w)?; + self.params.encode(w)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_namespace_encode_decode() { + let msg = Namespace { + id: 42, + track_namespace: TrackNamespace::from_utf8_path("live/room1"), + params: KeyValuePairs::new(), + }; + + let mut buf = Vec::new(); + msg.encode(&mut buf).unwrap(); + + let decoded = Namespace::decode(&mut buf.as_slice()).unwrap(); + assert_eq!(decoded.id, 42); + assert_eq!(decoded.track_namespace.to_utf8_path(), "live/room1"); + } +} diff --git a/moq-transport/src/message/parameters.rs b/moq-transport/src/message/parameters.rs new file mode 100644 index 00000000..b61d9e37 --- /dev/null +++ b/moq-transport/src/message/parameters.rs @@ -0,0 +1,47 @@ +/// Version-Specific Message Parameter Types +/// Used in SUBSCRIBE, SUBSCRIBE_OK, PUBLISH, FETCH, REQUEST_UPDATE, etc. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u64)] +pub enum ParameterType { + /// Used in: REQUEST_OK, PUBLISH, PUBLISH_OK, SUBSCRIBE, SUBSCRIBE_OK, REQUEST_UPDATE + DeliveryTimeout = 0x02, + /// Used in: CLIENT_SETUP, SERVER_SETUP, PUBLISH, SUBSCRIBE, REQUEST_UPDATE, + /// SUBSCRIBE_NAMESPACE, PUBLISH_NAMESPACE, TRACK_STATUS, FETCH + AuthorizationToken = 0x03, + /// Used in: PUBLISH, SUBSCRIBE_OK, FETCH_OK, REQUEST_OK + MaxCacheDuration = 0x04, + /// Used in: SUBSCRIBE_OK, PUBLISH, PUBLISH_OK + Expires = 0x08, + /// Used in: SUBSCRIBE_OK, PUBLISH, REQUEST_OK + LargestObject = 0x09, + /// Used in: SUBSCRIBE_OK, PUBLISH + PublisherPriority = 0x0E, + /// Used in: SUBSCRIBE, REQUEST_UPDATE, PUBLISH, PUBLISH_OK, SUBSCRIBE_NAMESPACE + Forward = 0x10, + /// Used in: SUBSCRIBE, FETCH, REQUEST_UPDATE, PUBLISH_OK + SubscriberPriority = 0x20, + /// Used in: SUBSCRIBE, PUBLISH_OK, REQUEST_UPDATE (renamed to SubscriptionLocationFilter per PR #1518) + SubscriptionFilter = 0x21, + /// Used in: SUBSCRIBE, SUBSCRIBE_OK, REQUEST_OK, PUBLISH, PUBLISH_OK, FETCH + GroupOrder = 0x22, + /// Used in: SUBSCRIBE, FETCH - Filter by subgroup ID ranges (PR #1518) + SubgroupFilter = 0x25, + /// Used in: SUBSCRIBE, FETCH - Filter by object ID ranges (PR #1518) + ObjectFilter = 0x26, + /// Used in: SUBSCRIBE, FETCH - Filter by priority ranges (PR #1518) + PriorityFilter = 0x27, + /// Used in: SUBSCRIBE, FETCH - Filter by property value ranges (PR #1518) + PropertyFilter = 0x28, + /// Used in: SUBSCRIBE_NAMESPACE - Track filter for top-N selection (PR #1518) + TrackFilter = 0x29, + /// Used in: PUBLISH, SUBSCRIBE_OK + DynamicGroups = 0x30, + /// Used in: PUBLISH_OK, SUBSCRIBE, REQUEST_UPDATE + NewGroupRequest = 0x32, +} + +impl From for u64 { + fn from(value: ParameterType) -> Self { + value as u64 + } +} diff --git a/moq-transport/src/message/publish.rs b/moq-transport/src/message/publish.rs index 2246b68b..4615e128 100644 --- a/moq-transport/src/message/publish.rs +++ b/moq-transport/src/message/publish.rs @@ -1,12 +1,10 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{ - Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, -}; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; +use crate::data::ExtensionHeaders; /// Sent by publisher to initiate a subscription to a track. +/// +/// Draft-16: Fields like group_order, content_exists, largest_location, forward +/// have been moved to Parameters (Section 9.2.2). #[derive(Clone, Debug, Eq, PartialEq)] pub struct Publish { /// The publish request ID @@ -17,14 +15,11 @@ pub struct Publish { pub track_name: String, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) pub track_alias: u64, - pub group_order: GroupOrder, - pub content_exists: bool, - // The largest object available for this track, if content exists. - pub largest_location: Option, - pub forward: bool, - - /// Optional parameters + /// Optional parameters (may contain Forward, GroupOrder, LargestObject, PublisherPriority, etc.) pub params: KeyValuePairs, + + /// Track extensions + pub track_extensions: ExtensionHeaders, } impl Decode for Publish { @@ -35,31 +30,18 @@ impl Decode for Publish { let track_name = String::decode(r)?; let track_alias = u64::decode(r)?; - let group_order = GroupOrder::decode(r)?; - // GroupOrder enum has Publisher in it, but it's not allowed to be used in this - // publish message, so validate it now so we can return a protocol error. - if group_order == GroupOrder::Publisher { - return Err(DecodeError::InvalidGroupOrder); - } - let content_exists = bool::decode(r)?; - let largest_location = match content_exists { - true => Some(Location::decode(r)?), - false => None, - }; - let forward = bool::decode(r)?; - let params = KeyValuePairs::decode(r)?; + // Track Extensions use remaining bytes (no length prefix per draft-16) + let track_extensions = ExtensionHeaders::decode_remaining_bytes(r)?; + Ok(Self { id, track_namespace, track_name, track_alias, - group_order, - content_exists, - largest_location, - forward, params, + track_extensions, }) } } @@ -72,22 +54,9 @@ impl Encode for Publish { self.track_name.encode(w)?; self.track_alias.encode(w)?; - // GroupOrder enum has Publisher in it, but it's not allowed to be used in this - // publish message. - if self.group_order == GroupOrder::Publisher { - return Err(EncodeError::InvalidValue); - } - self.group_order.encode(w)?; - self.content_exists.encode(w)?; - if self.content_exists { - if let Some(largest) = &self.largest_location { - largest.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - } - self.forward.encode(w)?; self.params.encode(w)?; + // Track Extensions use remaining bytes (no length prefix per draft-16) + self.track_extensions.encode_without_length(w)?; Ok(()) } @@ -102,37 +71,16 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - // Content exists = true - let msg = Publish { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: true, - largest_location: Some(Location::new(2, 3)), - forward: true, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = Publish::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // Content exists = false let msg = Publish { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: false, - largest_location: None, - forward: true, params: kvps.clone(), + track_extensions: Default::default(), }; msg.encode(&mut buf).unwrap(); let decoded = Publish::decode(&mut buf).unwrap(); @@ -140,7 +88,7 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn encode_decode_no_params() { let mut buf = BytesMut::new(); let msg = Publish { @@ -148,32 +96,38 @@ mod tests { track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: true, - largest_location: None, - forward: true, params: Default::default(), + track_extensions: Default::default(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); + msg.encode(&mut buf).unwrap(); + let decoded = Publish::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } #[test] - fn encode_bad_group_order() { + fn encode_decode_with_track_extensions() { let mut buf = BytesMut::new(); + let mut track_ext = ExtensionHeaders::new(); + track_ext.set_intvalue(0x12, 50); // AUDIO_LEVEL_EXT = 0x12 (18), value = 50 + let msg = Publish { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - track_alias: 212, - group_order: GroupOrder::Publisher, - content_exists: false, - largest_location: None, - forward: true, + id: 1, + track_namespace: TrackNamespace::from_utf8_path("topn-test/speaker-0"), + track_name: "audio".to_string(), + track_alias: 0, params: Default::default(), + track_extensions: track_ext, }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::InvalidValue)); + msg.encode(&mut buf).unwrap(); + let decoded = Publish::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + // Verify the track extension was decoded correctly + let kvp = decoded.track_extensions.get(0x12).unwrap(); + assert_eq!(kvp.key, 0x12); + match &kvp.value { + crate::coding::Value::IntValue(v) => assert_eq!(*v, 50), + _ => panic!("Expected int value"), + } } } diff --git a/moq-transport/src/message/publish_done.rs b/moq-transport/src/message/publish_done.rs index c198788c..98e52316 100644 --- a/moq-transport/src/message/publish_done.rs +++ b/moq-transport/src/message/publish_done.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; // TODO SLG - add an enum for status_codes diff --git a/moq-transport/src/message/publish_error.rs b/moq-transport/src/message/publish_error.rs deleted file mode 100644 index c0e3f857..00000000 --- a/moq-transport/src/message/publish_error.rs +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct PublishError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for PublishError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for PublishError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/publish_namespace.rs b/moq-transport/src/message/publish_namespace.rs index 5c20c1df..fa2fce24 100644 --- a/moq-transport/src/message/publish_namespace.rs +++ b/moq-transport/src/message/publish_namespace.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; /// Sent by the publisher to announce the availability of a group of tracks. diff --git a/moq-transport/src/message/publish_namespace_cancel.rs b/moq-transport/src/message/publish_namespace_cancel.rs index 946d1ccb..05e05429 100644 --- a/moq-transport/src/message/publish_namespace_cancel.rs +++ b/moq-transport/src/message/publish_namespace_cancel.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase, TrackNamespace}; /// Sent by the subscriber to terminate an Announce after PUBLISH_NAMESPACE_OK diff --git a/moq-transport/src/message/pubilsh_namespace_done.rs b/moq-transport/src/message/publish_namespace_done.rs similarity index 88% rename from moq-transport/src/message/pubilsh_namespace_done.rs rename to moq-transport/src/message/publish_namespace_done.rs index 9fa6799e..4540ab47 100644 --- a/moq-transport/src/message/pubilsh_namespace_done.rs +++ b/moq-transport/src/message/publish_namespace_done.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, TrackNamespace}; /// Sent by the publisher to terminate a PUBLISH_NAMESPACE. diff --git a/moq-transport/src/message/publish_namespace_error.rs b/moq-transport/src/message/publish_namespace_error.rs deleted file mode 100644 index 6ba721d5..00000000 --- a/moq-transport/src/message/publish_namespace_error.rs +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an PUBLISH_NAMESPACE. -#[derive(Clone, Debug)] -pub struct PublishNamespaceError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for PublishNamespaceError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for PublishNamespaceError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/publish_namespace_ok.rs b/moq-transport/src/message/publish_namespace_ok.rs deleted file mode 100644 index 84dcc1b9..00000000 --- a/moq-transport/src/message/publish_namespace_ok.rs +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError}; - -/// Sent by the subscriber to accept a PUBLISH_NAMESPACE. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct PublishNamespaceOk { - /// The request ID of the PUBLISH_NAMESPACE this message is replying to. - pub id: u64, -} - -impl Decode for PublishNamespaceOk { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - Ok(Self { id }) - } -} - -impl Encode for PublishNamespaceOk { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = PublishNamespaceOk { id: 12345 }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishNamespaceOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/message/publish_ok.rs b/moq-transport/src/message/publish_ok.rs index ddd7c6b0..e376c89b 100644 --- a/moq-transport/src/message/publish_ok.rs +++ b/moq-transport/src/message/publish_ok.rs @@ -1,113 +1,30 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; -use crate::message::FilterType; -use crate::message::GroupOrder; - -/// Sent by the subscriber to request all future objects for the given track. +/// Sent by the subscriber to acknowledge a PUBLISH message and establish a subscription. /// -/// Objects will use the provided ID instead of the full track name, to save bytes. +/// Draft-16: All subscription properties (forward, subscriber_priority, group_order, +/// filter_type, etc.) are now in Parameters (Section 9.2.2). #[derive(Clone, Debug, Eq, PartialEq)] pub struct PublishOk { /// The request ID of the Publish this message is replying to. pub id: u64, - /// Forward Flag - pub forward: bool, - - /// Subscriber Priority - pub subscriber_priority: u8, - - /// The order the subscription will be delivered in - pub group_order: GroupOrder, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, - - /// Optional parameters + /// Parameters (may contain Forward, SubscriberPriority, GroupOrder, SubscriptionFilter, etc.) pub params: KeyValuePairs, } impl Decode for PublishOk { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; - - let forward = bool::decode(r)?; - let subscriber_priority = u8::decode(r)?; - let group_order = GroupOrder::decode(r)?; - - let filter_type = FilterType::decode(r)?; - let start_location: Option; - let end_group_id: Option; - match filter_type { - FilterType::AbsoluteStart => { - start_location = Some(Location::decode(r)?); - end_group_id = None; - } - FilterType::AbsoluteRange => { - start_location = Some(Location::decode(r)?); - end_group_id = Some(u64::decode(r)?); - } - _ => { - start_location = None; - end_group_id = None; - } - } - let params = KeyValuePairs::decode(r)?; - Ok(Self { - id, - forward, - subscriber_priority, - group_order, - filter_type, - start_location, - end_group_id, - params, - }) + Ok(Self { id, params }) } } impl Encode for PublishOk { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; - - self.forward.encode(w)?; - self.subscriber_priority.encode(w)?; - self.group_order.encode(w)?; - - self.filter_type.encode(w)?; - match self.filter_type { - FilterType::AbsoluteStart => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - // Just ignore end_group_id if it happens to be set - } - FilterType::AbsoluteRange => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - if let Some(end) = self.end_group_id { - end.encode(w)?; - } else { - return Err(EncodeError::MissingField("EndGroupId".to_string())); - } - } - _ => {} - } - self.params.encode(w)?; Ok(()) @@ -123,49 +40,11 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - // FilterType = NextGroupStart - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::NextGroupStart, - start_location: None, - end_group_id: None, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // FilterType = AbsoluteStart let msg = PublishOk { id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteStart, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // FilterType = AbsoluteRange - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: Some(23456), params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); @@ -174,49 +53,15 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn encode_decode_no_params() { let mut buf = BytesMut::new(); - // FilterType = AbsoluteStart - missing start_location - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteStart, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing start_location - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing end_group_id let msg = PublishOk { id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, params: Default::default(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); + msg.encode(&mut buf).unwrap(); + let decoded = PublishOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } } diff --git a/moq-transport/src/message/publisher.rs b/moq-transport/src/message/publisher.rs index abe1c62d..61700289 100644 --- a/moq-transport/src/message/publisher.rs +++ b/moq-transport/src/message/publisher.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::message::{self, Message}; use std::fmt; @@ -52,14 +48,12 @@ macro_rules! publisher_msgs { publisher_msgs! { PublishNamespace, PublishNamespaceDone, + Namespace, Publish, PublishDone, SubscribeOk, - SubscribeError, TrackStatusOk, - TrackStatusError, FetchOk, - FetchError, - SubscribeNamespaceOk, - SubscribeNamespaceError, + RequestOk, + RequestError, } diff --git a/moq-transport/src/message/request_error.rs b/moq-transport/src/message/request_error.rs new file mode 100644 index 00000000..fce02c6f --- /dev/null +++ b/moq-transport/src/message/request_error.rs @@ -0,0 +1,87 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; + +/// REQUEST_ERROR message (draft-16 Section 9.8). +/// +/// Sent in response to any request (SUBSCRIBE, FETCH, PUBLISH, etc.) to indicate failure. +#[derive(Clone, Debug)] +pub struct RequestError { + pub id: u64, + + /// An error code identifying the failure reason. + pub error_code: u64, + + /// Minimum time in milliseconds before the request SHOULD be sent again, plus one. + /// A value of 0 means the request SHOULD NOT be retried. + /// A value of 1 means the request can be retried immediately. + pub retry_interval: u64, + + /// An optional, human-readable reason. + pub reason_phrase: ReasonPhrase, +} + +impl Decode for RequestError { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let error_code = u64::decode(r)?; + let retry_interval = u64::decode(r)?; + let reason_phrase = ReasonPhrase::decode(r)?; + + Ok(Self { + id, + error_code, + retry_interval, + reason_phrase, + }) + } +} + +impl Encode for RequestError { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.error_code.encode(w)?; + self.retry_interval.encode(w)?; + self.reason_phrase.encode(w)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode() { + let mut buf = BytesMut::new(); + + let msg = RequestError { + id: 42, + error_code: 0x1, + retry_interval: 5000, + reason_phrase: ReasonPhrase("unauthorized".to_string()), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded.id, msg.id); + assert_eq!(decoded.error_code, msg.error_code); + assert_eq!(decoded.retry_interval, msg.retry_interval); + } + + #[test] + fn encode_decode_no_retry() { + let mut buf = BytesMut::new(); + + let msg = RequestError { + id: 10, + error_code: 0x0, + retry_interval: 0, + reason_phrase: ReasonPhrase("internal error".to_string()), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded.id, msg.id); + assert_eq!(decoded.error_code, msg.error_code); + assert_eq!(decoded.retry_interval, 0); + } +} diff --git a/moq-transport/src/message/request_ok.rs b/moq-transport/src/message/request_ok.rs new file mode 100644 index 00000000..9ceb8879 --- /dev/null +++ b/moq-transport/src/message/request_ok.rs @@ -0,0 +1,45 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; + +/// Reqeust Ok +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct RequestOk { + /// The SubscribeNamespace/PublishNamespace request ID this message is replying to. + pub id: u64, + + /// Optional parameters + pub params: KeyValuePairs, +} + +impl Decode for RequestOk { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let params = KeyValuePairs::decode(r)?; + Ok(Self { id, params }) + } +} + +impl Encode for RequestOk { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.params.encode(w) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode() { + let mut buf = BytesMut::new(); + + let msg = RequestOk { + id: 12345, + params: KeyValuePairs::new(), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } +} diff --git a/moq-transport/src/message/requests_blocked.rs b/moq-transport/src/message/requests_blocked.rs index ee3323c0..6a9d7f15 100644 --- a/moq-transport/src/message/requests_blocked.rs +++ b/moq-transport/src/message/requests_blocked.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Sent by the publisher to update the max allowed subscription ID for the session. diff --git a/moq-transport/src/message/subscribe.rs b/moq-transport/src/message/subscribe.rs index 6ab571c4..e2a3d0bd 100644 --- a/moq-transport/src/message/subscribe.rs +++ b/moq-transport/src/message/subscribe.rs @@ -1,12 +1,4 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{ - Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, -}; -use crate::message::FilterType; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; /// Sent by the subscriber to request all future objects for the given track. /// @@ -20,22 +12,9 @@ pub struct Subscribe { pub track_namespace: TrackNamespace, pub track_name: String, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) - /// Subscriber Priority - pub subscriber_priority: u8, - pub group_order: GroupOrder, - - /// Forward Flag - pub forward: bool, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, - /// Optional parameters + /// NOTE(itzmanish): since the forward and other fields are moved to parameters + /// we need to validate them on publisher logic pub params: KeyValuePairs, } @@ -46,41 +25,12 @@ impl Decode for Subscribe { let track_namespace = TrackNamespace::decode(r)?; let track_name = String::decode(r)?; - let subscriber_priority = u8::decode(r)?; - let group_order = GroupOrder::decode(r)?; - - let forward = bool::decode(r)?; - - let filter_type = FilterType::decode(r)?; - let start_location: Option; - let end_group_id: Option; - match filter_type { - FilterType::AbsoluteStart => { - start_location = Some(Location::decode(r)?); - end_group_id = None; - } - FilterType::AbsoluteRange => { - start_location = Some(Location::decode(r)?); - end_group_id = Some(u64::decode(r)?); - } - _ => { - start_location = None; - end_group_id = None; - } - } - let params = KeyValuePairs::decode(r)?; Ok(Self { id, track_namespace, track_name, - subscriber_priority, - group_order, - forward, - filter_type, - start_location, - end_group_id, params, }) } @@ -92,37 +42,6 @@ impl Encode for Subscribe { self.track_namespace.encode(w)?; self.track_name.encode(w)?; - - self.subscriber_priority.encode(w)?; - self.group_order.encode(w)?; - - self.forward.encode(w)?; - - self.filter_type.encode(w)?; - match self.filter_type { - FilterType::AbsoluteStart => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - // Just ignore end_group_id if it happens to be set - } - FilterType::AbsoluteRange => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - if let Some(end) = self.end_group_id { - end.encode(w)?; - } else { - return Err(EncodeError::MissingField("EndGroupId".to_string())); - } - } - _ => {} - } - self.params.encode(w)?; Ok(()) @@ -147,12 +66,6 @@ mod tests { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::NextGroupStart, - start_location: None, - end_group_id: None, params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); @@ -164,12 +77,6 @@ mod tests { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteStart, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); @@ -181,69 +88,10 @@ mod tests { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: Some(23456), params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); let decoded = Subscribe::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } - - #[test] - fn encode_missing_fields() { - let mut buf = BytesMut::new(); - - // FilterType = AbsoluteStart - missing start_location - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteStart, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing start_location - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing end_group_id - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - } } diff --git a/moq-transport/src/message/subscribe_error.rs b/moq-transport/src/message/subscribe_error.rs deleted file mode 100644 index a4761c64..00000000 --- a/moq-transport/src/message/subscribe_error.rs +++ /dev/null @@ -1,45 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct SubscribeError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for SubscribeError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for SubscribeError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/subscribe_namespace.rs b/moq-transport/src/message/subscribe_namespace.rs index 59a38311..92662036 100644 --- a/moq-transport/src/message/subscribe_namespace.rs +++ b/moq-transport/src/message/subscribe_namespace.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; /// Subscribe Namespace @@ -12,19 +9,36 @@ pub struct SubscribeNamespace { /// The track namespace prefix pub track_namespace_prefix: TrackNamespace, + /// The Forward value that new subscriptions resulting from this SUBSCRIBE_NAMESPACE will have + pub forward: u8, + /// Optional parameters pub params: KeyValuePairs, } +impl SubscribeNamespace { + /// Creates a new SubscribeNamespace message. + pub fn new(id: u64, track_namespace_prefix: TrackNamespace, forward: u8) -> Self { + Self { + id, + track_namespace_prefix, + forward, + params: KeyValuePairs::new(), + } + } +} + impl Decode for SubscribeNamespace { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; let track_namespace_prefix = TrackNamespace::decode(r)?; + let forward = u8::decode(r)?; let params = KeyValuePairs::decode(r)?; Ok(Self { id, track_namespace_prefix, + forward, params, }) } @@ -34,6 +48,7 @@ impl Encode for SubscribeNamespace { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; self.track_namespace_prefix.encode(w)?; + self.forward.encode(w)?; self.params.encode(w)?; Ok(()) @@ -55,11 +70,14 @@ mod tests { let msg = SubscribeNamespace { id: 12345, + forward: 0, track_namespace_prefix: TrackNamespace::from_utf8_path("path/prefix"), params: kvps, }; msg.encode(&mut buf).unwrap(); let decoded = SubscribeNamespace::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); + assert_eq!(decoded.id, msg.id); + assert_eq!(decoded.forward, msg.forward); + assert_eq!(decoded.track_namespace_prefix, msg.track_namespace_prefix); } } diff --git a/moq-transport/src/message/subscribe_namespace_error.rs b/moq-transport/src/message/subscribe_namespace_error.rs deleted file mode 100644 index 10e47f18..00000000 --- a/moq-transport/src/message/subscribe_namespace_error.rs +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct SubscribeNamespaceError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for SubscribeNamespaceError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for SubscribeNamespaceError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/subscribe_namespace_ok.rs b/moq-transport/src/message/subscribe_namespace_ok.rs deleted file mode 100644 index 865f8c32..00000000 --- a/moq-transport/src/message/subscribe_namespace_ok.rs +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError}; - -/// Subscribe Namespace Ok -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct SubscribeNamespaceOk { - /// The SubscribeNamespace request ID this message is replying to. - pub id: u64, -} - -impl Decode for SubscribeNamespaceOk { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - Ok(Self { id }) - } -} - -impl Encode for SubscribeNamespaceOk { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = SubscribeNamespaceOk { id: 12345 }; - msg.encode(&mut buf).unwrap(); - let decoded = SubscribeNamespaceOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/message/subscribe_ok.rs b/moq-transport/src/message/subscribe_ok.rs index 2902a094..c87f952b 100644 --- a/moq-transport/src/message/subscribe_ok.rs +++ b/moq-transport/src/message/subscribe_ok.rs @@ -1,9 +1,4 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackExtensions}; /// Sent by the publisher to accept a Subscribe. #[derive(Clone, Debug, Eq, PartialEq)] @@ -14,42 +9,26 @@ pub struct SubscribeOk { /// The identifier used for this track in Subgroups or Datagrams. pub track_alias: u64, - /// The time in milliseconds after which the subscription is not longer valid. - pub expires: u64, - - /// Order groups will be delivered in - pub group_order: GroupOrder, - - /// If content_exists, then largest_location is the location of the largest - /// object available for this track - pub content_exists: bool, - pub largest_location: Option, // Only provided if content_exists is 1/true - - /// Subscribe Parameters + /// Subscribe Parameters (has count prefix per spec) pub params: KeyValuePairs, + + /// Track extensions (NO prefix per draft-16 Section 9.10 - reads until end of message) + pub track_extensions: TrackExtensions, } impl Decode for SubscribeOk { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; let track_alias = u64::decode(r)?; - let expires = u64::decode(r)?; - let group_order = GroupOrder::decode(r)?; - let content_exists = bool::decode(r)?; - let largest_location = match content_exists { - true => Some(Location::decode(r)?), - false => None, - }; let params = KeyValuePairs::decode(r)?; + // Track extensions have NO prefix - read until end of message + let track_extensions = TrackExtensions::decode(r)?; Ok(Self { id, track_alias, - expires, - group_order, - content_exists, - largest_location, params, + track_extensions, }) } } @@ -58,17 +37,8 @@ impl Encode for SubscribeOk { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; self.track_alias.encode(w)?; - self.expires.encode(w)?; - self.group_order.encode(w)?; - self.content_exists.encode(w)?; - if self.content_exists { - if let Some(largest) = &self.largest_location { - largest.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - } self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -87,14 +57,15 @@ mod tests { let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); + // Track extensions (no prefix) + let mut ext = TrackExtensions::new(); + ext.set_intvalue(2, 42); + let msg = SubscribeOk { id: 12345, track_alias: 100, - expires: 3600, - group_order: GroupOrder::Publisher, - content_exists: true, - largest_location: Some(Location::new(2, 3)), - params: kvps.clone(), + params: kvps, + track_extensions: ext, }; msg.encode(&mut buf).unwrap(); let decoded = SubscribeOk::decode(&mut buf).unwrap(); @@ -102,19 +73,22 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn encode_decode_empty_extensions() { let mut buf = BytesMut::new(); let msg = SubscribeOk { - id: 12345, - track_alias: 100, - expires: 3600, - group_order: GroupOrder::Publisher, - content_exists: true, - largest_location: None, - params: Default::default(), + id: 0, + track_alias: 0, + params: KeyValuePairs::new(), + track_extensions: TrackExtensions::new(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); + msg.encode(&mut buf).unwrap(); + // Expected: id=0 (1 byte), track_alias=0 (1 byte), params_count=0 (1 byte), NO track_extensions bytes + assert_eq!(buf.to_vec(), vec![0x00, 0x00, 0x00]); + let decoded = SubscribeOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } + + // Note: encode_missing_fields test removed — content_exists was removed + // from the struct in draft-16; no fields to validate at encode time. } diff --git a/moq-transport/src/message/subscribe_update.rs b/moq-transport/src/message/subscribe_update.rs index 671e08f1..895378d9 100644 --- a/moq-transport/src/message/subscribe_update.rs +++ b/moq-transport/src/message/subscribe_update.rs @@ -1,56 +1,31 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; - -/// Sent by the subscriber to request all future objects for the given track. +/// REQUEST_UPDATE message (draft-16 Section 9.11). /// -/// Objects will use the provided ID instead of the full track name, to save bytes. +/// Sent to modify an existing request (SUBSCRIBE, PUBLISH, FETCH, etc.). +/// Parameters previously set that are not present in the update remain unchanged. #[derive(Clone, Debug, Eq, PartialEq)] pub struct SubscribeUpdate { - /// The request ID of this request + /// The request ID of this REQUEST_UPDATE pub id: u64, - /// The request ID of the SUBSCRIBE this message is updating. - pub subscription_request_id: u64, - - /// The starting location - pub start_location: Location, - /// The end Group ID, plus 1. A value of 0 means the subscription is open-ended. - pub end_group_id: u64, - - /// Subscriber Priority - pub subscriber_priority: u8, + /// The request ID of the existing request this message is updating. + pub existing_request_id: u64, - /// Forward Flag - pub forward: bool, - - /// Optional parameters + /// Parameters to update (draft-16 Section 9.2.2). + /// Parameters not present remain unchanged from the original request. pub params: KeyValuePairs, } impl Decode for SubscribeUpdate { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; - - let subscription_request_id = u64::decode(r)?; - - let start_location = Location::decode(r)?; - let end_group_id = u64::decode(r)?; - - let subscriber_priority = u8::decode(r)?; - - let forward = bool::decode(r)?; - + let existing_request_id = u64::decode(r)?; let params = KeyValuePairs::decode(r)?; Ok(Self { id, - subscription_request_id, - start_location, - end_group_id, - subscriber_priority, - forward, + existing_request_id, params, }) } @@ -59,16 +34,7 @@ impl Decode for SubscribeUpdate { impl Encode for SubscribeUpdate { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; - - self.subscription_request_id.encode(w)?; - - self.start_location.encode(w)?; - self.end_group_id.encode(w)?; - - self.subscriber_priority.encode(w)?; - - self.forward.encode(w)?; - + self.existing_request_id.encode(w)?; self.params.encode(w)?; Ok(()) @@ -84,21 +50,30 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_intvalue(124, 456); let msg = SubscribeUpdate { id: 1000, - subscription_request_id: 924, - start_location: Location::new(1, 1), - end_group_id: 100000, - subscriber_priority: 127, - forward: true, + existing_request_id: 924, params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); let decoded = SubscribeUpdate::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } + + #[test] + fn encode_decode_empty_params() { + let mut buf = BytesMut::new(); + + let msg = SubscribeUpdate { + id: 5, + existing_request_id: 3, + params: KeyValuePairs::new(), + }; + msg.encode(&mut buf).unwrap(); + let decoded = SubscribeUpdate::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } } diff --git a/moq-transport/src/message/subscriber.rs b/moq-transport/src/message/subscriber.rs index a211c5e3..3c433149 100644 --- a/moq-transport/src/message/subscriber.rs +++ b/moq-transport/src/message/subscriber.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::message::{self, Message}; use std::fmt; @@ -57,10 +53,8 @@ subscriber_msgs! { FetchCancel, TrackStatus, SubscribeNamespace, - UnsubscribeNamespace, PublishNamespaceCancel, - PublishNamespaceOk, - PublishNamespaceError, + RequestOk, PublishOk, - PublishError, + RequestError, } diff --git a/moq-transport/src/message/track_status.rs b/moq-transport/src/message/track_status.rs index 29c2c426..af6cffd0 100644 --- a/moq-transport/src/message/track_status.rs +++ b/moq-transport/src/message/track_status.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{ Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, }; diff --git a/moq-transport/src/message/track_status_error.rs b/moq-transport/src/message/track_status_error.rs deleted file mode 100644 index 842c20da..00000000 --- a/moq-transport/src/message/track_status_error.rs +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct TrackStatusError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for TrackStatusError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for TrackStatusError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/track_status_ok.rs b/moq-transport/src/message/track_status_ok.rs index 35a2911d..f8ec005a 100644 --- a/moq-transport/src/message/track_status_ok.rs +++ b/moq-transport/src/message/track_status_ok.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; use crate::message::GroupOrder; diff --git a/moq-transport/src/message/unsubscribe.rs b/moq-transport/src/message/unsubscribe.rs index df67748b..3cce7a0e 100644 --- a/moq-transport/src/message/unsubscribe.rs +++ b/moq-transport/src/message/unsubscribe.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Sent by the subscriber to terminate a Subscribe. diff --git a/moq-transport/src/message/unsubscribe_namespace.rs b/moq-transport/src/message/unsubscribe_namespace.rs deleted file mode 100644 index 4ae21520..00000000 --- a/moq-transport/src/message/unsubscribe_namespace.rs +++ /dev/null @@ -1,45 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, TrackNamespace}; - -/// Unsubscribe Namespace -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct UnsubscribeNamespace { - // Echo back the track namespace prefix from subscribe namespace - pub track_namespace_prefix: TrackNamespace, -} - -impl Decode for UnsubscribeNamespace { - fn decode(r: &mut R) -> Result { - let track_namespace_prefix = TrackNamespace::decode(r)?; - Ok(Self { - track_namespace_prefix, - }) - } -} - -impl Encode for UnsubscribeNamespace { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.track_namespace_prefix.encode(w)?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = UnsubscribeNamespace { - track_namespace_prefix: TrackNamespace::from_utf8_path("test/path/to/resource"), - }; - msg.encode(&mut buf).unwrap(); - let decoded = UnsubscribeNamespace::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/mlog/events.rs b/moq-transport/src/mlog/events.rs index a93820df..f074b747 100644 --- a/moq-transport/src/mlog/events.rs +++ b/moq-transport/src/mlog/events.rs @@ -1,14 +1,11 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - // TODO: Unimplemented control message events (not yet needed for basic relay interop testing): // - SubscribeUpdate (parsed/created) // - PublishNamespaceDone (parsed/created) // - PublishNamespaceCancel (parsed/created) -// - TrackStatus, TrackStatusOk, TrackStatusError (parsed/created) -// - SubscribeNamespace, SubscribeNamespaceOk, SubscribeNamespaceError, UnsubscribeNamespace (parsed/created) -// - Fetch, FetchOk, FetchError, FetchCancel (parsed/created) -// - Publish, PublishOk, PublishError, PublishDone (parsed/created) +// - TrackStatus, TrackStatusOk (parsed/created) +// - SubscribeNamespace (parsed/created) +// - Fetch, FetchOk, FetchCancel (parsed/created) +// - Publish, PublishOk, PublishDone (parsed/created) // - MaxRequestId (parsed/created) // - RequestsBlocked (parsed/created) // @@ -210,7 +207,6 @@ fn create_control_message_event( /// Create a control_message_parsed event for CLIENT_SETUP pub fn client_setup_parsed(time: f64, stream_id: u64, msg: &setup::Client) -> Event { - let versions: Vec = msg.versions.0.iter().map(|v| format!("{:?}", v)).collect(); create_control_message_event( time, stream_id, @@ -218,8 +214,6 @@ pub fn client_setup_parsed(time: f64, stream_id: u64, msg: &setup::Client) -> Ev "client_setup", json!( { - "number_of_supported_versions": msg.versions.0.len(), - "supported_versions": versions, "parameters": key_value_pairs_to_vec(&msg.params.0), }), ) @@ -234,7 +228,6 @@ pub fn server_setup_created(time: f64, stream_id: u64, msg: &setup::Server) -> E "server_setup", json!( { - "selected_version": format!("{:?}", msg.version), "parameters": key_value_pairs_to_vec(&msg.params.0), }), ) @@ -242,25 +235,12 @@ pub fn server_setup_created(time: f64, stream_id: u64, msg: &setup::Server) -> E /// Helper to convert SUBSCRIBE message to JSON fn subscribe_to_json(msg: &message::Subscribe) -> JsonValue { - let mut json = json!({ + let json = json!({ "subscribe_id": msg.id, "track_namespace": msg.track_namespace.to_string(), "track_name": &msg.track_name, - "subscriber_priority": msg.subscriber_priority, - "group_order": format!("{:?}", msg.group_order), - "filter_type": format!("{:?}", msg.filter_type), "parameters": key_value_pairs_to_vec(&msg.params.0), }); - - // Add optional fields based on filter type - if let Some(start_loc) = &msg.start_location { - json["start_group"] = json!(start_loc.group_id); - json["start_object"] = json!(start_loc.object_id); - } - if let Some(end_group) = msg.end_group_id { - json["end_group"] = json!(end_group); - } - json } @@ -276,23 +256,11 @@ pub fn subscribe_created(time: f64, stream_id: u64, msg: &message::Subscribe) -> /// Helper to convert SUBSCRIBE_OK message to JSON fn subscribe_ok_to_json(msg: &message::SubscribeOk) -> JsonValue { - let mut json = json!({ + let json = json!({ "subscribe_id": msg.id, "track_alias": msg.track_alias, - "expires": msg.expires, - "group_order": format!("{:?}", msg.group_order), - "content_exists": msg.content_exists, "parameters": key_value_pairs_to_vec(&msg.params.0), }); - - // Add optional largest_location fields if content exists - if msg.content_exists { - if let Some(largest) = &msg.largest_location { - json["largest_group_id"] = json!(largest.group_id); - json["largest_object_id"] = json!(largest.object_id); - } - } - json } @@ -319,33 +287,34 @@ pub fn subscribe_ok_created(time: f64, stream_id: u64, msg: &message::SubscribeO } /// Helper to convert SUBSCRIBE_ERROR message to JSON -fn subscribe_error_to_json(msg: &message::SubscribeError) -> JsonValue { +fn request_error_to_json(msg: &message::RequestError) -> JsonValue { json!({ - "subscribe_id": msg.id, + "request_id": msg.id, "error_code": msg.error_code, + "retry_interval": msg.retry_interval, "reason_phrase": &msg.reason_phrase.0, }) } /// Create a control_message_parsed event for SUBSCRIBE_ERROR -pub fn subscribe_error_parsed(time: f64, stream_id: u64, msg: &message::SubscribeError) -> Event { +pub fn request_error_parsed(time: f64, stream_id: u64, msg: &message::RequestError) -> Event { create_control_message_event( time, stream_id, true, - "subscribe_error", - subscribe_error_to_json(msg), + "request_error", + request_error_to_json(msg), ) } /// Create a control_message_created event for SUBSCRIBE_ERROR -pub fn subscribe_error_created(time: f64, stream_id: u64, msg: &message::SubscribeError) -> Event { +pub fn reqeust_error_created(time: f64, stream_id: u64, msg: &message::RequestError) -> Event { create_control_message_event( time, stream_id, false, - "subscribe_error", - subscribe_error_to_json(msg), + "request_error", + request_error_to_json(msg), ) } @@ -389,78 +358,100 @@ pub fn publish_namespace_created( } /// Helper to convert PUBLISH_NAMESPACE_OK message to JSON -fn publish_namespace_ok_to_json(msg: &message::PublishNamespaceOk) -> JsonValue { +fn request_ok_to_json(msg: &message::RequestOk) -> JsonValue { json!({ "request_id": msg.id, }) } -/// Create a control_message_parsed event for PUBLISH_NAMESPACE_OK (was ANNOUNCE_OK) -pub fn publish_namespace_ok_parsed( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceOk, -) -> Event { +/// Create a control_message_parsed event for REQUEST_OK +pub fn request_ok_parsed(time: f64, stream_id: u64, msg: &message::RequestOk) -> Event { + create_control_message_event(time, stream_id, true, "request_ok", request_ok_to_json(msg)) +} + +/// Create a control_message_created event for Reqeust OK +pub fn reqeust_ok_created(time: f64, stream_id: u64, msg: &message::RequestOk) -> Event { create_control_message_event( time, stream_id, - true, - "publish_namespace_ok", - publish_namespace_ok_to_json(msg), + false, + "request_ok", + request_ok_to_json(msg), ) } -/// Create a control_message_created event for PUBLISH_NAMESPACE_OK -pub fn publish_namespace_ok_created( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceOk, -) -> Event { +fn publish_to_json(msg: &message::Publish) -> JsonValue { + json!({ + "publish_id": msg.id, + "track_namespace": msg.track_namespace.to_string(), + "track_name": &msg.track_name, + "track_alias": msg.track_alias, + "parameters": key_value_pairs_to_vec(&msg.params.0), + }) +} + +/// Create a control_message_parsed event for PUBLISH +pub fn publish_parsed(time: f64, stream_id: u64, msg: &message::Publish) -> Event { + create_control_message_event(time, stream_id, true, "publish", publish_to_json(msg)) +} + +/// Create a control_message_created event for PUBLISH +pub fn publish_created(time: f64, stream_id: u64, msg: &message::Publish) -> Event { + create_control_message_event(time, stream_id, false, "publish", publish_to_json(msg)) +} + +fn publish_ok_to_json(msg: &message::PublishOk) -> JsonValue { + json!({ + "publish_id": msg.id, + "parameters": key_value_pairs_to_vec(&msg.params.0), + }) +} + +/// Create a control_message_parsed event for PUBLISH_OK +pub fn publish_ok_parsed(time: f64, stream_id: u64, msg: &message::PublishOk) -> Event { + create_control_message_event(time, stream_id, true, "publish_ok", publish_ok_to_json(msg)) +} + +/// Create a control_message_created event for PUBLISH_OK +pub fn publish_ok_created(time: f64, stream_id: u64, msg: &message::PublishOk) -> Event { create_control_message_event( time, stream_id, false, - "publish_namespace_ok", - publish_namespace_ok_to_json(msg), + "publish_ok", + publish_ok_to_json(msg), ) } -/// Helper to convert PUBLISH_NAMESPACE_ERROR message to JSON -fn publish_namespace_error_to_json(msg: &message::PublishNamespaceError) -> JsonValue { +/// Helper to convert PUBLISH_DONE message to JSON +fn publish_done_to_json(msg: &message::PublishDone) -> JsonValue { json!({ - "request_id": msg.id, - "error_code": msg.error_code, - "reason_phrase": &msg.reason_phrase.0, + "publish_id": msg.id, + "status_code": msg.status_code, + "stream_count": msg.stream_count, + "reason": &msg.reason.0, }) } -/// Create a control_message_parsed event for PUBLISH_NAMESPACE_ERROR (was ANNOUNCE_ERROR) -pub fn publish_namespace_error_parsed( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceError, -) -> Event { +/// Create a control_message_parsed event for PUBLISH_DONE +pub fn publish_done_parsed(time: f64, stream_id: u64, msg: &message::PublishDone) -> Event { create_control_message_event( time, stream_id, true, - "publish_namespace_error", - publish_namespace_error_to_json(msg), + "publish_done", + publish_done_to_json(msg), ) } -/// Create a control_message_created event for PUBLISH_NAMESPACE_ERROR -pub fn publish_namespace_error_created( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceError, -) -> Event { +/// Create a control_message_created event for PUBLISH_DONE +pub fn publish_done_created(time: f64, stream_id: u64, msg: &message::PublishDone) -> Event { create_control_message_event( time, stream_id, false, - "publish_namespace_error", - publish_namespace_error_to_json(msg), + "publish_done", + publish_done_to_json(msg), ) } diff --git a/moq-transport/src/mlog/mod.rs b/moq-transport/src/mlog/mod.rs index f61c5948..3e6bb55e 100644 --- a/moq-transport/src/mlog/mod.rs +++ b/moq-transport/src/mlog/mod.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! MoQ Transport logging (mlog) following qlog patterns //! //! Based on draft-pardue-moq-qlog-moq-events but adapted for MoQ Transport draft-14 diff --git a/moq-transport/src/mlog/writer.rs b/moq-transport/src/mlog/writer.rs index 5dc29478..2ec6cf7d 100644 --- a/moq-transport/src/mlog/writer.rs +++ b/moq-transport/src/mlog/writer.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::fs::File; use std::io::{self, BufWriter, Write}; use std::path::Path; @@ -41,7 +38,6 @@ impl MlogWriter { } }); - writer.write_all(b"\x1e")?; serde_json::to_writer(&mut writer, &header)?; writer.write_all(b"\n")?; writer.flush()?; @@ -56,7 +52,6 @@ impl MlogWriter { /// Add an event to the log pub fn add_event(&mut self, event: Event) -> io::Result<()> { - self.writer.write_all(b"\x1e")?; serde_json::to_writer(&mut self.writer, &event)?; self.writer.write_all(b"\n")?; self.writer.flush()?; diff --git a/moq-transport/src/serve/broadcast.rs b/moq-transport/src/serve/broadcast.rs index af3eb8a5..caaabedc 100644 --- a/moq-transport/src/serve/broadcast.rs +++ b/moq-transport/src/serve/broadcast.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! A broadcast is a collection of tracks, split into two handles: [Writer] and [Reader]. //! //! The [Writer] can create tracks, either manually or on request. diff --git a/moq-transport/src/serve/datagram.rs b/moq-transport/src/serve/datagram.rs index 3145405c..1eb07e73 100644 --- a/moq-transport/src/serve/datagram.rs +++ b/moq-transport/src/serve/datagram.rs @@ -1,128 +1,106 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{fmt, sync::Arc}; -use crate::watch::State; +use tokio::sync::broadcast; use super::{ServeError, Track}; +const DATAGRAM_CHANNEL_SIZE: usize = 4096; + pub struct Datagrams { pub track: Arc, } impl Datagrams { pub fn produce(self) -> (DatagramsWriter, DatagramsReader) { - let (writer, reader) = State::default().split(); + let (tx, rx) = broadcast::channel(DATAGRAM_CHANNEL_SIZE); - let writer = DatagramsWriter::new(writer, self.track.clone()); - let reader = DatagramsReader::new(reader, self.track); + // Keep a reference to the sender in the reader so clones get fresh receivers + let tx_for_reader = tx.clone(); + let writer = DatagramsWriter::new(tx, self.track.clone()); + let reader = DatagramsReader::new(rx, tx_for_reader, self.track); (writer, reader) } } -struct DatagramsState { - // The latest datagram - latest: Option, - - // Increased each time datagram changes. - epoch: u64, - - // Set when the writer or all readers are dropped. - closed: Result<(), ServeError>, -} - -impl Default for DatagramsState { - fn default() -> Self { - Self { - latest: None, - epoch: 0, - closed: Ok(()), - } - } -} - pub struct DatagramsWriter { - state: State, + tx: broadcast::Sender, pub track: Arc, } impl DatagramsWriter { - fn new(state: State, track: Arc) -> Self { - Self { state, track } + fn new(tx: broadcast::Sender, track: Arc) -> Self { + Self { tx, track } } pub fn write(&mut self, datagram: Datagram) -> Result<(), ServeError> { - let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; - - state.latest = Some(datagram); - state.epoch += 1; - + // Ignore send errors (no receivers) - datagrams are fire-and-forget + let _ = self.tx.send(datagram); Ok(()) } - pub fn close(self, err: ServeError) -> Result<(), ServeError> { - let state = self.state.lock(); - state.closed.clone()?; - - let mut state = state.into_mut().ok_or(ServeError::Cancel)?; - state.closed = Err(err); - + pub fn close(self, _err: ServeError) -> Result<(), ServeError> { + // Channel closes when tx is dropped Ok(()) } } -#[derive(Clone)] pub struct DatagramsReader { - state: State, + rx: broadcast::Receiver, + tx: broadcast::Sender, pub track: Arc, + latest: Option<(u64, u64)>, +} - epoch: u64, +impl Clone for DatagramsReader { + fn clone(&self) -> Self { + // Subscribe to get a NEW receiver that will get all FUTURE datagrams + // This is correct for relay: each subscriber gets datagrams from now on + Self { + rx: self.tx.subscribe(), + tx: self.tx.clone(), + track: self.track.clone(), + latest: self.latest, + } + } } impl DatagramsReader { - fn new(state: State, track: Arc) -> Self { + fn new(rx: broadcast::Receiver, tx: broadcast::Sender, track: Arc) -> Self { Self { - state, + rx, + tx, track, - epoch: 0, + latest: None, } } pub async fn read(&mut self) -> Result, ServeError> { loop { - { - let state = self.state.lock(); - if self.epoch < state.epoch { - self.epoch = state.epoch; - return Ok(state.latest.clone()); + match self.rx.recv().await { + Ok(datagram) => { + self.latest = Some((datagram.group_id, datagram.object_id)); + return Ok(Some(datagram)); } - - state.closed.clone()?; - match state.modified() { - Some(notify) => notify, - None => return Ok(None), // No more updates will come + Err(broadcast::error::RecvError::Lagged(n)) => { + log::warn!("[DATAGRAMS] reader lagged by {} datagrams", n); + // Continue reading - we'll get the next available datagram + } + Err(broadcast::error::RecvError::Closed) => { + return Ok(None); // Channel closed } } - .await; } } - // Returns the largest group/sequence pub fn latest(&self) -> Option<(u64, u64)> { - let state = self.state.lock(); - state - .latest - .as_ref() - .map(|datagram| (datagram.group_id, datagram.object_id)) + self.latest } - /// Check if the datagrams writer has been closed or dropped. pub fn is_closed(&self) -> bool { - let state = self.state.lock(); - state.closed.is_err() || state.modified().is_none() + // Check if sender is gone (receiver_count would be 0 or send would fail) + // But we can't easily check this, so return false (conservative) + false } } @@ -136,6 +114,9 @@ pub struct Datagram { // Extension headers (for draft-14 compliance, particularly immutable extensions) pub extension_headers: crate::data::ExtensionHeaders, + + // Object status (e.g., EndOfGroup) + pub status: Option, } impl fmt::Debug for Datagram { @@ -146,6 +127,7 @@ impl fmt::Debug for Datagram { .field("priority", &self.priority) .field("payload", &self.payload.len()) .field("extension_headers", &self.extension_headers) + .field("status", &self.status) .finish() } } diff --git a/moq-transport/src/serve/error.rs b/moq-transport/src/serve/error.rs index e9366d86..57666d3a 100644 --- a/moq-transport/src/serve/error.rs +++ b/moq-transport/src/serve/error.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - #[derive(thiserror::Error, Debug, Clone, PartialEq)] pub enum ServeError { // TODO stop using? @@ -40,6 +36,10 @@ pub enum ServeError { #[error("not implemented: {0} [error:{1}]")] NotImplementedWithId(String, uuid::Uuid), + + /// Relay already has an active SUBSCRIBE path, not interested in PUBLISH + #[error("uninterested")] + Uninterested, } impl ServeError { @@ -64,6 +64,8 @@ impl ServeError { Self::NotImplemented(_) | Self::NotImplementedWithId(_, _) => 0x3, // INTERNAL_ERROR (0x0) - per-request error registries use 0x0 Self::Internal(_) | Self::InternalWithId(_, _) => 0x0, + // UNINTERESTED (0x1) - relay already has data path via SUBSCRIBE + Self::Uninterested => 0x1, } } @@ -75,7 +77,7 @@ impl ServeError { pub fn not_found_id() -> Self { let id = uuid::Uuid::new_v4(); let loc = std::panic::Location::caller(); - tracing::warn!("[{}] Not found at {}:{}", id, loc.file(), loc.line()); + log::warn!("[{}] Not found at {}:{}", id, loc.file(), loc.line()); Self::NotFoundWithId("Track not found".to_string(), id) } @@ -88,7 +90,7 @@ impl ServeError { let context = internal_context.into(); let id = uuid::Uuid::new_v4(); let loc = std::panic::Location::caller(); - tracing::warn!( + log::warn!( "[{}] Not found: {} at {}:{}", id, context, @@ -111,7 +113,7 @@ impl ServeError { let message = external_message.into(); let id = uuid::Uuid::new_v4(); let loc = std::panic::Location::caller(); - tracing::warn!( + log::warn!( "[{}] Not found: {} at {}:{}", id, context, @@ -130,7 +132,7 @@ impl ServeError { let context = internal_context.into(); let id = uuid::Uuid::new_v4(); let loc = std::panic::Location::caller(); - tracing::error!( + log::error!( "[{}] Internal error: {} at {}:{}", id, context, @@ -149,7 +151,7 @@ impl ServeError { let feature = feature.into(); let id = uuid::Uuid::new_v4(); let loc = std::panic::Location::caller(); - tracing::warn!( + log::warn!( "[{}] Not implemented: {} at {}:{}", id, feature, diff --git a/moq-transport/src/serve/mod.rs b/moq-transport/src/serve/mod.rs index d432d92a..be41c0c9 100644 --- a/moq-transport/src/serve/mod.rs +++ b/moq-transport/src/serve/mod.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - mod datagram; mod error; mod object; diff --git a/moq-transport/src/serve/object.rs b/moq-transport/src/serve/object.rs index 13971a1a..6129bf8c 100644 --- a/moq-transport/src/serve/object.rs +++ b/moq-transport/src/serve/object.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! A fragment is a stream of bytes with a header, split into a [Writer] and [Reader] handle. //! //! A [Writer] writes an ordered stream of bytes in chunks. diff --git a/moq-transport/src/serve/stream.rs b/moq-transport/src/serve/stream.rs index 126aa0a6..020e2aba 100644 --- a/moq-transport/src/serve/stream.rs +++ b/moq-transport/src/serve/stream.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use bytes::Bytes; use std::{ops::Deref, sync::Arc}; @@ -193,7 +189,6 @@ impl StreamReader { }) } - /// Check if the stream writer has been closed or dropped. pub fn is_closed(&self) -> bool { let state = self.state.lock(); state.closed.is_err() || state.modified().is_none() diff --git a/moq-transport/src/serve/subgroup.rs b/moq-transport/src/serve/subgroup.rs index daddb65d..e47048b6 100644 --- a/moq-transport/src/serve/subgroup.rs +++ b/moq-transport/src/serve/subgroup.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! A stream is a stream of objects with a header, split into a [Writer] and [Reader] handle. //! //! A [Writer] writes an ordered stream of objects. @@ -98,6 +95,7 @@ impl SubgroupsWriter { group_id, subgroup_id, priority, + header_type: None, }) } @@ -108,6 +106,7 @@ impl SubgroupsWriter { group_id: subgroup.group_id, subgroup_id: subgroup.subgroup_id, priority: subgroup.priority, + header_type: subgroup.header_type, }; let (writer, reader) = subgroup.produce(); @@ -117,8 +116,17 @@ impl SubgroupsWriter { // TODO: Check this logic again if writer.group_id.cmp(&latest.group_id) == cmp::Ordering::Equal { match writer.subgroup_id.cmp(&latest.subgroup_id) { - cmp::Ordering::Less => return Ok(writer), // dropped immediately, lul - cmp::Ordering::Equal => return Err(ServeError::Duplicate), + cmp::Ordering::Less => return Ok(writer), // dropped immediately + cmp::Ordering::Equal => { + // Duplicate subgroup - silently drop instead of erroring + // This can happen with SubgroupZeroIdEndOfGroup streams + log::warn!( + "duplicate subgroup: group_id={}, subgroup_id={} - dropping", + writer.group_id, + writer.subgroup_id + ); + return Ok(writer); // writer dropped, data lost but relay continues + } cmp::Ordering::Greater => state.latest_subgroup_reader = Some(reader), } } else if writer.group_id.cmp(&latest.group_id) == cmp::Ordering::Greater { @@ -231,6 +239,9 @@ pub struct Subgroup { // The priority of the group within the track. pub priority: u8, + + // The stream header type used for this subgroup (preserved from incoming stream) + pub header_type: Option, } /// Static information about the group @@ -248,6 +259,9 @@ pub struct SubgroupInfo { // The priority of the group within the track. pub priority: u8, + + // The stream header type used for this subgroup (preserved from incoming stream) + pub header_type: Option, } impl SubgroupInfo { @@ -322,11 +336,21 @@ impl SubgroupWriter { &mut self, size: usize, extension_headers: Option, + ) -> Result { + self.create_with_status(size, extension_headers, ObjectStatus::NormalObject) + } + + /// Write an object with a specific status (e.g., EndOfGroup). + pub fn create_with_status( + &mut self, + size: usize, + extension_headers: Option, + status: ObjectStatus, ) -> Result { let (writer, reader) = SubgroupObject { group: self.info.clone(), object_id: self.next_object_id, - status: ObjectStatus::NormalObject, + status, size, extension_headers: extension_headers.unwrap_or_default(), } @@ -340,6 +364,16 @@ impl SubgroupWriter { Ok(writer) } + /// Write an EndOfGroup marker object to signal the end of this subgroup. + /// This should be called when the group is complete. + pub fn end_of_group(&mut self) -> Result<(), ServeError> { + // Create an object with size=0 and status=EndOfGroup + let object_writer = self.create_with_status(0, None, ObjectStatus::EndOfGroup)?; + // Object writer with size=0 will complete immediately when dropped + drop(object_writer); + Ok(()) + } + /// Close the stream with an error. pub fn close(self, err: ServeError) -> Result<(), ServeError> { let state = self.state.lock(); diff --git a/moq-transport/src/serve/track.rs b/moq-transport/src/serve/track.rs index 3cb2edaf..01591ca8 100644 --- a/moq-transport/src/serve/track.rs +++ b/moq-transport/src/serve/track.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! A track is a collection of semi-reliable and semi-ordered streams, split into a [Writer] and [Reader] handle. //! //! A [Writer] creates streams with a sequence number and priority. @@ -89,14 +85,7 @@ impl TrackWriter { .produce(); // Lock state to modify it - let mut state = self.state.lock_mut().ok_or_else(|| { - tracing::debug!( - namespace = %self.info.namespace.to_utf8_path(), - track = %self.info.name, - "track state dropped (Cancel) in stream()" - ); - ServeError::Cancel - })?; + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; // Set the Stream mode to TrackReaderMode::Stream state.reader_mode = Some(reader.into()); @@ -112,14 +101,7 @@ impl TrackWriter { .produce(); // Lock state to modify it - let mut state = self.state.lock_mut().ok_or_else(|| { - tracing::debug!( - namespace = %self.info.namespace.to_utf8_path(), - track = %self.info.name, - "track state dropped (Cancel) in subgroups()" - ); - ServeError::Cancel - })?; + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; // Set the Stream mode to TrackReaderMode::Subgroups state.reader_mode = Some(reader.into()); @@ -133,14 +115,7 @@ impl TrackWriter { .produce(); // Lock state to modify it - let mut state = self.state.lock_mut().ok_or_else(|| { - tracing::debug!( - namespace = %self.info.namespace.to_utf8_path(), - track = %self.info.name, - "track state dropped (Cancel) in datagrams()" - ); - ServeError::Cancel - })?; + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; // Set the Stream mode to TrackReaderMode::Datagrams state.reader_mode = Some(reader.into()); @@ -149,23 +124,10 @@ impl TrackWriter { /// Close the track with an error. pub fn close(self, err: ServeError) -> Result<(), ServeError> { - tracing::debug!( - namespace = %self.info.namespace.to_utf8_path(), - track = %self.info.name, - error = %err, - "track closing" - ); let state = self.state.lock(); state.closed.clone()?; - let mut state = state.into_mut().ok_or_else(|| { - tracing::debug!( - namespace = %self.info.namespace.to_utf8_path(), - track = %self.info.name, - "track state already dropped during close" - ); - ServeError::Cancel - })?; + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; state.closed = Err(err); Ok(()) } @@ -359,9 +321,7 @@ mod tests { let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); let (writer, reader) = track.produce(); - let _subgroups_writer = writer - .subgroups() - .expect("subgroups transition should succeed"); + let _subgroups_writer = writer.subgroups().expect("subgroups transition should succeed"); assert!( !reader.is_closed(), @@ -374,9 +334,7 @@ mod tests { let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); let (writer, reader) = track.produce(); - let subgroups_writer = writer - .subgroups() - .expect("subgroups transition should succeed"); + let subgroups_writer = writer.subgroups().expect("subgroups transition should succeed"); drop(subgroups_writer); assert!( @@ -390,9 +348,7 @@ mod tests { let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); let (writer, reader) = track.produce(); - let subgroups_writer = writer - .subgroups() - .expect("subgroups transition should succeed"); + let subgroups_writer = writer.subgroups().expect("subgroups transition should succeed"); subgroups_writer.close(ServeError::Cancel).unwrap(); assert!( @@ -433,9 +389,7 @@ mod tests { let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); let (writer, reader) = track.produce(); - let _datagrams_writer = writer - .datagrams() - .expect("datagrams transition should succeed"); + let _datagrams_writer = writer.datagrams().expect("datagrams transition should succeed"); assert!( !reader.is_closed(), @@ -448,9 +402,7 @@ mod tests { let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); let (writer, reader) = track.produce(); - let datagrams_writer = writer - .datagrams() - .expect("datagrams transition should succeed"); + let datagrams_writer = writer.datagrams().expect("datagrams transition should succeed"); drop(datagrams_writer); assert!( @@ -464,15 +416,15 @@ mod tests { let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); let (writer, reader) = track.produce(); - let mut subgroups_writer = writer - .subgroups() - .expect("subgroups transition should succeed"); + let mut subgroups_writer = + writer.subgroups().expect("subgroups transition should succeed"); let _subgroup_writer = subgroups_writer .create(Subgroup { group_id: 0, subgroup_id: 0, priority: 0, + header_type: None, }) .expect("create subgroup should succeed"); diff --git a/moq-transport/src/serve/tracks.rs b/moq-transport/src/serve/tracks.rs index 674e2690..cea6e954 100644 --- a/moq-transport/src/serve/tracks.rs +++ b/moq-transport/src/serve/tracks.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! A broadcast is a collection of tracks, split into two handles: [Writer] and [Reader]. //! //! The [Writer] can create tracks, either manually or on request. @@ -95,6 +91,21 @@ impl TracksWriter { }; self.state.lock_mut()?.tracks.remove(&full_name) } + + /// Insert an existing track reader into the broadcast. + /// Returns None if all readers have been dropped or if a track with this name already exists. + pub fn insert(&mut self, reader: TrackReader) -> Option<()> { + let full_name = FullTrackName { + namespace: reader.namespace.clone(), + name: reader.name.clone(), + }; + let mut state = self.state.lock_mut()?; + if state.tracks.contains_key(&full_name) { + return None; + } + state.tracks.insert(full_name, reader); + Some(()) + } } impl Deref for TracksWriter { @@ -139,16 +150,7 @@ impl Deref for TracksRequest { impl Drop for TracksRequest { fn drop(&mut self) { // Close any tracks still in the Queue - let pending_tracks = self.incoming.take().unwrap().close(); - if !pending_tracks.is_empty() { - tracing::debug!( - target: "moq_transport::tracks", - namespace = %self.info.namespace.to_utf8_path(), - count = pending_tracks.len(), - "TracksRequest dropped with pending track requests" - ); - } - for track in pending_tracks { + for track in self.incoming.take().unwrap().close() { let _ = track.close(ServeError::not_found_ctx( "tracks request dropped before track handled", )); @@ -211,21 +213,9 @@ impl TracksReader { if let Some(track_reader) = state.tracks.get(&full_name) { if !track_reader.is_closed() { // Track is still active, return the cached reader - tracing::debug!( - target: "moq_transport::tracks", - namespace = %namespace.to_utf8_path(), - track = %track_name, - "track cache hit (active)" - ); return Some(track_reader.clone()); } // Track is closed/stale, fall through to create a new one - tracing::debug!( - target: "moq_transport::tracks", - namespace = %namespace.to_utf8_path(), - track = %track_name, - "track cache hit but stale, will evict and re-request" - ); } let mut state = state.into_mut()?; @@ -240,12 +230,6 @@ impl TracksReader { .produce(); if self.queue.push(track_writer_reader.0).is_err() { - tracing::debug!( - target: "moq_transport::tracks", - namespace = %namespace.to_utf8_path(), - track = %track_name, - "track request queue closed" - ); return None; } @@ -254,15 +238,15 @@ impl TracksReader { .tracks .insert(full_name, track_writer_reader.1.clone()); - tracing::debug!( - target: "moq_transport::tracks", - namespace = %namespace.to_utf8_path(), - track = %track_name, - "track cache miss, requested from upstream" - ); - Some(track_writer_reader.1) } + + /// Forward an existing track writer to the upstream subscription queue. + /// The writer will be received by [TracksRequest::next()]. + /// Returns None if the queue is closed. + pub fn forward_upstream(&mut self, writer: TrackWriter) -> Option<()> { + self.queue.push(writer).ok() + } } impl Deref for TracksReader { @@ -361,53 +345,41 @@ mod tests { ); } - /// Test that normal track caching works correctly when tracks are still alive. - /// - /// Multiple subscribers to the same track should share the same TrackReader - /// (deduplication), and the publisher should only receive one request. #[tokio::test] - async fn test_track_deduplication_while_alive() { + async fn test_track_not_stale_after_subgroups_transition() { let namespace = TrackNamespace::from_utf8_path("test/namespace"); let track_name = "test-track"; let (_writer, mut request, mut reader) = Tracks::new(namespace.clone()).produce(); - // First subscription - let track_reader_1 = reader + let _track_reader_1 = reader .subscribe(namespace.clone(), track_name) .expect("first subscribe should succeed"); - // Publisher receives request - let _track_writer = request + let track_writer = request .next() .await .expect("publisher should receive track request"); - // Second subscription to the SAME track (while it's still alive) - let track_reader_2 = reader + let _subgroups_writer = track_writer + .subgroups() + .expect("subgroups transition should succeed"); + + let _track_reader_2 = reader .subscribe(namespace.clone(), track_name) .expect("second subscribe should succeed"); - // Publisher should NOT receive another request (track is cached and alive) let maybe_second_request = tokio::time::timeout(std::time::Duration::from_millis(100), request.next()).await; assert!( maybe_second_request.is_err(), - "Publisher should NOT receive a second request - track is cached and alive" + "publisher should NOT get a second request while SubgroupsWriter is alive" ); - - // Both readers should refer to the same track - assert_eq!(track_reader_1.name, track_reader_2.name); - assert_eq!(track_reader_1.namespace, track_reader_2.namespace); } - /// Test that a track is NOT considered stale after the writer transitions to - /// subgroups mode. This is the core regression: TrackWriter::subgroups() - /// consumes self, dropping the Track-level State, but the SubgroupsWriter - /// is still alive — so is_closed() must return false. #[tokio::test] - async fn test_track_not_stale_after_subgroups_transition() { + async fn test_track_stale_after_subgroups_writer_dropped() { let namespace = TrackNamespace::from_utf8_path("test/namespace"); let track_name = "test-track"; @@ -422,9 +394,10 @@ mod tests { .await .expect("publisher should receive track request"); - let _subgroups_writer = track_writer + let subgroups_writer = track_writer .subgroups() .expect("subgroups transition should succeed"); + drop(subgroups_writer); let _track_reader_2 = reader .subscribe(namespace.clone(), track_name) @@ -434,48 +407,49 @@ mod tests { tokio::time::timeout(std::time::Duration::from_millis(100), request.next()).await; assert!( - maybe_second_request.is_err(), - "publisher should NOT get a second request while SubgroupsWriter is alive" + maybe_second_request.is_ok(), + "publisher should get a new request after SubgroupsWriter is dropped" ); } - /// Test that a track IS considered stale after the SubgroupsWriter is dropped. - /// This preserves the RT-458 eviction behavior for dead publishers. + /// Test that normal track caching works correctly when tracks are still alive. + /// + /// Multiple subscribers to the same track should share the same TrackReader + /// (deduplication), and the publisher should only receive one request. #[tokio::test] - async fn test_track_stale_after_subgroups_writer_dropped() { + async fn test_track_deduplication_while_alive() { let namespace = TrackNamespace::from_utf8_path("test/namespace"); let track_name = "test-track"; let (_writer, mut request, mut reader) = Tracks::new(namespace.clone()).produce(); - let _track_reader_1 = reader + // First subscription + let track_reader_1 = reader .subscribe(namespace.clone(), track_name) .expect("first subscribe should succeed"); - let track_writer = request + // Publisher receives request + let _track_writer = request .next() .await .expect("publisher should receive track request"); - let subgroups_writer = track_writer - .subgroups() - .expect("subgroups transition should succeed"); - drop(subgroups_writer); - - let _track_reader_2 = reader + // Second subscription to the SAME track (while it's still alive) + let track_reader_2 = reader .subscribe(namespace.clone(), track_name) .expect("second subscribe should succeed"); + // Publisher should NOT receive another request (track is cached and alive) let maybe_second_request = tokio::time::timeout(std::time::Duration::from_millis(100), request.next()).await; assert!( - maybe_second_request.is_ok(), - "publisher should get a new request after SubgroupsWriter is dropped" + maybe_second_request.is_err(), + "Publisher should NOT receive a second request - track is cached and alive" ); - let _second_request = maybe_second_request - .unwrap() - .expect("publisher should receive second track request"); + // Both readers should refer to the same track + assert_eq!(track_reader_1.name, track_reader_2.name); + assert_eq!(track_reader_1.namespace, track_reader_2.namespace); } } diff --git a/moq-transport/src/session/announce.rs b/moq-transport/src/session/announce.rs deleted file mode 100644 index 8e2a0a90..00000000 --- a/moq-transport/src/session/announce.rs +++ /dev/null @@ -1,231 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use std::{collections::VecDeque, ops}; - -use crate::coding::TrackNamespace; -use crate::watch::State; -use crate::{message, serve::ServeError}; - -use super::{Publisher, Subscribed, TrackStatusRequested}; - -#[derive(Debug, Clone)] -pub struct AnnounceInfo { - pub request_id: u64, - pub namespace: TrackNamespace, -} - -struct AnnounceState { - subscribers: VecDeque, - track_statuses_requested: VecDeque, - ok: bool, - closed: Result<(), ServeError>, -} - -impl Default for AnnounceState { - fn default() -> Self { - Self { - subscribers: Default::default(), - track_statuses_requested: Default::default(), - ok: false, - closed: Ok(()), - } - } -} - -impl Drop for AnnounceState { - fn drop(&mut self) { - for subscriber in self.subscribers.drain(..) { - subscriber - .close(ServeError::not_found_ctx( - "announce dropped before subscription handled", - )) - .ok(); - } - } -} - -#[must_use = "unannounce on drop"] -pub struct Announce { - publisher: Publisher, - state: State, - - pub info: AnnounceInfo, -} - -impl Announce { - pub(super) fn new( - mut publisher: Publisher, - request_id: u64, - namespace: TrackNamespace, - ) -> (Announce, AnnounceRecv) { - let info = AnnounceInfo { - request_id, - namespace: namespace.clone(), - }; - - publisher.send_message(message::PublishNamespace { - id: request_id, - track_namespace: namespace.clone(), - params: Default::default(), - }); - - let (send, recv) = State::default().split(); - - let send = Self { - publisher, - info, - state: send, - }; - let recv = AnnounceRecv { - state: recv, - request_id, - }; - - (send, recv) - } - - // Run until we get an error - pub async fn closed(&self) -> Result<(), ServeError> { - loop { - { - let state = self.state.lock(); - state.closed.clone()?; - - match state.modified() { - Some(notified) => notified, - None => return Ok(()), - } - } - .await; - } - } - - /// Wait until a subscriber is received - pub async fn subscribed(&self) -> Result, ServeError> { - loop { - { - let state = self.state.lock(); - if !state.subscribers.is_empty() { - return Ok(state - .into_mut() - .and_then(|mut state| state.subscribers.pop_front())); - } - - state.closed.clone()?; - match state.modified() { - Some(notified) => notified, - None => return Ok(None), - } - } - .await; - } - } - - pub async fn track_status_requested(&self) -> Result, ServeError> { - loop { - { - let state = self.state.lock(); - if !state.track_statuses_requested.is_empty() { - return Ok(state - .into_mut() - .and_then(|mut state| state.track_statuses_requested.pop_front())); - } - - state.closed.clone()?; - match state.modified() { - Some(notified) => notified, - None => return Ok(None), - } - } - .await; - } - } - - // Wait until an OK is received - pub async fn ok(&self) -> Result<(), ServeError> { - loop { - { - let state = self.state.lock(); - if state.ok { - return Ok(()); - } - state.closed.clone()?; - - match state.modified() { - Some(notified) => notified, - None => return Ok(()), - } - } - .await; - } - } -} - -impl Drop for Announce { - fn drop(&mut self) { - if self.state.lock().closed.is_err() { - return; - } - - self.publisher.send_message(message::PublishNamespaceDone { - track_namespace: self.namespace.clone(), - }); - } -} - -impl ops::Deref for Announce { - type Target = AnnounceInfo; - - fn deref(&self) -> &Self::Target { - &self.info - } -} - -pub(super) struct AnnounceRecv { - state: State, - pub request_id: u64, // TODO SLG - Announcements need to be looked up by both request_id and namespace, consider 2 hashmaps in publisher instead of this -} - -impl AnnounceRecv { - pub fn recv_ok(&mut self) -> Result<(), ServeError> { - if let Some(mut state) = self.state.lock_mut() { - if state.ok { - return Err(ServeError::Duplicate); - } - - state.ok = true; - } - - Ok(()) - } - - pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { - let state = self.state.lock(); - state.closed.clone()?; - - let mut state = state.into_mut().ok_or(ServeError::Done)?; - state.closed = Err(err); - - Ok(()) - } - - pub fn recv_subscribe(&mut self, subscriber: Subscribed) -> Result<(), ServeError> { - let mut state = self.state.lock_mut().ok_or(ServeError::Done)?; - state.subscribers.push_back(subscriber); - - Ok(()) - } - - pub fn recv_track_status_requested( - &mut self, - track_status_requested: TrackStatusRequested, - ) -> Result<(), ServeError> { - let mut state = self.state.lock_mut().ok_or(ServeError::Done)?; - state - .track_statuses_requested - .push_back(track_status_requested); - Ok(()) - } -} diff --git a/moq-transport/src/session/error.rs b/moq-transport/src/session/error.rs index 86478ee6..2e5ceacb 100644 --- a/moq-transport/src/session/error.rs +++ b/moq-transport/src/session/error.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::{coding, serve, setup}; #[derive(thiserror::Error, Debug, Clone)] @@ -40,9 +36,6 @@ pub enum SessionError { #[error("wrong size")] WrongSize, - - #[error("invalid connection path: {0}")] - InvalidPath(String), } // Session Termination Error Codes from draft-ietf-moq-transport-14 Section 13.1.1 @@ -63,7 +56,6 @@ impl SessionError { // PROTOCOL_VIOLATION (0x3) - Malformed messages Self::Decode(_) => 0x3, Self::WrongSize => 0x3, - Self::InvalidPath(_) => 0x3, // DUPLICATE_TRACK_ALIAS (0x5) Self::Duplicate => 0x5, // Delegate to ServeError for per-request error codes @@ -76,54 +68,6 @@ impl SessionError { pub fn unimplemented(feature: &str) -> Self { Self::Serve(serve::ServeError::not_implemented_ctx(feature)) } - - /// Returns true if this error represents a graceful connection close. - /// - /// For WebTransport, a graceful close is a `CLOSE_WEBTRANSPORT_SESSION` capsule - /// with code 0. For raw QUIC, it's `APPLICATION_CLOSE` with code 0 (NO_ERROR). - /// Both are normal session termination, not error conditions. - /// - /// This method checks for: - /// - WebTransport `Closed(0, _)` — web-transport-quinn v0.11+ typically converts - /// HTTP/3-encoded `ApplicationClosed` codes into `WebTransportError::Closed(code, reason)` - /// during `SessionError` conversion when decoding via `error_from_http3` succeeds - /// - Raw QUIC `ApplicationClosed` with code 0 - /// - The local side closing the connection (`LocallyClosed`) - /// - /// ## Implementation Notes - /// - /// We pattern match on `web_transport_quinn::SessionError` variants. In v0.11+, - /// WebTransport graceful closes arrive as `WebTransportError::Closed(0, _)` because - /// the crate decodes HTTP/3 error codes at the `SessionError` level. For raw QUIC - /// connections, the close code is checked directly on `ConnectionError::ApplicationClosed`. - /// - /// **Coupling note**: This implementation is coupled to `web-transport-quinn` and - /// `quinn`. When transitioning to a different WebTransport backend (e.g., tokio-quiche), - /// ensure the replacement provides equivalent error introspection, or update this - /// method to handle the new error types. - pub fn is_graceful_close(&self) -> bool { - match self { - Self::WebTransport(wt_err) => match wt_err { - web_transport::Error::Session(session_err) => { - is_session_error_graceful(session_err) - } - web_transport::Error::Read(read_err) => { - if let web_transport::quinn::ReadError::SessionError(session_err) = read_err { - return is_session_error_graceful(session_err); - } - false - } - web_transport::Error::Write(write_err) => { - if let web_transport::quinn::WriteError::SessionError(session_err) = write_err { - return is_session_error_graceful(session_err); - } - false - } - _ => false, - }, - _ => false, - } - } } impl From for serve::ServeError { @@ -134,60 +78,3 @@ impl From for serve::ServeError { } } } - -/// Helper to check if a `web_transport_quinn::SessionError` represents a graceful close. -/// -/// This handles: -/// - WebTransport connections: `WebTransportError::Closed(0, _)` — web-transport-quinn v0.11+ -/// typically decodes HTTP/3-encoded close codes at this layer (when `SessionError` conversion -/// applies), so graceful closes usually arrive here rather than as a raw -/// `ConnectionError::ApplicationClosed`. -/// - Raw QUIC connections: `ConnectionError::ApplicationClosed` with code 0 -/// - Local close: `ConnectionError::LocallyClosed` -fn is_session_error_graceful(err: &web_transport::quinn::SessionError) -> bool { - use web_transport::quinn::{SessionError, WebTransportError}; - - match err { - SessionError::ConnectionError(conn_err) => is_connection_error_graceful(conn_err), - // WebTransport graceful close: peer sent close with code 0 - SessionError::WebTransportError(WebTransportError::Closed(0, _)) => true, - // Other WebTransport errors (UnknownSession, read/write errors, non-zero close codes) - SessionError::WebTransportError(_) => false, - // SendDatagramError doesn't represent connection close - SessionError::SendDatagramError(_) => false, - } -} - -/// Helper to check if a `quinn::ConnectionError` represents a graceful close. -/// -/// Note: In web-transport-quinn v0.11+, WebTransport `ApplicationClosed` with an HTTP/3-encoded -/// close code is usually converted to `WebTransportError::Closed` during `SessionError` conversion -/// when decoding succeeds. This function primarily handles raw QUIC (moqt:// ALPN) connections -/// or non-decodable cases where the close code is not HTTP/3 encoded. -fn is_connection_error_graceful(err: &web_transport::quinn::quinn::ConnectionError) -> bool { - use web_transport::quinn::quinn::ConnectionError; - - match err { - ConnectionError::ApplicationClosed(close) => { - let code = close.error_code.into_inner(); - - // Check for raw QUIC code 0 (direct MoQ-over-QUIC) - if code == 0 { - return true; - } - - // Check for WebTransport code 0 (HTTP/3 encoded) - // This is a fallback — in v0.11+, WebTransport closes are typically caught - // by is_session_error_graceful's WebTransportError::Closed branch. - if let Some(wt_code) = web_transport::quinn::proto::error_from_http3(code) { - return wt_code == 0; - } - - false - } - // LocallyClosed means we closed the connection ourselves - ConnectionError::LocallyClosed => true, - // Other errors are not graceful closes - _ => false, - } -} diff --git a/moq-transport/src/session/mod.rs b/moq-transport/src/session/mod.rs index a2313636..b984ca94 100644 --- a/moq-transport/src/session/mod.rs +++ b/moq-transport/src/session/mod.rs @@ -1,23 +1,27 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -mod announce; -mod announced; mod error; +mod publish_namespace; +mod publish_namespace_received; +mod publish_received; +mod published; mod publisher; mod reader; mod subscribe; +mod subscribe_namespace; +mod subscribe_namespace_received; mod subscribed; mod subscriber; mod track_status_requested; mod writer; -pub use announce::*; -pub use announced::*; pub use error::*; +pub use publish_namespace::*; +pub use publish_namespace_received::*; +pub use publish_received::*; +pub use published::*; pub use publisher::*; pub use subscribe::*; +pub use subscribe_namespace::*; +pub use subscribe_namespace_received::*; pub use subscribed::*; pub use subscriber::*; pub use track_status_requested::*; @@ -28,32 +32,13 @@ use writer::*; use futures::{stream::FuturesUnordered, StreamExt}; use std::sync::{atomic, Arc, Mutex}; -use crate::coding::{KeyValuePairs, Value}; +use crate::coding::KeyValuePairs; use crate::message::Message; use crate::mlog; use crate::watch::Queue; use crate::{message, setup}; use std::path::PathBuf; -/// The transport protocol negotiated for this MoQT connection. -/// -/// MoQT can run over either WebTransport (HTTP/3 + QUIC) or raw QUIC. -/// The transport type affects protocol behavior — for example, the PATH -/// parameter is only sent in CLIENT_SETUP for raw QUIC connections, -/// since WebTransport carries the path in the HTTP/3 CONNECT URL. -/// -/// This enum is intentionally extensible for future transport options -/// (e.g., QMUX, WebSocket fallback). -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Transport { - /// WebTransport over HTTP/3 (RFC 9220). - /// ALPN: "h3". Path carried in HTTP/3 CONNECT :path pseudo-header. - WebTransport, - /// Raw QUIC with MoQT framing directly on QUIC streams. - /// ALPN: "moq-00". Path carried in CLIENT_SETUP PATH parameter. - RawQuic, -} - /// Session object for managing all communications in a single QUIC connection. #[must_use = "run() must be called"] pub struct Session { @@ -72,426 +57,15 @@ pub struct Session { /// Optional mlog writer for MoQ Transport events /// Wrapped in Arc> to share across send/recv tasks when enabled mlog: Option>>, - - /// The transport protocol negotiated for this connection. - transport: Transport, - - /// The connection path, derived from the WebTransport URL path or CLIENT_SETUP PATH parameter. - /// For incoming connections: extracted during accept() from the WebTransport CONNECT URL - /// (takes precedence) or the CLIENT_SETUP PATH parameter (key 0x1). - /// For outgoing connections: auto-extracted from the session URL in connect(). - connection_path: Option, } impl Session { - const MAX_CONNECTION_PATH_LEN: usize = 1024; - - /// Normalize and validate a connection path. - /// - /// Returns `Ok(None)` for empty or root-only paths. Returns `Err` for - /// paths that are too long, don't start with `/`, contain empty, - /// dot, or percent-encoded segments, or are otherwise malformed. - /// - /// Percent-encoded characters are rejected rather than decoded because - /// scope identity must be unambiguous: `/foo%2Fbar` and `/foo/bar` - /// must not silently map to different scopes, and `%2E%2E` must not - /// bypass the dot-segment check. - /// - /// This is used internally by `accept()` and `connect()`, but is also - /// available for callers that need to validate paths from other sources - /// (e.g., announce URLs used for forward connections). - pub fn normalize_connection_path(raw: &str) -> Result, SessionError> { - if raw.is_empty() || raw == "/" { - return Ok(None); - } - - if raw.len() > Self::MAX_CONNECTION_PATH_LEN { - return Err(SessionError::InvalidPath("path too long".to_string())); - } - - if !raw.starts_with('/') { - return Err(SessionError::InvalidPath( - "path must start with '/'".to_string(), - )); - } - - let trimmed = raw.trim_end_matches('/'); - if trimmed.is_empty() { - return Ok(None); - } - - let mut segments = trimmed.split('/'); - let _ = segments.next(); - for segment in segments { - if segment.is_empty() { - return Err(SessionError::InvalidPath( - "path contains empty segment".to_string(), - )); - } - if segment.contains('%') { - return Err(SessionError::InvalidPath( - "path must not contain percent-encoded characters".to_string(), - )); - } - if segment == "." || segment == ".." { - return Err(SessionError::InvalidPath( - "path contains invalid segment".to_string(), - )); - } - } - - Ok(Some(trimmed.to_string())) - } - - fn decode_client_setup_path(params: &KeyValuePairs) -> Result, SessionError> { - let Some(kvp) = params.get(setup::ParameterType::Path.into()) else { - return Ok(None); - }; - - let bytes = match &kvp.value { - Value::BytesValue(bytes) => bytes, - _ => { - return Err(SessionError::InvalidPath( - "PATH parameter must be bytes-encoded".to_string(), - )) - } - }; - - if bytes.len() > Self::MAX_CONNECTION_PATH_LEN { - return Err(SessionError::InvalidPath("path too long".to_string())); - } - - let path = std::str::from_utf8(bytes) - .map_err(|_| SessionError::InvalidPath("path must be UTF-8".to_string()))?; - - Self::normalize_connection_path(path) - } - - /// Returns the negotiated transport protocol for this connection. - pub fn transport(&self) -> Transport { - self.transport - } - - /// Returns the connection path, if one was present on the incoming connection. - /// - /// For server-side sessions (created via `accept()`), this is derived from: - /// 1. The WebTransport CONNECT URL path (takes precedence), or - /// 2. The CLIENT_SETUP PATH parameter (key 0x1), used for raw QUIC connections. - /// - /// Returns `None` if no path was present or if the path was just "/". - pub fn connection_path(&self) -> Option<&str> { - self.connection_path.as_deref() - } - - // Helper for determining the largest supported version - fn largest_common(a: &[T], b: &[T]) -> Option { - a.iter() - .filter(|x| b.contains(x)) // keep only items also in b - .cloned() // clone because we return T, not &T - .max() // take the largest - } - - /// Log a control message with structured fields for observability. - /// Uses target "moq_transport::control" so it can be filtered independently. - fn log_control_message(msg: &Message, direction: &str) { - match msg { - Message::Subscribe(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE", - subscribe_id = m.id, - namespace = %m.track_namespace, - track_name = %m.track_name, - filter_type = ?m.filter_type, - "MoQT control message" - ); - } - Message::SubscribeOk(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE_OK", - subscribe_id = m.id, - track_alias = m.track_alias, - content_exists = m.content_exists, - "MoQT control message" - ); - } - Message::SubscribeError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE_ERROR", - subscribe_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } - Message::SubscribeUpdate(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE_UPDATE", - request_id = m.id, - subscription_request_id = m.subscription_request_id, - "MoQT control message" - ); - } - Message::Unsubscribe(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "UNSUBSCRIBE", - subscribe_id = m.id, - "MoQT control message" - ); - } - Message::PublishNamespace(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "PUBLISH_NAMESPACE", - request_id = m.id, - namespace = %m.track_namespace, - "MoQT control message" - ); - } - Message::PublishNamespaceOk(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "PUBLISH_NAMESPACE_OK", - request_id = m.id, - "MoQT control message" - ); - } - Message::PublishNamespaceError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "PUBLISH_NAMESPACE_ERROR", - request_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } - Message::PublishNamespaceDone(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "PUBLISH_NAMESPACE_DONE", - namespace = %m.track_namespace, - "MoQT control message" - ); - } - Message::PublishNamespaceCancel(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "PUBLISH_NAMESPACE_CANCEL", - namespace = %m.track_namespace, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } - Message::TrackStatus(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "TRACK_STATUS", - request_id = m.id, - namespace = %m.track_namespace, - track_name = %m.track_name, - "MoQT control message" - ); - } - Message::TrackStatusOk(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "TRACK_STATUS_OK", - request_id = m.id, - track_alias = m.track_alias, - content_exists = m.content_exists, - "MoQT control message" - ); - } - Message::TrackStatusError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "TRACK_STATUS_ERROR", - request_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } - Message::SubscribeNamespace(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE_NAMESPACE", - request_id = m.id, - namespace_prefix = %m.track_namespace_prefix, - "MoQT control message" - ); - } - Message::SubscribeNamespaceOk(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE_NAMESPACE_OK", - request_id = m.id, - "MoQT control message" - ); - } - Message::SubscribeNamespaceError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE_NAMESPACE_ERROR", - request_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } - Message::UnsubscribeNamespace(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "UNSUBSCRIBE_NAMESPACE", - namespace_prefix = %m.track_namespace_prefix, - "MoQT control message" - ); - } - Message::Fetch(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "FETCH", - request_id = m.id, - fetch_type = ?m.fetch_type, - "MoQT control message" - ); - } - Message::FetchOk(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "FETCH_OK", - request_id = m.id, - end_of_track = m.end_of_track, - "MoQT control message" - ); - } - Message::FetchError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "FETCH_ERROR", - request_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } - Message::FetchCancel(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "FETCH_CANCEL", - request_id = m.id, - "MoQT control message" - ); - } - Message::Publish(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "PUBLISH", - request_id = m.id, - namespace = %m.track_namespace, - track_name = %m.track_name, - track_alias = m.track_alias, - "MoQT control message" - ); - } - Message::PublishOk(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "PUBLISH_OK", - request_id = m.id, - filter_type = ?m.filter_type, - "MoQT control message" - ); - } - Message::PublishError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "PUBLISH_ERROR", - request_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } - Message::PublishDone(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "PUBLISH_DONE", - request_id = m.id, - status_code = m.status_code, - stream_count = m.stream_count, - "MoQT control message" - ); - } - Message::GoAway(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "GOAWAY", - uri = %m.uri.0, - "MoQT control message" - ); - } - Message::MaxRequestId(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "MAX_REQUEST_ID", - request_id = m.request_id, - "MoQT control message" - ); - } - Message::RequestsBlocked(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "REQUESTS_BLOCKED", - max_request_id = m.max_request_id, - "MoQT control message" - ); - } - } - } - fn new( webtransport: web_transport::Session, sender: Writer, recver: Reader, first_requestid: u64, mlog: Option, - transport: Transport, - connection_path: Option, ) -> (Self, Option, Option) { let next_requestid = Arc::new(atomic::AtomicU64::new(first_requestid)); let outgoing = Queue::default().split(); @@ -519,8 +93,6 @@ impl Session { subscriber: subscriber.clone(), outgoing: outgoing.1, mlog: mlog_shared, - transport, - connection_path, }; (session, publisher, subscriber) @@ -528,76 +100,36 @@ impl Session { /// Create an outbound/client QUIC connection, by opening a bi-directional QUIC stream for /// MOQT control messaging. Performs SETUP messaging and version negotiation. - /// - /// If the session URL contains a non-trivial path (not empty or "/"), the PATH - /// parameter (key 0x1) is automatically sent in CLIENT_SETUP. This propagates - /// the connection path (App ID / MoQT scope) to the remote peer, which is needed - /// for relay-to-relay connections. To connect without sending PATH, use a URL - /// with no path component. pub async fn connect( session: web_transport::Session, mlog_path: Option, - transport: Transport, ) -> Result<(Session, Publisher, Subscriber), SessionError> { - // Auto-extract path from the session URL. - // This aligns with the unified moqt:// URI scheme direction (IETF PR #1486) - // where the path is always part of the URI regardless of transport. - let url_path = session.url().path(); - let path = Self::normalize_connection_path(url_path)?; let mlog = mlog_path.and_then(|path| { mlog::MlogWriter::new(path) - .map_err(|e| tracing::warn!("Failed to create mlog: {}", e)) + .map_err(|e| log::warn!("Failed to create mlog: {}", e)) .ok() }); let control = session.open_bi().await?; let mut sender = Writer::new(control.0); let mut recver = Reader::new(control.1); - let versions: setup::Versions = [setup::Version::DRAFT_14].into(); - - // TODO SLG - make configurable? let mut params = KeyValuePairs::default(); params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); - // Only send PATH in CLIENT_SETUP for raw QUIC connections. - // For WebTransport, the path is already carried in the HTTP/3 CONNECT URL. - if let Some(ref path) = path { - if transport == Transport::RawQuic { - params.set_bytesvalue(setup::ParameterType::Path.into(), path.as_bytes().to_vec()); - } - } + let client = setup::Client { params }; - let client = setup::Client { - versions: versions.clone(), - params, - }; - - tracing::debug!( - target: "moq_transport::control", - direction = "sent", - msg_type = "CLIENT_SETUP", - versions = ?client.versions, - ?transport, - path = path.as_deref(), - "MoQT control message" - ); + log::debug!("sending CLIENT_SETUP: {:?}", client); sender.encode(&client).await?; // TODO: emit client_setup_created event when we add that let server: setup::Server = recver.decode().await?; - tracing::debug!( - target: "moq_transport::control", - direction = "recv", - msg_type = "SERVER_SETUP", - version = ?server.version, - "MoQT control message" - ); + log::debug!("received SERVER_SETUP: {:?}", server); // TODO: emit server_setup_parsed event // We are the client, so the first request id is 0 - let session = Session::new(session, sender, recver, 0, mlog, transport, path); + let session = Session::new(session, sender, recver, 0, mlog); Ok((session.0, session.1.unwrap(), session.2.unwrap())) } @@ -606,11 +138,10 @@ impl Session { pub async fn accept( session: web_transport::Session, mlog_path: Option, - transport: Transport, ) -> Result<(Session, Option, Option), SessionError> { let mut mlog = mlog_path.and_then(|path| { mlog::MlogWriter::new(path) - .map_err(|e| tracing::warn!("Failed to create mlog: {}", e)) + .map_err(|e| log::warn!("Failed to create mlog: {}", e)) .ok() }); let control = session.accept_bi().await?; @@ -618,39 +149,7 @@ impl Session { let mut recver = Reader::new(control.1); let client: setup::Client = recver.decode().await?; - tracing::debug!( - target: "moq_transport::control", - direction = "recv", - msg_type = "CLIENT_SETUP", - versions = ?client.versions, - "MoQT control message" - ); - - // Extract WebTransport URL path from the underlying session. - // For WebTransport connections, this comes from the HTTP/3 CONNECT :path. - // For raw QUIC, this is the placeholder URL ("moqt://localhost") and has no meaningful path. - let wt_url_path = session.url().path(); - let wt_path = Self::normalize_connection_path(wt_url_path)?; - - // Extract CLIENT_SETUP PATH parameter (key 0x1, BytesValue). - // Used for raw QUIC connections where there's no HTTP CONNECT URL. - let client_setup_path = if wt_path.is_none() { - Self::decode_client_setup_path(&client.params)? - } else { - None - }; - - // Combine: WebTransport URL path takes precedence over CLIENT_SETUP PATH. - // WebTransport connections always have the path in the CONNECT URL. - // Raw QUIC connections only have CLIENT_SETUP PATH. - let connection_path = wt_path.or(client_setup_path); - - if connection_path.is_some() { - tracing::debug!( - connection_path = connection_path.as_deref(), - "Connection path resolved" - ); - } + log::debug!("received CLIENT_SETUP: {:?}", client); // Emit mlog event for CLIENT_SETUP parsed if let Some(ref mut mlog) = mlog { @@ -658,49 +157,24 @@ impl Session { let _ = mlog.add_event(event); } - let server_versions = setup::Versions(vec![setup::Version::DRAFT_14]); - - if let Some(largest_common_version) = - Self::largest_common(&server_versions, &client.versions) - { - // TODO SLG - make configurable? - let mut params = KeyValuePairs::default(); - params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); + // TODO SLG - make configurable? + let mut params = KeyValuePairs::default(); + params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); - let server = setup::Server { - version: largest_common_version, - params, - }; + let server = setup::Server { params }; - tracing::debug!( - target: "moq_transport::control", - direction = "sent", - msg_type = "SERVER_SETUP", - version = ?server.version, - "MoQT control message" - ); - - // Emit mlog event for SERVER_SETUP created - if let Some(ref mut mlog) = mlog { - let event = mlog::events::server_setup_created(mlog.elapsed_ms(), 0, &server); - let _ = mlog.add_event(event); - } + log::debug!("sending SERVER_SETUP: {:?}", server); - sender.encode(&server).await?; - - // We are the server, so the first request id is 1 - Ok(Session::new( - session, - sender, - recver, - 1, - mlog, - transport, - connection_path, - )) - } else { - Err(SessionError::Version(client.versions, server_versions)) + // Emit mlog event for SERVER_SETUP created + if let Some(ref mut mlog) = mlog { + let event = mlog::events::server_setup_created(mlog.elapsed_ms(), 0, &server); + let _ = mlog.add_event(event); } + + sender.encode(&server).await?; + + // We are the server, so the first request id is 1 + Ok(Session::new(session, sender, recver, 1, mlog)) } /// Run Tasks for the session, including sending of control messages, receiving and processing @@ -708,9 +182,10 @@ impl Session { /// and receiving and processing QUIC datagrams received pub async fn run(self) -> Result<(), SessionError> { tokio::select! { - res = Self::run_recv(self.recver, self.publisher, self.subscriber.clone(), self.mlog.clone()) => res, + res = Self::run_recv(self.recver, self.publisher.clone(), self.subscriber.clone(), self.mlog.clone()) => res, res = Self::run_send(self.sender, self.outgoing, self.mlog.clone()) => res, res = Self::run_streams(self.webtransport.clone(), self.subscriber.clone()) => res, + res = Self::run_bidi_streams(self.webtransport.clone(), self.publisher) => res, res = Self::run_datagrams(self.webtransport, self.subscriber) => res, } } @@ -722,8 +197,7 @@ impl Session { mlog: Option>>, ) -> Result<(), SessionError> { while let Some(msg) = outgoing.pop().await { - // Emit structured tracing log for sent control messages - Self::log_control_message(&msg, "sent"); + log::debug!("sending message: {:?}", msg); // Emit mlog event for sent control messages if let Some(ref mlog) = mlog { @@ -739,8 +213,8 @@ impl Session { Message::SubscribeOk(m) => { Some(mlog::events::subscribe_ok_created(time, stream_id, m)) } - Message::SubscribeError(m) => { - Some(mlog::events::subscribe_error_created(time, stream_id, m)) + Message::RequestError(m) => { + Some(mlog::events::reqeust_error_created(time, stream_id, m)) } Message::Unsubscribe(m) => { Some(mlog::events::unsubscribe_created(time, stream_id, m)) @@ -748,16 +222,22 @@ impl Session { Message::PublishNamespace(m) => { Some(mlog::events::publish_namespace_created(time, stream_id, m)) } - Message::PublishNamespaceOk(m) => Some( - mlog::events::publish_namespace_ok_created(time, stream_id, m), - ), - Message::PublishNamespaceError(m) => Some( - mlog::events::publish_namespace_error_created(time, stream_id, m), - ), + Message::RequestOk(m) => { + Some(mlog::events::reqeust_ok_created(time, stream_id, m)) + } Message::GoAway(m) => { Some(mlog::events::go_away_created(time, stream_id, m)) } - _ => None, // TODO: Add other message types + Message::Publish(m) => { + Some(mlog::events::publish_created(time, stream_id, m)) + } + Message::PublishOk(m) => { + Some(mlog::events::publish_ok_created(time, stream_id, m)) + } + Message::PublishDone(m) => { + Some(mlog::events::publish_done_created(time, stream_id, m)) + } + _ => None, }; if let Some(event) = event { @@ -783,11 +263,11 @@ impl Session { mut subscriber: Option, mlog: Option>>, ) -> Result<(), SessionError> { + log::debug!("[SESSION] run_recv: starting message receive loop"); loop { + log::trace!("[SESSION] run_recv: waiting for next message..."); let msg: message::Message = recver.decode().await?; - - // Emit structured tracing log for received control messages - Self::log_control_message(&msg, "recv"); + log::debug!("received message: {:?}", msg); // Emit mlog event for received control messages if let Some(ref mlog) = mlog { @@ -803,8 +283,8 @@ impl Session { Message::SubscribeOk(m) => { Some(mlog::events::subscribe_ok_parsed(time, stream_id, m)) } - Message::SubscribeError(m) => { - Some(mlog::events::subscribe_error_parsed(time, stream_id, m)) + Message::RequestError(m) => { + Some(mlog::events::request_error_parsed(time, stream_id, m)) } Message::Unsubscribe(m) => { Some(mlog::events::unsubscribe_parsed(time, stream_id, m)) @@ -812,16 +292,22 @@ impl Session { Message::PublishNamespace(m) => { Some(mlog::events::publish_namespace_parsed(time, stream_id, m)) } - Message::PublishNamespaceOk(m) => Some( - mlog::events::publish_namespace_ok_parsed(time, stream_id, m), - ), - Message::PublishNamespaceError(m) => Some( - mlog::events::publish_namespace_error_parsed(time, stream_id, m), - ), + Message::RequestOk(m) => { + Some(mlog::events::request_ok_parsed(time, stream_id, m)) + } Message::GoAway(m) => { Some(mlog::events::go_away_parsed(time, stream_id, m)) } - _ => None, // TODO: Add other message types + Message::Publish(m) => { + Some(mlog::events::publish_parsed(time, stream_id, m)) + } + Message::PublishOk(m) => { + Some(mlog::events::publish_ok_parsed(time, stream_id, m)) + } + Message::PublishDone(m) => { + Some(mlog::events::publish_done_parsed(time, stream_id, m)) + } + _ => None, }; if let Some(event) = event { @@ -830,6 +316,29 @@ impl Session { } } + // RequestOk and RequestError are bidirectional — they can be responses + // to requests originated by either side (e.g., PUBLISH_NAMESPACE from the + // publisher or SUBSCRIBE_NAMESPACE from the subscriber). We must try both + // handlers so the response reaches whichever side owns that request ID. + match &msg { + Message::RequestOk(_) | Message::RequestError(_) => { + // Try subscriber handler first (for SUBSCRIBE_NAMESPACE responses) + if let Ok(pub_msg) = TryInto::::try_into(msg.clone()) { + if let Some(sub) = subscriber.as_mut() { + let _ = sub.recv_message(pub_msg); + } + } + // Also try publisher handler (for PUBLISH_NAMESPACE responses) + if let Ok(sub_msg) = TryInto::::try_into(msg) { + if let Some(pub_) = publisher.as_mut() { + let _ = pub_.recv_message(sub_msg); + } + } + continue; + } + _ => {} + } + let msg = match TryInto::::try_into(msg) { Ok(msg) => { subscriber @@ -853,7 +362,7 @@ impl Session { }; // TODO GOAWAY, MAX_REQUEST_ID, REQUESTS_BLOCKED - tracing::warn!("Unimplemented message type received: {:?}", msg); + log::warn!("Unimplemented message type received: {:?}", msg); return Err(SessionError::unimplemented(&format!( "message type {:?}", msg @@ -878,8 +387,54 @@ impl Session { tasks.push(async move { if let Err(err) = Subscriber::recv_stream(subscriber, stream).await { - tracing::warn!("failed to serve stream: {}", err); + log::warn!("failed to serve stream: {}", err); + }; + }); + }, + _ = tasks.next(), if !tasks.is_empty() => {}, + }; + } + } + + /// Accepts bidirectional QUIC streams for messages like SUBSCRIBE_NAMESPACE. + /// In draft-16, SUBSCRIBE_NAMESPACE uses its own bidirectional stream. + async fn run_bidi_streams( + webtransport: web_transport::Session, + publisher: Option, + ) -> Result<(), SessionError> { + let mut tasks = FuturesUnordered::new(); + + loop { + tokio::select! { + res = webtransport.accept_bi() => { + let (_send, recv) = res?; + let mut publisher = publisher.clone().ok_or(SessionError::RoleViolation)?; + + tasks.push(async move { + let mut reader = Reader::new(recv); + + // Read the message from the bidi stream + let msg: message::Message = match reader.decode().await { + Ok(msg) => msg, + Err(e) => { + log::warn!("failed to decode message on bidi stream: {}", e); + return; + } }; + + log::debug!("received message on bidi stream: {:?}", msg); + + // Handle SUBSCRIBE_NAMESPACE on its dedicated bidi stream + match msg { + Message::SubscribeNamespace(subscribe_ns) => { + if let Err(e) = publisher.recv_message(message::Subscriber::SubscribeNamespace(subscribe_ns)) { + log::warn!("failed to handle SUBSCRIBE_NAMESPACE: {}", e); + } + } + other => { + log::warn!("unexpected message type on bidi stream: {:?}", other); + } + } }); }, _ = tasks.next(), if !tasks.is_empty() => {}, @@ -902,78 +457,3 @@ impl Session { } } } - -#[cfg(test)] -mod tests { - use super::*; - - // ======================================================================== - // normalize_connection_path - // ======================================================================== - - #[test] - fn normalize_empty_and_root() { - assert_eq!(Session::normalize_connection_path("").unwrap(), None); - assert_eq!(Session::normalize_connection_path("/").unwrap(), None); - assert_eq!(Session::normalize_connection_path("///").unwrap(), None); - } - - #[test] - fn normalize_valid_paths() { - assert_eq!( - Session::normalize_connection_path("/app").unwrap(), - Some("/app".to_string()) - ); - assert_eq!( - Session::normalize_connection_path("/tenant/stream-1").unwrap(), - Some("/tenant/stream-1".to_string()) - ); - // Trailing slash is trimmed - assert_eq!( - Session::normalize_connection_path("/app/").unwrap(), - Some("/app".to_string()) - ); - } - - #[test] - fn normalize_rejects_missing_leading_slash() { - assert!(Session::normalize_connection_path("app").is_err()); - } - - #[test] - fn normalize_rejects_empty_segments() { - assert!(Session::normalize_connection_path("/app//stream").is_err()); - } - - #[test] - fn normalize_rejects_dot_segments() { - assert!(Session::normalize_connection_path("/app/./stream").is_err()); - assert!(Session::normalize_connection_path("/app/../secret").is_err()); - assert!(Session::normalize_connection_path("/..").is_err()); - } - - #[test] - fn normalize_rejects_percent_encoded_characters() { - // %2F = '/' — would create scope ambiguity - assert!(Session::normalize_connection_path("/foo%2Fbar").is_err()); - // %2E%2E = '..' — would bypass dot-segment check - assert!(Session::normalize_connection_path("/%2E%2E/secret").is_err()); - // %00 = null — general injection risk - assert!(Session::normalize_connection_path("/app/%00").is_err()); - // Uppercase hex digits - assert!(Session::normalize_connection_path("/app/%2e%2e").is_err()); - } - - #[test] - fn normalize_rejects_too_long_path() { - let long_path = format!("/{}", "a".repeat(Session::MAX_CONNECTION_PATH_LEN)); - assert!(Session::normalize_connection_path(&long_path).is_err()); - } - - #[test] - fn normalize_accepts_max_length_path() { - // Exactly at the limit (1024 total including leading slash) - let path = format!("/{}", "a".repeat(Session::MAX_CONNECTION_PATH_LEN - 1)); - assert!(Session::normalize_connection_path(&path).is_ok()); - } -} diff --git a/moq-transport/src/session/publish_namespace.rs b/moq-transport/src/session/publish_namespace.rs new file mode 100644 index 00000000..14f7fcee --- /dev/null +++ b/moq-transport/src/session/publish_namespace.rs @@ -0,0 +1,157 @@ +use std::ops; + +use crate::coding::TrackNamespace; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::Publisher; + +#[derive(Debug, Clone)] +pub struct PublishNamespaceInfo { + pub request_id: u64, + pub namespace: TrackNamespace, +} + +/// Internal state for PublishNamespace. +/// +/// PublishNamespace is a namespace registry that advertises to subscribers +/// that a publisher has tracks available in a namespace. It does NOT route +/// subscriptions - that happens via PUBLISH/SUBSCRIBE messages directly. +struct PublishNamespaceState { + ok: bool, + closed: Result<(), ServeError>, +} + +impl Default for PublishNamespaceState { + fn default() -> Self { + Self { + ok: false, + closed: Ok(()), + } + } +} + +/// Represents an outbound PUBLISH_NAMESPACE request (publisher side). +/// When dropped, sends PUBLISH_NAMESPACE_DONE to the peer. +#[must_use = "sends PUBLISH_NAMESPACE_DONE on drop"] +pub struct PublishNamespace { + publisher: Publisher, + state: State, + + pub info: PublishNamespaceInfo, +} + +impl PublishNamespace { + pub(super) fn new( + mut publisher: Publisher, + request_id: u64, + namespace: TrackNamespace, + ) -> (PublishNamespace, PublishNamespaceRecv) { + let info = PublishNamespaceInfo { + request_id, + namespace: namespace.clone(), + }; + + publisher.send_message(message::PublishNamespace { + id: request_id, + track_namespace: namespace.clone(), + params: Default::default(), + }); + + let (send, recv) = State::default().split(); + + let send = Self { + publisher, + info, + state: send, + }; + let recv = PublishNamespaceRecv { + state: recv, + request_id, + }; + + (send, recv) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } + + pub async fn ok(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + if state.ok { + return Ok(()); + } + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } +} + +impl Drop for PublishNamespace { + fn drop(&mut self) { + if self.state.lock().closed.is_err() { + return; + } + + self.publisher.send_message(message::PublishNamespaceDone { + track_namespace: self.namespace.clone(), + }); + } +} + +impl ops::Deref for PublishNamespace { + type Target = PublishNamespaceInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +pub(super) struct PublishNamespaceRecv { + state: State, + pub request_id: u64, +} + +impl PublishNamespaceRecv { + pub fn recv_ok(&mut self) -> Result<(), ServeError> { + if let Some(mut state) = self.state.lock_mut() { + if state.ok { + return Err(ServeError::Duplicate); + } + + state.ok = true; + } + + Ok(()) + } + + pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } +} diff --git a/moq-transport/src/session/announced.rs b/moq-transport/src/session/publish_namespace_received.rs similarity index 57% rename from moq-transport/src/session/announced.rs rename to moq-transport/src/session/publish_namespace_received.rs index e87b96be..a25e6b76 100644 --- a/moq-transport/src/session/announced.rs +++ b/moq-transport/src/session/publish_namespace_received.rs @@ -1,37 +1,33 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::ops; use crate::coding::{ReasonPhrase, TrackNamespace}; use crate::watch::State; use crate::{message, serve::ServeError}; -use super::{AnnounceInfo, Subscriber}; +use super::{PublishNamespaceInfo, Subscriber}; -// There's currently no feedback from the peer, so the shared state is empty. -// If Unannounce contained an error code then we'd be talking. #[derive(Default)] -struct AnnouncedState {} +struct PublishNamespaceReceivedState {} -pub struct Announced { +/// Represents an inbound PUBLISH_NAMESPACE that was received (subscriber side). +/// When dropped, sends PUBLISH_NAMESPACE_CANCEL (if ok'd) or PUBLISH_NAMESPACE_ERROR. +pub struct PublishNamespaceReceived { session: Subscriber, - state: State, + state: State, - pub info: AnnounceInfo, + pub info: PublishNamespaceInfo, ok: bool, error: Option, } -impl Announced { +impl PublishNamespaceReceived { pub(super) fn new( session: Subscriber, request_id: u64, namespace: TrackNamespace, - ) -> (Announced, AnnouncedRecv) { - let info = AnnounceInfo { + ) -> (PublishNamespaceReceived, PublishNamespaceReceivedRecv) { + let info = PublishNamespaceInfo { request_id, namespace, }; @@ -44,19 +40,19 @@ impl Announced { error: None, state: send, }; - let recv = AnnouncedRecv { _state: recv }; + let recv = PublishNamespaceReceivedRecv { _state: recv }; (send, recv) } - // Send an ANNOUNCE_OK pub fn ok(&mut self) -> Result<(), ServeError> { if self.ok { return Err(ServeError::Duplicate); } - self.session.send_message(message::PublishNamespaceOk { + self.session.send_message(message::RequestOk { id: self.info.request_id, + params: Default::default(), }); self.ok = true; @@ -66,8 +62,6 @@ impl Announced { pub async fn closed(&self) -> Result<(), ServeError> { loop { - // Wow this is dumb and yet pretty cool. - // Basically loop until the state changes and exit when Recv is dropped. self.state .lock() .modified() @@ -82,19 +76,18 @@ impl Announced { } } -impl ops::Deref for Announced { - type Target = AnnounceInfo; +impl ops::Deref for PublishNamespaceReceived { + type Target = PublishNamespaceInfo; - fn deref(&self) -> &AnnounceInfo { + fn deref(&self) -> &PublishNamespaceInfo { &self.info } } -impl Drop for Announced { +impl Drop for PublishNamespaceReceived { fn drop(&mut self) { let err = self.error.clone().unwrap_or(ServeError::Done); - // TODO SLG - ServeError's do not align with draft-13 Announce error codes (section 8.25) if self.ok { self.session.send_message(message::PublishNamespaceCancel { track_namespace: self.namespace.clone(), @@ -102,22 +95,22 @@ impl Drop for Announced { reason_phrase: ReasonPhrase(err.to_string()), }); } else { - self.session.send_message(message::PublishNamespaceError { + self.session.send_message(message::RequestError { id: self.info.request_id, error_code: err.code(), + retry_interval: 0, reason_phrase: ReasonPhrase(err.to_string()), }); } } } -pub(super) struct AnnouncedRecv { - _state: State, +pub(super) struct PublishNamespaceReceivedRecv { + _state: State, } -impl AnnouncedRecv { - pub fn recv_unannounce(self) -> Result<(), ServeError> { - // Will cause the state to be dropped +impl PublishNamespaceReceivedRecv { + pub fn recv_done(self) -> Result<(), ServeError> { Ok(()) } } diff --git a/moq-transport/src/session/publish_received.rs b/moq-transport/src/session/publish_received.rs new file mode 100644 index 00000000..e0703aba --- /dev/null +++ b/moq-transport/src/session/publish_received.rs @@ -0,0 +1,297 @@ +use std::ops; + +use crate::coding::{ReasonPhrase, TrackNamespace}; +use crate::data::ExtensionHeaders; +use crate::serve::ServeError; +use crate::watch::State; +use crate::{data, message, serve}; + +use super::Subscriber; + +#[derive(Debug, Clone)] +pub struct PublishReceivedInfo { + pub id: u64, + pub track_namespace: TrackNamespace, + pub track_name: String, + pub track_alias: u64, + /// Forward parameter from PUBLISH (0x10): true = forward immediately, false = paused + pub forward: bool, + /// Track extensions from the original PUBLISH message + pub track_extensions: ExtensionHeaders, +} + +impl PublishReceivedInfo { + pub fn new_from_publish(msg: &message::Publish) -> Self { + // Forward parameter (0x10): default to true if not present + // Value of 0 means paused, 1 (or non-zero) means forward + let forward = msg + .params + .get_intvalue(0x10) // ParameterType::Forward + .map(|v| v != 0) + .unwrap_or(true); + + Self { + id: msg.id, + track_namespace: msg.track_namespace.clone(), + track_name: msg.track_name.clone(), + track_alias: msg.track_alias, + forward, + track_extensions: msg.track_extensions.clone(), + } + } +} + +struct PublishReceivedState { + ok: bool, + closed: Result<(), ServeError>, + writer: Option, +} + +impl Default for PublishReceivedState { + fn default() -> Self { + Self { + ok: false, + closed: Ok(()), + writer: None, + } + } +} + +#[must_use = "sends PUBLISH_ERROR on drop if not accepted"] +pub struct PublishReceived { + subscriber: Subscriber, + pub info: PublishReceivedInfo, + state: State, + ok: bool, +} + +impl PublishReceived { + pub(super) fn new( + subscriber: Subscriber, + msg: &message::Publish, + ) -> (Self, PublishReceivedRecv) { + let info = PublishReceivedInfo::new_from_publish(msg); + + let (send, recv) = State::default().split(); + + let send = Self { + subscriber, + info, + state: send, + ok: false, + }; + + let recv = PublishReceivedRecv { + state: recv, + writer_mode: None, + }; + + (send, recv) + } + + pub fn accept( + mut self, + track: serve::TrackWriter, + publish_msg: message::PublishOk, + ) -> Result<(), ServeError> { + let state = self.state.lock(); + if state.ok { + return Err(ServeError::Duplicate); + } + state.closed.clone()?; + + self.subscriber.send_message(publish_msg); + + if let Some(mut state) = state.into_mut() { + state.ok = true; + state.writer = Some(track); + } + + self.ok = true; + + std::mem::forget(self); + + Ok(()) + } + + pub fn reject(mut self, error_code: u64, reason: &str) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + self.subscriber.send_message(message::RequestError { + id: self.info.id, + error_code, + retry_interval: 0, + reason_phrase: ReasonPhrase(reason.to_string()), + }); + + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Closed(error_code)); + } + + std::mem::forget(self); + + Ok(()) + } + + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notify) => notify, + None => return Ok(()), + } + } + .await; + } + } +} + +impl ops::Deref for PublishReceived { + type Target = PublishReceivedInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +impl Drop for PublishReceived { + fn drop(&mut self) { + if self.ok { + return; + } + + let state = self.state.lock(); + let err = state + .closed + .as_ref() + .err() + .cloned() + .unwrap_or(ServeError::NotFound); + drop(state); + + self.subscriber.send_message(message::RequestError { + id: self.info.id, + error_code: err.code(), + retry_interval: 0, + reason_phrase: ReasonPhrase(err.to_string()), + }); + } +} + +pub(super) struct PublishReceivedRecv { + state: State, + writer_mode: Option, +} + +impl PublishReceivedRecv { + pub fn track_alias(&self) -> Option { + None + } + + pub fn recv_done(&mut self) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Done); + } + + Ok(()) + } + + fn take_writer(&mut self) -> Result { + if let Some(writer) = self.writer_mode.take() { + return Ok(writer); + } + + let mut state = self.state.lock_mut().ok_or(ServeError::Done)?; + let writer = state.writer.take().ok_or(ServeError::Done)?; + Ok(writer.into()) + } + + fn put_writer(&mut self, writer: serve::TrackWriterMode) { + self.writer_mode = Some(writer); + } + + pub fn subgroup( + &mut self, + header: data::SubgroupHeader, + ) -> Result { + let writer = self.take_writer()?; + + let mut subgroups = match writer { + serve::TrackWriterMode::Track(track) => track.subgroups()?, + serve::TrackWriterMode::Subgroups(subgroups) => subgroups, + _ => return Err(ServeError::Mode), + }; + + let result = subgroups.create(serve::Subgroup { + group_id: header.group_id, + subgroup_id: header.subgroup_id.unwrap_or(0), + priority: header.publisher_priority.unwrap_or(127), + header_type: Some(header.header_type), + }); + + // Always put writer back, even on error, to avoid losing it + self.put_writer(subgroups.into()); + + result + } + + pub fn datagram(&mut self, datagram: data::Datagram) -> Result<(), ServeError> { + let writer = self.take_writer()?; + + // Determine status from datagram type or explicit status field + let status = if datagram.datagram_type.is_end_of_group() { + Some(crate::data::ObjectStatus::EndOfGroup) + } else { + datagram.status + }; + + match writer { + serve::TrackWriterMode::Track(track) => { + let mut datagrams = track.datagrams()?; + datagrams.write(serve::Datagram { + group_id: datagram.group_id, + object_id: datagram.object_id.unwrap_or(0), + priority: datagram.publisher_priority.unwrap_or(127), + payload: datagram.payload.unwrap_or_default(), + extension_headers: datagram.extension_headers.unwrap_or_default(), + status, + })?; + self.put_writer(serve::TrackWriterMode::Datagrams(datagrams)); + Ok(()) + } + serve::TrackWriterMode::Datagrams(mut datagrams) => { + datagrams.write(serve::Datagram { + group_id: datagram.group_id, + object_id: datagram.object_id.unwrap_or(0), + priority: datagram.publisher_priority.unwrap_or(127), + payload: datagram.payload.unwrap_or_default(), + extension_headers: datagram.extension_headers.unwrap_or_default(), + status, + })?; + self.put_writer(serve::TrackWriterMode::Datagrams(datagrams)); + Ok(()) + } + other => { + self.put_writer(other); + Err(ServeError::Mode) + } + } + } +} diff --git a/moq-transport/src/session/published.rs b/moq-transport/src/session/published.rs new file mode 100644 index 00000000..aba4e724 --- /dev/null +++ b/moq-transport/src/session/published.rs @@ -0,0 +1,717 @@ +use std::ops; +use std::sync::{Arc, Mutex}; + +use futures::stream::FuturesUnordered; +use futures::StreamExt; + +use crate::coding::{Encode, Location, ReasonPhrase, TrackNamespace}; +use crate::data::ExtensionHeaders; +use crate::message::ParameterType; +use crate::mlog; +use crate::serve::{ServeError, TrackReaderMode}; +use crate::watch::State; +use crate::{data, message, serve}; + +use super::{Publisher, SessionError, Writer}; + +/// Callback for object observation and filtering during streaming. +/// Called for each object before it's forwarded. +/// +/// Arguments: +/// - group_id: The group ID of the object +/// - object_id: The object ID within the group +/// - extension_headers: The extension headers on the object +/// +/// Returns: +/// - `true` to forward the object +/// - `false` to skip/drop the object +/// +/// Use cases: +/// - Update TopN tracker with audio level metrics from extension headers +/// - Filter objects based on whether the track is in top-N for the subscriber +pub type ObjectObserverFn = Box bool + Send + Sync>; + +#[derive(Debug, Clone)] +pub struct PublishInfo { + pub id: u64, + pub track_namespace: TrackNamespace, + pub track_name: String, + pub track_alias: u64, +} + +impl PublishInfo { + pub fn new_from_publish(msg: &message::Publish) -> Self { + Self { + id: msg.id, + track_namespace: msg.track_namespace.clone(), + track_name: msg.track_name.clone(), + track_alias: msg.track_alias, + } + } +} + +#[derive(Debug)] +struct PublishedState { + ok: bool, + forward: bool, + subscriber_priority: u8, + group_order: message::GroupOrder, + largest_location: Option, + closed: Result<(), ServeError>, +} + +impl PublishedState { + fn update_largest_location(&mut self, group_id: u64, object_id: u64) -> Result<(), ServeError> { + let new_location = Location::new(group_id, object_id); + if let Some(current) = self.largest_location { + if new_location > current { + self.largest_location = Some(new_location); + } + } else { + self.largest_location = Some(new_location); + } + Ok(()) + } +} + +impl Default for PublishedState { + fn default() -> Self { + Self { + ok: false, + forward: true, + subscriber_priority: 128, + group_order: message::GroupOrder::Ascending, + largest_location: None, + closed: Ok(()), + } + } +} + +#[must_use = "sends PUBLISH_DONE on drop"] +pub struct Published { + publisher: Publisher, + pub info: PublishInfo, + state: State, + ok: bool, + mlog: Option>>, +} + +impl Published { + pub(super) fn new( + mut publisher: Publisher, + msg: message::Publish, + mlog: Option>>, + ) -> (Self, PublishedRecv) { + let info = PublishInfo::new_from_publish(&msg); + + publisher.send_message(msg); + + let (send, recv) = State::default().split(); + + let send = Self { + publisher, + info, + state: send, + ok: false, + mlog, + }; + + let recv = PublishedRecv { state: recv }; + + (send, recv) + } + + pub async fn ok(&mut self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + if state.ok { + self.ok = true; + return Ok(()); + } + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notify) => notify, + None => return Ok(()), + } + } + .await; + } + } + + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } + + pub async fn serve(mut self, track: serve::TrackReader) -> Result<(), SessionError> { + let res = self.serve_inner(track, None).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + res + } + + /// Serve with an observer callback that's called for each object. + /// The observer receives object extension headers and can update external state (e.g., TopN tracker). + /// All objects are forwarded regardless of observer return. + pub async fn serve_with_observer( + mut self, + track: serve::TrackReader, + observer: ObjectObserverFn, + ) -> Result<(), SessionError> { + let res = self.serve_inner(track, Some(Arc::new(observer))).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + res + } + + /// Serve using a pre-acquired TrackReaderMode. + /// Use this when you need to acquire the mode early (before network round trips) + /// to avoid missing frames in late-join scenarios. + pub async fn serve_mode(mut self, mode: TrackReaderMode) -> Result<(), SessionError> { + let res = self.serve_mode_inner(mode).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + res + } + + /// Serve immediately without waiting for PUBLISH_OK. + /// Use this for relay scenarios where you want to start forwarding data right away. + /// The subscriber will receive data as soon as they're ready. + pub async fn serve_immediately(mut self, track: serve::TrackReader) -> Result<(), SessionError> { + let res = self.serve_immediately_inner(track).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + res + } + + async fn serve_inner( + &mut self, + track: serve::TrackReader, + observer: Option>, + ) -> Result<(), SessionError> { + self.ok().await?; + + let forward = { + let state = self.state.lock(); + state.forward + }; + + if !forward { + self.closed().await?; + return Ok(()); + } + + match track.mode().await? { + TrackReaderMode::Stream(_stream) => panic!("deprecated"), + TrackReaderMode::Subgroups(subgroups) => { + self.serve_subgroups(subgroups, observer).await + } + TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + } + } + + async fn serve_mode_inner(&mut self, mode: TrackReaderMode) -> Result<(), SessionError> { + self.ok().await?; + + let forward = { + let state = self.state.lock(); + state.forward + }; + + if !forward { + self.closed().await?; + return Ok(()); + } + + match mode { + TrackReaderMode::Stream(_stream) => panic!("deprecated"), + TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups, None).await, + TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + } + } + + async fn serve_immediately_inner(&mut self, track: serve::TrackReader) -> Result<(), SessionError> { + // Don't wait for PUBLISH_OK - start streaming immediately + // This is useful for relay scenarios where we want minimal latency + + match track.mode().await? { + TrackReaderMode::Stream(_stream) => panic!("deprecated"), + TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups, None).await, + TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + } + } + + async fn serve_subgroups( + &mut self, + mut subgroups: serve::SubgroupsReader, + observer: Option>, + ) -> Result<(), SessionError> { + let mut tasks = FuturesUnordered::new(); + let mut done: Option> = None; + + loop { + tokio::select! { + res = subgroups.next(), if done.is_none() => match res { + Ok(Some(subgroup)) => { + // Header type will be determined in serve_subgroup based on extension headers + let track_alias = self.info.track_alias; + let publisher = self.publisher.clone(); + let state = self.state.clone(); + let info = subgroup.info.clone(); + let mlog = self.mlog.clone(); + let observer = observer.clone(); + + tasks.push(async move { + if let Err(err) = Self::serve_subgroup(track_alias, subgroup, publisher, state, mlog, observer).await { + log::warn!("failed to serve subgroup: {:?}, error: {}", info, err); + } + }); + }, + Ok(None) => done = Some(Ok(())), + Err(err) => done = Some(Err(err)), + }, + res = self.closed(), if done.is_none() => done = Some(res), + _ = tasks.next(), if !tasks.is_empty() => {}, + else => return Ok(done.unwrap()?), + } + } + } + + async fn serve_subgroup( + track_alias: u64, + mut subgroup_reader: serve::SubgroupReader, + mut publisher: Publisher, + state: State, + mlog: Option>>, + observer: Option>, + ) -> Result<(), SessionError> { + log::debug!( + "[PUBLISHED] serve_subgroup: starting - track_alias={}, group_id={}, subgroup_id={:?}, priority={}", + track_alias, + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + subgroup_reader.priority + ); + + // Read the first object to determine if we have extension headers + let first_object = match subgroup_reader.next().await? { + Some(obj) => obj, + None => { + log::debug!("[PUBLISHED] serve_subgroup: no objects in subgroup, skipping"); + return Ok(()); + } + }; + + // Call observer for first object if present (always forward first object to establish stream) + let should_forward_first = if let Some(ref obs) = observer { + obs( + subgroup_reader.group_id, + first_object.object_id, + &first_object.extension_headers, + ) + } else { + true + }; + + // Use preserved header type if available, otherwise determine from extension headers + let has_extension_headers = !first_object.extension_headers.is_empty(); + let header_type = subgroup_reader.info.header_type.unwrap_or_else(|| { + // Fallback: determine header type based on extension headers + if has_extension_headers { + data::StreamHeaderType::SubgroupZeroIdExtEndOfGroup + } else { + data::StreamHeaderType::SubgroupZeroIdEndOfGroup + } + }); + + // If we're not writing extension headers but the preserved header type has extensions, + // convert to the non-Ext variant to avoid mismatch between header and object encoding + let header_type = if !has_extension_headers && header_type.has_extension_headers() { + log::debug!( + "[PUBLISHED] serve_subgroup: converting header_type {:?} to non-Ext variant (objects have no extensions)", + header_type + ); + header_type.without_extensions() + } else { + header_type + }; + + // Set subgroup_id based on header type (ZeroId variants don't include it on wire) + let subgroup_id = if header_type.has_subgroup_id() { + Some(subgroup_reader.subgroup_id) + } else { + None + }; + + let header = data::SubgroupHeader { + header_type, + track_alias, + group_id: subgroup_reader.group_id, + subgroup_id, + publisher_priority: Some(subgroup_reader.priority), + }; + + let mut send_stream = publisher.open_uni().await?; + send_stream.set_priority(subgroup_reader.priority as i32); + + let mut writer = Writer::new(send_stream); + + log::debug!( + "[PUBLISHED] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}, header_type={:?}, has_ext={}", + header.track_alias, + header.group_id, + header.subgroup_id, + header.publisher_priority, + header.header_type, + has_extension_headers + ); + + writer.encode(&header).await?; + + if let Some(ref mlog) = mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let event = mlog::subgroup_header_created(time, stream_id, &header); + let _ = mlog_guard.add_event(event); + } + } + + // Helper to write an object + async fn write_object( + writer: &mut Writer, + object_reader: &mut serve::SubgroupObjectReader, + has_extension_headers: bool, + object_count: u64, + subgroup_reader: &serve::SubgroupReader, + state: &State, + mlog: &Option>>, + ) -> Result<(), SessionError> { + if has_extension_headers { + let subgroup_object = data::SubgroupObjectExt { + object_id_delta: 0, + extension_headers: object_reader.extension_headers.clone(), + payload_length: object_reader.size, + status: if object_reader.size == 0 { + Some(object_reader.status) + } else { + None + }, + }; + + log::debug!( + "[PUBLISHED] serve_subgroup: sending object #{} (ext) - object_id={}, payload_length={}, status={:?}", + object_count + 1, + object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status + ); + + writer.encode(&subgroup_object).await?; + + if let Some(ref mlog) = mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let event = mlog::subgroup_object_ext_created( + time, + stream_id, + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + object_reader.object_id, + &subgroup_object, + ); + let _ = mlog_guard.add_event(event); + } + } + } else { + let subgroup_object = data::SubgroupObject { + object_id_delta: 0, + payload_length: object_reader.size, + status: if object_reader.size == 0 { + Some(object_reader.status) + } else { + None + }, + }; + + log::debug!( + "[PUBLISHED] serve_subgroup: sending object #{} - object_id={}, payload_length={}, status={:?}", + object_count + 1, + object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status + ); + + writer.encode(&subgroup_object).await?; + + // No mlog for non-ext objects currently + } + + state + .lock_mut() + .ok_or(ServeError::Done)? + .update_largest_location( + subgroup_reader.group_id, + object_reader.object_id, + )?; + + while let Some(chunk) = object_reader.read().await? { + writer.write(&chunk).await?; + } + + Ok(()) + } + + // Write the first object that we already read (if observer allows) + let mut object_count = 0; + let mut first_object = first_object; + if should_forward_first { + write_object( + &mut writer, + &mut first_object, + has_extension_headers, + object_count, + &subgroup_reader, + &state, + &mlog, + ) + .await?; + object_count += 1; + } else { + // Consume the object data even if not forwarding + while first_object.read().await?.is_some() {} + log::debug!( + "[PUBLISHED] serve_subgroup: skipped first object (group_id={}, object_id={}) - filtered by observer", + subgroup_reader.group_id, + first_object.object_id + ); + } + + // Continue with remaining objects + while let Some(mut subgroup_object_reader) = subgroup_reader.next().await? { + // Call observer for each object if present - observer returns whether to forward + let should_forward = if let Some(ref obs) = observer { + obs( + subgroup_reader.group_id, + subgroup_object_reader.object_id, + &subgroup_object_reader.extension_headers, + ) + } else { + true + }; + + if should_forward { + write_object( + &mut writer, + &mut subgroup_object_reader, + has_extension_headers, + object_count, + &subgroup_reader, + &state, + &mlog, + ) + .await?; + object_count += 1; + } else { + // Consume the object data even if not forwarding + while subgroup_object_reader.read().await?.is_some() {} + log::debug!( + "[PUBLISHED] serve_subgroup: skipped object (group_id={}, object_id={}) - filtered by observer", + subgroup_reader.group_id, + subgroup_object_reader.object_id + ); + } + } + + log::info!( + "[PUBLISHED] serve_subgroup: completed subgroup (group_id={}, subgroup_id={:?}, {} objects sent, header_type={:?})", + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + object_count, + header_type + ); + + Ok(()) + } + + async fn serve_datagrams( + &mut self, + mut datagrams: serve::DatagramsReader, + ) -> Result<(), SessionError> { + log::debug!("[PUBLISHED] serve_datagrams: starting"); + + let mut datagram_count = 0; + while let Some(datagram) = datagrams.read().await? { + let has_extension_headers = !datagram.extension_headers.is_empty(); + let datagram_type = if has_extension_headers { + data::DatagramType::ObjectIdPayloadExt + } else { + data::DatagramType::ObjectIdPayload + }; + + let encoded_datagram = data::Datagram { + datagram_type, + track_alias: self.info.track_alias, + group_id: datagram.group_id, + object_id: Some(datagram.object_id), + publisher_priority: Some(datagram.priority), + extension_headers: if has_extension_headers { + Some(datagram.extension_headers.clone()) + } else { + None + }, + status: None, + payload: Some(datagram.payload), + }; + + let payload_len = encoded_datagram + .payload + .as_ref() + .map(|p| p.len()) + .unwrap_or(0); + let mut buffer = bytes::BytesMut::with_capacity(payload_len + 100); + encoded_datagram.encode(&mut buffer)?; + + log::debug!( + "[PUBLISHED] serve_datagrams: sending datagram #{} - track_alias={}, group_id={}, object_id={}, priority={:?}, payload_len={}", + datagram_count + 1, + encoded_datagram.track_alias, + encoded_datagram.group_id, + encoded_datagram.object_id.unwrap(), + encoded_datagram.publisher_priority, + payload_len + ); + + if let Some(ref mlog) = self.mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let _ = mlog_guard.add_event(mlog::object_datagram_created( + time, + stream_id, + &encoded_datagram, + )); + } + } + + self.publisher.send_datagram(buffer.into()).await?; + + self.state + .lock_mut() + .ok_or(ServeError::Done)? + .update_largest_location( + encoded_datagram.group_id, + encoded_datagram.object_id.unwrap(), + )?; + + datagram_count += 1; + } + + log::info!( + "[PUBLISHED] serve_datagrams: completed ({} datagrams sent)", + datagram_count + ); + + Ok(()) + } +} + +impl ops::Deref for Published { + type Target = PublishInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +impl Drop for Published { + fn drop(&mut self) { + let state = self.state.lock(); + let err = state + .closed + .as_ref() + .err() + .cloned() + .unwrap_or(ServeError::Done); + drop(state); + + self.publisher.send_message(message::PublishDone { + id: self.info.id, + status_code: err.code(), + stream_count: 0, // TODO SLG + reason: ReasonPhrase(err.to_string()), + }); + } +} + +pub(super) struct PublishedRecv { + state: State, +} + +impl PublishedRecv { + pub fn recv_ok(&mut self, msg: &message::PublishOk) -> Result<(), ServeError> { + let state = self.state.lock(); + if state.ok { + return Err(ServeError::Duplicate); + } + + if let Some(mut state) = state.into_mut() { + state.ok = true; + + // Extract subscription properties from parameters (draft-16) + if let Some(v) = msg.params.get_intvalue(ParameterType::Forward.into()) { + state.forward = v == 1; + } + if let Some(v) = msg.params.get_intvalue(ParameterType::SubscriberPriority.into()) { + state.subscriber_priority = v as u8; + } + if let Some(v) = msg.params.get_intvalue(ParameterType::GroupOrder.into()) { + state.group_order = match v { + 0x0 => message::GroupOrder::Publisher, + 0x1 => message::GroupOrder::Ascending, + 0x2 => message::GroupOrder::Descending, + _ => message::GroupOrder::Ascending, + }; + } + } + + Ok(()) + } + + pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } +} diff --git a/moq-transport/src/session/publisher.rs b/moq-transport/src/session/publisher.rs index 8b8c5085..afedac74 100644 --- a/moq-transport/src/session/publisher.rs +++ b/moq-transport/src/session/publisher.rs @@ -1,59 +1,49 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{ - collections::{hash_map, HashMap}, + collections::{hash_map, HashMap, HashSet}, sync::{atomic, Arc, Mutex}, }; -use futures::{stream::FuturesUnordered, StreamExt}; - use crate::{ - coding::TrackNamespace, - message::{self, Message}, + coding::{KeyValuePairs, TrackNamespace}, + message::{self, GroupOrder, Message, ParameterType}, mlog, - serve::{ServeError, TracksReader}, + serve::{self, ServeError, TracksReader}, }; use crate::watch::Queue; use super::{ - Announce, AnnounceRecv, Session, SessionError, Subscribed, SubscribedRecv, TrackStatusRequested, + PublishNamespace, PublishNamespaceRecv, Published, PublishedRecv, Session, SessionError, + SubscribeNamespaceReceived, SubscribeNamespaceReceivedRecv, Subscribed, SubscribedRecv, + TrackStatusRequested, }; -// TODO remove Clone. #[derive(Clone)] pub struct Publisher { webtransport: web_transport::Session, - /// When the announce method is used, a new entry is added to this HashMap to track outbound announcement - announces: Arc>>, + publish_namespaces: Arc>>, + + filtered_namespaces: Arc>>, - /// When a Subscribe is received and we have a previous announce for the namespace, then a new entry is - /// added to this HashMap to track the inbound subscription subscribeds: Arc>>, - /// When a Subscribe is received and we DO NOT have a previous announce for the namespace, then a new entry is - /// added to this Queue to track the inbound subscription unknown_subscribed: Queue, - /// When a TrackStatus is received and we DO NOT have a previous announce for the namespace, then a new entry is - /// added to this Queue to track the inbound track status request unknown_track_status_requested: Queue, - /// The queue we will write any outbound control messages we want to sent, the session run_send task - /// will process the queue and send the message on the control stream. + subscribe_namespaces_received: Arc>>, + + subscribe_namespace_received_queue: Queue, + + publisheds: Arc>>, + + next_track_alias: Arc, + outgoing: Queue, - /// When we need a new Request Id for sending a request, we can get it from here. Note: The instance - /// of AtomicU64 is shared with the Subscriber, so the session uses unique request ids for all requests - /// generated. Note: If we initiated the QUIC connection then request id's start at 0 and increment by 2 - /// for each request (even numbers). If we accepted an inbound QUIC connection then request id's start at 1 and - /// increment by 2 for each request (odd numbers). next_requestid: Arc, - /// Optional mlog writer for logging transport events mlog: Option>>, } @@ -66,107 +56,71 @@ impl Publisher { ) -> Self { Self { webtransport, - announces: Default::default(), + publish_namespaces: Default::default(), + filtered_namespaces: Default::default(), subscribeds: Default::default(), unknown_subscribed: Default::default(), unknown_track_status_requested: Default::default(), + subscribe_namespaces_received: Default::default(), + subscribe_namespace_received_queue: Default::default(), + publisheds: Default::default(), + next_track_alias: Arc::new(atomic::AtomicU64::new(0)), outgoing, next_requestid, mlog, } } + pub fn next_track_alias(&self) -> u64 { + self.next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed) + } + pub async fn accept( session: web_transport::Session, - transport: super::Transport, ) -> Result<(Session, Publisher), SessionError> { - let (session, publisher, _) = Session::accept(session, None, transport).await?; + let (session, publisher, _) = Session::accept(session, None).await?; Ok((session, publisher.unwrap())) } pub async fn connect( session: web_transport::Session, - transport: super::Transport, ) -> Result<(Session, Publisher), SessionError> { - let (session, publisher, _) = Session::connect(session, None, transport).await?; + let (session, publisher, _) = Session::connect(session, None).await?; Ok((session, publisher)) } - /// Announce a namespace and serve tracks using the provided [serve::TracksReader]. - /// The caller uses [serve::TracksWriter] for static tracks and [serve::TracksRequest] for dynamic tracks. - pub async fn announce(&mut self, tracks: TracksReader) -> Result<(), SessionError> { - // Check if annouce for this namespace already exists or not, and if not, then create a new Announce - let announce = match self - .announces + pub async fn publish_namespace( + &mut self, + namespace: TrackNamespace, + ) -> Result { + if self + .filtered_namespaces + .lock() + .unwrap() + .contains(&namespace) + { + return Err(ServeError::Cancel.into()); + } + + let publish_ns = match self + .publish_namespaces .lock() .unwrap() - .entry(tracks.namespace.clone()) + .entry(namespace.clone()) { - // Namespace already exists in HashMap (has already been announced) - return Duplicate error hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), - // This is a new announce, send announce message to peer. hash_map::Entry::Vacant(entry) => { - // Get the current next request id to use and increment the value for by 2 for the next request let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); - let (send, recv) = - Announce::new(self.clone(), request_id, tracks.namespace.clone()); + let (send, recv) = PublishNamespace::new(self.clone(), request_id, namespace); entry.insert(recv); send } }; - let mut subscribe_tasks = FuturesUnordered::new(); - let mut status_tasks = FuturesUnordered::new(); - let mut subscribe_done = false; - let mut status_done = false; - - // The code enters an infinite loop and waits for one of several events: - // - A new subscription arrives. - // - A new track status request arrives. - // - One of the spawned subscription-handling tasks completes. - // - One of the spawned status-handling tasks completes. - // Exit the loop when all input streams are done (None), and all tasks have completed - loop { - tokio::select! { - // Get next subscription to this announce - res = announce.subscribed(), if !subscribe_done => { - match res? { - Some(subscribed) => { - let tracks = tracks.clone(); - - subscribe_tasks.push(async move { - let info = subscribed.info.clone(); - if let Err(err) = Self::serve_subscribe(subscribed, tracks).await { - tracing::warn!("failed serving subscribe: {:?}, error: {}", info, err) - } - }); - }, - None => subscribe_done = true, - } - - }, - res = announce.track_status_requested(), if !status_done => { - match res? { - Some(status) => { - let tracks = tracks.clone(); - - status_tasks.push(async move { - let request_msg = status.request_msg.clone(); - if let Err(err) = Self::serve_track_status(status, tracks).await { - tracing::warn!("failed serving track status request: {:?}, error: {}", request_msg, err) - } - }); - }, - None => status_done = true, - } - }, - Some(res) = subscribe_tasks.next() => res, - Some(res) = status_tasks.next() => res, - else => return Ok(()) - } - } + Ok(publish_ns) } pub async fn serve_subscribe( @@ -212,16 +166,124 @@ impl Publisher { Ok(()) } - // Returns subscriptions that do not map to an active announce. pub async fn subscribed(&mut self) -> Option { self.unknown_subscribed.pop().await } - // Returns track_status requests that do not map to an active announce. pub async fn track_status_requested(&mut self) -> Option { self.unknown_track_status_requested.pop().await } + pub async fn subscribe_namespace_received(&mut self) -> Option { + self.subscribe_namespace_received_queue.pop().await + } + + pub async fn publish(&mut self, track: serve::TrackReader) -> Result { + let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); + let track_alias = self + .next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed); + + let mut params = KeyValuePairs::new(); + params.set_intvalue(ParameterType::GroupOrder.into(), GroupOrder::Ascending as u64); + params.set_intvalue(ParameterType::Forward.into(), 1); + if let Some(loc) = track.largest_location() { + let mut buf = bytes::BytesMut::new(); + use crate::coding::Encode; + loc.encode(&mut buf).ok(); + params.set_bytesvalue(ParameterType::LargestObject.into(), buf.to_vec()); + } + + let msg = message::Publish { + id: request_id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + track_alias, + params, + track_extensions: Default::default(), + }; + + let (send, recv) = Published::new(self.clone(), msg, self.mlog.clone()); + + self.publisheds.lock().unwrap().insert(request_id, recv); + + Ok(send) + } + + pub async fn publish_with_options( + &mut self, + track: serve::TrackReader, + group_order: GroupOrder, + forward: bool, + ) -> Result { + let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); + let track_alias = self + .next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed); + + let mut params = KeyValuePairs::new(); + params.set_intvalue(ParameterType::GroupOrder.into(), group_order as u64); + params.set_intvalue(ParameterType::Forward.into(), if forward { 1 } else { 0 }); + if let Some(loc) = track.largest_location() { + let mut buf = bytes::BytesMut::new(); + use crate::coding::Encode; + loc.encode(&mut buf).ok(); + params.set_bytesvalue(ParameterType::LargestObject.into(), buf.to_vec()); + } + + let msg = message::Publish { + id: request_id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + track_alias, + params, + track_extensions: Default::default(), + }; + + let (send, recv) = Published::new(self.clone(), msg, self.mlog.clone()); + + self.publisheds.lock().unwrap().insert(request_id, recv); + + Ok(send) + } + + /// Publish a track with specific track extensions (for relay forwarding) + pub async fn publish_with_extensions( + &mut self, + track: serve::TrackReader, + track_extensions: crate::data::ExtensionHeaders, + ) -> Result { + let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); + let track_alias = self + .next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed); + + let mut params = KeyValuePairs::new(); + params.set_intvalue(ParameterType::GroupOrder.into(), GroupOrder::Ascending as u64); + params.set_intvalue(ParameterType::Forward.into(), 1); + if let Some(loc) = track.largest_location() { + let mut buf = bytes::BytesMut::new(); + use crate::coding::Encode; + loc.encode(&mut buf).ok(); + params.set_bytesvalue(ParameterType::LargestObject.into(), buf.to_vec()); + } + + let msg = message::Publish { + id: request_id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + track_alias, + params, + track_extensions, + }; + + let (send, recv) = Published::new(self.clone(), msg, self.mlog.clone()); + + self.publisheds.lock().unwrap().insert(request_id, recv); + + Ok(send) + } + pub(crate) fn recv_message(&mut self, msg: message::Subscriber) -> Result<(), SessionError> { let res = match msg { message::Subscriber::Subscribe(msg) => self.recv_subscribe(msg), @@ -232,68 +294,45 @@ impl Publisher { Err(SessionError::unimplemented("FETCH_CANCEL")) } message::Subscriber::TrackStatus(msg) => self.recv_track_status(msg), - message::Subscriber::SubscribeNamespace(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE")) - } - message::Subscriber::UnsubscribeNamespace(_msg) => { - Err(SessionError::unimplemented("UNSUBSCRIBE_NAMESPACE")) - } + message::Subscriber::SubscribeNamespace(msg) => self.recv_subscribe_namespace(msg), message::Subscriber::PublishNamespaceCancel(msg) => { self.recv_publish_namespace_cancel(msg) } - message::Subscriber::PublishNamespaceOk(msg) => self.recv_publish_namespace_ok(msg), - message::Subscriber::PublishNamespaceError(msg) => { - self.recv_publish_namespace_error(msg) - } - message::Subscriber::PublishOk(_msg) => Err(SessionError::unimplemented("PUBLISH_OK")), - message::Subscriber::PublishError(_msg) => { - Err(SessionError::unimplemented("PUBLISH_ERROR")) - } + message::Subscriber::RequestOk(msg) => self.recv_request_ok(msg), + message::Subscriber::PublishOk(msg) => self.recv_publish_ok(msg), + message::Subscriber::RequestError(msg) => self.recv_request_error(msg), }; if let Err(err) = res { - tracing::warn!("failed to process message: {}", err); + log::warn!("failed to process message: {}", err); } Ok(()) } - fn recv_publish_namespace_ok( - &mut self, - msg: message::PublishNamespaceOk, - ) -> Result<(), SessionError> { - // We need to find the announce request using the request id, however the self.announces data structure - // is a HashMap indexed by Namespace (which is needed for handling PUBLISH_NAMESPACE_CANCEL). TODO - make more efficient. - // For now iterate through all self.annouces until we find the matching id. - let mut announces = self.announces.lock().unwrap(); - let announce = announces.iter_mut().find(|(_k, v)| v.request_id == msg.id); - - if let Some(announce) = announce { - announce.1.recv_ok()?; + fn recv_request_ok(&mut self, msg: message::RequestOk) -> Result<(), SessionError> { + let mut publish_namespaces = self.publish_namespaces.lock().unwrap(); + let entry = publish_namespaces + .iter_mut() + .find(|(_k, v)| v.request_id == msg.id); + + if let Some(entry) = entry { + entry.1.recv_ok()?; } Ok(()) } - fn recv_publish_namespace_error( - &mut self, - msg: message::PublishNamespaceError, - ) -> Result<(), SessionError> { - // We need to find the announce request using the request id, however the self.announces data structure - // is a HashMap indexed by Namespace (which is needed for handling PUBLISH_NAMESPACE_CANCEL). TODO - make more efficient. - // For now iterate through all self.annouces until we find the matching id. - let mut announces = self.announces.lock().unwrap(); + fn recv_request_error(&mut self, msg: message::RequestError) -> Result<(), SessionError> { + let mut publish_namespaces = self.publish_namespaces.lock().unwrap(); - // Find the key first (immutable borrow only) - let key_opt = announces + let key_opt = publish_namespaces .iter() .find(|(_k, v)| v.request_id == msg.id) .map(|(k, _)| k.clone()); - // Remove from HashMap and take ownership if let Some(key) = key_opt { - if let Some((_ns, v)) = announces.remove_entry(&key) { - // Step 3: call recv_error, consuming v + if let Some((_ns, v)) = publish_namespaces.remove_entry(&key) { v.recv_error(ServeError::Closed(msg.error_code))?; } } @@ -305,10 +344,21 @@ impl Publisher { &mut self, msg: message::PublishNamespaceCancel, ) -> Result<(), SessionError> { - // TODO: If a publisher receives new subscriptions for that namespace after receiving an ANNOUNCE_CANCEL, - // it SHOULD close the session as a 'Protocol Violation'. - if let Some(announce) = self.announces.lock().unwrap().remove(&msg.track_namespace) { - announce.recv_error(ServeError::Cancel)?; + if let Some(entry) = self + .publish_namespaces + .lock() + .unwrap() + .remove(&msg.track_namespace) + { + entry.recv_error(ServeError::Cancel)?; + } + + Ok(()) + } + + fn recv_publish_ok(&mut self, msg: message::PublishOk) -> Result<(), SessionError> { + if let Some(published) = self.publisheds.lock().unwrap().get_mut(&msg.id) { + published.recv_ok(&msg)?; } Ok(()) @@ -320,29 +370,18 @@ impl Publisher { let subscribed = { let mut subscribeds = self.subscribeds.lock().unwrap(); - // See if entry exists for this request id already, if so error out let entry = match subscribeds.entry(msg.id) { hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), hash_map::Entry::Vacant(entry) => entry, }; - // Create new Subscribed entry and add to HashMap let (send, recv) = Subscribed::new(self.clone(), msg, self.mlog.clone()); entry.insert(recv); send }; - // If we have an announce, route the subscribe to it. - if let Some(announce) = self.announces.lock().unwrap().get_mut(&namespace) { - return announce.recv_subscribe(subscribed).map_err(Into::into); - } - - // Otherwise, put it in the unknown queue. - // TODO Have some way to detect if the application is not reading from the unknown queue, - // then send SubscribeError. if let Err(err) = self.unknown_subscribed.push(subscribed) { - // Default to closing with a not found error I guess. err.close(ServeError::not_found_ctx(format!( "unknown_subscribed queue full for namespace {:?}", namespace @@ -361,26 +400,12 @@ impl Publisher { } fn recv_track_status(&mut self, msg: message::TrackStatus) -> Result<(), SessionError> { - let namespace = msg.track_namespace.clone(); - - // Create TrackStatusRequested to track this request let track_status_requested = TrackStatusRequested::new(self.clone(), msg); - // If we have an announce, route the track_status to it. - if let Some(announce) = self.announces.lock().unwrap().get_mut(&namespace) { - return announce - .recv_track_status_requested(track_status_requested) - .map_err(Into::into); - } - - // Otherwise, put it in the unknown_track_status queue. - // TODO Have some way to detect if the application is not reading from the unknown_track_status queue, - // then send TrackStatusError. if let Err(mut err) = self .unknown_track_status_requested .push(track_status_requested) { - // push only fails if the queue is dropped, send TrackStatusError, Internal error err.respond_error(0, "Internal error")?; } @@ -395,15 +420,48 @@ impl Publisher { Ok(()) } - /// Process a message before sending it, performing any necessary internal actions. + fn recv_subscribe_namespace( + &mut self, + msg: message::SubscribeNamespace, + ) -> Result<(), SessionError> { + let namespace_prefix = msg.track_namespace_prefix.clone(); + + self.filtered_namespaces + .lock() + .unwrap() + .remove(&namespace_prefix); + + let mut entries = self.subscribe_namespaces_received.lock().unwrap(); + + let entry = match entries.entry(msg.id) { + hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), + hash_map::Entry::Vacant(entry) => entry, + }; + + let (send, recv) = + SubscribeNamespaceReceived::new(self.clone(), msg.id, namespace_prefix, msg.params); + + if let Err(send) = self.subscribe_namespace_received_queue.push(send) { + send.reject(0x0, "Internal error")?; + return Ok(()); + } + + entry.insert(recv); + + Ok(()) + } + fn act_on_message_to_send>( &mut self, msg: T, ) -> message::Publisher { let msg = msg.into(); match &msg { - message::Publisher::PublishDone(m) => self.drop_subscribe(m.id), - message::Publisher::SubscribeError(m) => self.drop_subscribe(m.id), + message::Publisher::PublishDone(m) => { + self.drop_subscribe(m.id); + self.drop_published(m.id); + } + message::Publisher::RequestError(m) => self.drop_subscribe(m.id), message::Publisher::PublishNamespaceDone(m) => { self.drop_publish_namespace(&m.track_namespace); } @@ -415,7 +473,13 @@ impl Publisher { /// Send a message without waiting for it to be sent. pub(super) fn send_message + Into>(&mut self, msg: T) { let msg = self.act_on_message_to_send(msg); - self.outgoing.push(msg.into()).ok(); + let msg_name = format!("{:?}", msg); + let msg: Message = msg.into(); + log::debug!("[PUBLISHER] send_message: pushing {:?} to outgoing queue", msg_name); + match self.outgoing.push(msg) { + Ok(()) => log::debug!("[PUBLISHER] send_message: push succeeded"), + Err(_) => log::warn!("[PUBLISHER] send_message: push FAILED (queue closed?)"), + } } /// Send a message and wait until it is sent (or at least popped off the outgoing control message queue) @@ -435,7 +499,11 @@ impl Publisher { } fn drop_publish_namespace(&mut self, namespace: &TrackNamespace) { - self.announces.lock().unwrap().remove(namespace); + self.publish_namespaces.lock().unwrap().remove(namespace); + } + + fn drop_published(&mut self, id: u64) { + self.publisheds.lock().unwrap().remove(&id); } pub(super) async fn open_uni(&mut self) -> Result { @@ -445,4 +513,16 @@ impl Publisher { pub(super) async fn send_datagram(&mut self, data: bytes::Bytes) -> Result<(), SessionError> { Ok(self.webtransport.send_datagram(data).await?) } + + /// Forward a PUBLISH message to the subscriber (used by relay for SUBSCRIBE_NAMESPACE flow). + /// This sends the message without tracking it for PUBLISH_OK response handling. + pub fn forward_publish(&mut self, msg: message::Publish) { + self.outgoing.push(msg.into()).ok(); + } + + /// Forward a NAMESPACE message to the subscriber (used by relay for SUBSCRIBE_NAMESPACE flow). + /// This announces a namespace that matches the subscriber's SUBSCRIBE_NAMESPACE prefix. + pub fn forward_namespace(&mut self, msg: message::Namespace) { + self.outgoing.push(msg.into()).ok(); + } } diff --git a/moq-transport/src/session/reader.rs b/moq-transport/src/session/reader.rs index b3326c3f..1dd05530 100644 --- a/moq-transport/src/session/reader.rs +++ b/moq-transport/src/session/reader.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{cmp, io}; use bytes::{Buf, Bytes, BytesMut}; @@ -24,7 +20,7 @@ impl Reader { } pub async fn decode(&mut self) -> Result { - tracing::trace!( + log::trace!( "[READER] decode: attempting to decode {} (buffer_len={})", std::any::type_name::(), self.buffer.len() @@ -38,7 +34,7 @@ impl Reader { Ok(msg) => { let consumed = cursor.position() as usize; self.buffer.advance(consumed); - tracing::debug!( + log::debug!( "[READER] decode: successfully decoded {} (consumed={} bytes, buffer_remaining={})", std::any::type_name::(), consumed, @@ -48,7 +44,7 @@ impl Reader { } Err(DecodeError::More(required)) => { let total_needed = self.buffer.len() + required; - tracing::trace!( + log::trace!( "[READER] decode: need more data for {} (current={} bytes, need={} more, total_required={})", std::any::type_name::(), self.buffer.len(), @@ -58,7 +54,7 @@ impl Reader { total_needed } Err(err) => { - tracing::error!( + log::error!( "[READER] decode: ERROR decoding {} - {:?} (buffer_len={})", std::any::type_name::(), err, @@ -73,7 +69,7 @@ impl Reader { loop { let before_read = self.buffer.len(); if self.stream.read_buf(&mut self.buffer).await?.is_none() { - tracing::warn!( + log::warn!( "[READER] decode: stream ended while waiting for data (have={} bytes, need={})", self.buffer.len(), required @@ -82,14 +78,14 @@ impl Reader { }; let read_amount = self.buffer.len() - before_read; - tracing::trace!( + log::trace!( "[READER] decode: read {} bytes from stream (buffer_len={})", read_amount, self.buffer.len() ); if self.buffer.len() >= required { - tracing::trace!( + log::trace!( "[READER] decode: have enough data now (buffer_len={}), retrying decode", self.buffer.len() ); @@ -100,7 +96,7 @@ impl Reader { } pub async fn read_chunk(&mut self, max: usize) -> Result, SessionError> { - tracing::trace!( + log::trace!( "[READER] read_chunk: requested max={} bytes (buffer_len={})", max, self.buffer.len() @@ -109,7 +105,7 @@ impl Reader { if !self.buffer.is_empty() { let size = cmp::min(max, self.buffer.len()); let data = self.buffer.split_to(size).freeze(); - tracing::trace!( + log::trace!( "[READER] read_chunk: returned {} bytes from buffer (buffer_remaining={})", data.len(), self.buffer.len() @@ -119,9 +115,9 @@ impl Reader { let chunk = self.stream.read(max).await?; if let Some(ref data) = chunk { - tracing::trace!("[READER] read_chunk: read {} bytes from stream", data.len()); + log::trace!("[READER] read_chunk: read {} bytes from stream", data.len()); } else { - tracing::trace!("[READER] read_chunk: stream returned None"); + log::trace!("[READER] read_chunk: stream returned None"); } Ok(chunk) } diff --git a/moq-transport/src/session/subscribe.rs b/moq-transport/src/session/subscribe.rs index ce63e6c3..f9bc1f47 100644 --- a/moq-transport/src/session/subscribe.rs +++ b/moq-transport/src/session/subscribe.rs @@ -1,13 +1,8 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::ops; use crate::{ - coding::{KeyValuePairs, Location, TrackNamespace}, - data, - message::{self, FilterType, GroupOrder}, + coding::{KeyValuePairs, TrackNamespace}, + data, message, serve::{self, ServeError, TrackWriter, TrackWriterMode}, }; @@ -21,22 +16,6 @@ pub struct SubscribeInfo { pub id: u64, pub track_namespace: TrackNamespace, pub track_name: String, - - /// Subscriber Priority - pub subscriber_priority: u8, - pub group_order: GroupOrder, - - /// Forward Flag - pub forward: bool, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, - /// Optional parameters pub params: KeyValuePairs, @@ -50,12 +29,6 @@ impl SubscribeInfo { id: msg.id, track_namespace: msg.track_namespace.clone(), track_name: msg.track_name.clone(), - subscriber_priority: msg.subscriber_priority, - group_order: msg.group_order, - forward: msg.forward, - filter_type: msg.filter_type, - start_location: msg.start_location, - end_group_id: msg.end_group_id, params: msg.params.clone(), track_status: false, } @@ -97,13 +70,6 @@ impl Subscribe { id: request_id, track_namespace: track.namespace.clone(), track_name: track.name.clone(), - // TODO add prioritization logic on the publisher side - subscriber_priority: 127, // default to mid value, see: https://github.com/moq-wg/moq-transport/issues/504 - group_order: GroupOrder::Publisher, // defer to publisher send order - forward: true, // default to forwarding objects - filter_type: FilterType::LargestObject, - start_location: None, - end_group_id: None, params: Default::default(), }; let info = SubscribeInfo::new_from_subscribe(&subscribe_message); @@ -140,32 +106,12 @@ impl Subscribe { .await; } } - - pub async fn ok(&self) -> Result<(), ServeError> { - loop { - { - let state = self.state.lock(); - state.closed.clone()?; - - if state.ok { - return Ok(()); - } - - match state.modified() { - Some(notify) => notify, - None => return Err(ServeError::Done), - } - } - .await; - } - } } impl Drop for Subscribe { fn drop(&mut self) { self.subscriber .send_message(message::Unsubscribe { id: self.info.id }); - self.subscriber.remove_subscribe(self.info.id); } } @@ -229,16 +175,20 @@ impl SubscribeRecv { _ => return Err(ServeError::Mode), }; - let writer = subgroups.create(serve::Subgroup { + let result = subgroups.create(serve::Subgroup { group_id: header.group_id, // When subgroup_id is not present in the header type, it implicitly means subgroup 0 subgroup_id: header.subgroup_id.unwrap_or(0), - priority: header.publisher_priority, - })?; + // When priority is not present (NoPriority header types), default to 0 + priority: header.publisher_priority.unwrap_or(0), + // Preserve the incoming header type for forwarding + header_type: Some(header.header_type), + }); + // Always put writer back, even on error, to avoid losing it self.writer = Some(subgroups.into()); - Ok(writer) + result } pub fn datagram(&mut self, datagram: data::Datagram) -> Result<(), ServeError> { @@ -248,23 +198,39 @@ impl SubscribeRecv { TrackWriterMode::Track(track) => { // convert Track -> Datagrams writer, write, then put Datagrams back let mut datagrams = track.datagrams()?; + // Determine status from datagram type or explicit status field + let status = if datagram.datagram_type.is_end_of_group() { + Some(crate::data::ObjectStatus::EndOfGroup) + } else { + datagram.status + }; datagrams.write(serve::Datagram { group_id: datagram.group_id, object_id: datagram.object_id.unwrap_or(0), - priority: datagram.publisher_priority, + // When priority is not present (NoPriority datagram types), default to 0 + priority: datagram.publisher_priority.unwrap_or(0), payload: datagram.payload.unwrap_or_default(), extension_headers: datagram.extension_headers.unwrap_or_default(), + status, })?; self.writer = Some(TrackWriterMode::Datagrams(datagrams)); Ok(()) } TrackWriterMode::Datagrams(mut datagrams) => { + // Determine status from datagram type or explicit status field + let status = if datagram.datagram_type.is_end_of_group() { + Some(crate::data::ObjectStatus::EndOfGroup) + } else { + datagram.status + }; datagrams.write(serve::Datagram { group_id: datagram.group_id, object_id: datagram.object_id.unwrap_or(0), - priority: datagram.publisher_priority, + // When priority is not present (NoPriority datagram types), default to 0 + priority: datagram.publisher_priority.unwrap_or(0), payload: datagram.payload.unwrap_or_default(), extension_headers: datagram.extension_headers.unwrap_or_default(), + status, })?; self.writer = Some(TrackWriterMode::Datagrams(datagrams)); Ok(()) diff --git a/moq-transport/src/session/subscribe_namespace.rs b/moq-transport/src/session/subscribe_namespace.rs new file mode 100644 index 00000000..ddce76ed --- /dev/null +++ b/moq-transport/src/session/subscribe_namespace.rs @@ -0,0 +1,146 @@ +use std::ops; + +use crate::coding::TrackNamespace; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::Subscriber; + +#[derive(Debug, Clone)] +pub struct SubscribeNsInfo { + pub request_id: u64, + pub namespace_prefix: TrackNamespace, +} + +struct SubscribeNsState { + ok: bool, + closed: Result<(), ServeError>, +} + +impl Default for SubscribeNsState { + fn default() -> Self { + Self { + ok: false, + closed: Ok(()), + } + } +} + +/// Represents an outbound SUBSCRIBE_NAMESPACE request (subscriber side). +/// When dropped, sends UNSUBSCRIBE_NAMESPACE to the peer. +#[must_use = "sends UNSUBSCRIBE_NAMESPACE on drop"] +pub struct SubscribeNs { + subscriber: Subscriber, + state: State, + + pub info: SubscribeNsInfo, +} + +impl SubscribeNs { + pub(super) fn new( + mut subscriber: Subscriber, + request_id: u64, + namespace_prefix: TrackNamespace, + params: crate::coding::KeyValuePairs, + ) -> (SubscribeNs, SubscribeNsRecv) { + let info = SubscribeNsInfo { + request_id, + namespace_prefix: namespace_prefix.clone(), + }; + + let mut msg = message::SubscribeNamespace::new( + request_id, + namespace_prefix, + 1, + ); + msg.params = params; + subscriber.send_message(msg); + + let (send, recv) = State::default().split(); + + let send = Self { + subscriber, + info, + state: send, + }; + let recv = SubscribeNsRecv { state: recv }; + + (send, recv) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } + + pub async fn ok(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + if state.ok { + return Ok(()); + } + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } +} + +impl Drop for SubscribeNs { + fn drop(&mut self) { + // In draft-16, SUBSCRIBE_NAMESPACE uses its own bidirectional stream. + // Closing the stream implicitly unsubscribes. + } +} + +impl ops::Deref for SubscribeNs { + type Target = SubscribeNsInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +pub(super) struct SubscribeNsRecv { + state: State, +} + +impl SubscribeNsRecv { + pub fn recv_ok(&mut self) -> Result<(), ServeError> { + if let Some(mut state) = self.state.lock_mut() { + if state.ok { + return Err(ServeError::Duplicate); + } + + state.ok = true; + } + + Ok(()) + } + + pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } +} diff --git a/moq-transport/src/session/subscribe_namespace_received.rs b/moq-transport/src/session/subscribe_namespace_received.rs new file mode 100644 index 00000000..3debd381 --- /dev/null +++ b/moq-transport/src/session/subscribe_namespace_received.rs @@ -0,0 +1,151 @@ +use std::ops; + +use crate::coding::{KeyValuePairs, ReasonPhrase, TrackNamespace}; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::Publisher; + +#[derive(Debug, Clone)] +pub struct SubscribeNamespaceReceivedInfo { + pub request_id: u64, + pub namespace_prefix: TrackNamespace, + pub params: KeyValuePairs, +} + +struct SubscribeNamespaceReceivedState { + closed: Result<(), ServeError>, +} + +impl Default for SubscribeNamespaceReceivedState { + fn default() -> Self { + Self { closed: Ok(()) } + } +} + +#[must_use = "sends SUBSCRIBE_NAMESPACE_ERROR on drop if not accepted"] +pub struct SubscribeNamespaceReceived { + publisher: Publisher, + state: State, + pub info: SubscribeNamespaceReceivedInfo, + ok: bool, +} + +impl SubscribeNamespaceReceived { + pub(super) fn new( + publisher: Publisher, + request_id: u64, + namespace_prefix: TrackNamespace, + params: KeyValuePairs, + ) -> (Self, SubscribeNamespaceReceivedRecv) { + let info = SubscribeNamespaceReceivedInfo { + request_id, + namespace_prefix: namespace_prefix.clone(), + params, + }; + + let (send, recv) = State::default().split(); + + let send = Self { + publisher, + info, + state: send, + ok: false, + }; + + let recv = SubscribeNamespaceReceivedRecv { + state: recv, + namespace_prefix, + }; + + (send, recv) + } + + pub fn ok(&mut self) -> Result<(), ServeError> { + if self.ok { + return Err(ServeError::Duplicate); + } + + self.publisher.send_message(message::RequestOk { + id: self.info.request_id, + params: Default::default(), + }); + + self.ok = true; + + Ok(()) + } + + pub fn reject(mut self, error_code: u64, reason: &str) -> Result<(), ServeError> { + self.publisher.send_message(message::RequestError { + id: self.info.request_id, + error_code, + retry_interval: 0, + reason_phrase: ReasonPhrase(reason.to_string()), + }); + + self.ok = true; + + Ok(()) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } +} + +impl ops::Deref for SubscribeNamespaceReceived { + type Target = SubscribeNamespaceReceivedInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +impl Drop for SubscribeNamespaceReceived { + fn drop(&mut self) { + if self.ok { + return; + } + + self.publisher.send_message(message::RequestError { + id: self.info.request_id, + error_code: ServeError::NotFound.code(), + retry_interval: 0, + reason_phrase: ReasonPhrase("SUBSCRIBE_NAMESPACE not handled".to_string()), + }); + } +} + +pub(super) struct SubscribeNamespaceReceivedRecv { + state: State, + namespace_prefix: TrackNamespace, +} + +impl SubscribeNamespaceReceivedRecv { + pub fn namespace_prefix(&self) -> &TrackNamespace { + &self.namespace_prefix + } + + pub fn recv_unsubscribe(&mut self) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Cancel); + } + + Ok(()) + } +} diff --git a/moq-transport/src/session/subscribed.rs b/moq-transport/src/session/subscribed.rs index 5847e6e4..fc17c129 100644 --- a/moq-transport/src/session/subscribed.rs +++ b/moq-transport/src/session/subscribed.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::ops; use std::sync::{Arc, Mutex}; @@ -24,7 +20,20 @@ struct SubscribedState { closed: Result<(), ServeError>, } +impl Default for SubscribedState { + fn default() -> Self { + Self { + largest_location: None, + closed: Ok(()), + } + } +} + impl SubscribedState { + fn is_closed(&self) -> bool { + self.closed.is_err() + } + fn update_largest_location(&mut self, group_id: u64, object_id: u64) -> Result<(), ServeError> { if let Some(current_largest_location) = self.largest_location { let update_largest_location = Location::new(group_id, object_id); @@ -37,15 +46,6 @@ impl SubscribedState { } } -impl Default for SubscribedState { - fn default() -> Self { - Self { - largest_location: None, - closed: Ok(()), - } - } -} - pub struct Subscribed { /// The sessions Publisher manager, used to send control messages, /// create new QUIC streams, and send datagrams @@ -70,7 +70,7 @@ impl Subscribed { msg: message::Subscribe, mlog: Option>>, ) -> (Self, SubscribedRecv) { - let (send, recv) = State::default().split(); + let (send, recv) = State::new(SubscribedState::default()).split(); let info = SubscribeInfo::new_from_subscribe(&msg); let send = Self { publisher, @@ -106,14 +106,12 @@ impl Subscribed { // Send SubscribeOk using send_message_and_wait to ensure it is sent at least to the QUIC stack before // we start serving the track. If a subscriber gets the stream before SubscribeOk // then they won't recognize the track_alias in the stream header. + let track_alias = self.publisher.next_track_alias(); self.publisher .send_message_and_wait(message::SubscribeOk { id: self.info.id, - track_alias: self.info.id, // use subscription id as track alias - expires: 0, // TODO SLG - group_order: message::GroupOrder::Descending, // TODO: resolve correct value from publisher / subscriber prefs - content_exists: largest_location.is_some(), - largest_location, + track_alias, + track_extensions: Default::default(), params: Default::default(), }) .await; @@ -124,8 +122,12 @@ impl Subscribed { match track.mode().await? { // TODO cancel track/datagrams on closed TrackReaderMode::Stream(_stream) => panic!("deprecated"), - TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups).await, - TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + TrackReaderMode::Subgroups(subgroups) => { + self.serve_subgroups(subgroups, track_alias).await + } + TrackReaderMode::Datagrams(datagrams) => { + self.serve_datagrams(datagrams, track_alias).await + } } } @@ -182,9 +184,10 @@ impl Drop for Subscribed { reason: ReasonPhrase(err.to_string()), }); } else { - self.publisher.send_message(message::SubscribeError { + self.publisher.send_message(message::RequestError { id: self.info.id, error_code: err.code(), + retry_interval: 0, reason_phrase: ReasonPhrase(err.to_string()), }); }; @@ -195,6 +198,7 @@ impl Subscribed { async fn serve_subgroups( &mut self, mut subgroups: serve::SubgroupsReader, + track_alias: u64, ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); let mut done: Option> = None; @@ -203,12 +207,19 @@ impl Subscribed { tokio::select! { res = subgroups.next(), if done.is_none() => match res { Ok(Some(subgroup)) => { + // Use preserved header type if available, otherwise default to SubgroupIdExt + let header_type = subgroup.info.header_type.unwrap_or(data::StreamHeaderType::SubgroupIdExt); + let subgroup_id = if header_type.has_subgroup_id() { + Some(subgroup.subgroup_id) + } else { + None + }; let header = data::SubgroupHeader { - header_type: data::StreamHeaderType::SubgroupIdExt, // SubGroupId = Yes, Extensions = Yes, ContainsEndOfGroup = No - track_alias: self.info.id, // use subscription id as track_alias + header_type, + track_alias, group_id: subgroup.group_id, - subgroup_id: Some(subgroup.subgroup_id), - publisher_priority: subgroup.priority, + subgroup_id, + publisher_priority: Some(subgroup.priority), }; let publisher = self.publisher.clone(); @@ -218,7 +229,7 @@ impl Subscribed { tasks.push(async move { if let Err(err) = Self::serve_subgroup(header, subgroup, publisher, state, mlog).await { - tracing::warn!("failed to serve subgroup: {:?}, error: {}", info, err); + log::warn!("failed to serve subgroup: {:?}, error: {}", info, err); } }); }, @@ -239,7 +250,7 @@ impl Subscribed { state: State, mlog: Option>>, ) -> Result<(), SessionError> { - tracing::debug!( + log::debug!( "[PUBLISHER] serve_subgroup: starting - group_id={}, subgroup_id={:?}, priority={}", subgroup_reader.group_id, subgroup_reader.subgroup_id, @@ -247,15 +258,15 @@ impl Subscribed { ); let mut send_stream = publisher.open_uni().await?; - tracing::trace!("[PUBLISHER] serve_subgroup: opened unidirectional stream"); + log::trace!("[PUBLISHER] serve_subgroup: opened unidirectional stream"); // TODO figure out u32 vs u64 priority send_stream.set_priority(subgroup_reader.priority as i32); let mut writer = Writer::new(send_stream); - tracing::debug!( - "[PUBLISHER] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={}, header_type={:?}", + log::debug!( + "[PUBLISHER] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}, header_type={:?}", header.track_alias, header.group_id, header.subgroup_id, @@ -275,47 +286,78 @@ impl Subscribed { } } + let has_extension_headers = header.header_type.has_extension_headers(); let mut object_count = 0; while let Some(mut subgroup_object_reader) = subgroup_reader.next().await? { - let subgroup_object = data::SubgroupObjectExt { - object_id_delta: 0, // before delta logic, used to be subgroup_object_reader.object_id, - extension_headers: subgroup_object_reader.extension_headers.clone(), // Pass through extension headers - payload_length: subgroup_object_reader.size, - status: if subgroup_object_reader.size == 0 { - // Only set status if payload length is zero - Some(subgroup_object_reader.status) - } else { - None - }, - }; + if state.lock().is_closed() { + log::debug!( + "[PUBLISHER] serve_subgroup: subscription cancelled, stopping (group_id={}, subgroup_id={:?}, {} objects sent)", + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + object_count + ); + return Ok(()); + } - tracing::debug!( - "[PUBLISHER] serve_subgroup: sending object #{} - object_id={}, object_id_delta={}, payload_length={}, status={:?}, extension_headers={:?}", - object_count + 1, - subgroup_object_reader.object_id, - subgroup_object.object_id_delta, - subgroup_object.payload_length, - subgroup_object.status, - subgroup_object.extension_headers - ); + // Encode object based on header type - must match what receiver expects + if has_extension_headers { + let subgroup_object = data::SubgroupObjectExt { + object_id_delta: 0, + extension_headers: subgroup_object_reader.extension_headers.clone(), + payload_length: subgroup_object_reader.size, + status: if subgroup_object_reader.size == 0 { + Some(subgroup_object_reader.status) + } else { + None + }, + }; - writer.encode(&subgroup_object).await?; + log::debug!( + "[PUBLISHER] serve_subgroup: sending object #{} (ext) - object_id={}, payload_length={}, status={:?}, extension_headers={:?}", + object_count + 1, + subgroup_object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status, + subgroup_object.extension_headers + ); - // Log subgroup object created/sent - if let Some(ref mlog) = mlog { - if let Ok(mut mlog_guard) = mlog.lock() { - let time = mlog_guard.elapsed_ms(); - let stream_id = 0; // TODO: Placeholder, need actual QUIC stream ID - let event = mlog::subgroup_object_ext_created( - time, - stream_id, - subgroup_reader.group_id, - subgroup_reader.subgroup_id, - subgroup_object_reader.object_id, - &subgroup_object, - ); - let _ = mlog_guard.add_event(event); + writer.encode(&subgroup_object).await?; + + if let Some(ref mlog) = mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let event = mlog::subgroup_object_ext_created( + time, + stream_id, + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + subgroup_object_reader.object_id, + &subgroup_object, + ); + let _ = mlog_guard.add_event(event); + } } + } else { + let subgroup_object = data::SubgroupObject { + object_id_delta: 0, + payload_length: subgroup_object_reader.size, + status: if subgroup_object_reader.size == 0 { + Some(subgroup_object_reader.status) + } else { + None + }, + }; + + log::debug!( + "[PUBLISHER] serve_subgroup: sending object #{} - object_id={}, payload_length={}, status={:?}", + object_count + 1, + subgroup_object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status + ); + + writer.encode(&subgroup_object).await?; } state @@ -329,7 +371,14 @@ impl Subscribed { let mut chunks_sent = 0; let mut bytes_sent = 0; while let Some(chunk) = subgroup_object_reader.read().await? { - tracing::trace!( + if state.lock().is_closed() { + log::debug!( + "[PUBLISHER] serve_subgroup: subscription cancelled during payload transfer" + ); + return Ok(()); + } + + log::trace!( "[PUBLISHER] serve_subgroup: sending payload chunk #{} for object #{} ({} bytes)", chunks_sent + 1, object_count + 1, @@ -340,7 +389,7 @@ impl Subscribed { chunks_sent += 1; } - tracing::trace!( + log::trace!( "[PUBLISHER] serve_subgroup: completed object #{} ({} chunks, {} bytes total)", object_count + 1, chunks_sent, @@ -349,7 +398,7 @@ impl Subscribed { object_count += 1; } - tracing::info!( + log::info!( "[PUBLISHER] serve_subgroup: completed subgroup (group_id={}, subgroup_id={:?}, {} objects sent)", subgroup_reader.group_id, subgroup_reader.subgroup_id, @@ -362,12 +411,20 @@ impl Subscribed { async fn serve_datagrams( &mut self, mut datagrams: serve::DatagramsReader, + track_alias: u64, ) -> Result<(), SessionError> { - tracing::debug!("[PUBLISHER] serve_datagrams: starting"); + log::debug!("[PUBLISHER] serve_datagrams: starting"); let mut datagram_count = 0; while let Some(datagram) = datagrams.read().await? { - // Determine datagram type based on extension headers presence + if self.state.lock().is_closed() { + log::debug!( + "[PUBLISHER] serve_datagrams: subscription cancelled, stopping ({} datagrams sent)", + datagram_count + ); + return Ok(()); + } + let has_extension_headers = !datagram.extension_headers.is_empty(); let datagram_type = if has_extension_headers { data::DatagramType::ObjectIdPayloadExt @@ -377,10 +434,10 @@ impl Subscribed { let encoded_datagram = data::Datagram { datagram_type, - track_alias: self.info.id, // use subscription id as track_alias + track_alias, group_id: datagram.group_id, object_id: Some(datagram.object_id), - publisher_priority: datagram.priority, + publisher_priority: Some(datagram.priority), extension_headers: if has_extension_headers { Some(datagram.extension_headers.clone()) } else { @@ -398,8 +455,8 @@ impl Subscribed { let mut buffer = bytes::BytesMut::with_capacity(payload_len + 100); encoded_datagram.encode(&mut buffer)?; - tracing::debug!( - "[PUBLISHER] serve_datagrams: sending datagram #{} - track_alias={}, group_id={}, object_id={}, priority={}, payload_len={}, extension_headers={:?}, total_encoded_len={}", + log::debug!( + "[PUBLISHER] serve_datagrams: sending datagram #{} - track_alias={}, group_id={}, object_id={}, priority={:?}, payload_len={}, extension_headers={:?}, total_encoded_len={}", datagram_count + 1, encoded_datagram.track_alias, encoded_datagram.group_id, @@ -436,7 +493,7 @@ impl Subscribed { datagram_count += 1; } - tracing::info!( + log::info!( "[PUBLISHER] serve_datagrams: completed ({} datagrams sent)", datagram_count ); diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index 2653abee..199491ab 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{ collections::{hash_map, HashMap}, io, @@ -21,41 +17,40 @@ use crate::{ use crate::watch::Queue; -use super::{Announced, AnnouncedRecv, Reader, Session, SessionError, Subscribe, SubscribeRecv}; +use super::{ + PublishNamespaceReceived, PublishNamespaceReceivedRecv, PublishReceived, PublishReceivedRecv, + Reader, Session, SessionError, Subscribe, SubscribeNs, SubscribeNsRecv, SubscribeRecv, +}; // Default timeout for waiting for subscribe aliases to become available via SUBSCRIBE_OK (1 second) const DEFAULT_ALIAS_WAIT_TIME_MS: u64 = 1000; -// TODO remove Clone. #[derive(Clone)] pub struct Subscriber { - /// The currently active inbound announces, keyed by namespace. - announced: Arc>>, + publish_namespaces_received: Arc>>, + + publish_namespace_received_queue: Queue, - /// Queue of announced namespaces we have received from the Publisher, waiting to be processed. - announced_queue: Queue, + subscribe_namespaces: Arc>>, - /// The currently active outbound subscribes, keyed by request id. subscribes: Arc>>, - /// Map of track alias to subscription id for quick lookup when receiving streams/datagrams. subscribe_alias_map: Arc>>, - /// Notify when subscribe alias map is updated subscribe_alias_notify: Arc, - /// The queue we will write any outbound control messages we want to send, the session run_send task - /// will process the queue and send the message on the control stream. + publishes_received: Arc>>, + + publish_received_queue: Queue, + + publish_alias_map: Arc>>, + + publish_alias_notify: Arc, + outgoing: Queue, - /// When we need a new Request Id for sending a request, we can get it from here. Note: The instance - /// of AtomicU64 is shared with the Subscriber, so the session uses unique request ids for all requests - /// generated. Note: If we initiated the QUIC connection then request id's start at 0 and increment by 2 - /// for each request (even numbers). If we accepted an inbound QUIC connection then request id's start at 1 and - /// increment by 2 for each request (odd numbers). next_requestid: Arc, - /// Optional mlog writer for logging transport events mlog: Option>>, } @@ -66,42 +61,44 @@ impl Subscriber { mlog: Option>>, ) -> Self { Self { - announced: Default::default(), - announced_queue: Default::default(), + publish_namespaces_received: Default::default(), + publish_namespace_received_queue: Default::default(), + subscribe_namespaces: Default::default(), subscribes: Default::default(), subscribe_alias_map: Default::default(), + subscribe_alias_notify: Arc::new(Notify::new()), + publishes_received: Default::default(), + publish_received_queue: Default::default(), + publish_alias_map: Default::default(), + publish_alias_notify: Arc::new(Notify::new()), outgoing, next_requestid, mlog, - subscribe_alias_notify: Arc::new(Notify::new()), } } /// Create an inbound/server QUIC connection, by accepting a bi-directional QUIC stream for control messages. - pub async fn accept( - session: web_transport::Session, - transport: super::Transport, - ) -> Result<(Session, Self), SessionError> { - let (session, _, subscriber) = Session::accept(session, None, transport).await?; + pub async fn accept(session: web_transport::Session) -> Result<(Session, Self), SessionError> { + let (session, _, subscriber) = Session::accept(session, None).await?; Ok((session, subscriber.unwrap())) } /// Create an outbound/client QUIC connection, by opening a bi-directional QUIC stream for control messages. - pub async fn connect( - session: web_transport::Session, - transport: super::Transport, - ) -> Result<(Session, Self), SessionError> { - let (session, _, subscriber) = Session::connect(session, None, transport).await?; + pub async fn connect(session: web_transport::Session) -> Result<(Session, Self), SessionError> { + let (session, _, subscriber) = Session::connect(session, None).await?; Ok((session, subscriber)) } - /// Wait for the next announced namespace from the publisher, if any. - pub async fn announced(&mut self) -> Option { - self.announced_queue.pop().await + pub async fn publish_ns_recvd(&mut self) -> Option { + self.publish_namespace_received_queue.pop().await + } + + pub async fn publish_received(&mut self) -> Option { + self.publish_received_queue.pop().await } /// Get the current next request id to use and increment the value for by 2 for the next request - fn get_next_request_id(&self) -> u64 { + pub fn get_next_request_id(&self) -> u64 { self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed) } @@ -123,88 +120,92 @@ impl Subscriber { /// Subscribe to a track by creating a new subscribe request to the publisher. Block until subscription is closed. pub async fn subscribe(&mut self, track: serve::TrackWriter) -> Result<(), ServeError> { - let subscribe = self.subscribe_open(track).await?; - subscribe.closed().await + let request_id = self.get_next_request_id(); + let (send, recv) = Subscribe::new(self.clone(), request_id, track); + self.subscribes.lock().unwrap().insert(request_id, recv); + + send.closed().await + } + + pub fn subscribe_ns( + &mut self, + namespace_prefix: TrackNamespace, + ) -> Result { + self.subscribe_ns_with_params(namespace_prefix, crate::coding::KeyValuePairs::new()) } - /// Subscribe to a track and wait until the publisher acknowledges it. - pub async fn subscribe_open( + pub fn subscribe_ns_with_params( &mut self, - track: serve::TrackWriter, - ) -> Result { + namespace_prefix: TrackNamespace, + params: crate::coding::KeyValuePairs, + ) -> Result { let request_id = self.get_next_request_id(); - let (send, recv) = Subscribe::new(self.clone(), request_id, track); - self.subscribes.lock().unwrap().insert(request_id, recv); - send.ok().await?; + + let mut subscribe_namespaces = self.subscribe_namespaces.lock().unwrap(); + let entry = match subscribe_namespaces.entry(request_id) { + hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate), + hash_map::Entry::Vacant(entry) => entry, + }; + + let (send, recv) = SubscribeNs::new(self.clone(), request_id, namespace_prefix, params); + entry.insert(recv); + Ok(send) } /// Send a message to the publisher via the control stream. - pub(super) fn send_message>(&mut self, msg: M) { + pub fn send_message>(&mut self, msg: M) { let msg = msg.into(); // Remove our entry on terminal state. - match &msg { - message::Subscriber::PublishNamespaceCancel(msg) => { - self.drop_publish_namespace(&msg.track_namespace) - } - // TODO SLG - there is no longer a namespace in the error, need to map via request id - message::Subscriber::PublishNamespaceError(_msg) => {} // Not implemented yet - need request id mapping - _ => {} + if let message::Subscriber::PublishNamespaceCancel(msg) = &msg { + self.drop_publish_namespace(&msg.track_namespace) } // TODO report dropped messages? let _ = self.outgoing.push(msg.into()); } - /// Receive a message from the publisher via the control stream. pub(super) fn recv_message(&mut self, msg: message::Publisher) -> Result<(), SessionError> { let res = match &msg { message::Publisher::PublishNamespace(msg) => self.recv_publish_namespace(msg), - message::Publisher::PublishNamespaceDone(msg) => self.recv_publish_namespace_done(msg), - message::Publisher::Publish(_msg) => Err(SessionError::unimplemented("PUBLISH")), + message::Publisher::PublishNamespaceDone(msg) => self.recv_publish_ns_done(msg), + message::Publisher::Namespace(msg) => self.recv_namespace(msg), + message::Publisher::Publish(msg) => self.recv_publish(msg), message::Publisher::PublishDone(msg) => self.recv_publish_done(msg), message::Publisher::SubscribeOk(msg) => self.recv_subscribe_ok(msg), - message::Publisher::SubscribeError(msg) => self.recv_subscribe_error(msg), + message::Publisher::RequestError(msg) => self.recv_request_error(msg), message::Publisher::TrackStatusOk(msg) => self.recv_track_status_ok(msg), - message::Publisher::TrackStatusError(_msg) => { - Err(SessionError::unimplemented("TRACK_STATUS_ERROR")) - } message::Publisher::FetchOk(_msg) => Err(SessionError::unimplemented("FETCH_OK")), - message::Publisher::FetchError(_msg) => Err(SessionError::unimplemented("FETCH_ERROR")), - message::Publisher::SubscribeNamespaceOk(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE_OK")) - } - message::Publisher::SubscribeNamespaceError(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE_ERROR")) - } + message::Publisher::RequestOk(msg) => self.recv_request_ok(msg), }; if let Err(SessionError::Serve(err)) = res { - tracing::debug!("failed to process message: {:?} {}", msg, err); + log::debug!("failed to process message: {:?} {}", msg, err); return Ok(()); } res } - /// Handle the reception of a PublishNamespace message from the publisher. fn recv_publish_namespace( &mut self, msg: &message::PublishNamespace, ) -> Result<(), SessionError> { - let mut announces = self.announced.lock().unwrap(); + let mut entries = self.publish_namespaces_received.lock().unwrap(); - // Check for duplicate namespace announcement - let entry = match announces.entry(msg.track_namespace.clone()) { + let entry = match entries.entry(msg.track_namespace.clone()) { hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), hash_map::Entry::Vacant(entry) => entry, }; - // Create the announced namespace and insert it into our map of active announces, and the announced queue. - let (announced, recv) = Announced::new(self.clone(), msg.id, msg.track_namespace.clone()); - if let Err(announced) = self.announced_queue.push(announced) { - announced.close(ServeError::Cancel)?; + let (publish_ns_received, recv) = + PublishNamespaceReceived::new(self.clone(), msg.id, msg.track_namespace.clone()); + if let Err(publish_ns_received) = self + .publish_namespace_received_queue + .push(publish_ns_received) + { + publish_ns_received.close(ServeError::Cancel)?; return Ok(()); } entry.insert(recv); @@ -212,18 +213,59 @@ impl Subscriber { Ok(()) } - /// Handle the reception of a PublishNamespaceDone message from the publisher. - fn recv_publish_namespace_done( + fn recv_publish_ns_done( &mut self, msg: &message::PublishNamespaceDone, ) -> Result<(), SessionError> { - if let Some(announce) = self.announced.lock().unwrap().remove(&msg.track_namespace) { - announce.recv_unannounce()?; + if let Some(entry) = self + .publish_namespaces_received + .lock() + .unwrap() + .remove(&msg.track_namespace) + { + entry.recv_done()?; } Ok(()) } + /// Handle NAMESPACE message (draft-16) - relay forwards this in response to SUBSCRIBE_NAMESPACE + fn recv_namespace(&mut self, msg: &message::Namespace) -> Result<(), SessionError> { + log::info!( + "received NAMESPACE for {:?} (request_id={})", + msg.track_namespace, + msg.id + ); + // TODO: Implement proper handling - notify the SUBSCRIBE_NAMESPACE handler + // For now, just log and accept + Ok(()) + } + + fn recv_publish(&mut self, msg: &message::Publish) -> Result<(), SessionError> { + let mut entries = self.publishes_received.lock().unwrap(); + + let entry = match entries.entry(msg.id) { + hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), + hash_map::Entry::Vacant(entry) => entry, + }; + + let (publish_received, recv) = PublishReceived::new(self.clone(), msg); + + self.publish_alias_map + .lock() + .unwrap() + .insert(msg.track_alias, msg.id); + self.publish_alias_notify.notify_waiters(); + + if let Err(publish_received) = self.publish_received_queue.push(publish_received) { + publish_received.close(ServeError::Cancel)?; + return Ok(()); + } + entry.insert(recv); + + Ok(()) + } + /// Handle the reception of a SubscribeOk message from the publisher. fn recv_subscribe_ok(&mut self, msg: &message::SubscribeOk) -> Result<(), SessionError> { if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&msg.id) { @@ -244,7 +286,7 @@ impl Subscriber { } /// Remove a subscribe from our map of active subscribes, and the alias map if present. - pub(super) fn remove_subscribe(&mut self, id: u64) -> Option { + fn remove_subscribe(&mut self, id: u64) -> Option { if let Some(subscribe) = self.subscribes.lock().unwrap().remove(&id) { // Remove from alias map if present if let Some(track_alias) = subscribe.track_alias() { @@ -259,19 +301,45 @@ impl Subscriber { } } - /// Handle the reception of a SubscribeError message from the publisher. - fn recv_subscribe_error(&mut self, msg: &message::SubscribeError) -> Result<(), SessionError> { + fn recv_request_ok(&mut self, msg: &message::RequestOk) -> Result<(), SessionError> { + if let Some(subscribe_ns) = self.subscribe_namespaces.lock().unwrap().get_mut(&msg.id) { + subscribe_ns.recv_ok()?; + return Ok(()); + } + + log::warn!( + "[SUBSCRIBER] recv_request_ok: request id {} not found", + msg.id + ); + Ok(()) + } + + fn recv_request_error(&mut self, msg: &message::RequestError) -> Result<(), SessionError> { if let Some(subscribe) = self.remove_subscribe(msg.id) { subscribe.error(ServeError::Closed(msg.error_code))?; + return Ok(()); + } + + if let Some(subscribe_ns) = self.subscribe_namespaces.lock().unwrap().remove(&msg.id) { + subscribe_ns.recv_error(ServeError::Closed(msg.error_code))?; + return Ok(()); } + log::warn!( + "[SUBSCRIBER] recv_request_error: request id {} not found", + msg.id + ); Ok(()) } - /// Handle the reception of a PublishDone message from the publisher. fn recv_publish_done(&mut self, msg: &message::PublishDone) -> Result<(), SessionError> { if let Some(subscribe) = self.remove_subscribe(msg.id) { subscribe.error(ServeError::Closed(msg.status_code))?; + return Ok(()); + } + + if let Some(mut publish_recv) = self.remove_publish_received(msg.id) { + publish_recv.recv_done()?; } Ok(()) @@ -285,23 +353,32 @@ impl Subscriber { Ok(()) } - /// Remove an announced namespace from our map of active announces. fn drop_publish_namespace(&mut self, namespace: &TrackNamespace) { - self.announced.lock().unwrap().remove(namespace); + self.publish_namespaces_received + .lock() + .unwrap() + .remove(namespace); + } + + fn remove_publish_received(&mut self, id: u64) -> Option { + if let Some(publish_recv) = self.publishes_received.lock().unwrap().remove(&id) { + if let Some(track_alias) = publish_recv.track_alias() { + self.publish_alias_map.lock().unwrap().remove(&track_alias); + } + Some(publish_recv) + } else { + None + } } - /// Get a subscribe id by track alias, waiting up to the specified timeout if not present. - /// If timeout_ms is None, only check if already present and return None if not. async fn get_subscribe_id_by_alias( &self, track_alias: u64, timeout_ms: Option, ) -> Option { - // If no timeout specified, don't wait let timeout_ms = match timeout_ms { Some(ms) => ms, None => { - // Just check once return self .subscribe_alias_map .lock() @@ -311,14 +388,11 @@ impl Subscriber { } }; - // Wait for it to appear, checking after each notification let timeout_duration = Duration::from_millis(timeout_ms); tokio::time::timeout(timeout_duration, async { loop { - // Register for notification before checking map let notified = self.subscribe_alias_notify.notified(); - // Check Map for alias if let Some(id) = self .subscribe_alias_map .lock() @@ -329,7 +403,45 @@ impl Subscriber { return id; } - // Alias not present yet, wait for notification + notified.await; + } + }) + .await + .ok() + } + + async fn get_publish_id_by_alias( + &self, + track_alias: u64, + timeout_ms: Option, + ) -> Option { + let timeout_ms = match timeout_ms { + Some(ms) => ms, + None => { + return self + .publish_alias_map + .lock() + .unwrap() + .get(&track_alias) + .cloned(); + } + }; + + let timeout_duration = Duration::from_millis(timeout_ms); + tokio::time::timeout(timeout_duration, async { + loop { + let notified = self.publish_alias_notify.notified(); + + if let Some(id) = self + .publish_alias_map + .lock() + .unwrap() + .get(&track_alias) + .cloned() + { + return id; + } + notified.await; } }) @@ -342,12 +454,12 @@ impl Subscriber { mut self, stream: web_transport::RecvStream, ) -> Result<(), SessionError> { - tracing::trace!("[SUBSCRIBER] recv_stream: new stream received, decoding header"); + log::trace!("[SUBSCRIBER] recv_stream: new stream received, decoding header"); let mut reader = Reader::new(stream); // Decode the stream header let stream_header: data::StreamHeader = reader.decode().await?; - tracing::debug!( + log::debug!( "[SUBSCRIBER] recv_stream: decoded stream header type={:?}", stream_header.header_type ); @@ -370,7 +482,7 @@ impl Subscriber { } let track_alias = stream_header.subgroup_header.as_ref().unwrap().track_alias; - tracing::trace!( + log::trace!( "[SUBSCRIBER] recv_stream: stream for subscription track_alias={}", track_alias ); @@ -378,7 +490,7 @@ impl Subscriber { let mlog = self.mlog.clone(); let res = self.recv_stream_inner(reader, stream_header, mlog).await; if let Err(SessionError::Serve(err)) = &res { - tracing::warn!( + log::warn!( "[SUBSCRIBER] recv_stream: stream processing error for track_alias={}: {:?}", track_alias, err @@ -395,7 +507,6 @@ impl Subscriber { res } - /// Continue handling the reception of a new stream from the QUIC session. async fn recv_stream_inner( &mut self, reader: Reader, @@ -403,24 +514,33 @@ impl Subscriber { mlog: Option>>, ) -> Result<(), SessionError> { let track_alias = stream_header.subgroup_header.as_ref().unwrap().track_alias; - tracing::trace!( + log::trace!( "[SUBSCRIBER] recv_stream_inner: processing stream for track_alias={}", track_alias ); - // This is super silly, but I couldn't figure out a way to avoid the mutex guard across awaits. enum Writer { - //Fetch(serve::FetchWriter), Subgroup(serve::SubgroupWriter), } + // First check both maps WITHOUT waiting - this is the fast path for subsequent groups + // where the alias mapping is already established + let (subscribe_id_immediate, publish_id_immediate) = { + let subscribe_id = self.get_subscribe_id_by_alias(track_alias, None).await; + let publish_id = self.get_publish_id_by_alias(track_alias, None).await; + (subscribe_id, publish_id) + }; + + log::debug!( + "[SUBSCRIBER] recv_stream_inner: track_alias={}, subscribe_id_immediate={:?}, publish_id_immediate={:?}", + track_alias, subscribe_id_immediate, publish_id_immediate + ); + + // Determine which path to use, waiting only if neither map has the alias yet let writer = { - // Look up the subscribe id for this track alias - if let Some(subscribe_id) = self - .get_subscribe_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) - .await - { - // Look up the subscribe by id + if let Some(subscribe_id) = subscribe_id_immediate { + // Found in subscribe map immediately + log::debug!("[SUBSCRIBER] recv_stream_inner: using SUBSCRIBE path (immediate)"); let mut subscribes = self.subscribes.lock().unwrap(); let subscribe = subscribes.get_mut(&subscribe_id).ok_or_else(|| { ServeError::not_found_ctx(format!( @@ -429,10 +549,37 @@ impl Subscriber { )) })?; - // Create the appropriate writer based on the stream header type if stream_header.header_type.is_subgroup() { - tracing::trace!("[SUBSCRIBER] recv_stream_inner: creating subgroup writer"); - Writer::Subgroup(subscribe.subgroup(stream_header.subgroup_header.unwrap())?) + log::trace!( + "[SUBSCRIBER] recv_stream_inner: creating subgroup writer from subscribe" + ); + Writer::Subgroup( + subscribe.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) + } else { + return Err(SessionError::Serve(ServeError::internal_ctx(format!( + "unsupported stream header type={}", + stream_header.header_type + )))); + } + } else if let Some(publish_id) = publish_id_immediate { + // Found in publish map immediately + log::debug!("[SUBSCRIBER] recv_stream_inner: using PUBLISH path (immediate)"); + let mut publishes = self.publishes_received.lock().unwrap(); + let publish_recv = publishes.get_mut(&publish_id).ok_or_else(|| { + ServeError::not_found_ctx(format!( + "publish_id={} not found for track_alias={}", + publish_id, track_alias + )) + })?; + + if stream_header.header_type.is_subgroup() { + log::trace!( + "[SUBSCRIBER] recv_stream_inner: creating subgroup writer from publish" + ); + Writer::Subgroup( + publish_recv.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) } else { return Err(SessionError::Serve(ServeError::internal_ctx(format!( "unsupported stream header type={}", @@ -440,24 +587,77 @@ impl Subscriber { )))); } } else { - return Err(SessionError::Serve(ServeError::not_found_ctx(format!( - "subscription track_alias={} not found", + // Not found in either map - wait for either to become available + // This only happens for the first stream before control messages establish the mapping + log::debug!( + "[SUBSCRIBER] recv_stream_inner: track_alias={} NOT FOUND in either map, WAITING for alias mapping", track_alias - )))); + ); + + // Race both lookups with timeout + let subscribe_fut = self.get_subscribe_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + let publish_fut = self.get_publish_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + + tokio::select! { + Some(subscribe_id) = subscribe_fut => { + let mut subscribes = self.subscribes.lock().unwrap(); + let subscribe = subscribes.get_mut(&subscribe_id).ok_or_else(|| { + ServeError::not_found_ctx(format!( + "subscribe_id={} not found for track_alias={}", + subscribe_id, track_alias + )) + })?; + + if stream_header.header_type.is_subgroup() { + Writer::Subgroup( + subscribe.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) + } else { + return Err(SessionError::Serve(ServeError::internal_ctx(format!( + "unsupported stream header type={}", + stream_header.header_type + )))); + } + } + Some(publish_id) = publish_fut => { + let mut publishes = self.publishes_received.lock().unwrap(); + let publish_recv = publishes.get_mut(&publish_id).ok_or_else(|| { + ServeError::not_found_ctx(format!( + "publish_id={} not found for track_alias={}", + publish_id, track_alias + )) + })?; + + if stream_header.header_type.is_subgroup() { + Writer::Subgroup( + publish_recv.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) + } else { + return Err(SessionError::Serve(ServeError::internal_ctx(format!( + "unsupported stream header type={}", + stream_header.header_type + )))); + } + } + else => { + return Err(SessionError::Serve(ServeError::not_found_ctx(format!( + "subscription track_alias={} not found", + track_alias + )))); + } + } } }; - // Handle the stream based on the writer type match writer { - //Writer::Fetch(fetch) => Self::recv_fetch(fetch, reader).await?, Writer::Subgroup(subgroup_writer) => { - tracing::trace!("[SUBSCRIBER] recv_stream_inner: receiving subgroup data"); + log::trace!("[SUBSCRIBER] recv_stream_inner: receiving subgroup data"); Self::recv_subgroup(stream_header.header_type, subgroup_writer, reader, mlog) .await? } }; - tracing::debug!( + log::debug!( "[SUBSCRIBER] recv_stream_inner: completed processing stream for track_alias={}", track_alias ); @@ -471,7 +671,7 @@ impl Subscriber { mut reader: Reader, mlog: Option>>, ) -> Result<(), SessionError> { - tracing::debug!( + log::debug!( "[SUBSCRIBER] recv_subgroup: starting - group_id={}, subgroup_id={}, priority={}", subgroup_writer.info.group_id, subgroup_writer.info.subgroup_id, @@ -481,7 +681,7 @@ impl Subscriber { let mut object_count = 0; let mut current_object_id = 0u64; while !reader.done().await? { - tracing::trace!( + log::trace!( "[SUBSCRIBER] recv_subgroup: reading object #{} (has_ext_headers={})", object_count + 1, stream_header_type.has_extension_headers() @@ -493,7 +693,7 @@ impl Subscriber { match stream_header_type.has_extension_headers() { true => { let object = reader.decode::().await?; - tracing::debug!( + log::debug!( "[SUBSCRIBER] recv_subgroup: object #{} with extension headers - object_id_delta={}, payload_length={}, status={:?}, extension_headers={:?}", object_count + 1, object.object_id_delta, @@ -506,12 +706,12 @@ impl Subscriber { // Check for Immutable Extensions (type 0xB = 11) if object.extension_headers.has(0xB) { - tracing::info!( + log::info!( "[SUBSCRIBER] recv_subgroup: object #{} contains IMMUTABLE EXTENSIONS (type 0xB) - will be forwarded", object_count + 1 ); if let Some(immutable_ext) = object.extension_headers.get(0xB) { - tracing::debug!( + log::debug!( "[SUBSCRIBER] recv_subgroup: immutable extension details: {:?}", immutable_ext ); @@ -520,18 +720,32 @@ impl Subscriber { // Check for Prior Group ID Gap (type 0x3C = 60) if object.extension_headers.has(0x3C) { - tracing::info!( + log::info!( "[SUBSCRIBER] recv_subgroup: object #{} contains PRIOR GROUP ID GAP (type 0x3C)", object_count + 1 ); if let Some(gap_ext) = object.extension_headers.get(0x3C) { - tracing::debug!( + log::debug!( "[SUBSCRIBER] recv_subgroup: prior group id gap details: {:?}", gap_ext ); } } + // Check for Prior Object ID Gap (type 0x3E = 62) + if object.extension_headers.has(0x3E) { + log::info!( + "[SUBSCRIBER] recv_subgroup: object #{} contains PRIOR OBJECT ID GAP (type 0x3E)", + object_count + 1 + ); + if let Some(gap_ext) = object.extension_headers.get(0x3E) { + log::debug!( + "[SUBSCRIBER] recv_subgroup: prior object id gap details: {:?}", + gap_ext + ); + } + } + let obj_copy = object.clone(); ( object.payload_length, @@ -542,7 +756,7 @@ impl Subscriber { } false => { let object = reader.decode::().await?; - tracing::debug!( + log::debug!( "[SUBSCRIBER] recv_subgroup: object #{} - object_id_delta={}, payload_length={}, status={:?}", object_count + 1, object.object_id_delta, @@ -600,11 +814,13 @@ impl Subscriber { } } - // Pass extension headers through to the serve layer - // TODO SLG - object_id_delta and object status are still being ignored - - let mut object_writer = subgroup_writer.create(remaining_bytes, extension_headers)?; - tracing::trace!( + // Pass extension headers and status through to the serve layer + let mut object_writer = subgroup_writer.create_with_status( + remaining_bytes, + extension_headers, + status.unwrap_or(crate::data::ObjectStatus::NormalObject), + )?; + log::trace!( "[SUBSCRIBER] recv_subgroup: reading payload for object #{} ({} bytes)", object_count + 1, remaining_bytes @@ -616,14 +832,14 @@ impl Subscriber { .read_chunk(remaining_bytes) .await? .ok_or_else(|| { - tracing::error!( + log::error!( "[SUBSCRIBER] recv_subgroup: ERROR - stream ended with {} bytes remaining for object #{}", remaining_bytes, object_count + 1 ); SessionError::WrongSize })?; - tracing::trace!( + log::trace!( "[SUBSCRIBER] recv_subgroup: received payload chunk #{} for object #{} ({} bytes, {} remaining)", chunks_read + 1, object_count + 1, @@ -635,7 +851,7 @@ impl Subscriber { chunks_read += 1; } - tracing::trace!( + log::trace!( "[SUBSCRIBER] recv_subgroup: completed object #{} ({} chunks)", object_count + 1, chunks_read @@ -643,11 +859,27 @@ impl Subscriber { object_count += 1; } - tracing::info!( - "[SUBSCRIBER] recv_subgroup: completed subgroup (group_id={}, subgroup_id={}, {} objects received)", + // If the stream header type signals end-of-group, write an EndOfGroup marker + // This forwards the "stream end = group end" semantic to downstream subscribers + if stream_header_type.signals_end_of_group() { + log::debug!( + "[SUBSCRIBER] recv_subgroup: writing EndOfGroup marker (header_type={:?} signals EOG)", + stream_header_type + ); + if let Err(e) = subgroup_writer.end_of_group() { + log::warn!( + "[SUBSCRIBER] recv_subgroup: failed to write EndOfGroup marker: {}", + e + ); + } + } + + log::info!( + "[SUBSCRIBER] recv_subgroup: completed subgroup (group_id={}, subgroup_id={}, {} objects received, eog={})", subgroup_writer.info.group_id, subgroup_writer.info.subgroup_id, - object_count + object_count, + stream_header_type.signals_end_of_group() ); Ok(()) @@ -669,7 +901,7 @@ impl Subscriber { // Check for extension headers in the datagram if let Some(ref ext_headers) = datagram.extension_headers { - tracing::debug!( + log::debug!( "[SUBSCRIBER] recv_datagram: datagram contains extension headers: {:?}", ext_headers ); @@ -678,11 +910,11 @@ impl Subscriber { // Check for Immutable Extensions (type 0xB = 11) if ext_headers.has(0xB) { - tracing::info!( + log::info!( "[SUBSCRIBER] recv_datagram: datagram contains IMMUTABLE EXTENSIONS (type 0xB)" ); if let Some(immutable_ext) = ext_headers.get(0xB) { - tracing::debug!( + log::debug!( "[SUBSCRIBER] recv_datagram: immutable extension details: {:?}", immutable_ext ); @@ -691,27 +923,43 @@ impl Subscriber { // Check for Prior Group ID Gap (type 0x3C = 60) if ext_headers.has(0x3C) { - tracing::info!( + log::info!( "[SUBSCRIBER] recv_datagram: datagram contains PRIOR GROUP ID GAP (type 0x3C)" ); if let Some(gap_ext) = ext_headers.get(0x3C) { - tracing::debug!( + log::debug!( "[SUBSCRIBER] recv_datagram: prior group id gap details: {:?}", gap_ext ); } } + + // Check for Prior Object ID Gap (type 0x3E = 62) + if ext_headers.has(0x3E) { + log::info!( + "[SUBSCRIBER] recv_datagram: datagram contains PRIOR OBJECT ID GAP (type 0x3E)" + ); + if let Some(gap_ext) = ext_headers.get(0x3E) { + log::debug!( + "[SUBSCRIBER] recv_datagram: prior object id gap details: {:?}", + gap_ext + ); + } + } } - // Look up the subscribe id for this track alias - if let Some(subscribe_id) = self - .get_subscribe_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) - .await - { - // Look up the subscribe by id + // Fast path: check both maps immediately WITHOUT waiting + // This allows datagrams to flow at full rate once alias mapping is established + let (subscribe_id_immediate, publish_id_immediate) = { + let subscribe_id = self.get_subscribe_id_by_alias(datagram.track_alias, None).await; + let publish_id = self.get_publish_id_by_alias(datagram.track_alias, None).await; + (subscribe_id, publish_id) + }; + + if let Some(subscribe_id) = subscribe_id_immediate { if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&subscribe_id) { - tracing::trace!( - "[SUBSCRIBER] recv_datagram: track_alias={}, group_id={}, object_id={}, publisher_priority={}, status={}, payload_length={}", + log::trace!( + "[SUBSCRIBER] recv_datagram: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", datagram.track_alias, datagram.group_id, datagram.object_id.unwrap_or(0), @@ -720,100 +968,65 @@ impl Subscriber { datagram.payload.as_ref().map_or(0, |p| p.len())); subscribe.datagram(datagram)?; } + } else if let Some(publish_id) = publish_id_immediate { + if let Some(publish_recv) = self.publishes_received.lock().unwrap().get_mut(&publish_id) + { + log::trace!( + "[SUBSCRIBER] recv_datagram from publish: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + publish_recv.datagram(datagram)?; + } } else { - tracing::warn!( - "[SUBSCRIBER] recv_datagram: discarded due to unknown track_alias: track_alias={}, group_id={}, object_id={}, publisher_priority={}, status={}, payload_length={}", - datagram.track_alias, - datagram.group_id, - datagram.object_id.unwrap_or(0), - datagram.publisher_priority, - datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), - datagram.payload.as_ref().map_or(0, |p| p.len())); + // Slow path: alias not found immediately, wait with timeout (only for first datagram) + let subscribe_fut = self.get_subscribe_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + let publish_fut = self.get_publish_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + + tokio::select! { + Some(subscribe_id) = subscribe_fut => { + if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&subscribe_id) { + log::trace!( + "[SUBSCRIBER] recv_datagram (waited): track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + subscribe.datagram(datagram)?; + } + } + Some(publish_id) = publish_fut => { + if let Some(publish_recv) = self.publishes_received.lock().unwrap().get_mut(&publish_id) + { + log::trace!( + "[SUBSCRIBER] recv_datagram from publish (waited): track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + publish_recv.datagram(datagram)?; + } + } + else => { + log::warn!( + "[SUBSCRIBER] recv_datagram: discarded due to unknown track_alias: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + } + } } Ok(()) } } - -#[cfg(test)] -mod tests { - use std::{sync::atomic, task::Poll}; - - use super::*; - use crate::{ - message::{self, GroupOrder}, - serve::Track, - }; - - fn subscriber() -> Subscriber { - Subscriber::new(Queue::default(), Arc::new(atomic::AtomicU64::new(0)), None) - } - - #[tokio::test] - async fn subscribe_open_cleans_up_when_cancelled_before_ok() { - let mut subscriber = subscriber(); - let observer = subscriber.clone(); - let (writer, _reader) = - Track::new(TrackNamespace::from_utf8_path("test"), "0.mp4".into()).produce(); - - { - let subscribe = subscriber.subscribe_open(writer); - futures::pin_mut!(subscribe); - - assert!(matches!(futures::poll!(&mut subscribe), Poll::Pending)); - assert_eq!(observer.subscribes.lock().unwrap().len(), 1); - } - - assert!(observer.subscribes.lock().unwrap().is_empty()); - assert!(observer.subscribe_alias_map.lock().unwrap().is_empty()); - } - - #[tokio::test] - async fn dropping_open_subscribe_removes_recv_state() { - let mut subscriber = subscriber(); - let observer = subscriber.clone(); - let (writer, _reader) = - Track::new(TrackNamespace::from_utf8_path("test"), "0.mp4".into()).produce(); - - let subscribe = subscriber.subscribe_open(writer); - futures::pin_mut!(subscribe); - - assert!(matches!(futures::poll!(&mut subscribe), Poll::Pending)); - assert_eq!(observer.subscribes.lock().unwrap().len(), 1); - - let mut receiver = observer.clone(); - receiver - .recv_subscribe_ok(&message::SubscribeOk { - id: 0, - track_alias: 10, - expires: 0, - group_order: GroupOrder::Publisher, - content_exists: false, - largest_location: None, - params: Default::default(), - }) - .unwrap(); - - let subscribe = match futures::poll!(&mut subscribe) { - Poll::Ready(Ok(subscribe)) => subscribe, - Poll::Ready(Err(err)) => panic!("subscribe failed: {err}"), - Poll::Pending => panic!("subscribe remained pending after SubscribeOk"), - }; - - assert_eq!(observer.subscribes.lock().unwrap().len(), 1); - assert_eq!( - observer - .subscribe_alias_map - .lock() - .unwrap() - .get(&10) - .copied(), - Some(0) - ); - - drop(subscribe); - - assert!(observer.subscribes.lock().unwrap().is_empty()); - assert!(observer.subscribe_alias_map.lock().unwrap().is_empty()); - } -} diff --git a/moq-transport/src/session/track_status_requested.rs b/moq-transport/src/session/track_status_requested.rs index 8a5a5e5c..587610be 100644 --- a/moq-transport/src/session/track_status_requested.rs +++ b/moq-transport/src/session/track_status_requested.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use super::{Publisher, SessionError}; use crate::coding::ReasonPhrase; use crate::message; @@ -24,9 +21,10 @@ impl TrackStatusRequested { error_code: u64, error_message: &str, ) -> Result<(), SessionError> { - let status_error = message::TrackStatusError { + let status_error = message::RequestError { id: self.request_msg.id, error_code, + retry_interval: 0, reason_phrase: ReasonPhrase(error_message.to_string()), }; self.publisher.send_message(status_error); diff --git a/moq-transport/src/session/writer.rs b/moq-transport/src/session/writer.rs index 47d9299f..4db0aa2c 100644 --- a/moq-transport/src/session/writer.rs +++ b/moq-transport/src/session/writer.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::io; use crate::coding::{Encode, EncodeError}; @@ -24,14 +20,14 @@ impl Writer { pub async fn encode(&mut self, msg: &T) -> Result<(), SessionError> { self.buffer.clear(); - tracing::trace!( + log::trace!( "[WRITER] encode: encoding {} to buffer", std::any::type_name::() ); msg.encode(&mut self.buffer)?; let encoded_len = self.buffer.len(); - tracing::debug!( + log::debug!( "[WRITER] encode: encoded {} ({} bytes), sending to stream", std::any::type_name::(), encoded_len @@ -41,7 +37,7 @@ impl Writer { while !self.buffer.is_empty() { let written = self.stream.write_buf(&mut self.buffer).await?; total_written += written; - tracing::trace!( + log::trace!( "[WRITER] encode: wrote {} bytes to stream (total={}/{}, remaining={})", written, total_written, @@ -50,7 +46,7 @@ impl Writer { ); } - tracing::debug!( + log::debug!( "[WRITER] encode: finished sending {} ({} bytes total)", std::any::type_name::(), total_written @@ -60,7 +56,7 @@ impl Writer { } pub async fn write(&mut self, buf: &[u8]) -> Result<(), SessionError> { - tracing::trace!("[WRITER] write: writing {} bytes to stream", buf.len()); + log::trace!("[WRITER] write: writing {} bytes to stream", buf.len()); let mut cursor = io::Cursor::new(buf); let total_len = buf.len(); @@ -69,14 +65,14 @@ impl Writer { while cursor.has_remaining() { let size = self.stream.write_buf(&mut cursor).await?; if size == 0 { - tracing::error!( + log::error!( "[WRITER] write: ERROR - wrote 0 bytes with {} bytes remaining", cursor.remaining() ); return Err(EncodeError::More(cursor.remaining()).into()); } total_written += size; - tracing::trace!( + log::trace!( "[WRITER] write: wrote {} bytes (total={}/{}, remaining={})", size, total_written, @@ -85,7 +81,7 @@ impl Writer { ); } - tracing::debug!("[WRITER] write: finished writing {} bytes", total_written); + log::debug!("[WRITER] write: finished writing {} bytes", total_written); Ok(()) } diff --git a/moq-transport/src/setup/auth_token.rs b/moq-transport/src/setup/auth_token.rs new file mode 100644 index 00000000..a1b22be5 --- /dev/null +++ b/moq-transport/src/setup/auth_token.rs @@ -0,0 +1,298 @@ +//! Authorization Token support for MOQT. +//! +//! This module provides support for authorization tokens as defined in the MOQT specification. +//! Tokens can be sent inline or referenced by alias to avoid retransmission of large tokens. + +use std::collections::HashMap; + +/// Authorization Token Types +/// +/// Defines how an authorization token is transmitted in messages. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u8)] +pub enum AuthTokenType { + /// No authorization token present + None = 0x0, + /// Authorization token sent inline + Inline = 0x1, + /// Authorization token referenced by alias + Alias = 0x2, + /// Authorization token cached with new alias + Store = 0x3, + /// Use previously stored token (DELETE is not allowed in CLIENT_SETUP) + UseAlias = 0x4, +} + +impl TryFrom for AuthTokenType { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 0x0 => Ok(Self::None), + 0x1 => Ok(Self::Inline), + 0x2 => Ok(Self::Alias), + 0x3 => Ok(Self::Store), + 0x4 => Ok(Self::UseAlias), + _ => Err(()), + } + } +} + +/// An authorization token value +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct AuthToken { + /// The raw token bytes + pub token: Vec, + /// Optional alias for caching + pub alias: Option, +} + +impl AuthToken { + /// Create a new authorization token + pub fn new(token: Vec) -> Self { + Self { token, alias: None } + } + + /// Create a new authorization token with an alias for caching + pub fn with_alias(token: Vec, alias: u64) -> Self { + Self { + token, + alias: Some(alias), + } + } + + /// Check if the token is empty + pub fn is_empty(&self) -> bool { + self.token.is_empty() + } +} + +/// Authorization Token Cache +/// +/// Stores authorization tokens by their alias for efficient re-use across multiple messages. +/// The cache enforces a maximum size limit as negotiated during setup. +#[derive(Debug)] +pub struct AuthTokenCache { + /// Maximum number of tokens that can be cached + max_size: usize, + /// Cached tokens by alias + tokens: HashMap>, + /// Next available alias (for server-assigned aliases) + next_alias: u64, +} + +impl Default for AuthTokenCache { + fn default() -> Self { + Self::new(0) + } +} + +impl AuthTokenCache { + /// Create a new auth token cache with the specified maximum size + pub fn new(max_size: usize) -> Self { + Self { + max_size, + tokens: HashMap::new(), + next_alias: 0, + } + } + + /// Get the maximum cache size + pub fn max_size(&self) -> usize { + self.max_size + } + + /// Set the maximum cache size (typically from setup negotiation) + pub fn set_max_size(&mut self, max_size: usize) { + self.max_size = max_size; + } + + /// Get the current number of cached tokens + pub fn len(&self) -> usize { + self.tokens.len() + } + + /// Check if the cache is empty + pub fn is_empty(&self) -> bool { + self.tokens.is_empty() + } + + /// Check if the cache is at capacity + pub fn is_full(&self) -> bool { + self.tokens.len() >= self.max_size + } + + /// Store a token with the given alias + /// + /// Returns an error if: + /// - The cache is at capacity + /// - The alias is already in use + pub fn store(&mut self, alias: u64, token: Vec) -> Result<(), AuthTokenCacheError> { + if self.max_size == 0 { + return Err(AuthTokenCacheError::CacheDisabled); + } + if self.tokens.len() >= self.max_size { + return Err(AuthTokenCacheError::CacheOverflow); + } + if self.tokens.contains_key(&alias) { + return Err(AuthTokenCacheError::DuplicateAlias(alias)); + } + self.tokens.insert(alias, token); + Ok(()) + } + + /// Store a token with an auto-generated alias + /// + /// Returns the assigned alias, or an error if the cache is full + pub fn store_with_auto_alias(&mut self, token: Vec) -> Result { + if self.max_size == 0 { + return Err(AuthTokenCacheError::CacheDisabled); + } + if self.tokens.len() >= self.max_size { + return Err(AuthTokenCacheError::CacheOverflow); + } + + // Find next available alias + while self.tokens.contains_key(&self.next_alias) { + self.next_alias = self.next_alias.wrapping_add(1); + } + + let alias = self.next_alias; + self.tokens.insert(alias, token); + self.next_alias = self.next_alias.wrapping_add(1); + + Ok(alias) + } + + /// Get a token by its alias + pub fn get(&self, alias: u64) -> Option<&Vec> { + self.tokens.get(&alias) + } + + /// Remove a token by its alias + pub fn remove(&mut self, alias: u64) -> Option> { + self.tokens.remove(&alias) + } + + /// Clear all cached tokens + pub fn clear(&mut self) { + self.tokens.clear(); + } +} + +/// Errors that can occur when working with the auth token cache +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum AuthTokenCacheError { + /// The cache is disabled (max_size is 0) + CacheDisabled, + /// The cache is full and cannot accept more tokens + CacheOverflow, + /// The alias is already in use + DuplicateAlias(u64), + /// The alias was not found in the cache + UnknownAlias(u64), + /// The token is malformed + MalformedToken, + /// The token has expired + ExpiredToken, +} + +impl std::fmt::Display for AuthTokenCacheError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::CacheDisabled => write!(f, "authorization token cache is disabled"), + Self::CacheOverflow => write!(f, "authorization token cache is full"), + Self::DuplicateAlias(alias) => { + write!(f, "duplicate authorization token alias: {}", alias) + } + Self::UnknownAlias(alias) => { + write!(f, "unknown authorization token alias: {}", alias) + } + Self::MalformedToken => write!(f, "malformed authorization token"), + Self::ExpiredToken => write!(f, "expired authorization token"), + } + } +} + +impl std::error::Error for AuthTokenCacheError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_auth_token_type_conversion() { + assert_eq!(AuthTokenType::try_from(0u8), Ok(AuthTokenType::None)); + assert_eq!(AuthTokenType::try_from(1u8), Ok(AuthTokenType::Inline)); + assert_eq!(AuthTokenType::try_from(2u8), Ok(AuthTokenType::Alias)); + assert_eq!(AuthTokenType::try_from(3u8), Ok(AuthTokenType::Store)); + assert_eq!(AuthTokenType::try_from(4u8), Ok(AuthTokenType::UseAlias)); + assert!(AuthTokenType::try_from(5u8).is_err()); + } + + #[test] + fn test_auth_token() { + let token = AuthToken::new(vec![1, 2, 3, 4]); + assert!(!token.is_empty()); + assert!(token.alias.is_none()); + + let token_with_alias = AuthToken::with_alias(vec![5, 6, 7, 8], 42); + assert_eq!(token_with_alias.alias, Some(42)); + + let empty_token = AuthToken::default(); + assert!(empty_token.is_empty()); + } + + #[test] + fn test_auth_token_cache() { + let mut cache = AuthTokenCache::new(3); + assert_eq!(cache.max_size(), 3); + assert!(cache.is_empty()); + + // Store tokens + cache.store(1, vec![1, 2, 3]).unwrap(); + cache.store(2, vec![4, 5, 6]).unwrap(); + assert_eq!(cache.len(), 2); + assert!(!cache.is_full()); + + // Get token + assert_eq!(cache.get(1), Some(&vec![1, 2, 3])); + assert_eq!(cache.get(2), Some(&vec![4, 5, 6])); + assert_eq!(cache.get(3), None); + + // Store with auto-alias + let alias = cache.store_with_auto_alias(vec![7, 8, 9]).unwrap(); + assert!(cache.is_full()); + + // Cache overflow + assert_eq!( + cache.store(99, vec![10, 11]), + Err(AuthTokenCacheError::CacheOverflow) + ); + + // Duplicate alias + cache.remove(alias); + assert_eq!( + cache.store(1, vec![10, 11]), + Err(AuthTokenCacheError::DuplicateAlias(1)) + ); + + // Remove and clear + assert!(cache.remove(1).is_some()); + cache.clear(); + assert!(cache.is_empty()); + } + + #[test] + fn test_auth_token_cache_disabled() { + let mut cache = AuthTokenCache::new(0); + assert_eq!( + cache.store(1, vec![1, 2, 3]), + Err(AuthTokenCacheError::CacheDisabled) + ); + assert_eq!( + cache.store_with_auto_alias(vec![1, 2, 3]), + Err(AuthTokenCacheError::CacheDisabled) + ); + } +} diff --git a/moq-transport/src/setup/client.rs b/moq-transport/src/setup/client.rs index fcbe8870..2d5b7f7e 100644 --- a/moq-transport/src/setup/client.rs +++ b/moq-transport/src/setup/client.rs @@ -1,18 +1,9 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use super::Versions; use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; /// Sent by the client to setup the session. -/// This CLIENT_SETUP message is used by moq-transport draft versions 11 and later. -/// Id = 0x20 vs 0x40 for versions <= 10. +/// Draft-16: version negotiation uses ALPN; no Versions field in CLIENT_SETUP. #[derive(Debug)] pub struct Client { - /// The list of supported versions in preferred order. - pub versions: Versions, - /// Setup Parameters, ie: PATH, MAX_REQUEST_ID, /// MAX_AUTH_TOKEN_CACHE_SIZE, AUTHORIZATION_TOKEN, etc. pub params: KeyValuePairs, @@ -30,10 +21,9 @@ impl Decode for Client { let _len = u16::decode(r)?; // TODO: Check the length of the message. - let versions = Versions::decode(r)?; let params = KeyValuePairs::decode(r)?; - Ok(Self { versions, params }) + Ok(Self { params }) } } @@ -49,7 +39,6 @@ impl Encode for Client { // write the length later, to avoid the copy of the message bytes? let mut buf = Vec::new(); - self.versions.encode(&mut buf).unwrap(); self.params.encode(&mut buf).unwrap(); // Make sure buf.len() <= u16::MAX @@ -70,7 +59,7 @@ impl Encode for Client { #[cfg(test)] mod tests { use super::*; - use crate::setup::{ParameterType, Version}; + use crate::setup::ParameterType; use bytes::BytesMut; #[test] @@ -80,26 +69,22 @@ mod tests { let mut params = KeyValuePairs::default(); params.set_bytesvalue(ParameterType::Path.into(), "testpath".as_bytes().to_vec()); - let client = Client { - versions: [Version::DRAFT_13].into(), - params, - }; + let client = Client { params }; client.encode(&mut buf).unwrap(); + // Draft-16: no Versions field, just Type + Length + Parameters #[rustfmt::skip] assert_eq!( buf.to_vec(), vec![ - 0x20, // Type - 0x00, 0x14, // Length - 0x01, // 1 Version - 0xC0, 0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x0D, // Version DRAFT_13 (0xff00000D) - 0x01, // 1 Param - 0x01, 0x08, 0x74, 0x65, 0x73, 0x74, 0x70, 0x61, 0x74, 0x68, // Key=1 (Path), Value="testpath" + 0x20, // Type (CLIENT_SETUP) + 0x00, 0x0b, // Length = 11 bytes + 0x01, // 1 Parameter (count) + // Delta=1 (Path), Length=8, "testpath" + 0x01, 0x08, 0x74, 0x65, 0x73, 0x74, 0x70, 0x61, 0x74, 0x68, ] ); let decoded = Client::decode(&mut buf).unwrap(); - assert_eq!(decoded.versions, client.versions); assert_eq!(decoded.params, client.params); } } diff --git a/moq-transport/src/setup/mod.rs b/moq-transport/src/setup/mod.rs index 3a16ead5..1098ed7a 100644 --- a/moq-transport/src/setup/mod.rs +++ b/moq-transport/src/setup/mod.rs @@ -1,21 +1,19 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - //! Messages used for the MoQ Transport handshake. //! //! After establishing the WebTransport session, the client creates a bidirectional QUIC stream. //! The client sends the [Client] message and the server responds with the [Server] message. //! Both sides negotate the [Version] and [Role]. +mod auth_token; mod client; mod param_types; mod server; mod version; +pub use auth_token::*; pub use client::*; pub use param_types::*; pub use server::*; pub use version::*; -pub const ALPN: &[u8] = b"moq-00"; +pub const ALPN: &[u8] = b"moqt-16"; diff --git a/moq-transport/src/setup/param_types.rs b/moq-transport/src/setup/param_types.rs index 7776e0b9..65f731e5 100644 --- a/moq-transport/src/setup/param_types.rs +++ b/moq-transport/src/setup/param_types.rs @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - /// Setup Parameter Types #[derive(Clone, Copy, Debug, Eq, PartialEq)] #[repr(u64)] @@ -10,7 +7,11 @@ pub enum ParameterType { AuthorizationToken = 0x3, MaxAuthTokenCacheSize = 0x4, Authority = 0x5, + /// Maximum number of Range pairs allowed per subscription/fetch (PR #1518) + MaxFilterRanges = 0x6, MOQTImplementation = 0x7, + /// Maximum value for MaxTracksSelected parameter in TRACK_FILTER (PR #1518) + MaxTracksSelected = 0x8, } impl From for u64 { diff --git a/moq-transport/src/setup/server.rs b/moq-transport/src/setup/server.rs index 5a4b952f..7880228b 100644 --- a/moq-transport/src/setup/server.rs +++ b/moq-transport/src/setup/server.rs @@ -1,18 +1,9 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use super::Version; use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; /// Sent by the server in response to a client setup. -/// This SERVER_SETUP message is used by moq-transport draft versions 11 and later. -/// Id = 0x21 vs 0x41 for versions <= 10. +/// Draft-16: version negotiation uses ALPN; no Versions field in SERVER_SETUP. #[derive(Debug)] pub struct Server { - /// The list of supported versions in preferred order. - pub version: Version, - /// Setup Parameters, ie: MAX_REQUEST_ID, MAX_AUTH_TOKEN_CACHE_SIZE, /// AUTHORIZATION_TOKEN, etc. pub params: KeyValuePairs, @@ -30,10 +21,9 @@ impl Decode for Server { let _len = u16::decode(r)?; // TODO: Check the length of the message. - let version = Version::decode(r)?; let params = KeyValuePairs::decode(r)?; - Ok(Self { version, params }) + Ok(Self { params }) } } @@ -48,7 +38,6 @@ impl Encode for Server { // write the length later, to avoid the copy of the message bytes? let mut buf = Vec::new(); - self.version.encode(&mut buf).unwrap(); self.params.encode(&mut buf).unwrap(); // Make sure buf.len() <= u16::MAX @@ -79,27 +68,24 @@ mod tests { let mut params = KeyValuePairs::default(); params.set_intvalue(ParameterType::MaxRequestId.into(), 1000); - let server = Server { - version: Version::DRAFT_14, - params, - }; + let server = Server { params }; server.encode(&mut buf).unwrap(); + // Draft-16: no Versions field, just Type + Length + Parameters #[rustfmt::skip] assert_eq!( buf.to_vec(), vec![ - 0x21, // Type - 0x00, 0x0c, // Length - 0xC0, 0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x0E, // Version DRAFT_14 (0xff00000E) - 0x01, // 1 Param - 0x02, 0x43, 0xe8, // Key=2 (MaxRequestId), Value=1000 + 0x21, // Type (SERVER_SETUP) + 0x00, 0x04, // Length = 4 bytes + 0x01, // 1 Parameter (count) + // Delta=2 (MaxRequestId), Value=1000 + 0x02, 0x43, 0xe8, ] ); let decoded = Server::decode(&mut buf).unwrap(); - assert_eq!(decoded.version, server.version); assert_eq!(decoded.params, server.params); } } diff --git a/moq-transport/src/setup/version.rs b/moq-transport/src/setup/version.rs index 20704d97..2fb41ae4 100644 --- a/moq-transport/src/setup/version.rs +++ b/moq-transport/src/setup/version.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use crate::coding::{Decode, DecodeError, Encode, EncodeError, VarInt}; use std::fmt; @@ -27,6 +23,12 @@ impl Version { /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-14.html pub const DRAFT_14: Version = Version(0xff00000e); + + /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-15.html + pub const DRAFT_15: Version = Version(0xff00000f); + + /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-16.html + pub const DRAFT_16: Version = Version(0xff000010); } impl From for Version { diff --git a/moq-transport/src/util/mod.rs b/moq-transport/src/util/mod.rs index 5a7fe787..0263f841 100644 --- a/moq-transport/src/util/mod.rs +++ b/moq-transport/src/util/mod.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - mod queue; mod state; mod watch; diff --git a/moq-transport/src/util/queue.rs b/moq-transport/src/util/queue.rs index a3d28216..26290192 100644 --- a/moq-transport/src/util/queue.rs +++ b/moq-transport/src/util/queue.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::collections::VecDeque; use super::Watch; diff --git a/moq-transport/src/util/state.rs b/moq-transport/src/util/state.rs index 4bdc4296..905a378f 100644 --- a/moq-transport/src/util/state.rs +++ b/moq-transport/src/util/state.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{ fmt, future::Future, diff --git a/moq-transport/src/util/watch.rs b/moq-transport/src/util/watch.rs index 0d5a8844..57edbc8b 100644 --- a/moq-transport/src/util/watch.rs +++ b/moq-transport/src/util/watch.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{ fmt, future::Future, diff --git a/moq-transport/src/watch/mod.rs b/moq-transport/src/watch/mod.rs index d531ec2c..b14ad78a 100644 --- a/moq-transport/src/watch/mod.rs +++ b/moq-transport/src/watch/mod.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - mod queue; mod state; diff --git a/moq-transport/src/watch/queue.rs b/moq-transport/src/watch/queue.rs index 23427513..9ab92f84 100644 --- a/moq-transport/src/watch/queue.rs +++ b/moq-transport/src/watch/queue.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use super::State; use futures::channel::oneshot; use std::collections::VecDeque; diff --git a/moq-transport/src/watch/state.rs b/moq-transport/src/watch/state.rs index 368e4cac..c686f99d 100644 --- a/moq-transport/src/watch/state.rs +++ b/moq-transport/src/watch/state.rs @@ -1,7 +1,3 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - use std::{ fmt, future::Future, diff --git a/package.nix b/package.nix index 30db0b10..12b8e7dc 100644 --- a/package.nix +++ b/package.nix @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - { lib, rustPlatform, diff --git a/tools/analyze_flamegraph.py b/tools/analyze_flamegraph.py new file mode 100755 index 00000000..5af30f03 --- /dev/null +++ b/tools/analyze_flamegraph.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +"""Analyze collapsed perf stacks and produce a text breakdown of CPU usage for moq-rs.""" + +import sys +from collections import defaultdict + + +CATEGORIES = [ + ("TopN Compute", [ + "top_n_tracker", + "compute_top_n", + "rebuild_snapshot", + "update_value", + "register_track", + "subscriber_registry", + ]), + ("TopN Filter", [ + "track_filter", + "property_check", + ]), + ("QUIC Transport", [ + "quinn::", + "quic::", + "quinn_proto", + "quinn_udp", + ]), + ("HTTP/WebTransport", [ + "h3::", + "web_transport", + "webtransport", + ]), + ("Tokio Runtime", [ + "tokio::", + "mio::", + "poll::", + ]), + ("MOQ Protocol", [ + "moq_transport", + "moq_relay", + "encode", + "decode", + ]), + ("Memory Allocation", [ + "alloc::", + "malloc", + "free", + "realloc", + "jemalloc", + ]), +] + +TOPN_PATTERNS = { + "update_value": "top_n_tracker::.*update_value", + "rebuild_snapshot": "top_n_tracker::.*rebuild_snapshot", + "register_track": "top_n_tracker::.*register_track", + "compute_top_n": "top_n_tracker::.*compute_top_n", + "subscriber_registry": "subscriber_registry::", +} + + +def parse_collapsed(filepath): + """Parse a collapsed stack file (output of stackcollapse-perf.pl).""" + stacks = [] + total_weight = 0 + for line in open(filepath): + line = line.strip() + if not line: + continue + parts = line.rsplit(" ", 1) + if len(parts) != 2: + continue + stack_str, weight_str = parts + try: + weight = int(weight_str) + except ValueError: + try: + weight = int(float(weight_str)) + except ValueError: + continue + stacks.append((stack_str, weight)) + total_weight += weight + return stacks, total_weight + + +def categorize_inclusive(stacks, total_weight): + """Categorize by inclusive time (frame appears anywhere in stack).""" + cat_weights = defaultdict(int) + for stack_str, weight in stacks: + matched_cats = set() + for cat_name, patterns in CATEGORIES: + for pat in patterns: + if pat in stack_str: + matched_cats.add(cat_name) + break + for cat in matched_cats: + cat_weights[cat] += weight + return cat_weights + + +def categorize_exclusive(stacks, total_weight): + """Categorize by exclusive time (leaf frame only).""" + cat_weights = defaultdict(int) + for stack_str, weight in stacks: + frames = stack_str.split(";") + leaf = frames[-1] if frames else "" + matched = False + for cat_name, patterns in CATEGORIES: + for pat in patterns: + if pat in leaf: + cat_weights[cat_name] += weight + matched = True + break + if matched: + break + if not matched: + cat_weights["Other"] += weight + return cat_weights + + +def top_functions_self(stacks, total_weight, n=25): + """Get top N functions by self (leaf) time.""" + func_weights = defaultdict(int) + for stack_str, weight in stacks: + frames = stack_str.split(";") + leaf = frames[-1] if frames else "" + func_weights[leaf] += weight + sorted_funcs = sorted(func_weights.items(), key=lambda x: -x[1]) + return sorted_funcs[:n] + + +def topn_specific_breakdown(stacks, total_weight): + """Detailed breakdown of top-N related work.""" + detail = defaultdict(int) + for stack_str, weight in stacks: + frames = stack_str.split(";") + leaf = frames[-1] if frames else "" + for label, pat in TOPN_PATTERNS.items(): + if pat.replace("::.*", "::") in leaf or pat.replace(".*", "") in leaf: + detail[label] += weight + break + return detail + + +def main(): + if len(sys.argv) < 2: + print("Usage: analyze_flamegraph.py [output_file]") + print(" Generate collapsed stacks with:") + print(" perf script | stackcollapse-perf.pl > collapsed.txt") + sys.exit(1) + + filepath = sys.argv[1] + output = sys.argv[2] if len(sys.argv) > 2 else None + + stacks, total_weight = parse_collapsed(filepath) + total_samples = len(stacks) + + lines = [] + def p(s=""): + lines.append(s) + + p("=" * 80) + p(" FLAMEGRAPH ANALYSIS - moq-rs relay CPU profile") + p("=" * 80) + p() + p(f"Total unique stacks: {total_samples}") + p(f"Total weight: {total_weight:,}") + p() + + # Inclusive breakdown + p("-" * 80) + p("INCLUSIVE CPU TIME (frame appears anywhere in the call stack)") + p(" Note: percentages overlap because callers include callee time") + p("-" * 80) + inclusive = categorize_inclusive(stacks, total_weight) + for cat_name, _ in CATEGORIES: + w = inclusive.get(cat_name, 0) + pct = w / total_weight * 100 if total_weight else 0 + bar = "#" * int(pct / 2) + p(f" {cat_name:<35s} {pct:6.2f}% {bar}") + p() + + # Exclusive breakdown + p("-" * 80) + p("EXCLUSIVE CPU TIME (leaf frame only - where CPU actually spends cycles)") + p(" Note: percentages do NOT overlap, total = 100%") + p("-" * 80) + exclusive = categorize_exclusive(stacks, total_weight) + sorted_excl = sorted(exclusive.items(), key=lambda x: -x[1]) + for cat_name, w in sorted_excl: + pct = w / total_weight * 100 if total_weight else 0 + bar = "#" * int(pct / 2) + p(f" {cat_name:<35s} {pct:6.2f}% {bar}") + p() + + # Top-N specific + p("-" * 80) + p("TOP-N SPECIFIC BREAKDOWN (self time in top-N related functions)") + p("-" * 80) + topn_detail = topn_specific_breakdown(stacks, total_weight) + topn_total = sum(topn_detail.values()) + topn_pct = topn_total / total_weight * 100 if total_weight else 0 + p(f" Total Top-N self time: {topn_pct:.2f}% of all CPU") + p() + sorted_topn = sorted(topn_detail.items(), key=lambda x: -x[1]) + for func, w in sorted_topn: + pct = w / total_weight * 100 if total_weight else 0 + p(f" {func:<40s} {pct:5.2f}%") + p() + + # Top functions by self time + p("-" * 80) + p("TOP 25 FUNCTIONS BY SELF (LEAF) TIME") + p("-" * 80) + top_funcs = top_functions_self(stacks, total_weight) + for i, (func, w) in enumerate(top_funcs, 1): + pct = w / total_weight * 100 if total_weight else 0 + name = func if len(func) <= 70 else func[:67] + "..." + p(f" {i:2d}. {pct:5.2f}% {name}") + p() + + # Summary + p("=" * 80) + p("SUMMARY") + p("=" * 80) + topn_incl = inclusive.get("TopN Compute", 0) / total_weight * 100 if total_weight else 0 + topn_filt_incl = inclusive.get("TopN Filter", 0) / total_weight * 100 if total_weight else 0 + quic_incl = inclusive.get("QUIC Transport", 0) / total_weight * 100 if total_weight else 0 + quic_excl = exclusive.get("QUIC Transport", 0) / total_weight * 100 if total_weight else 0 + topn_excl = exclusive.get("TopN Compute", 0) / total_weight * 100 if total_weight else 0 + moq_incl = inclusive.get("MOQ Protocol", 0) / total_weight * 100 if total_weight else 0 + p(f" Top-N compute (inclusive): {topn_incl:.1f}% - ranking decisions & value tracking") + p(f" Top-N filter (inclusive): {topn_filt_incl:.1f}% - property check + filter interception") + p(f" Top-N self time: {topn_pct:.1f}% - actual CPU in top-N code (no callees)") + p(f" QUIC transport (self): {quic_excl:.1f}% - packet I/O & stream management") + p(f" MOQ protocol (inclusive): {moq_incl:.1f}% - encode/decode + session management") + p() + p("-" * 80) + p("TOP-N vs QUIC RELATIVE COST") + p("-" * 80) + if quic_excl > 0 and topn_pct > 0: + self_ratio = topn_pct / quic_excl + p(f" Self vs self: Top-N {topn_pct:.1f}% vs QUIC {quic_excl:.1f}% → Top-N is 1/{int(1/self_ratio)} of QUIC") + else: + p(f" Self vs self: Top-N {topn_pct:.1f}% vs QUIC {quic_excl:.1f}%") + if quic_incl > 0 and topn_incl > 0: + incl_ratio = topn_incl / quic_incl + p(f" Inclusive vs inclusive: Top-N {topn_incl:.1f}% vs QUIC {quic_incl:.1f}% → Top-N is 1/{int(1/incl_ratio)} of QUIC") + else: + p(f" Inclusive vs inclusive: Top-N {topn_incl:.1f}% vs QUIC {quic_incl:.1f}%") + p() + p(f" Interpretation: Top-N ranking adds ~{topn_pct:.1f}% CPU overhead (self time).") + p(" The vast majority of CPU goes to QUIC transport and fan-out to subscribers.") + p() + + report = "\n".join(lines) + print(report) + + if output: + with open(output, "w") as f: + f.write(report + "\n") + print(f"\nReport written to: {output}") + + +if __name__ == "__main__": + main() diff --git a/tools/analyze_speech.py b/tools/analyze_speech.py new file mode 100644 index 00000000..a932e0ec --- /dev/null +++ b/tools/analyze_speech.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +import json, sys, re +from collections import defaultdict + +events = [] +with open(sys.argv[1]) as f: + for line in f: + m = re.search(r"TOPN_EVENT:(.*)", line) + if m: + try: + events.append(json.loads(m.group(1))) + except: + pass + +# Track speaking sessions per publisher +speaking_sessions = defaultdict(list) +current_speaking = {} +test_duration_ms = 0 + +for ev in events: + ts = ev.get("ts_ms", 0) + test_duration_ms = max(test_duration_ms, ts) + + if ev.get("event") == "value_updated": + pub_id = ev.get("publisher_id") + if pub_id is None: + track = ev.get("track", "") + m2 = re.search(r"speaker-(\d+)", track) + if not m2: + continue + pub_id = int(m2.group(1)) + new_value = ev.get("new_value", 0) + old_value = ev.get("old_value", 0) + + if new_value > 0 and old_value == 0: + current_speaking[pub_id] = ts + elif new_value == 0 and old_value > 0: + if pub_id in current_speaking: + start = current_speaking.pop(pub_id) + speaking_sessions[pub_id].append((start, ts)) + +# Close any still-speaking +for pub_id, start in current_speaking.items(): + speaking_sessions[pub_id].append((start, test_duration_ms)) + +total_speakers = len(speaking_sessions) +all_sessions = [] +for pub_id, sessions in speaking_sessions.items(): + for s in sessions: + all_sessions.append((pub_id, s[0], s[1])) +all_sessions.sort(key=lambda x: x[1]) + +# Silence analysis +timeline = [] +for pub_id, sessions in speaking_sessions.items(): + for start, end in sessions: + timeline.append((start, 1)) + timeline.append((end, -1)) +timeline.sort() + +silence_periods = [] +active_count = 0 +last_ts = 0 +total_silence_ms = 0 +max_concurrent = 0 + +for ts, delta in timeline: + if active_count == 0 and ts > last_ts: + silence_periods.append((last_ts, ts)) + total_silence_ms += (ts - last_ts) + active_count += delta + max_concurrent = max(max_concurrent, active_count) + last_ts = ts + +if active_count == 0 and last_ts < test_duration_ms: + silence_periods.append((last_ts, test_duration_ms)) + total_silence_ms += (test_duration_ms - last_ts) + +# Per-speaker stats +speaker_stats = [] +for pub_id in sorted(speaking_sessions.keys()): + sessions = speaking_sessions[pub_id] + total_speaking = sum(end - start for start, end in sessions) + speaker_stats.append((pub_id, len(sessions), total_speaking)) + +print("=" * 60) +print(" SPEECH ACTIVITY ANALYSIS") +print("=" * 60) +print() +print(f"Test duration: {test_duration_ms/1000:.1f}s") +print(f"Total unique speakers: {total_speakers} / 80 publishers") +print(f"Total speaking sessions: {len(all_sessions)}") +print(f"Max concurrent speakers: {max_concurrent}") +avg_sessions = len(all_sessions) / max(1, total_speakers) +avg_duration = sum(s[2]-s[1] for s in all_sessions) / max(1, len(all_sessions)) +print(f"Avg sessions per speaker: {avg_sessions:.1f}") +print(f"Avg speaking duration: {avg_duration/1000:.1f}s") +print() +print(f"Total silence (nobody speaking): {total_silence_ms/1000:.1f}s ({100*total_silence_ms/max(1,test_duration_ms):.1f}%)") +print(f"Number of silence gaps: {len(silence_periods)}") +if silence_periods: + longest_silence = max(end - start for start, end in silence_periods) + print(f"Longest silence gap: {longest_silence/1000:.1f}s") +print() +print("-" * 60) +print("PER-SPEAKER BREAKDOWN (sorted by total speaking time)") +print("-" * 60) +hdr = f"{'Pub':>4} {'Sessions':>8} {'Total':>8} {'Avg':>6}" +print(hdr) +for pub_id, num_sessions, total_ms in sorted(speaker_stats, key=lambda x: -x[2]): + avg = total_ms / max(1, num_sessions) + print(f"{pub_id:>4} {num_sessions:>8} {total_ms/1000:>7.1f}s {avg/1000:>5.1f}s") +print() +print("-" * 60) +print("SILENCE PERIODS (gaps where no publisher was speaking)") +print("-" * 60) +if not silence_periods: + print(" None - at least one speaker active throughout the test") +else: + for i, (start, end) in enumerate(silence_periods[:20]): + dur = (end - start) / 1000 + print(f" {start/1000:>7.1f}s - {end/1000:>7.1f}s ({dur:.1f}s)") + if len(silence_periods) > 20: + print(f" ... and {len(silence_periods) - 20} more gaps") diff --git a/tools/gen_correctness_report.py b/tools/gen_correctness_report.py new file mode 100644 index 00000000..3bf20341 --- /dev/null +++ b/tools/gen_correctness_report.py @@ -0,0 +1,694 @@ +#!/usr/bin/env python3 +"""Generate an interactive HTML correctness report from moq-topn-test e2e output. + +Shows: +- Click on a publisher to see their speech timeline +- See which subscribers hear that publisher at each moment +- View ranking changes over time +""" + +import json +import re +import sys +from collections import defaultdict + + +def parse_events(filepath): + events = [] + with open(filepath) as f: + for line in f: + m = re.search(r"TOPN_EVENT:(.*)", line) + if m: + try: + events.append(json.loads(m.group(1))) + except: + pass + return events + + +def parse_results(filepath): + results = {} + with open(filepath) as f: + for line in f: + if "Accuracy:" in line: + m = re.search(r"(\d+\.\d+)%", line) + if m: + results["accuracy"] = float(m.group(1)) + if "Correct deliveries:" in line: + m = re.search(r"(\d+)", line.split(":")[-1]) + if m: + results["correct"] = int(m.group(1)) + if "Incorrect deliveries:" in line: + m = re.search(r"(\d+)", line.split(":")[-1]) + if m: + results["incorrect"] = int(m.group(1)) + if "Publishers (X):" in line: + m = re.search(r"(\d+)", line.split(":")[-1]) + if m: + results["publishers"] = int(m.group(1)) + if "Subscribers (Y):" in line: + m = re.search(r"(\d+)", line.split(":")[-1]) + if m: + results["subscribers"] = int(m.group(1)) + if "Top-N filter:" in line: + m = re.search(r"(\d+)", line.split(":")[-1]) + if m: + results["top_n"] = int(m.group(1)) + if "Duration:" in line and "duration" not in results: + m = re.search(r"(\d+)s", line) + if m: + results["duration"] = int(m.group(1)) + if "Mixed top-N:" in line: + m = re.search(r"Mixed top-N:\s*(.*)", line) + if m: + results["mixed_topn"] = m.group(1).strip() + return results + + +def build_timeline_data(events): + """Build per-publisher speech timeline and subscriber info.""" + # Speaking sessions + speaking = defaultdict(list) + current_speaking = {} + test_duration = 0 + + # Track all value changes for timeline blocks + value_changes = defaultdict(list) # pub_id -> [(ts, value)] + + # Subscriber info + subscribers = {} # sub_id -> {is_pub_sub, publisher_id, top_n} + + for ev in events: + ts = ev.get("ts_ms", 0) + test_duration = max(test_duration, ts) + + if ev.get("event") == "subscriber_registered": + sub_id = ev.get("subscriber_id") + subscribers[sub_id] = { + "is_pub_sub": ev.get("is_pub_sub", False), + "publisher_id": ev.get("publisher_id"), + } + + elif ev.get("event") == "value_updated": + pub_id = ev.get("publisher_id") + if pub_id is None: + continue + new_value = ev.get("new_value", 0) + old_value = ev.get("old_value", 0) + value_changes[pub_id].append((ts, new_value)) + + if new_value > 0 and old_value == 0: + current_speaking[pub_id] = ts + elif new_value == 0 and old_value > 0: + if pub_id in current_speaking: + start = current_speaking.pop(pub_id) + speaking[pub_id].append({"start": start, "end": ts}) + + for pub_id, start in current_speaking.items(): + speaking[pub_id].append({"start": start, "end": test_duration}) + + # Build timeline blocks per publisher + pub_timelines = {} + for pub_id in sorted(value_changes.keys()): + changes = value_changes[pub_id] + blocks = [] + for i, (ts, val) in enumerate(changes): + end_ts = changes[i + 1][0] if i + 1 < len(changes) else test_duration + state = "silent" if val == 0 else "start" if val == 2 else "speaking" + blocks.append({"start": ts, "end": end_ts, "state": state, "value": val}) + # Add initial silent block if first event is not at t=0 + if changes and changes[0][0] > 0: + blocks.insert(0, {"start": 0, "end": changes[0][0], "state": "silent", "value": 0}) + pub_timelines[pub_id] = blocks + + return speaking, pub_timelines, subscribers, test_duration + + +def generate_html(filepath, output_path): + events = parse_events(filepath) + results = parse_results(filepath) + speaking, pub_timelines, subscribers, test_duration = build_timeline_data(events) + + num_pubs = results.get("publishers", len(pub_timelines)) + top_n = results.get("top_n", 5) + total_subs = results.get("subscribers", len(subscribers)) + accuracy = results.get("accuracy", 0) + correct = results.get("correct", 0) + incorrect = results.get("incorrect", 0) + duration = results.get("duration", test_duration / 1000) + + # Prepare JSON data for JS + data = { + "publishers": {str(k): v for k, v in pub_timelines.items()}, + "speaking": {str(k): v for k, v in speaking.items()}, + "subscribers": {str(k): v for k, v in subscribers.items()}, + "events": events, + "testDuration": test_duration, + "numPubs": num_pubs, + "topN": top_n, + "totalSubs": total_subs, + } + + html = f""" + + + + +Top-N E2E Correctness - Interactive Report + + + +
+ +

Top-N E2E Correctness Report

+

Real QUIC connections — {num_pubs} publishers, {total_subs} subscribers, N={top_n}, {duration}s

+ +
+
+
{accuracy:.1f}%
+
Accuracy
+
+
+
{correct:,}
+
Correct
+
+
+
{incorrect}
+
In-flight misses
+
+
+
{num_pubs}
+
Publishers
+
+
+
{total_subs}
+
Subscribers
+
+
+
N={top_n}
+
Top-N Filter
+
+
+ +
+

Select a Publisher

+

+ Click a publisher to see their speech timeline and which subscribers hear them. +

+
+
+ + + +
+

All Publishers Overview

+

+ Full timeline of all publishers. Click any row to select that publisher. +

+
+
+
+ +
+

Correctness Analysis

+

The {incorrect} incorrect deliveries ({100 - accuracy:.1f}%) are objects that were + already in-flight on the QUIC stream when a ranking transition occurred. This is + eventual consistency by design — not a bug.

+
    +
  • A publisher stops speaking (value → 0), triggering a snapshot rebuild
  • +
  • The epoch counter increments, but subscriber observers check on next object
  • +
  • Objects already queued in QUIC stream buffers are delivered before the filter updates
  • +
  • Convergence time is bounded by one group interval (33ms at 30 Hz)
  • +
+
+ +
+ + + +""" + + with open(output_path, "w") as f: + f.write(html) + print(f"Report written to: {output_path}") + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: gen_correctness_report.py [output.html]") + sys.exit(1) + input_file = sys.argv[1] + output_file = sys.argv[2] if len(sys.argv) > 2 else "correctness_report.html" + generate_html(input_file, output_file) diff --git a/tools/run_perf_e2e.sh b/tools/run_perf_e2e.sh new file mode 100755 index 00000000..c232d46a --- /dev/null +++ b/tools/run_perf_e2e.sh @@ -0,0 +1,352 @@ +#!/bin/bash +set -euo pipefail + +# Usage: ./run_perf_e2e.sh [top_n] [duration_sec] [output_dir] [mixed_topn] +# Example: ./run_perf_e2e.sh 80 720 45 120 ./perf-results +# Example: ./run_perf_e2e.sh 80 800 25 180 ./perf-results "1,2,4,8,16,32,64,75" + +PUB_SUBS=${1:?Usage: $0 [top_n] [duration_sec] [output_dir] [mixed_topn]} +PURE_SUBS=${2:?Usage: $0 [top_n] [duration_sec] [output_dir] [mixed_topn]} +TOP_N=${3:-45} +DURATION=${4:-120} +OUTPUT_DIR=${5:-./perf-results} +MIXED_TOPN=${6:-} + +TOTAL_SUBS=$((PUB_SUBS + PURE_SUBS)) +SETUP_WAIT=$((TOTAL_SUBS / 20 > 30 ? TOTAL_SUBS / 20 : 30)) +STEADY_DURATION=$((DURATION - SETUP_WAIT - 10)) +REMOTE="admin@snk-dev-1.m10x.org" +SSH_KEY="$HOME/.ssh/keys/snk-dev-server.pem" +SSH="ssh -i $SSH_KEY $REMOTE" +SCP="scp -i $SSH_KEY" +TIMESTAMP=$(date +%Y%m%d-%H%M%S) +REMOTE_DIR="/tmp/moq-rs-perf-$TIMESTAMP" +LOCAL_DIR="$OUTPUT_DIR/$TIMESTAMP" + +mkdir -p "$LOCAL_DIR" + +echo "=== MOQ-RS E2E Performance Test ===" +echo " Publishers (pub-sub): $PUB_SUBS" +echo " Pure subscribers: $PURE_SUBS" +echo " Total subscribers: $TOTAL_SUBS" +echo " Top-N filter: $TOP_N" +if [ -n "$MIXED_TOPN" ]; then +echo " Mixed top-N values: $MIXED_TOPN" +fi +echo " Duration: ${DURATION}s" +echo " Setup wait: ${SETUP_WAIT}s" +echo " Steady-state capture: ${STEADY_DURATION}s" +echo " Remote: $REMOTE" +echo " Remote dir: $REMOTE_DIR" +echo " Local output: $LOCAL_DIR" +echo "" + +# Step 1: Collect system info +echo "[1/7] Collecting system configuration..." +$SSH "bash -s" > "$LOCAL_DIR/sysinfo.txt" <<'SYSINFO' +echo "=== System Configuration ===" +echo "Hostname: $(hostname)" +echo "Date: $(date -u)" +echo "Kernel: $(uname -r)" +echo "Arch: $(uname -m)" +echo "CPU:" +lscpu | grep -E "^(Model name|CPU\(s\)|Thread|Core|Socket|CPU max)" +echo "" +echo "Memory:" +free -h | head -2 +echo "" +echo "OS:" +cat /etc/os-release 2>/dev/null | grep -E "^(NAME|VERSION)" || true +echo "" +echo "Rust:" +source ~/.cargo/env 2>/dev/null +rustc --version 2>/dev/null || echo "unknown" +echo "" +echo "Git commit:" +cd ~/moq-rs-top-n && git log --oneline -1 +echo "Git branch:" +cd ~/moq-rs-top-n && git branch --show-current +SYSINFO + +echo " Done." + +# Step 2: Build on remote +echo "[2/7] Building on remote..." +$SSH "cd ~/moq-rs-top-n && source ~/.cargo/env && cargo build --release --bin moq-relay-ietf --bin moq-topn-test 2>&1 | tail -3" +echo " Done." + +# Step 3: Run perf test (steady-state only) +echo "[3/8] Running perf test (${DURATION}s total, profiling last ${STEADY_DURATION}s)..." + +MIXED_ARG="" +if [ -n "$MIXED_TOPN" ]; then + MIXED_ARG="--mixed-topn \"$MIXED_TOPN\"" +fi + +$SSH "bash -s" </dev/null || true +sleep 1 + +# Start relay +./target/release/moq-relay-ietf --bind "[::]:4443" --tls-cert cert.pem --tls-key key.pem > $REMOTE_DIR/relay.log 2>&1 & +RELAY_PID=\$! +sleep 2 + +# Capture memory baseline +ps -o rss= -p \$RELAY_PID > $REMOTE_DIR/mem_before.txt + +# Run e2e test in background +./target/release/moq-topn-test -m e2e \\ + --relay https://localhost:4443 \\ + --tls-disable-verify \\ + -x $PUB_SUBS \\ + -y $TOTAL_SUBS \\ + -n $TOP_N \\ + $MIXED_ARG \\ + -d $DURATION \\ + --group-interval-ms 33 \\ + --connection-batch-size 50 \\ + > $REMOTE_DIR/test_output.txt 2>&1 & +TEST_PID=\$! + +# Wait for connections to establish before profiling +sleep $SETUP_WAIT + +# Start perf recording for steady-state only +perf record -F 999 -p \$RELAY_PID -g -o $REMOTE_DIR/perf.data -- sleep $STEADY_DURATION & +PERF_PID=\$! + +# Wait for test to finish +wait \$TEST_PID 2>/dev/null || true + +# Capture memory after test +ps -o rss= -p \$RELAY_PID > $REMOTE_DIR/mem_after.txt 2>/dev/null || echo "0" > $REMOTE_DIR/mem_after.txt + +wait \$PERF_PID 2>/dev/null || true + +# Generate flamegraph +perf script -i $REMOTE_DIR/perf.data | ~/FlameGraph/stackcollapse-perf.pl > $REMOTE_DIR/collapsed.txt +~/FlameGraph/flamegraph.pl $REMOTE_DIR/collapsed.txt > $REMOTE_DIR/flamegraph.svg + +# Run analysis +python3 ~/moq-rs-top-n/tools/analyze_flamegraph.py $REMOTE_DIR/collapsed.txt $REMOTE_DIR/analysis.txt + +kill \$RELAY_PID 2>/dev/null || true +echo "DONE" > $REMOTE_DIR/status.txt +PERF_SCRIPT + +echo " Done." + +# Step 4: Wait for completion and verify +echo "[4/8] Verifying completion..." +STATUS=$($SSH "cat $REMOTE_DIR/status.txt 2>/dev/null || echo FAILED") +if [ "$STATUS" != "DONE" ]; then + echo " ERROR: Test did not complete successfully" + exit 1 +fi +echo " Done." + +# Step 5: Collect results +echo "[5/8] Collecting results from remote..." +$SCP "$REMOTE:$REMOTE_DIR/flamegraph.svg" "$LOCAL_DIR/flamegraph.svg" +$SCP "$REMOTE:$REMOTE_DIR/collapsed.txt" "$LOCAL_DIR/collapsed.txt" +$SCP "$REMOTE:$REMOTE_DIR/analysis.txt" "$LOCAL_DIR/analysis.txt" +$SCP "$REMOTE:$REMOTE_DIR/test_output.txt" "$LOCAL_DIR/test_output.txt" +$SCP "$REMOTE:$REMOTE_DIR/relay.log" "$LOCAL_DIR/relay.log" +$SCP "$REMOTE:$REMOTE_DIR/mem_before.txt" "$LOCAL_DIR/mem_before.txt" +$SCP "$REMOTE:$REMOTE_DIR/mem_after.txt" "$LOCAL_DIR/mem_after.txt" +echo " Done." + +# Step 6: Run speech activity analysis +echo "[6/8] Running speech activity analysis..." +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "$SCRIPT_DIR/analyze_speech.py" "$LOCAL_DIR/test_output.txt" > "$LOCAL_DIR/speech_analysis.txt" 2>/dev/null || true +echo " Done." + +# Step 7: Extract test metrics from output +echo "[7/8] Extracting metrics..." +TEST_OUTPUT="$LOCAL_DIR/test_output.txt" + +# Memory +MEM_BEFORE=$(cat "$LOCAL_DIR/mem_before.txt" | tr -d ' ') +MEM_AFTER=$(cat "$LOCAL_DIR/mem_after.txt" | tr -d ' ') +MEM_BEFORE_MB=$(echo "scale=1; $MEM_BEFORE / 1024" | bc 2>/dev/null || echo "N/A") +MEM_AFTER_MB=$(echo "scale=1; $MEM_AFTER / 1024" | bc 2>/dev/null || echo "N/A") +MEM_GROWTH_MB=$(echo "scale=1; ($MEM_AFTER - $MEM_BEFORE) / 1024" | bc 2>/dev/null || echo "N/A") + +# Extract from analysis +TOPN_SELF=$(grep "Total Top-N self time" "$LOCAL_DIR/analysis.txt" | grep -o "[0-9.]*%" | head -1 || echo "N/A") +TOPN_INCLUSIVE=$(grep "Top-N compute (inclusive)" "$LOCAL_DIR/analysis.txt" | grep -o "[0-9.]*%" | head -1 || echo "N/A") +QUIC_SELF=$(grep "QUIC transport (self)" "$LOCAL_DIR/analysis.txt" | grep -o "[0-9.]*%" | head -1 || echo "N/A") +QUIC_INCLUSIVE=$(grep "QUIC Transport" "$LOCAL_DIR/analysis.txt" | head -1 | grep -o "[0-9.]*%" | head -1 || echo "N/A") +MOQ_INCLUSIVE=$(grep "MOQ protocol (inclusive)" "$LOCAL_DIR/analysis.txt" | grep -o "[0-9.]*%" | head -1 || echo "N/A") +ALLOC_SELF=$(grep "Memory Allocation" "$LOCAL_DIR/analysis.txt" | tail -1 | grep -o "[0-9.]*%" | head -1 || echo "N/A") + +echo " Done." + +# Step 8: Generate report +echo "[8/8] Generating report..." +REPORT="$LOCAL_DIR/report.md" +SYSINFO=$(cat "$LOCAL_DIR/sysinfo.txt") +ANALYSIS=$(cat "$LOCAL_DIR/analysis.txt") +SPEECH=$(cat "$LOCAL_DIR/speech_analysis.txt" 2>/dev/null || echo "No speech data available") + +MIXED_NOTE="" +if [ -n "$MIXED_TOPN" ]; then + MIXED_NOTE="| Mixed top-N values | $MIXED_TOPN |" +fi + +cat > "$REPORT" </dev/null || echo "N/A") KB | + +### Top-N vs QUIC Relative Cost + +- Self vs self: Top-N $TOPN_SELF vs QUIC $QUIC_SELF +- Inclusive vs inclusive: Top-N $TOPN_INCLUSIVE vs QUIC $QUIC_INCLUSIVE + +## Methodology + +### Profiling Approach + +1. **Steady-state isolation**: The relay and all ${TOTAL_SUBS} connections are established + during a ${SETUP_WAIT}s warm-up period. CPU profiling begins only after setup completes, + capturing ${STEADY_DURATION}s of pure steady-state operation. This excludes one-time costs + like TLS handshakes, session setup, SUBSCRIBE exchanges, and filter registration. + +2. **Sampling**: Linux \`perf record\` at 999 Hz with call-graph (\`-g\`) on the relay process. + Stacks are collapsed via FlameGraph's \`stackcollapse-perf.pl\` and categorized by + function name patterns into: Top-N Compute, Top-N Filter, QUIC Transport, Tokio Runtime, + MOQ Protocol, and Memory Allocation. + +3. **Speech simulation**: Publishers emit audio-level values at 30 Hz (33ms groups). + Each publisher independently transitions between silent (value=0) and speaking (value>0) + with p(start)=0.03/tick and random duration 3-10s. This creates realistic dynamic + top-N ranking churn. + +### Speech Simulator State Machine + +\`\`\` + p=0.03/tick + ┌─────────┐ ──────────────────► ┌──────────────┐ + │ │ │ │ + │ SILENT │ │ SPEECH_START │ + │ value=0 │ ◄────────────────── │ value=2 │ + │ │ duration expired │ │ + └─────────┘ └──────┬───────┘ + ▲ │ + │ │ next tick + │ ▼ + │ duration expired ┌──────────┐ + └──────────────────────────── │ SPEAKING │ + │ value=1 │ + │ │ + └──────────┘ + (3-10s random) + + Transitions: + SILENT → SPEECH_START : p=0.03 per tick (30 Hz), emit value=2 + SPEECH_START → SPEAKING: next tick, emit value=1 + SPEAKING → SPEAKING : each tick until duration expires, emit value=1 + SPEAKING → SILENT : duration (3-10s uniform) expired, emit value=0 + + Values sent as object payload each group interval (33ms): + 0 = silent (not ranked) + 1 = speaking (ranked, maintains position) + 2 = speech_start (ranked, signals new utterance) +\`\`\` + +4. **Memory**: RSS measured via \`ps -o rss\` before connections and after steady-state. + Growth reflects per-connection overhead (QUIC state, stream buffers, filter state). + +### What Top-N Does Per Object + +- **Ingest observer** (1 per track): reads each published object, calls \`update_track_value\` + if the audio level changed → mutex lock, BTreeMap update, ArcSwap snapshot rebuild. +- **Subscriber filter** (1 per subscriber per track): atomic epoch load + comparison. + If epoch unchanged, uses cached decision. If changed, loads ArcSwap snapshot and does + binary search on pre-sorted Vec. +- Both paths are O(1) amortized for the common case (epoch cache hit). + +## Speech Activity Analysis + +\`\`\` +$SPEECH +\`\`\` + +## Detailed Flamegraph Analysis + +\`\`\` +$ANALYSIS +\`\`\` + +## Artifacts + +- \`flamegraph.svg\` — Interactive flamegraph (open in browser) +- \`collapsed.txt\` — Collapsed stacks for custom analysis +- \`test_output.txt\` — Full test driver output +- \`relay.log\` — Relay server logs +- \`analysis.txt\` — Raw flamegraph analysis +- \`speech_analysis.txt\` — Speech activity breakdown + +EOF + +echo " Done." +echo "" +echo "=== Report generated: $REPORT ===" +echo "=== Flamegraph: $LOCAL_DIR/flamegraph.svg ===" +echo "" + +# Open results directory and flamegraph in browser (macOS) +if command -v open &>/dev/null; then + open "$LOCAL_DIR" + open -a "Google Chrome" "$LOCAL_DIR/flamegraph.svg" 2>/dev/null || open "$LOCAL_DIR/flamegraph.svg" +fi