diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..ce1f1c9d --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,32 @@ +# Repository Guidance + +## Scope And Constraints + +- This repository targets MoQT `draft-ietf-moq-transport-16`. +- Core protocol code lives in `moq-transport`; relay behavior lives in `moq-relay-ietf`. +- `moq-relay-ietf` is production critical, so prefer small, tested changes over broad refactors. +- Avoid breaking public APIs, wire behavior, CLI flags, and operator workflows unless explicitly requested. +- Do not add `unsafe` code unless there is a clear performance or FFI need and the safety invariant is documented at the call site. +- Do not commit or directly edit `docs/draft-16.txt` unless explicitly requested; it is local protocol reference material. +- Avoid `unwrap()`, `expect()`, `panic!`, `todo!`, `unimplemented!`, and `dbg!` in production paths. Propagate errors or convert lock poisoning to an internal error instead. + +## Commands + +- Run `cargo fmt --all` after Rust changes. +- Run `cargo test -p moq-transport` for protocol changes. +- Run `cargo test -p moq-relay-ietf` or `cargo build -p moq-relay-ietf` for relay-only changes. +- Run `cargo build --workspace` when shared APIs, relay behavior, or test-client behavior changes. +- Consider `cargo clippy --workspace --all-targets` before larger changes or when review asks for lint coverage. +- When tests fail, diagnose from compiler/test output first and keep fixes scoped to the failing behavior. + +## Draft-16 Terminology + +- Use `PUBLISH_NAMESPACE` terminology in protocol code and docs; older docs and comments may call this Announce. +- `PublishedNamespace` represents an inbound `PUBLISH_NAMESPACE` received by a subscriber. +- `TrackNamespacePrefix` is for `SUBSCRIBE_NAMESPACE` prefixes and may be empty; `TrackNamespace` is full and non-empty. +- `REQUEST_OK`, `REQUEST_ERROR`, and `REQUEST_UPDATE` are shared draft-16 request response/update messages. + +## External Contributions + +- Treat public GitHub mirror and community PRs with heightened scrutiny for protocol correctness, resource bounds, panic-free production paths, and safe peer-facing errors. +- Be careful with compatibility: do not rename public CLI flags or remove behavior solely for terminology cleanup without an explicit migration decision. diff --git a/Cargo.lock b/Cargo.lock index e44e6b2c..7182343a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1072,10 +1072,12 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.85" +version = "0.3.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" dependencies = [ + "cfg-if", + "futures-util", "once_cell", "wasm-bindgen", ] @@ -1105,7 +1107,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -1669,7 +1671,7 @@ dependencies = [ "quinn-udp", "rustc-hash 2.0.0", "rustls 0.23.31", - "socket2 0.6.0", + "socket2 0.5.7", "thiserror 2.0.17", "tokio", "tracing", @@ -1710,7 +1712,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.6.0", + "socket2 0.5.7", "tracing", "windows-sys 0.59.0", ] @@ -2833,9 +2835,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.108" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" dependencies = [ "cfg-if", "once_cell", @@ -2858,9 +2860,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.108" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2868,9 +2870,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.108" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" dependencies = [ "bumpalo", "proc-macro2", @@ -2881,9 +2883,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.108" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" dependencies = [ "unicode-ident", ] @@ -2902,9 +2904,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.85" +version = "0.3.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "312e32e551d92129218ea9a2452120f4aabc03529ef03e4d0d82fb2780608598" +checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" dependencies = [ "js-sys", "wasm-bindgen", @@ -2922,9 +2924,9 @@ dependencies = [ [[package]] name = "web-transport" -version = "0.10.1" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb0d7e0f5ab200ec6b36fef6f3eee02d41d561fe73b85d2ab822af1b3fb4c680" +checksum = "23c3f78eca5afa10eb7b8ab64b4e5e521a006f0cbd88de09e44d55ef37e8855a" dependencies = [ "bytes", "thiserror 2.0.17", @@ -2935,9 +2937,9 @@ dependencies = [ [[package]] name = "web-transport-proto" -version = "0.5.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17633ea7058419f87cbb7f341ab75ac5c1d6d187c154b0bd4c87539e66f4c4e4" +checksum = "0225d295c8ac00a2e9a498aefeaf3f3c6186da12a251c938189b15b82ea22808" dependencies = [ "bytes", "http", @@ -2949,9 +2951,9 @@ dependencies = [ [[package]] name = "web-transport-quinn" -version = "0.11.4" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b195557749e84091d7b912a25e190e9606283b5121d041faf538b0b55f40d7" +checksum = "82e77c81fe4cf56c1049e07c6ed9c00862a967010fe9da4f5e02dc7f4d71fdac" dependencies = [ "bytes", "futures", @@ -2969,18 +2971,18 @@ dependencies = [ [[package]] name = "web-transport-trait" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "802d6aa508f2c63c9050ceabc17265bbf90ed4d6f4e4357e987583883628e79c" +checksum = "cb67841c4a481ca3c1412ee4c9f463987401991e1ddc000903df2124f3dc85e9" dependencies = [ "bytes", ] [[package]] name = "web-transport-wasm" -version = "0.5.5" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03261961ff4d65f873dd0521909b6795e5d7fe40581df2b7897db05e62db9620" +checksum = "1ee22c228309b45651038d975a9f3c041525e9f2cc0f6c3bd8753a110804df11" dependencies = [ "bytes", "js-sys", diff --git a/Cargo.toml b/Cargo.toml index 61a90c1d..3fa211da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ members = [ resolver = "2" [workspace.dependencies] -web-transport = "0.10" +web-transport = "0.10.4" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/README.md b/README.md index 7b7e66a3..7066f36e 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ An implementation of the Media over QUIC Transport (MoQT) protocol for live media delivery over QUIC, as specified by the IETF MoQ working group. -This codebase was originally created by [Luke Curley (@kixelated)](https://github.com/kixelated). [Mike English (@englishm)](https://github.com/englishm) contributed to early design and has maintained this IETF-aligned fork. The project is now maintained by Cloudflare. The implementation targets [draft-ietf-moq-transport-14](https://datatracker.ietf.org/doc/draft-ietf-moq-transport/14/). +This codebase was originally created by [Luke Curley (@kixelated)](https://github.com/kixelated). [Mike English (@englishm)](https://github.com/englishm) contributed to early design and has maintained this IETF-aligned fork. The project is now maintained by Cloudflare. The implementation targets [draft-ietf-moq-transport-16](https://datatracker.ietf.org/doc/draft-ietf-moq-transport/16/). ## Protocol Support -The `main` branch targets **draft-14** of the MoQT specification. For draft-07 compatibility (used in [Cloudflare's current production deployment](https://developers.cloudflare.com/moq/)), see the [`draft-ietf-moq-transport-07`](https://github.com/cloudflare/moq-rs/tree/draft-ietf-moq-transport-07) branch. +The `main` branch targets **draft-16** of the MoQT specification. For draft-07 compatibility (used in [Cloudflare's current production deployment](https://developers.cloudflare.com/moq/)), see the [`draft-ietf-moq-transport-07`](https://github.com/cloudflare/moq-rs/tree/draft-ietf-moq-transport-07) branch. ### What's Included diff --git a/deploy/MLOG_SETUP.md b/deploy/MLOG_SETUP.md index 92c37ca3..4b588600 100644 --- a/deploy/MLOG_SETUP.md +++ b/deploy/MLOG_SETUP.md @@ -40,8 +40,8 @@ curl https://interop-relay.cloudflare.mediaoverquic.com:443/mlog/22c73802597dcd9 The output is JSON-SEQ: each record starts with ASCII Record Separator (`0x1e`) and contains a JSON object. The first record is a header; subsequent records are events: ```json -{"time":0.179,"name":"moqt:control_message_parsed","data":{"message_type":"client_setup","supported_versions":["DRAFT_14"],...}} -{"time":0.216,"name":"moqt:control_message_created","data":{"message_type":"server_setup","selected_version":"DRAFT_14",...}} +{"time":0.179,"name":"moqt:control_message_parsed","data":{"message_type":"client_setup","supported_versions":["DRAFT_16"],...}} +{"time":0.216,"name":"moqt:control_message_created","data":{"message_type":"server_setup","selected_version":"DRAFT_16",...}} ``` - `control_message_parsed` = the relay **received** a message from your client @@ -74,12 +74,12 @@ cargo run --bin moq-test-client -- \ "stream_id": 0, "message_type": "client_setup", "number_of_supported_versions": 1, - "supported_versions": ["DRAFT_14"], + "supported_versions": ["DRAFT_16"], "parameters": [["2", "100"]] } } ``` -The relay parsed your CLIENT_SETUP. It offered version DRAFT_14. The `parameters` array contains SETUP parameters as `[id, value]` pairs (here, PATH with max length 100). +The relay parsed your CLIENT_SETUP. It offered version DRAFT_16. The `parameters` array contains SETUP parameters as `[id, value]` pairs (here, PATH with max length 100). ```json { @@ -89,18 +89,18 @@ The relay parsed your CLIENT_SETUP. It offered version DRAFT_14. The `parameters "event_type": "control_message_created", "stream_id": 0, "message_type": "server_setup", - "selected_version": "DRAFT_14", + "selected_version": "DRAFT_16", "parameters": [["2", "100"]] } } ``` -The relay responded with SERVER_SETUP, selecting DRAFT_14. The handshake is complete. +The relay responded with SERVER_SETUP, selecting DRAFT_16. The handshake is complete. **What to look for:** - If you see `client_setup` parsed but no `server_setup` created, the relay rejected your version or parameters - If you see nothing at all, your connection didn't reach the MoQ layer (check QUIC connectivity) -### Example 2: Publishing a namespace (`announce-only`) +### Example 2: Publishing a namespace (`publish-namespace-only`) After SETUP, announce a namespace and verify the relay accepts it. @@ -108,7 +108,7 @@ After SETUP, announce a namespace and verify the relay accepts it. cargo run --bin moq-test-client -- \ --relay https://interop-relay.cloudflare.mediaoverquic.com:443 \ --tls-disable-verify \ - --test announce-only + --test publish-namespace-only ``` **mlog output** (after the SETUP exchange): @@ -148,7 +148,7 @@ The relay accepted the namespace with PUBLISH_NAMESPACE_OK. Your client is now r - `publish_namespace_ok` created confirms it was accepted - The `request_id` ties the response to the request -### Example 3: Full publish-subscribe flow (`announce-subscribe`) +### Example 3: Full publish-subscribe flow (`publish-namespace-subscribe`) This test uses two connections: a publisher and a subscriber. The test client reports both Connection IDs: @@ -156,13 +156,13 @@ This test uses two connections: a publisher and a subscriber. The test client re cargo run --bin moq-test-client -- \ --relay https://interop-relay.cloudflare.mediaoverquic.com:443 \ --tls-disable-verify \ - --test announce-subscribe + --test publish-namespace-subscribe ``` **TAP output:** ``` -ok 1 - announce-subscribe +ok 1 - publish-namespace-subscribe --- duration_ms: 3436 publisher_connection_id: 71d4b5eb1a807779af03331c330d5fa9 diff --git a/moq-clock-ietf/src/main.rs b/moq-clock-ietf/src/main.rs index ac8a1ef6..619f84f3 100644 --- a/moq-clock-ietf/src/main.rs +++ b/moq-clock-ietf/src/main.rs @@ -66,7 +66,7 @@ async fn main() -> anyhow::Result<()> { tokio::select! { res = session.run() => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, - res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, + res = publisher.publish_namespace(tracks_reader) => res.context("failed to serve tracks")?, } } else { tracing::info!("publishing clock via streams"); @@ -82,7 +82,7 @@ async fn main() -> anyhow::Result<()> { tokio::select! { res = session.run() => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, - res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, + res = publisher.publish_namespace(tracks_reader) => res.context("failed to serve tracks")?, } } } else { diff --git a/moq-native-ietf/Cargo.toml b/moq-native-ietf/Cargo.toml index 9bc8efab..464e3c86 100644 --- a/moq-native-ietf/Cargo.toml +++ b/moq-native-ietf/Cargo.toml @@ -17,7 +17,7 @@ categories = ["multimedia", "network-programming", "web-programming"] [dependencies] moq-transport = { path = "../moq-transport", version = "0.14" } web-transport = { workspace = true } -web-transport-quinn = { version = "0.11", default-features = false, features = ["ring"] } +web-transport-quinn = { version = "0.11.8", default-features = false, features = ["ring"] } rustls = { version = "0.23", features = ["ring"] } rustls-pemfile = "2" diff --git a/moq-native-ietf/src/quic.rs b/moq-native-ietf/src/quic.rs index 59effe88..dba5abd2 100644 --- a/moq-native-ietf/src/quic.rs +++ b/moq-native-ietf/src/quic.rs @@ -426,9 +426,18 @@ impl Server { .await .context("failed to receive WebTransport request")?; + let moqt_protocol = std::str::from_utf8(moq_transport::setup::ALPN) + .context("invalid MoQT ALPN")? + .to_string(); + let response = if request.protocols.contains(&moqt_protocol) { + web_transport_quinn::proto::ConnectResponse::OK.with_protocol(moqt_protocol) + } else { + web_transport_quinn::proto::ConnectResponse::OK + }; + // Accept the CONNECT request. let session = request - .ok() + .respond(response) .await .context("failed to respond to WebTransport request")?; (session, Transport::WebTransport) @@ -552,10 +561,17 @@ impl Client { .to_string(); let (session, transport) = match url.scheme() { - "https" => ( - web_transport_quinn::Session::connect(connection, url.clone()).await?, - Transport::WebTransport, - ), + "https" => { + let moqt_protocol = std::str::from_utf8(moq_transport::setup::ALPN) + .context("invalid MoQT ALPN")? + .to_string(); + let request = web_transport_quinn::proto::ConnectRequest::new(url.clone()) + .with_protocol(moqt_protocol); + ( + web_transport_quinn::Session::connect(connection, request).await?, + Transport::WebTransport, + ) + } "moqt" => ( web_transport_quinn::Session::raw( connection, diff --git a/moq-pub/src/main.rs b/moq-pub/src/main.rs index 17d6403b..0a05239e 100644 --- a/moq-pub/src/main.rs +++ b/moq-pub/src/main.rs @@ -85,7 +85,7 @@ async fn main() -> anyhow::Result<()> { res = run_media(media) => { res.context("media error")? }, - res = publisher.announce(reader) => res.context("publisher error")?, + res = publisher.publish_namespace(reader) => res.context("publisher error")?, } Ok(()) diff --git a/moq-pub/src/media.rs b/moq-pub/src/media.rs index 3cb0d9e3..20ecf60a 100644 --- a/moq-pub/src/media.rs +++ b/moq-pub/src/media.rs @@ -174,7 +174,7 @@ impl Media { let mut selection_params = moq_catalog::SelectionParam::default(); let mut track = moq_catalog::Track { - init_track: Some(self.init.name.clone()), + init_track: Some(self.init.name.to_string()), name: name.clone(), namespace: Some(self.broadcast.namespace.to_utf8_path()), packaging: Some(moq_catalog::TrackPackaging::Cmaf), diff --git a/moq-relay-ietf/src/bin/moq-relay-ietf/main.rs b/moq-relay-ietf/src/bin/moq-relay-ietf/main.rs index 9ccfe462..ce209fe5 100644 --- a/moq-relay-ietf/src/bin/moq-relay-ietf/main.rs +++ b/moq-relay-ietf/src/bin/moq-relay-ietf/main.rs @@ -32,8 +32,8 @@ pub struct Cli { #[arg(long)] pub mlog_dir: Option, - /// Forward all announces to the provided server for authentication/routing. - /// If not provided, the relay accepts every unique announce. + /// Forward all PUBLISH_NAMESPACE messages to the provided server for auth/routing. + /// If not provided, the relay accepts every unique namespace publish. #[arg(long)] pub announce: Option, diff --git a/moq-relay-ietf/src/consumer.rs b/moq-relay-ietf/src/consumer.rs index 7b25ae78..5984ae7c 100644 --- a/moq-relay-ietf/src/consumer.rs +++ b/moq-relay-ietf/src/consumer.rs @@ -7,7 +7,7 @@ use anyhow::Context; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_transport::{ serve::Tracks, - session::{Announced, SessionError, Subscriber}, + session::{PublishedNamespace, SessionError, Subscriber}, }; use crate::{metrics::GaugeGuard, Coordinator, Locals, Producer}; @@ -42,27 +42,31 @@ impl Consumer { } } - /// Run the consumer to serve announce requests. + /// Run the consumer to handle inbound PUBLISH_NAMESPACE requests. pub async fn run(mut self) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); loop { tokio::select! { - // Handle a new announce request - Some(announce) = self.subscriber.announced() => { + Some(published_ns) = self.subscriber.published_namespace() => { metrics::counter!("moq_relay_publishers_total").increment(1); let this = self.clone(); tasks.push(async move { - let info = announce.clone(); + let info = published_ns.clone(); let namespace = info.namespace.to_utf8_path(); - tracing::info!(namespace = %namespace, "serving announce: {:?}", info); - - // Serve the announce request - if let Err(err) = this.serve(announce).await { - tracing::warn!(namespace = %namespace, error = %err, "failed serving announce: {:?}, error: {}", info, err); - // Note: phase-specific error counters are incremented in serve() + tracing::info!( + namespace = %namespace, + "serving PUBLISH_NAMESPACE: {:?}", info + ); + + if let Err(err) = this.serve(published_ns).await { + tracing::warn!( + namespace = %namespace, + error = %err, + "failed serving PUBLISH_NAMESPACE: {:?}", info + ); } }); }, @@ -72,22 +76,18 @@ impl Consumer { } } - /// Serve an announce request. - async fn serve(mut self, mut announce: Announced) -> Result<(), anyhow::Error> { - // Track active publishers - decrements when this function returns + /// Serve an inbound PUBLISH_NAMESPACE. + async fn serve(mut self, mut published_ns: PublishedNamespace) -> Result<(), anyhow::Error> { + // Track active publishers - decrements when this function returns. let _publisher_guard = GaugeGuard::new("moq_relay_active_publishers"); let mut tasks = FuturesUnordered::new(); - // Produce the tracks for this announce and return the reader - let (_, mut request, reader) = Tracks::new(announce.namespace.clone()).produce(); - - // should we allow the same namespace being served from multiple relays?? - // Manish: NO. + let (_, mut request, reader) = Tracks::new(published_ns.namespace.clone()).produce(); let ns = reader.namespace.to_utf8_path(); - // Register the local tracks, unregister on drop + // Register the namespace locally so downstream subscribers can be served. tracing::debug!(namespace = %ns, "registering namespace in locals"); let _register = match self .locals @@ -103,10 +103,7 @@ impl Consumer { }; tracing::debug!(namespace = %ns, "namespace registered in locals"); - // NOTE(mpandit): once the track is pulled from origin, internally it will be relayed - // from this metal only, because now coordinator will have entry for the namespace. - - // Register namespace with the coordinator + // Register namespace with the coordinator so other relay nodes can route to us. tracing::debug!(namespace = %ns, "registering namespace with coordinator"); let _namespace_registration = match self .coordinator @@ -122,55 +119,60 @@ impl Consumer { }; tracing::debug!(namespace = %ns, "namespace registered with coordinator"); - // Accept the announce with an OK response - if let Err(err) = announce.ok() { + // Accept the PUBLISH_NAMESPACE with REQUEST_OK. + if let Err(err) = published_ns.ok() { metrics::counter!("moq_relay_announce_errors_total", "phase" => "send_ok").increment(1); return Err(err.into()); } - tracing::debug!(namespace = %ns, "sent ANNOUNCE_OK"); - - // Successfully sent ANNOUNCE_OK + tracing::debug!(namespace = %ns, "sent REQUEST_OK for PUBLISH_NAMESPACE"); metrics::counter!("moq_relay_announce_ok_total").increment(1); - // Forward the announce, if needed + // Forward the namespace upstream, if configured. if let Some(mut forward) = self.forward { tasks.push( async move { let namespace = reader.namespace.to_utf8_path(); - tracing::info!(namespace = %namespace, "forwarding announce: {:?}", reader.info); + tracing::info!( + namespace = %namespace, + "forwarding PUBLISH_NAMESPACE: {:?}", reader.info + ); forward - .announce(reader) + .publish_namespace(reader) .await - .context("failed forwarding announce") + .context("failed forwarding PUBLISH_NAMESPACE") } .boxed(), ); } - // Serve subscribe requests loop { tokio::select! { - // If the announce is closed, return the error - Err(err) = announce.closed() => { - let ns = announce.namespace.to_utf8_path(); - tracing::info!(namespace = %ns, error = %err, "announce closed"); - return Err(err.into()); + res = published_ns.closed() => { + let ns = published_ns.namespace.to_utf8_path(); + res?; + tracing::info!(namespace = %ns, "PUBLISH_NAMESPACE closed"); + return Ok(()); }, - - // Wait for the next subscriber and serve the track. Some(track) = request.next() => { let mut subscriber = self.subscriber.clone(); - // Spawn a new task to handle the subscribe tasks.push(async move { let info = track.clone(); let namespace = info.namespace.to_utf8_path(); let track_name = info.name.clone(); - tracing::info!(namespace = %namespace, track = %track_name, "forwarding subscribe: {:?}", info); + tracing::info!( + namespace = %namespace, + track = %track_name, + "forwarding subscribe: {:?}", info + ); - // Forward the subscribe request if let Err(err) = subscriber.subscribe(track).await { - tracing::warn!(namespace = %namespace, track = %track_name, error = %err, "failed forwarding subscribe: {:?}, error: {}", info, err) + tracing::warn!( + namespace = %namespace, + track = %track_name, + error = %err, + "failed forwarding subscribe: {:?}", info + ) } Ok(()) diff --git a/moq-relay-ietf/src/metrics.rs b/moq-relay-ietf/src/metrics.rs index aac3455b..9cd70875 100644 --- a/moq-relay-ietf/src/metrics.rs +++ b/moq-relay-ietf/src/metrics.rs @@ -23,9 +23,9 @@ //! | `moq_relay_connections_total` | - | Total incoming connections accepted | //! | `moq_relay_connections_closed_total` | - | Total connections that have closed (graceful or error) | //! | `moq_relay_connection_errors_total` | `stage` | Connection failures (stage: session_accept, session_run) | -//! | `moq_relay_publishers_total` | - | Total publishers (ANNOUNCE requests) received | -//! | `moq_relay_announce_ok_total` | - | Successful ANNOUNCE_OK responses sent | -//! | `moq_relay_announce_errors_total` | `phase` | Announce failures (phase: coordinator_register, local_register, send_ok) | +//! | `moq_relay_publishers_total` | - | Total publishers (PUBLISH_NAMESPACE requests) received | +//! | `moq_relay_announce_ok_total` | - | Successful REQUEST_OK responses sent for PUBLISH_NAMESPACE | +//! | `moq_relay_announce_errors_total` | `phase` | PUBLISH_NAMESPACE failures (phase: coordinator_register, local_register, send_ok) | //! | `moq_relay_subscribers_total` | - | Total subscribers (SUBSCRIBE requests) received | //! | `moq_relay_subscribe_not_found_total` | - | Track not found after checking all sources | //! | `moq_relay_subscribe_route_errors_total` | - | Infrastructure failure when routing to remote | @@ -39,7 +39,7 @@ //! | `moq_relay_active_publishers` | Current number of active publishers | //! | `moq_relay_active_subscriptions` | Current number of active subscriptions | //! | `moq_relay_active_tracks` | Current number of tracks being served | -//! | `moq_relay_announced_namespaces` | Current number of registered namespaces | +//! | `moq_relay_announced_namespaces` | Current number of namespaces registered via PUBLISH_NAMESPACE | //! | `moq_relay_upstream_connections` | Current number of upstream/origin connections | //! //! ## Histograms @@ -74,15 +74,15 @@ pub fn describe_metrics() { ); describe_counter!( "moq_relay_publishers_total", - "Total publishers (ANNOUNCE requests) received" + "Total publishers (PUBLISH_NAMESPACE requests) received" ); describe_counter!( "moq_relay_announce_ok_total", - "Successful ANNOUNCE_OK responses sent" + "Successful REQUEST_OK responses sent for PUBLISH_NAMESPACE" ); describe_counter!( "moq_relay_announce_errors_total", - "Announce failures by phase (coordinator_register, local_register, send_ok)" + "PUBLISH_NAMESPACE failures by phase (coordinator_register, local_register, send_ok)" ); describe_counter!( "moq_relay_subscribers_total", diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 9387b6a1..14e1aebd 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -39,9 +39,9 @@ impl Producer { } } - /// Announce new tracks to the remote server. - pub async fn announce(&mut self, tracks: TracksReader) -> Result<(), SessionError> { - self.publisher.announce(tracks).await + /// Send PUBLISH_NAMESPACE for a set of tracks to the remote peer. + pub async fn publish_namespace(&mut self, tracks: TracksReader) -> Result<(), SessionError> { + self.publisher.publish_namespace(tracks).await } /// Run the producer to serve subscribe requests. @@ -69,7 +69,11 @@ impl Producer { // Serve the subscribe request if let Err(err) = this.serve_subscribe(subscribed).await { - tracing::warn!(namespace = %namespace, track = %track_name, error = %err, "failed serving subscribe: {:?}, error: {}", info, err); + if Self::is_expected_serve_shutdown(&err) { + tracing::debug!(namespace = %namespace, track = %track_name, subscribe_info = ?info, error = %err, "stopped serving subscribe"); + } else { + tracing::warn!(namespace = %namespace, track = %track_name, subscribe_info = ?info, error = %err, "failed serving subscribe"); + } } }.boxed()) }, @@ -169,6 +173,16 @@ impl Producer { Err(err.into()) } + fn is_expected_serve_shutdown(err: &anyhow::Error) -> bool { + matches!( + err.downcast_ref::(), + Some(SessionError::Serve(ServeError::Cancel | ServeError::Done)) + ) || matches!( + err.downcast_ref::(), + Some(ServeError::Cancel | ServeError::Done) + ) + } + /// Serve a track_status request. async fn serve_track_status( self, @@ -210,7 +224,10 @@ impl Producer { } }*/ - track_status_requested.respond_error(4, "Track not found")?; + track_status_requested.respond_error( + moq_transport::message::RequestErrorCode::DoesNotExist as u64, + "track not found", + )?; Err(ServeError::not_found_ctx(format!( "track '{}/{}' not found for track_status", @@ -220,3 +237,36 @@ impl Producer { .into()) } } + +#[cfg(test)] +mod tests { + use moq_transport::{serve::ServeError, session::SessionError}; + + use super::Producer; + + #[test] + fn expected_serve_shutdown_accepts_wrapped_session_errors() { + assert!(Producer::is_expected_serve_shutdown(&anyhow::Error::new( + SessionError::Serve(ServeError::Cancel) + ))); + assert!(Producer::is_expected_serve_shutdown(&anyhow::Error::new( + SessionError::Serve(ServeError::Done) + ))); + assert!(!Producer::is_expected_serve_shutdown(&anyhow::Error::new( + SessionError::Serve(ServeError::NotFound) + ))); + } + + #[test] + fn expected_serve_shutdown_accepts_direct_serve_errors() { + assert!(Producer::is_expected_serve_shutdown(&anyhow::Error::new( + ServeError::Cancel + ))); + assert!(Producer::is_expected_serve_shutdown(&anyhow::Error::new( + ServeError::Done + ))); + assert!(!Producer::is_expected_serve_shutdown(&anyhow::Error::new( + ServeError::NotFound + ))); + } +} diff --git a/moq-relay-ietf/src/relay.rs b/moq-relay-ietf/src/relay.rs index 06a43c9e..09cdc8ce 100644 --- a/moq-relay-ietf/src/relay.rs +++ b/moq-relay-ietf/src/relay.rs @@ -44,7 +44,7 @@ pub struct RelayConfig { /// Directory to write mlog files (one per connection) pub mlog_dir: Option, - /// Forward all announcements to the (optional) URL. + /// Forward all PUBLISH_NAMESPACE messages to the (optional) upstream URL. pub announce: Option, /// Our hostname which we advertise to other origins. @@ -137,7 +137,7 @@ impl Relay { // Start the forwarder, if any let forward_producer = if let Some(url) = &announce_url { - tracing::info!("forwarding announces to {}", url); + tracing::info!("forwarding PUBLISH_NAMESPACE messages to {}", url); // Establish a QUIC connection to the forward URL let (session, _quic_client_initial_cid, transport) = quic_endpoints[0] diff --git a/moq-relay-ietf/src/remote.rs b/moq-relay-ietf/src/remote.rs index f86ba374..49ddc4ad 100644 --- a/moq-relay-ietf/src/remote.rs +++ b/moq-relay-ietf/src/remote.rs @@ -7,7 +7,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Weak}; use moq_native_ietf::quic; -use moq_transport::coding::TrackNamespace; +use moq_transport::coding::{TrackName, TrackNamespace}; use moq_transport::serve::{Track, TrackReader}; use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; @@ -21,7 +21,7 @@ use crate::{metrics::GaugeGuard, Coordinator, CoordinatorError}; /// only when both match. type RemoteCacheKey = (Url, Option); type RemoteSlot = Arc>>; -type TrackCacheKey = (TrackNamespace, String); +type TrackCacheKey = (TrackNamespace, TrackName); type TrackSlot = Arc>>; /// Manages connections to remote relays. @@ -56,8 +56,9 @@ impl RemoteManager { &self, scope: Option<&str>, namespace: &TrackNamespace, - track_name: &str, + track_name: impl Into, ) -> anyhow::Result> { + let track_name = track_name.into(); let (origin, client) = match self.coordinator.lookup(scope, namespace).await { Ok(result) => result, Err(CoordinatorError::NamespaceNotFound) => return Ok(None), @@ -78,10 +79,7 @@ impl RemoteManager { } }; - match remote - .subscribe(namespace.clone(), track_name.to_string()) - .await - { + match remote.subscribe(namespace.clone(), track_name).await { Ok(reader) => Ok(reader), Err(err) => { tracing::warn!(remote_url = %url, error = %err, "remote subscribe failed, removing from cache"); @@ -350,7 +348,7 @@ impl Remote { async fn subscribe( &self, namespace: TrackNamespace, - track_name: String, + track_name: TrackName, ) -> anyhow::Result> { let key = (namespace.clone(), track_name.clone()); diff --git a/moq-relay-ietf/src/session.rs b/moq-relay-ietf/src/session.rs index d83d8595..ddb0b14f 100644 --- a/moq-relay-ietf/src/session.rs +++ b/moq-relay-ietf/src/session.rs @@ -56,17 +56,15 @@ impl Session { /// Drain incoming PUBLISH_NAMESPACEs and reject each one. /// - /// The transport `Subscriber` queues incoming PUBLISH_NAMESPACE messages - /// as `Announced` events. Dropping an `Announced` without calling `ok()` - /// triggers its `Drop` impl, which sends PUBLISH_NAMESPACE_ERROR back - /// to the peer. + /// Dropping a `PublishedNamespace` without calling `ok()` triggers its + /// `Drop` impl, which sends REQUEST_ERROR back to the peer. async fn drain_and_reject_publishes(mut subscriber: Subscriber) -> Result<(), SessionError> { - while let Some(announced) = subscriber.announced().await { + while let Some(published_ns) = subscriber.published_namespace().await { tracing::debug!( - namespace = %announced.namespace, + namespace = %published_ns.namespace, "rejecting PUBLISH_NAMESPACE: publish not permitted for this session" ); - drop(announced); + drop(published_ns); } Ok(()) } diff --git a/moq-test-client/README.md b/moq-test-client/README.md index 240af2d3..9db28c8f 100644 --- a/moq-test-client/README.md +++ b/moq-test-client/README.md @@ -45,11 +45,11 @@ moq-test-client --relay https://localhost:4443 --tls-disable-verify | Test | Description | |------|-------------| | `setup-only` | Connect, complete SETUP exchange, close gracefully | -| `announce-only` | Connect, announce namespace, receive OK, close | +| `publish-namespace-only` | Connect, send PUBLISH_NAMESPACE, receive REQUEST_OK, close | | `subscribe-error` | Subscribe to non-existent track, expect error | -| `announce-subscribe` | Publisher announces, subscriber subscribes, verify handshake | -| `subscribe-before-announce` | Subscriber subscribes before publisher announces | -| `publish-namespace-done` | Announce namespace, send PUBLISH_NAMESPACE_DONE | +| `publish-namespace-subscribe` | Publisher sends PUBLISH_NAMESPACE, subscriber subscribes, verify handshake | +| `subscribe-before-publish-namespace` | Subscriber subscribes before publisher sends PUBLISH_NAMESPACE | +| `publish-namespace-done` | Send PUBLISH_NAMESPACE, then send PUBLISH_NAMESPACE_DONE | ## Running with moq-relay @@ -78,10 +78,10 @@ MoQT Interop Test Client Relay: https://localhost:4443 ✓ setup-only (42 ms) -✓ announce-only (38 ms) +✓ publish-namespace-only (38 ms) ✓ subscribe-error (51 ms) -✓ announce-subscribe (127 ms) -✓ subscribe-before-announce (89 ms) +✓ publish-namespace-subscribe (127 ms) +✓ subscribe-before-publish-namespace (89 ms) ✓ publish-namespace-done (45 ms) Results: 6 passed, 0 failed @@ -112,13 +112,13 @@ The test cases implemented here correspond to the specifications in [moq-interop | Test Case | Protocol References | |-----------|---------------------| | `setup-only` | MoQT §3.3, §9.3 | -| `announce-only` | MoQT §6.2, §9.23-9.24 | +| `publish-namespace-only` | MoQT §6.2, §9.23-9.24 | | `publish-namespace-done` | MoQT §6.2, §9.26 | | `subscribe-error` | MoQT §5.1, §9.7, §9.9 | -| `announce-subscribe` | MoQT §5.1, §6.2, §9.7-9.8, §9.23-9.24 | -| `subscribe-before-announce` | MoQT §5.1, §6.2 | +| `publish-namespace-subscribe` | MoQT §5.1, §6.2, §9.7-9.8, §9.23-9.24 | +| `subscribe-before-publish-namespace` | MoQT §5.1, §6.2 | -Protocol references are to [draft-ietf-moq-transport-14](https://www.ietf.org/archive/id/draft-ietf-moq-transport-14.html). +Protocol references are to [draft-ietf-moq-transport-16](https://www.ietf.org/archive/id/draft-ietf-moq-transport-16.html). ## Design Goals diff --git a/moq-test-client/src/main.rs b/moq-test-client/src/main.rs index 29796d51..f6e250b5 100644 --- a/moq-test-client/src/main.rs +++ b/moq-test-client/src/main.rs @@ -69,15 +69,15 @@ pub struct Args { pub enum TestCase { /// T0.1: Connect, complete SETUP exchange, close gracefully SetupOnly, - /// T0.2: Connect, announce namespace, receive OK, close - AnnounceOnly, + /// T0.2: Connect, send PUBLISH_NAMESPACE, receive REQUEST_OK, close + PublishNamespaceOnly, /// T0.3: Subscribe to non-existent track, expect error SubscribeError, - /// T0.4: Publisher announces, subscriber subscribes, verify handshake - AnnounceSubscribe, - /// T0.5: Subscriber subscribes before publisher announces - SubscribeBeforeAnnounce, - /// T0.6: Announce namespace, receive OK, send PUBLISH_NAMESPACE_DONE + /// T0.4: Publisher sends PUBLISH_NAMESPACE, subscriber subscribes, verify handshake + PublishNamespaceSubscribe, + /// T0.5: Subscriber subscribes before publisher sends PUBLISH_NAMESPACE + SubscribeBeforePublishNamespace, + /// T0.6: Send PUBLISH_NAMESPACE, receive REQUEST_OK, send PUBLISH_NAMESPACE_DONE PublishNamespaceDone, } @@ -85,10 +85,10 @@ impl TestCase { fn all() -> Vec { vec![ TestCase::SetupOnly, - TestCase::AnnounceOnly, + TestCase::PublishNamespaceOnly, TestCase::SubscribeError, - TestCase::AnnounceSubscribe, - TestCase::SubscribeBeforeAnnounce, + TestCase::PublishNamespaceSubscribe, + TestCase::SubscribeBeforePublishNamespace, TestCase::PublishNamespaceDone, ] } @@ -96,10 +96,10 @@ impl TestCase { fn name(&self) -> &'static str { match self { TestCase::SetupOnly => "setup-only", - TestCase::AnnounceOnly => "announce-only", + TestCase::PublishNamespaceOnly => "publish-namespace-only", TestCase::SubscribeError => "subscribe-error", - TestCase::AnnounceSubscribe => "announce-subscribe", - TestCase::SubscribeBeforeAnnounce => "subscribe-before-announce", + TestCase::PublishNamespaceSubscribe => "publish-namespace-subscribe", + TestCase::SubscribeBeforePublishNamespace => "subscribe-before-publish-namespace", TestCase::PublishNamespaceDone => "publish-namespace-done", } } @@ -143,10 +143,14 @@ async fn run_test(args: &Args, test_case: TestCase) -> TestResult { let result = match test_case { TestCase::SetupOnly => scenarios::test_setup_only(args).await, - TestCase::AnnounceOnly => scenarios::test_announce_only(args).await, + TestCase::PublishNamespaceOnly => scenarios::test_publish_namespace_only(args).await, TestCase::SubscribeError => scenarios::test_subscribe_error(args).await, - TestCase::AnnounceSubscribe => scenarios::test_announce_subscribe(args).await, - TestCase::SubscribeBeforeAnnounce => scenarios::test_subscribe_before_announce(args).await, + TestCase::PublishNamespaceSubscribe => { + scenarios::test_publish_namespace_subscribe(args).await + } + TestCase::SubscribeBeforePublishNamespace => { + scenarios::test_subscribe_before_publish_namespace(args).await + } TestCase::PublishNamespaceDone => scenarios::test_publish_namespace_done(args).await, }; @@ -174,7 +178,7 @@ fn print_tap_result(test_number: usize, result: &TestResult, verbose: bool) { 2 => { // Multi-connection tests: first is publisher, second is subscriber // (except subscribe-before-announce where subscriber connects first) - if result.test_case == TestCase::SubscribeBeforeAnnounce { + if result.test_case == TestCase::SubscribeBeforePublishNamespace { println!(" subscriber_connection_id: {}", result.cids[0]); println!(" publisher_connection_id: {}", result.cids[1]); } else { diff --git a/moq-test-client/src/scenarios.rs b/moq-test-client/src/scenarios.rs index 75fa4191..829dd3c3 100644 --- a/moq-test-client/src/scenarios.rs +++ b/moq-test-client/src/scenarios.rs @@ -23,7 +23,7 @@ const TEST_TIMEOUT: Duration = Duration::from_secs(10); /// Namespace used for test operations const TEST_NAMESPACE: &str = "moq-test/interop"; -/// Track name used for test operations +/// Track name used for test operations const TEST_TRACK: &str = "test-track"; /// Helper to connect to a relay and establish a session @@ -57,7 +57,6 @@ impl TestConnectionIds { /// T0.1: Setup Only /// /// Connect to relay, complete CLIENT_SETUP/SERVER_SETUP exchange, close gracefully. -/// This is the simplest possible test - if this fails, nothing else will work. pub async fn test_setup_only(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { let (session, cid, transport) = @@ -65,29 +64,25 @@ pub async fn test_setup_only(args: &Args) -> Result { let mut cids = TestConnectionIds::default(); cids.add(cid); - // Session::connect performs the SETUP exchange let (session, _publisher, _subscriber) = Session::connect(session, None, transport) .await .context("SETUP exchange failed")?; tracing::info!("SETUP exchange completed successfully"); - - // We don't need to run the session, just verify setup worked - // Dropping the session will close the connection drop(session); - Ok(cids) }) .await .context("test timed out")? } -/// T0.2: Announce Only +/// T0.2: Publish Namespace Only /// -/// Connect to relay, announce a namespace, receive PUBLISH_NAMESPACE_OK, close. -pub async fn test_announce_only(args: &Args) -> Result { +/// Connect to relay, send PUBLISH_NAMESPACE, receive REQUEST_OK, close. +pub async fn test_publish_namespace_only(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { - let (session, cid, transport) = connect(args).await.context("failed to connect to relay")?; + let (session, cid, transport) = + connect(args).await.context("failed to connect to relay")?; let mut cids = TestConnectionIds::default(); cids.add(cid); @@ -98,31 +93,27 @@ pub async fn test_announce_only(args: &Args) -> Result { let namespace = TrackNamespace::from_utf8_path(TEST_NAMESPACE); let (_, _, reader) = Tracks::new(namespace.clone()).produce(); - tracing::info!("Announcing namespace: {}", TEST_NAMESPACE); + tracing::info!("Sending PUBLISH_NAMESPACE for: {}", TEST_NAMESPACE); - // Run announce with a timeout - we want to verify we get PUBLISH_NAMESPACE_OK. - // NOTE: The announce() method blocks waiting for subscriptions after getting OK. - // If we get PUBLISH_NAMESPACE_ERROR instead of OK, the method returns Err immediately. - // So timing out here means: either (a) got OK and waiting for subs, or (b) relay never responded. - // We accept this limitation since (b) would indicate a broken relay anyway. - // TODO: For stricter verification, use lower-level Announce::ok() method directly. - let announce_result = tokio::select! { - res = publisher.announce(reader) => res, + // publish_namespace() blocks waiting for subscriptions after receiving REQUEST_OK. + // If we receive REQUEST_ERROR instead, it returns Err immediately. + // Timing out here means we received REQUEST_OK and are now waiting for subscribers, + // which is the expected success case. + let result = tokio::select! { + res = publisher.publish_namespace(reader) => res, res = session.run() => { res.context("session error")?; - anyhow::bail!("session ended before announce completed"); + anyhow::bail!("session ended before PUBLISH_NAMESPACE completed"); } _ = tokio::time::sleep(Duration::from_secs(2)) => { - // If we got an error from the relay, announce() would have returned already. - // Timing out means we're past the OK and now waiting for subscriptions. - tracing::info!("Announce succeeded (no error received, waiting for subscriptions timed out)"); + tracing::info!( + "PUBLISH_NAMESPACE succeeded (REQUEST_OK received, waiting for subscribers)" + ); return Ok(cids); } }; - // If we get here, announce completed (which means it errored or namespace was cancelled) - announce_result.context("announce failed")?; - + result.context("PUBLISH_NAMESPACE failed")?; Ok(cids) }) .await @@ -131,7 +122,7 @@ pub async fn test_announce_only(args: &Args) -> Result { /// T0.3: Subscribe Error /// -/// Subscribe to a non-existent track and verify we get SUBSCRIBE_ERROR. +/// Subscribe to a non-existent track and verify we get a subscription error. pub async fn test_subscribe_error(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { let (session, cid, transport) = @@ -146,7 +137,6 @@ pub async fn test_subscribe_error(args: &Args) -> Result { let namespace = TrackNamespace::from_utf8_path("nonexistent/namespace"); let (mut writer, _, _reader) = Tracks::new(namespace.clone()).produce(); - // Create a track to subscribe to let track = writer .create(TEST_TRACK) .ok_or_else(|| anyhow::anyhow!("failed to create track (already exists?)"))?; @@ -157,7 +147,6 @@ pub async fn test_subscribe_error(args: &Args) -> Result { TEST_TRACK ); - // Run subscribe - we expect an error let subscribe_result = tokio::select! { res = subscriber.subscribe(track) => res, res = session.run() => { @@ -166,30 +155,25 @@ pub async fn test_subscribe_error(args: &Args) -> Result { } }; - // We expect this to fail with a "not found" or similar error match subscribe_result { Ok(()) => { anyhow::bail!("subscribe succeeded but should have failed (track doesn't exist)"); } Err(e) => { - // Validate that the error is related to the track not existing. - // Different relays may return different error messages, but they should - // indicate the track/namespace was not found. let err_str = e.to_string().to_lowercase(); - let is_expected_error = err_str.contains("not found") + let is_expected = err_str.contains("not found") || err_str.contains("notfound") || err_str.contains("no such") || err_str.contains("doesn't exist") || err_str.contains("does not exist") || err_str.contains("unknown"); - if is_expected_error { + if is_expected { tracing::info!("Got expected 'not found' error: {}", e); } else { - // Log warning but still pass - relay may use different error text tracing::warn!( "Got error but not clearly 'not found': {}. \ - This may indicate a different error type than expected.", + Relay may use different error text.", e ); } @@ -201,22 +185,21 @@ pub async fn test_subscribe_error(args: &Args) -> Result { .context("test timed out")? } -/// T0.4: Announce + Subscribe +/// T0.4: Publish Namespace + Subscribe /// -/// Two clients: publisher announces a namespace, subscriber subscribes to a track. +/// Publisher sends PUBLISH_NAMESPACE; subscriber subscribes to a track in that namespace. /// Verifies the relay correctly routes the subscription to the publisher. -pub async fn test_announce_subscribe(args: &Args) -> Result { +pub async fn test_publish_namespace_subscribe(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { let mut cids = TestConnectionIds::default(); - // Publisher connection - let (pub_session, pub_cid, pub_transport) = connect(args).await.context("publisher failed to connect")?; + let (pub_session, pub_cid, pub_transport) = + connect(args).await.context("publisher failed to connect")?; cids.add(pub_cid); let (pub_session, mut publisher, _) = Session::connect(pub_session, None, pub_transport) .await .context("publisher SETUP failed")?; - // Subscriber connection let (sub_session, sub_cid, sub_transport) = connect(args) .await .context("subscriber failed to connect")?; @@ -227,15 +210,11 @@ pub async fn test_announce_subscribe(args: &Args) -> Result { let namespace = TrackNamespace::from_utf8_path(TEST_NAMESPACE); - // Publisher: set up tracks and announce let (mut pub_writer, _, pub_reader) = Tracks::new(namespace.clone()).produce(); - - // Create the track that subscriber will request let _track_writer = pub_writer.create(TEST_TRACK); - tracing::info!("Publisher announcing namespace: {}", TEST_NAMESPACE); + tracing::info!("Publisher sending PUBLISH_NAMESPACE: {}", TEST_NAMESPACE); - // Subscriber: set up tracks and subscribe let (mut sub_writer, _, _sub_reader) = Tracks::new(namespace.clone()).produce(); let sub_track = sub_writer .create(TEST_TRACK) @@ -247,35 +226,27 @@ pub async fn test_announce_subscribe(args: &Args) -> Result { TEST_TRACK ); - // Run everything concurrently. We expect the subscriber to get a response - // (either SUBSCRIBE_OK or error) within the timeout. tokio::select! { - // Publisher announces and waits for subscriptions - res = publisher.announce(pub_reader) => { - res.context("publisher announce failed")?; - tracing::info!("Publisher announce completed"); + res = publisher.publish_namespace(pub_reader) => { + res.context("publisher PUBLISH_NAMESPACE failed")?; + tracing::info!("Publisher PUBLISH_NAMESPACE completed"); } - // Subscriber subscribes - this is the main thing we're testing res = subscriber.subscribe(sub_track) => { match res { - Ok(()) => tracing::info!("Subscriber got SUBSCRIBE_OK - relay routed subscription correctly"), - Err(e) => tracing::info!("Subscriber got error: {} - subscription was processed", e), + Ok(()) => tracing::info!( + "Subscriber got subscription response - relay routed correctly" + ), + Err(e) => tracing::info!( + "Subscriber got error: {} - subscription was processed", e + ), } } - // Run publisher session - res = pub_session.run() => { - res.context("publisher session error")?; - } - // Run subscriber session - res = sub_session.run() => { - res.context("subscriber session error")?; - } - // Timeout: give the relay time to route the subscription + res = pub_session.run() => res.context("publisher session error")?, + res = sub_session.run() => res.context("subscriber session error")?, _ = tokio::time::sleep(Duration::from_secs(3)) => { - // If we hit this timeout, the subscription may still be pending. - // This isn't necessarily a failure - some relays may hold subscriptions - // until the track has data. Log for visibility. - tracing::info!("Test timeout reached - subscription routing may still be in progress"); + tracing::info!( + "Test timeout reached - subscription routing may still be in progress" + ); } }; @@ -285,9 +256,9 @@ pub async fn test_announce_subscribe(args: &Args) -> Result { .context("test timed out")? } -/// T0.6: Publish Namespace Done (Letter L) +/// T0.6: Publish Namespace Done /// -/// Announce a namespace, receive OK, then send PUBLISH_NAMESPACE_DONE. +/// Send PUBLISH_NAMESPACE, receive REQUEST_OK, then send PUBLISH_NAMESPACE_DONE. /// Verifies the relay handles namespace unpublishing correctly. pub async fn test_publish_namespace_done(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { @@ -303,29 +274,25 @@ pub async fn test_publish_namespace_done(args: &Args) -> Result res, + res = publisher.publish_namespace(reader) => res, res = session.run() => { res.context("session error")?; - anyhow::bail!("session ended before announce completed"); + anyhow::bail!("session ended before PUBLISH_NAMESPACE completed"); } _ = tokio::time::sleep(Duration::from_secs(2)) => { - // No error received - announce is active and waiting for subscriptions - tracing::info!("Announce active, now sending PUBLISH_NAMESPACE_DONE"); - // Dropping out of this block will drop the announce, which sends PUBLISH_NAMESPACE_DONE + // No error received: REQUEST_OK arrived and we are waiting for subscribers. + // Drop publish_namespace here to send PUBLISH_NAMESPACE_DONE. + tracing::info!("PUBLISH_NAMESPACE active; sending PUBLISH_NAMESPACE_DONE"); Ok(()) } }; - result.context("announce failed")?; + result.context("PUBLISH_NAMESPACE failed")?; - // Small delay to ensure PUBLISH_NAMESPACE_DONE is sent before we close tokio::time::sleep(Duration::from_millis(100)).await; - tracing::info!("PUBLISH_NAMESPACE_DONE sent successfully"); Ok(cids) }) @@ -333,15 +300,15 @@ pub async fn test_publish_namespace_done(args: &Args) -> Result Result { +pub async fn test_subscribe_before_publish_namespace(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { let mut cids = TestConnectionIds::default(); - // Subscriber connection - connects first + // Subscriber connects first. let (sub_session, sub_cid, sub_transport) = connect(args) .await .context("subscriber failed to connect")?; @@ -352,19 +319,17 @@ pub async fn test_subscribe_before_announce(args: &Args) -> Result res, @@ -376,10 +341,10 @@ pub async fn test_subscribe_before_announce(args: &Args) -> Result Result { - res.context("publisher announce failed")?; - } - res = pub_session.run() => { - res.context("publisher session error")?; + res = publisher.publish_namespace(pub_reader) => { + res.context("publisher PUBLISH_NAMESPACE failed")?; } + res = pub_session.run() => res.context("publisher session error")?, _ = tokio::time::sleep(Duration::from_secs(3)) => { - tracing::info!("Publisher announce timeout (expected)"); + tracing::info!("Publisher PUBLISH_NAMESPACE timeout (expected)"); } }; - // Check subscriber result tokio::select! { res = sub_handle => { match res { diff --git a/moq-transport/src/coding/decode.rs b/moq-transport/src/coding/decode.rs index c47d58df..abc712b1 100644 --- a/moq-transport/src/coding/decode.rs +++ b/moq-transport/src/coding/decode.rs @@ -75,11 +75,26 @@ pub enum DecodeError { #[error("key-value-pair length exceeded")] KeyValuePairLengthExceeded(), + /// Delta-encoded KVP type would overflow u64 (draft-16 §1.4.2 PROTOCOL_VIOLATION). + #[error("key-value-pair type delta overflow")] + KvpTypeOverflow, + #[error("field '{0}' too large")] FieldBoundsExceeded(String), + /// A namespace field had zero length (draft-16 §2.4.1 PROTOCOL_VIOLATION). + #[error("namespace field must not be empty")] + EmptyNamespaceField, + + /// A full track name exceeded 4096 bytes (draft-16 §2.4.1 PROTOCOL_VIOLATION). + #[error("full track name exceeds 4096 bytes")] + TrackNameTooLong, + #[error("invalid datagram type")] InvalidDatagramType, + + #[error("invalid subscribe namespace option: {0}")] + InvalidSubscribeOptions(u64), } impl From for DecodeError { diff --git a/moq-transport/src/coding/encode.rs b/moq-transport/src/coding/encode.rs index e6edfb57..3da80724 100644 --- a/moq-transport/src/coding/encode.rs +++ b/moq-transport/src/coding/encode.rs @@ -43,6 +43,18 @@ pub enum EncodeError { #[error("field '{0}' too large")] FieldBoundsExceeded(String), + + /// KVP keys must be in non-decreasing order during encode. + #[error("key-value-pair keys must be non-decreasing")] + KvpKeyOrder, + + /// Bytes-typed KVP value exceeds maximum length of 2^16-1. + #[error("key-value-pair bytes value too long")] + KeyValuePairLengthExceeded, + + /// A namespace field had zero length (draft-16 §2.4.1 PROTOCOL_VIOLATION). + #[error("namespace field must not be empty")] + EmptyNamespaceField, } impl From for EncodeError { diff --git a/moq-transport/src/coding/kvp.rs b/moq-transport/src/coding/kvp.rs index fba12977..2771620e 100644 --- a/moq-transport/src/coding/kvp.rs +++ b/moq-transport/src/coding/kvp.rs @@ -1,9 +1,35 @@ // SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 +//! Key-Value-Pair (KVP) encoding as defined in draft-ietf-moq-transport-16 §1.4.2. +//! +//! KVPs encode a Type value as a **delta from the previous Type** (or from 0 if +//! there is no previous Type). On the wire: +//! +//! ```text +//! Key-Value-Pair { +//! Delta Type (i), -- varint delta from previous absolute type +//! [Length (i),] -- only present when absolute Type is odd +//! Value (..) -- varint when Type is even, bytes when Type is odd +//! } +//! ``` +//! +//! The previous-Type-plus-Delta MUST NOT exceed 2^64-1. If it would, the +//! session MUST be closed with PROTOCOL_VIOLATION. +//! +//! A sequence of KVPs is prefixed by a varint count of the number of pairs. + use crate::coding::{Decode, DecodeError, Encode, EncodeError}; use std::fmt; +/// Maximum byte-value length for a bytes-typed KVP (2^16 − 1). +const MAX_BYTES_VALUE_LEN: usize = u16::MAX as usize; + +/// Smallest possible encoded KVP: a one-byte delta plus a one-byte varint value. +const MIN_KVP_WIRE_LEN: usize = 2; + +// ─── Value ──────────────────────────────────────────────────────────────────── + #[derive(Clone, Eq, PartialEq)] pub enum Value { IntValue(u64), @@ -15,7 +41,6 @@ impl fmt::Debug for Value { match self { Value::IntValue(v) => write!(f, "{}", v), Value::BytesValue(bytes) => { - // Show up to 16 bytes in hex for readability let preview: Vec = bytes .iter() .take(16) @@ -27,6 +52,13 @@ impl fmt::Debug for Value { } } +// ─── KeyValuePair ───────────────────────────────────────────────────────────── + +/// A single Key-Value-Pair with an absolute (resolved) key. +/// +/// The delta encoding is handled by [`KeyValuePairs`]; individual pairs always +/// carry their absolute key so they can be compared and looked up without +/// needing ordering context. #[derive(Clone, Eq, PartialEq)] pub struct KeyValuePair { pub key: u64, @@ -51,85 +83,103 @@ impl KeyValuePair { value: Value::BytesValue(value), } } -} - -impl Decode for KeyValuePair { - fn decode(r: &mut R) -> Result { - let key = u64::decode(r)?; - if key % 2 == 0 { - // VarInt variant + /// Decode a single KVP from the wire given the previous absolute type. + /// + /// Returns `(pair, new_absolute_type)`. + pub(crate) fn decode_with_prev( + r: &mut R, + prev: u64, + ) -> Result<(Self, u64), DecodeError> { + let delta = u64::decode(r)?; + + // Draft-16 §1.4.2: prev + delta MUST NOT overflow u64. + let abs_type = prev + .checked_add(delta) + .ok_or(DecodeError::KvpTypeOverflow)?; + + let pair = if abs_type % 2 == 0 { + // Even type → varint value. let value = u64::decode(r)?; - tracing::trace!("[KVP] Decoded even key={}, value={}", key, value); - Ok(KeyValuePair::new_int(key, value)) + KeyValuePair::new_int(abs_type, value) } else { - // Bytes variant + // Odd type → length-prefixed bytes value. let length = usize::decode(r)?; - tracing::trace!("[KVP] Decoded odd key={}, length={}", key, length); - if length > u16::MAX as usize { - tracing::error!( - "[KVP] Length exceeded! key={}, length={} (max={})", - key, - length, - u16::MAX - ); + if length > MAX_BYTES_VALUE_LEN { return Err(DecodeError::KeyValuePairLengthExceeded()); } - - Self::decode_remaining(r, length)?; - let mut buf = vec![0; length]; + ::decode_remaining(r, length)?; + let mut buf = vec![0u8; length]; r.copy_to_slice(&mut buf); - Ok(KeyValuePair::new_bytes(key, buf)) - } + KeyValuePair::new_bytes(abs_type, buf) + }; + + Ok((pair, abs_type)) } -} -impl Encode for KeyValuePair { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + /// Encode a single KVP onto the wire given the previous absolute type. + /// + /// Writes the delta and value; returns the new absolute type (== `self.key`). + pub(crate) fn encode_with_prev( + &self, + w: &mut W, + prev: u64, + ) -> Result { + // Keys must be consistent with their value parity. + match &self.value { + Value::IntValue(_) if self.key % 2 != 0 => return Err(EncodeError::InvalidValue), + Value::BytesValue(_) if self.key % 2 == 0 => return Err(EncodeError::InvalidValue), + _ => {} + } + + // Delta MUST NOT underflow (keys must be in ascending order within a sequence). + let delta = self.key.checked_sub(prev).ok_or(EncodeError::KvpKeyOrder)?; + delta.encode(w)?; + match &self.value { Value::IntValue(v) => { - // key must be even for IntValue - if !self.key.is_multiple_of(2) { - return Err(EncodeError::InvalidValue); - } - self.key.encode(w)?; (*v).encode(w)?; - Ok(()) } Value::BytesValue(v) => { - // key must be odd for BytesValue - if self.key.is_multiple_of(2) { - return Err(EncodeError::InvalidValue); + if v.len() > MAX_BYTES_VALUE_LEN { + return Err(EncodeError::KeyValuePairLengthExceeded); } - self.key.encode(w)?; v.len().encode(w)?; - Self::encode_remaining(w, v.len())?; + ::encode_remaining(w, v.len())?; w.put_slice(v); - Ok(()) } } + + Ok(self.key) } } +// Note: `KeyValuePair` intentionally does NOT implement `Decode`/`Encode` +// directly, because a single pair is always decoded/encoded with a `prev` +// context. Use `decode_with_prev` / `encode_with_prev` (or go through +// `KeyValuePairs` / `ExtensionHeaders`) so the delta state is always correct. + impl fmt::Debug for KeyValuePair { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{{{}: {:?}}}", self.key, self.value) } } -/// A collection of KeyValuePair entries, where the number of key-value-pairs are encoded/decoded first. -/// This structure is appropriate for Control message parameters. -/// Since duplicate parameters are allowed for unknown parameters, we don't do duplicate checking here. +// ─── KeyValuePairs ──────────────────────────────────────────────────────────── + +/// An ordered, count-prefixed sequence of [`KeyValuePair`] entries. +/// +/// On the wire the keys are delta-encoded from 0, so they MUST be in +/// non-decreasing order. Internally keys are stored as absolute values. #[derive(Default, Clone, Eq, PartialEq)] pub struct KeyValuePairs(pub Vec); -// TODO: These set/get API's all assume no duplicate keys. We can add API's to support duplicates if needed. impl KeyValuePairs { pub fn new() -> Self { Self::default() } - /// Insert or replace a KeyValuePair with the same key. + /// Insert or replace the pair with matching key. pub fn set(&mut self, kvp: KeyValuePair) { if let Some(existing) = self.0.iter_mut().find(|k| k.key == kvp.key) { *existing = kvp; @@ -153,16 +203,31 @@ impl KeyValuePairs { pub fn get(&self, key: u64) -> Option<&KeyValuePair> { self.0.iter().find(|k| k.key == key) } + + /// Return `true` if any key appears more than once. + pub fn has_duplicate_keys(&self) -> bool { + let mut seen = std::collections::HashSet::new(); + self.0.iter().any(|k| !seen.insert(k.key)) + } } impl Decode for KeyValuePairs { - fn decode(mut r: &mut R) -> Result { - let mut kvps = Vec::new(); - + fn decode(r: &mut R) -> Result { let count = u64::decode(r)?; + + // `count` is peer-controlled, so do not allocate directly from it. + // This is only a capacity hint: with the bytes currently buffered, at + // most `remaining / MIN_KVP_WIRE_LEN` pairs can be decoded before the + // normal decode loop asks for more bytes via `DecodeError::More`. + let count_capacity = usize::try_from(count).unwrap_or(usize::MAX); + let payload_capacity = r.remaining() / MIN_KVP_WIRE_LEN; + let mut kvps = Vec::with_capacity(count_capacity.min(payload_capacity)); + let mut prev = 0u64; + for _ in 0..count { - let kvp = KeyValuePair::decode(&mut r)?; - kvps.push(kvp); + let (pair, new_prev) = KeyValuePair::decode_with_prev(r, prev)?; + prev = new_prev; + kvps.push(pair); } Ok(KeyValuePairs(kvps)) @@ -171,10 +236,16 @@ impl Decode for KeyValuePairs { impl Encode for KeyValuePairs { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.0.len().encode(w)?; + // Sort a working copy by ascending key before encoding so the delta is + // always non-negative. The internal Vec is not required to be sorted. + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|k| k.key); + + sorted.len().encode(w)?; - for kvp in &self.0 { - kvp.encode(w)?; + let mut prev = 0u64; + for kvp in &sorted { + prev = kvp.encode_with_prev(w, prev)?; } Ok(()) @@ -194,84 +265,272 @@ impl fmt::Debug for KeyValuePairs { } } +// ─── Tests ──────────────────────────────────────────────────────────────────── + #[cfg(test)] mod tests { use super::*; - use bytes::Bytes; use bytes::BytesMut; + // ── single pair helpers ─────────────────────────────────────────────────── + + fn round_trip_pair(pairs: &[(u64, Value)]) -> Vec { + // Build the wire bytes by hand for a sequence of pairs (delta-encoded). + // This is the canonical wire format we expect encode/decode to match. + let mut buf = BytesMut::new(); + let mut prev = 0u64; + for (key, value) in pairs { + let delta = key - prev; + delta.encode(&mut buf).unwrap(); + match value { + Value::IntValue(v) => v.encode(&mut buf).unwrap(), + Value::BytesValue(b) => { + b.len().encode(&mut buf).unwrap(); + buf.extend_from_slice(b); + } + } + prev = *key; + } + buf.to_vec() + } + + // ── encode / decode single pair ─────────────────────────────────────────── + + #[test] + fn single_int_pair_roundtrip() { + // key=0 (even, int), value=42 → delta=0, value=42 + let mut buf = BytesMut::new(); + let kvps = KeyValuePairs(vec![KeyValuePair::new_int(0, 42)]); + kvps.encode(&mut buf).unwrap(); + // wire: count=1, delta=0, value=42 + assert_eq!(buf.to_vec(), vec![0x01, 0x00, 0x2a]); + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + assert_eq!(decoded, kvps); + } + + #[test] + fn single_bytes_pair_roundtrip() { + // key=1 (odd, bytes), value=[0xAB, 0xCD] + let mut buf = BytesMut::new(); + let kvps = KeyValuePairs(vec![KeyValuePair::new_bytes(1, vec![0xAB, 0xCD])]); + kvps.encode(&mut buf).unwrap(); + // wire: count=1, delta=1, length=2, 0xAB, 0xCD + assert_eq!(buf.to_vec(), vec![0x01, 0x01, 0x02, 0xAB, 0xCD]); + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + assert_eq!(decoded, kvps); + } + + // ── delta encoding ──────────────────────────────────────────────────────── + + #[test] + fn delta_encoding_multiple_pairs() { + // Three pairs: key=0 (int=1), key=2 (int=2), key=100 (int=3) + // Deltas on wire: 0, 2, 98 + let mut buf = BytesMut::new(); + let mut kvps = KeyValuePairs::new(); + kvps.set_intvalue(0, 1); + kvps.set_intvalue(2, 2); + kvps.set_intvalue(100, 3); + kvps.encode(&mut buf).unwrap(); + + let expected_wire = round_trip_pair(&[ + (0, Value::IntValue(1)), + (2, Value::IntValue(2)), + (100, Value::IntValue(3)), + ]); + // count prefix + assert_eq!(buf[1..], expected_wire[..]); + assert_eq!(buf[0], 0x03); // 3 pairs + + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + assert_eq!(decoded.0.len(), 3); + assert_eq!(decoded.get(0).unwrap().value, Value::IntValue(1)); + assert_eq!(decoded.get(2).unwrap().value, Value::IntValue(2)); + assert_eq!(decoded.get(100).unwrap().value, Value::IntValue(3)); + } + + #[test] + fn encode_sorts_before_delta() { + // Insert out of order; encode must sort them so deltas are non-negative. + let mut kvps = KeyValuePairs::new(); + kvps.set_intvalue(100, 99); + kvps.set_intvalue(0, 1); + kvps.set_intvalue(2, 2); + + let mut buf = BytesMut::new(); + kvps.encode(&mut buf).unwrap(); + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + + // All three keys must survive the round-trip regardless of insertion order. + assert_eq!(decoded.0.len(), 3); + assert_eq!(decoded.get(0).unwrap().value, Value::IntValue(1)); + assert_eq!(decoded.get(2).unwrap().value, Value::IntValue(2)); + assert_eq!(decoded.get(100).unwrap().value, Value::IntValue(99)); + } + + // ── parity enforcement ──────────────────────────────────────────────────── + + #[test] + fn encode_rejects_odd_key_with_int_value() { + let kvp = KeyValuePair::new(1, Value::IntValue(0)); // odd key, int value → invalid + let mut buf = BytesMut::new(); + // Wrap in a collection to exercise the encode_with_prev path. + let kvps = KeyValuePairs(vec![kvp]); + assert!(matches!( + kvps.encode(&mut buf).unwrap_err(), + EncodeError::InvalidValue + )); + } + + #[test] + fn encode_rejects_even_key_with_bytes_value() { + let kvp = KeyValuePair::new(0, Value::BytesValue(vec![0x01])); // even key, bytes value → invalid + let kvps = KeyValuePairs(vec![kvp]); + let mut buf = BytesMut::new(); + assert!(matches!( + kvps.encode(&mut buf).unwrap_err(), + EncodeError::InvalidValue + )); + } + + // ── overflow / bounds ───────────────────────────────────────────────────── + + #[test] + fn decode_detects_delta_overflow() { + // A single QUIC varint delta is at most 2^62-1. To overflow u64::MAX via + // cumulative addition we need 5 max-delta pairs: + // 5 * (2^62-1) = 23058430092136939515 > u64::MAX (18446744073709551615) + // + // Abs-type parity after each pair (k * max_delta): + // k=1: 4611686018427387903 (odd → bytes, length=0) + // k=2: 9223372036854775806 (even → int, value=0) + // k=3: 13835058055282163709 (odd → bytes, length=0) + // k=4: 18446744073709551612 (even → int, value=0) + // k=5: overflow → KvpTypeOverflow + let max_delta: u64 = (1u64 << 62) - 1; + let mut buf = BytesMut::new(); + (5u64).encode(&mut buf).unwrap(); // count = 5 + + // Pair 1: abs = max_delta (odd → bytes, length=0) + max_delta.encode(&mut buf).unwrap(); + (0usize).encode(&mut buf).unwrap(); + + // Pair 2: abs = 2*max_delta (even → int, value=0) + max_delta.encode(&mut buf).unwrap(); + (0u64).encode(&mut buf).unwrap(); + + // Pair 3: abs = 3*max_delta (odd → bytes, length=0) + max_delta.encode(&mut buf).unwrap(); + (0usize).encode(&mut buf).unwrap(); + + // Pair 4: abs = 4*max_delta (even → int, value=0) + max_delta.encode(&mut buf).unwrap(); + (0u64).encode(&mut buf).unwrap(); + + // Pair 5: abs = 5*max_delta → overflow + max_delta.encode(&mut buf).unwrap(); + + let err = KeyValuePairs::decode(&mut buf).unwrap_err(); + assert!( + matches!(err, DecodeError::KvpTypeOverflow), + "expected KvpTypeOverflow, got {:?}", + err + ); + } + #[test] - fn encode_decode_keyvaluepair() { + fn decode_rejects_bytes_value_too_long() { + // Craft: count=1, delta=1 (abs=1, odd), length = u16::MAX + 1 let mut buf = BytesMut::new(); + (1u64).encode(&mut buf).unwrap(); // count + (1u64).encode(&mut buf).unwrap(); // delta → abs_type = 1 (odd) + let too_long = MAX_BYTES_VALUE_LEN + 1; + too_long.encode(&mut buf).unwrap(); // length field + + let err = KeyValuePairs::decode(&mut buf).unwrap_err(); + assert!( + matches!(err, DecodeError::KeyValuePairLengthExceeded()), + "expected KeyValuePairLengthExceeded, got {:?}", + err + ); + } - // Type=1, VarInt value=0 - illegal with odd key/type - let kvp = KeyValuePair::new(1, Value::IntValue(0)); - let encoded = kvp.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::InvalidValue)); - - // Type=0, VarInt value=0 - let kvp = KeyValuePair::new(0, Value::IntValue(0)); - kvp.encode(&mut buf).unwrap(); - assert_eq!(buf.to_vec(), vec![0x00, 0x00]); - let decoded = KeyValuePair::decode(&mut buf).unwrap(); - assert_eq!(decoded, kvp); - - // Type=100, VarInt value=100 - let kvp = KeyValuePair::new(100, Value::IntValue(100)); - kvp.encode(&mut buf).unwrap(); - assert_eq!(buf.to_vec(), vec![0x40, 0x64, 0x40, 0x64]); // 2 2-byte VarInts with first 2 bits as 01 - let decoded = KeyValuePair::decode(&mut buf).unwrap(); - assert_eq!(decoded, kvp); - - // Type=0, Bytes value=[1,2,3,4,5] - illegal with even key/type - let kvp = KeyValuePair::new(0, Value::BytesValue(vec![0x01, 0x02, 0x03, 0x04, 0x05])); - let decoded = kvp.encode(&mut buf); - assert!(matches!(decoded.unwrap_err(), EncodeError::InvalidValue)); - - // Type=1, Bytes value=[1,2,3,4,5] - let kvp = KeyValuePair::new(1, Value::BytesValue(vec![0x01, 0x02, 0x03, 0x04, 0x05])); - kvp.encode(&mut buf).unwrap(); - assert_eq!(buf.to_vec(), vec![0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05]); - let decoded = KeyValuePair::decode(&mut buf).unwrap(); - assert_eq!(decoded, kvp); + #[test] + fn decode_large_count_does_not_allocate_count_capacity() { + let mut buf = BytesMut::new(); + ((1u64 << 62) - 1).encode(&mut buf).unwrap(); + + let err = KeyValuePairs::decode(&mut buf).unwrap_err(); + assert!(matches!(err, DecodeError::More(_))); } + // ── duplicate key detection ─────────────────────────────────────────────── + #[test] - fn decode_badtype() { - // Simulate a VarInt value of 5, but with an odd key/type - let data: Vec = vec![0x01, 0x05]; - let mut buf: Bytes = data.into(); - let decoded = KeyValuePair::decode(&mut buf); - assert!(matches!(decoded.unwrap_err(), DecodeError::More(_))); // Framing will be off now + fn has_duplicate_keys_detects_duplicates() { + let mut kvps = KeyValuePairs::new(); + kvps.0.push(KeyValuePair::new_int(0, 1)); + kvps.0.push(KeyValuePair::new_int(0, 2)); // duplicate key + assert!(kvps.has_duplicate_keys()); } #[test] - fn encode_decode_keyvaluepairs() { + fn has_duplicate_keys_no_false_positive() { + let mut kvps = KeyValuePairs::new(); + kvps.set_intvalue(0, 1); + kvps.set_intvalue(2, 2); + assert!(!kvps.has_duplicate_keys()); + } + + // ── empty collection ────────────────────────────────────────────────────── + + #[test] + fn empty_kvps_roundtrip() { + let kvps = KeyValuePairs::new(); let mut buf = BytesMut::new(); + kvps.encode(&mut buf).unwrap(); + assert_eq!(buf.to_vec(), vec![0x00]); // count = 0 + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + assert_eq!(decoded, kvps); + } + + // ── legacy byte-vector compatibility ────────────────────────────────────── + // These verify the existing test expectations from the pre-draft-16 code + // still hold for common patterns. + #[test] + fn existing_single_bytes_compat() { + let mut buf = BytesMut::new(); let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); kvps.encode(&mut buf).unwrap(); assert_eq!( buf.to_vec(), vec![ - 0x01, // 1 KeyValuePair - 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // Key=1, Value=[1,2,3,4,5] + 0x01, // count=1 + 0x01, // delta=1 → abs_type=1 (odd, bytes) + 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // length=5, value ] ); let decoded = KeyValuePairs::decode(&mut buf).unwrap(); assert_eq!(decoded, kvps); + } + #[test] + fn existing_multi_compat() { + let mut buf = BytesMut::new(); let mut kvps = KeyValuePairs::new(); kvps.set_intvalue(0, 0); kvps.set_intvalue(100, 100); kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); kvps.encode(&mut buf).unwrap(); - let buf_vec = buf.to_vec(); - // Validate the encoded length and the KeyValuePair count - assert_eq!(14, buf_vec.len()); // 14 bytes total - assert_eq!(3, buf_vec[0]); // 3 KeyValuePairs let decoded = KeyValuePairs::decode(&mut buf).unwrap(); - assert_eq!(decoded, kvps); + assert_eq!(decoded.0.len(), 3); + assert_eq!(decoded.get(0).unwrap().value, Value::IntValue(0)); + assert_eq!(decoded.get(100).unwrap().value, Value::IntValue(100)); + assert_eq!( + decoded.get(1).unwrap().value, + Value::BytesValue(vec![0x01, 0x02, 0x03, 0x04, 0x05]) + ); } } diff --git a/moq-transport/src/coding/track_namespace.rs b/moq-transport/src/coding/track_namespace.rs index 8299e4a3..ff6f9701 100644 --- a/moq-transport/src/coding/track_namespace.rs +++ b/moq-transport/src/coding/track_namespace.rs @@ -1,23 +1,238 @@ // SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 +//! Track namespace and track name encoding per draft-ietf-moq-transport-16 §2.4.1. +//! +//! Rules enforced here: +//! +//! - A **full** `TrackNamespace` must have 1–32 fields; each field must be ≥ 1 byte. +//! - A **prefix** `TrackNamespacePrefix` (used in `SUBSCRIBE_NAMESPACE`) may have +//! 0–32 fields; each non-empty field must also be ≥ 1 byte. +//! - The total length of a Full Track Name (sum of all namespace field lengths + +//! track name length) MUST NOT exceed 4096 bytes. Validated by message +//! decoders via [`validate_full_track_name`]. +//! - Track names are arbitrary bytes and may be empty. + use super::{Decode, DecodeError, Encode, EncodeError, TupleField}; use core::hash::{Hash, Hasher}; +use std::borrow::Cow; use std::convert::TryFrom; use std::fmt; use thiserror::Error; -/// Error type for TrackNamespace conversion failures +/// Maximum total length of a Full Track Name (namespace fields + track name). +pub const MAX_FULL_TRACK_NAME_LEN: usize = 4096; + +fn namespace_fields_byte_len(fields: &[TupleField]) -> usize { + fields.iter().map(|f| f.value.len()).sum() +} + +fn namespace_path(fields: &[TupleField]) -> String { + let mut path = String::new(); + for field in fields { + path.push('/'); + path.push_str(&String::from_utf8_lossy(&field.value)); + } + path +} + +fn decode_namespace_fields( + r: &mut R, + min_fields: usize, + max_fields: usize, + kind: &str, +) -> Result, DecodeError> { + let count = usize::decode(r)?; + if count < min_fields || count > max_fields { + return Err(DecodeError::FieldBoundsExceeded(format!( + "{kind} must have {min_fields}-{max_fields} fields, got {count}" + ))); + } + + let mut fields = Vec::with_capacity(count); + for _ in 0..count { + let field = TupleField::decode(r)?; + if field.value.is_empty() { + return Err(DecodeError::EmptyNamespaceField); + } + fields.push(field); + } + + Ok(fields) +} + +fn encode_namespace_fields( + fields: &[TupleField], + min_fields: usize, + max_fields: usize, + kind: &str, + w: &mut W, +) -> Result<(), EncodeError> { + if fields.len() < min_fields || fields.len() > max_fields { + return Err(EncodeError::FieldBoundsExceeded(format!( + "{kind} must have {min_fields}-{max_fields} fields" + ))); + } + + fields.len().encode(w)?; + for field in fields { + if field.value.is_empty() { + return Err(EncodeError::EmptyNamespaceField); + } + field.encode(w)?; + } + + Ok(()) +} + +fn validate_track_namespace_fields( + fields: &[TupleField], + min_fields: usize, + max_fields: usize, +) -> Result<(), TrackNamespaceError> { + if fields.len() < min_fields { + return Err(TrackNamespaceError::TooFewFields); + } + if fields.len() > max_fields { + return Err(TrackNamespaceError::TooManyFields(fields.len(), max_fields)); + } + for field in fields { + if field.value.is_empty() { + return Err(TrackNamespaceError::EmptyField); + } + if field.value.len() > TupleField::MAX_VALUE_SIZE { + return Err(TrackNamespaceError::FieldTooLarge( + field.value.len(), + TupleField::MAX_VALUE_SIZE, + )); + } + } + + Ok(()) +} + +/// A Track Name is arbitrary bytes and may be empty. +#[derive(Clone, Default, Eq, Hash, PartialEq)] +pub struct TrackName { + value: Vec, +} + +impl TrackName { + pub fn new(value: Vec) -> Self { + Self { value } + } + + pub fn as_bytes(&self) -> &[u8] { + &self.value + } + + pub fn to_string_lossy(&self) -> Cow<'_, str> { + String::from_utf8_lossy(&self.value) + } +} + +impl AsRef<[u8]> for TrackName { + fn as_ref(&self) -> &[u8] { + self.as_bytes() + } +} + +impl From> for TrackName { + fn from(value: Vec) -> Self { + Self::new(value) + } +} + +impl From<&[u8]> for TrackName { + fn from(value: &[u8]) -> Self { + Self::new(value.to_vec()) + } +} + +impl From for TrackName { + fn from(value: String) -> Self { + Self::new(value.into_bytes()) + } +} + +impl From<&str> for TrackName { + fn from(value: &str) -> Self { + Self::new(value.as_bytes().to_vec()) + } +} + +impl From<&String> for TrackName { + fn from(value: &String) -> Self { + Self::from(value.as_str()) + } +} + +impl From<&TrackName> for TrackName { + fn from(value: &TrackName) -> Self { + value.clone() + } +} + +impl Decode for TrackName { + fn decode(r: &mut R) -> Result { + let size = usize::decode(r)?; + if size > MAX_FULL_TRACK_NAME_LEN { + return Err(DecodeError::TrackNameTooLong); + } + Self::decode_remaining(r, size)?; + + let mut value = vec![0; size]; + r.copy_to_slice(&mut value); + Ok(Self { value }) + } +} + +impl Encode for TrackName { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + if self.value.len() > MAX_FULL_TRACK_NAME_LEN { + return Err(EncodeError::FieldBoundsExceeded("TrackName".to_string())); + } + self.value.len().encode(w)?; + Self::encode_remaining(w, self.value.len())?; + w.put_slice(&self.value); + Ok(()) + } +} + +impl fmt::Debug for TrackName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{self}") + } +} + +impl fmt::Display for TrackName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.to_string_lossy()) + } +} + +// ─── Errors ─────────────────────────────────────────────────────────────────── + #[derive(Debug, Clone, Error, PartialEq, Eq)] pub enum TrackNamespaceError { #[error("too many fields: {0} exceeds maximum of {1}")] TooManyFields(usize, usize), + #[error("too few fields: full track namespace requires at least 1 field")] + TooFewFields, + #[error("field too large: {0} bytes exceeds maximum of {1}")] FieldTooLarge(usize, usize), + + #[error("empty field: namespace fields must be at least 1 byte")] + EmptyField, } -/// TrackNamespace +// ─── TrackNamespace (full, 1–32 non-empty fields) ───────────────────────────── + +/// A full Track Namespace: 1–32 non-empty byte fields. +/// +/// Used in `SUBSCRIBE`, `PUBLISH`, `PUBLISH_NAMESPACE`, `TRACK_STATUS`, etc. #[derive(Clone, Default, Eq, PartialEq)] pub struct TrackNamespace { pub fields: Vec, @@ -38,21 +253,33 @@ impl TrackNamespace { self.fields.clear(); } + /// Build from a `/`-separated UTF-8 path (each segment becomes a field). + /// Empty segments (e.g. leading `/`) are included as empty fields and will + /// fail validation; callers producing full namespaces should ensure no + /// empty segments. pub fn from_utf8_path(path: &str) -> Self { - let mut tuple = TrackNamespace::new(); + let mut ns = TrackNamespace::new(); for part in path.split('/') { - tuple.add(TupleField::from_utf8(part)); + ns.add(TupleField::from_utf8(part)); } - tuple + ns } pub fn to_utf8_path(&self) -> String { - let mut path = String::new(); - for field in &self.fields { - path.push('/'); - path.push_str(&String::from_utf8_lossy(&field.value)); + namespace_path(&self.fields) + } + + /// Sum of all field lengths. Used for full-track-name limit calculation. + pub fn namespace_byte_len(&self) -> usize { + namespace_fields_byte_len(&self.fields) + } + + fn validate_namespace_byte_len(&self) -> Result<(), DecodeError> { + if self.namespace_byte_len() > MAX_FULL_TRACK_NAME_LEN { + return Err(DecodeError::TrackNameTooLong); } - path + + Ok(()) } } @@ -64,46 +291,34 @@ impl Hash for TrackNamespace { impl Decode for TrackNamespace { fn decode(r: &mut R) -> Result { - let count = usize::decode(r)?; - if count > Self::MAX_FIELDS { - return Err(DecodeError::FieldBoundsExceeded( - "TrackNamespace tuples".to_string(), - )); - } - - let mut fields = Vec::new(); - for _ in 0..count { - fields.push(TupleField::decode(r)?); - } - Ok(Self { fields }) + // Draft-16 §2.4.1: full namespaces have 1-32 non-empty fields. + let fields = decode_namespace_fields(r, 1, Self::MAX_FIELDS, "TrackNamespace")?; + let namespace = Self { fields }; + namespace.validate_namespace_byte_len()?; + Ok(namespace) } } impl Encode for TrackNamespace { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - if self.fields.len() > Self::MAX_FIELDS { + if self.namespace_byte_len() > MAX_FULL_TRACK_NAME_LEN { return Err(EncodeError::FieldBoundsExceeded( - "TrackNamespace tuples".to_string(), + "TrackNamespace".to_string(), )); } - self.fields.len().encode(w)?; - for field in &self.fields { - field.encode(w)?; - } - Ok(()) + encode_namespace_fields(&self.fields, 1, Self::MAX_FIELDS, "TrackNamespace", w) } } impl fmt::Debug for TrackNamespace { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // Just reuse the Display formatting write!(f, "{self}") } } impl fmt::Display for TrackNamespace { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{0}", self.to_utf8_path()) + write!(f, "{0}", namespace_path(&self.fields)) } } @@ -111,20 +326,7 @@ impl TryFrom> for TrackNamespace { type Error = TrackNamespaceError; fn try_from(fields: Vec) -> Result { - if fields.len() > Self::MAX_FIELDS { - return Err(TrackNamespaceError::TooManyFields( - fields.len(), - Self::MAX_FIELDS, - )); - } - for field in &fields { - if field.value.len() > TupleField::MAX_VALUE_SIZE { - return Err(TrackNamespaceError::FieldTooLarge( - field.value.len(), - TupleField::MAX_VALUE_SIZE, - )); - } - } + validate_track_namespace_fields(&fields, 1, Self::MAX_FIELDS)?; Ok(Self { fields }) } } @@ -168,13 +370,105 @@ impl TryFrom> for TrackNamespace { } } +// ─── TrackNamespacePrefix (0–32 fields, for SUBSCRIBE_NAMESPACE) ────────────── + +/// A Track Namespace Prefix used in `SUBSCRIBE_NAMESPACE`. +/// +/// Unlike [`TrackNamespace`], a prefix is allowed to have 0 fields (matching +/// all namespaces). Fields that are present must still be non-empty. +#[derive(Clone, Default, Eq, PartialEq)] +pub struct TrackNamespacePrefix { + pub fields: Vec, +} + +impl TrackNamespacePrefix { + pub const MAX_FIELDS: usize = 32; + + pub fn new() -> Self { + Self::default() + } + + pub fn from_utf8_path(path: &str) -> Self { + let mut prefix = TrackNamespacePrefix::new(); + for part in path.split('/').filter(|s| !s.is_empty()) { + prefix.fields.push(TupleField::from_utf8(part)); + } + prefix + } + + pub fn to_utf8_path(&self) -> String { + namespace_path(&self.fields) + } +} + +impl Hash for TrackNamespacePrefix { + fn hash(&self, state: &mut H) { + self.fields.hash(state); + } +} + +impl Decode for TrackNamespacePrefix { + fn decode(r: &mut R) -> Result { + // Draft-16 §9.25: prefixes have 0-32 non-empty fields. + let fields = decode_namespace_fields(r, 0, Self::MAX_FIELDS, "TrackNamespacePrefix")?; + Ok(Self { fields }) + } +} + +impl Encode for TrackNamespacePrefix { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + encode_namespace_fields(&self.fields, 0, Self::MAX_FIELDS, "TrackNamespacePrefix", w) + } +} + +impl fmt::Debug for TrackNamespacePrefix { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{self}") + } +} + +impl fmt::Display for TrackNamespacePrefix { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{0}", namespace_path(&self.fields)) + } +} + +// ─── Full track name length helper ─────────────────────────────────────────── + +/// Compute the wire-encoded byte length of a Full Track Name. +/// +/// This is the sum of each namespace field's length plus the track name length. +/// If the result exceeds [`MAX_FULL_TRACK_NAME_LEN`] the caller should close +/// the session with PROTOCOL_VIOLATION. +pub fn full_track_name_len(namespace: &TrackNamespace, track_name: &[u8]) -> usize { + namespace.namespace_byte_len() + track_name.len() +} + +/// Validate that a Full Track Name is within the draft-16 4096-byte limit. +pub fn validate_full_track_name( + namespace: &TrackNamespace, + track_name: &[u8], +) -> Result<(), DecodeError> { + if full_track_name_len(namespace, track_name) > MAX_FULL_TRACK_NAME_LEN { + return Err(DecodeError::TrackNameTooLong); + } + + Ok(()) +} + +// ─── Add missing EncodeError variant ───────────────────────────────────────── +// (defined here to keep it adjacent to the validation logic) + +// ─── Tests ──────────────────────────────────────────────────────────────────── + #[cfg(test)] mod tests { use super::*; - use bytes::Bytes; - use bytes::BytesMut; + use bytes::{Bytes, BytesMut}; use std::convert::TryInto; + // ── TrackNamespace encode/decode ────────────────────────────────────────── + #[test] fn encode_decode() { let mut buf = BytesMut::new(); @@ -186,63 +480,115 @@ mod tests { buf.to_vec(), vec![ 0x04, // 4 tuple fields - // Field 1: "test" - 0x04, 0x74, 0x65, 0x73, 0x74, - // Field 2: "path" - 0x04, 0x70, 0x61, 0x74, 0x68, - // Field 3: "to" - 0x02, 0x74, 0x6f, - // Field 4: "resource" - 0x08, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65 + 0x04, 0x74, 0x65, 0x73, 0x74, // "test" + 0x04, 0x70, 0x61, 0x74, 0x68, // "path" + 0x02, 0x74, 0x6f, // "to" + 0x08, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, // "resource" ] ); let decoded = TrackNamespace::decode(&mut buf).unwrap(); assert_eq!(decoded, t); + } + + #[test] + fn track_name_allows_non_utf8_bytes() { + let name = TrackName::from(vec![0xff, 0x00, b'a']); + let mut buf = BytesMut::new(); + + name.encode(&mut buf).unwrap(); + let decoded = TrackName::decode(&mut buf).unwrap(); + + assert_eq!(decoded.as_bytes(), &[0xff, 0x00, b'a']); + } - // Alternate construction + #[test] + fn encode_single_field() { + let mut buf = BytesMut::new(); let mut t = TrackNamespace::new(); t.add(TupleField::from_utf8("test")); t.encode(&mut buf).unwrap(); - assert_eq!( - buf.to_vec(), - vec![ - 0x01, // 1 tuple field - // Field 1: "test" - 0x04, 0x74, 0x65, 0x73, 0x74 - ] - ); + assert_eq!(buf.to_vec(), vec![0x01, 0x04, 0x74, 0x65, 0x73, 0x74]); let decoded = TrackNamespace::decode(&mut buf).unwrap(); assert_eq!(decoded, t); } + // ── 0-field rejection ───────────────────────────────────────────────────── + #[test] - fn encode_too_large() { + fn decode_zero_fields_is_error() { + // wire: count = 0 + let data: Vec = vec![0x00]; + let mut buf: Bytes = data.into(); + let err = TrackNamespace::decode(&mut buf).unwrap_err(); + assert!( + matches!(err, DecodeError::FieldBoundsExceeded(_)), + "expected FieldBoundsExceeded, got {:?}", + err + ); + } + + #[test] + fn encode_zero_fields_is_error() { + let t = TrackNamespace::new(); // empty let mut buf = BytesMut::new(); + assert!(matches!( + t.encode(&mut buf).unwrap_err(), + EncodeError::FieldBoundsExceeded(_) + )); + } + + // ── >32 fields rejection ────────────────────────────────────────────────── + #[test] + fn encode_too_large() { + let mut buf = BytesMut::new(); let mut t = TrackNamespace::new(); - for i in 0..TrackNamespace::MAX_FIELDS + 1 { - t.add(TupleField::from_utf8(&format!("field{}", i))); + for i in 0..=TrackNamespace::MAX_FIELDS { + t.add(TupleField::from_utf8(&format!("f{}", i))); } - - let encoded = t.encode(&mut buf); assert!(matches!( - encoded.unwrap_err(), + t.encode(&mut buf).unwrap_err(), EncodeError::FieldBoundsExceeded(_) )); } #[test] fn decode_too_large() { - let mut data: Vec = vec![0x00; 256]; // Create a vector with 256 bytes - data[0] = (TrackNamespace::MAX_FIELDS + 1) as u8; // Set first byte (count) to 33 as a VarInt + let mut data: Vec = vec![0x00; 256]; + data[0] = (TrackNamespace::MAX_FIELDS + 1) as u8; // count = 33 let mut buf: Bytes = data.into(); - let decoded = TrackNamespace::decode(&mut buf); assert!(matches!( - decoded.unwrap_err(), + TrackNamespace::decode(&mut buf).unwrap_err(), DecodeError::FieldBoundsExceeded(_) )); } + // ── empty field rejection ───────────────────────────────────────────────── + + #[test] + fn decode_empty_field_is_error() { + // wire: count=1, field length=0 + let data: Vec = vec![0x01, 0x00]; // count=1, field_len=0 + let mut buf: Bytes = data.into(); + assert!(matches!( + TrackNamespace::decode(&mut buf).unwrap_err(), + DecodeError::EmptyNamespaceField + )); + } + + #[test] + fn encode_empty_field_is_error() { + let mut t = TrackNamespace::new(); + t.add(TupleField { value: vec![] }); // empty field + let mut buf = BytesMut::new(); + assert!(matches!( + t.encode(&mut buf).unwrap_err(), + EncodeError::EmptyNamespaceField + )); + } + + // ── TryFrom conversions ─────────────────────────────────────────────────── + #[test] fn try_from_str() { let ns: TrackNamespace = "test/path/to/resource".try_into().unwrap(); @@ -282,12 +628,21 @@ mod tests { assert_eq!(ns.to_utf8_path(), "/test/path"); } + #[test] + fn try_from_empty_vec_is_error() { + let fields: Vec = vec![]; + let result: Result = fields.try_into(); + assert!(matches!( + result.unwrap_err(), + TrackNamespaceError::TooFewFields + )); + } + #[test] fn try_from_too_many_fields() { - let mut fields = Vec::new(); - for i in 0..TrackNamespace::MAX_FIELDS + 1 { - fields.push(TupleField::from_utf8(&format!("field{}", i))); - } + let fields: Vec = (0..=TrackNamespace::MAX_FIELDS) + .map(|i| TupleField::from_utf8(&format!("f{}", i))) + .collect(); let result: Result = fields.try_into(); assert!(matches!( result.unwrap_err(), @@ -307,4 +662,117 @@ mod tests { TrackNamespaceError::FieldTooLarge(4097, 4096) )); } + + #[test] + fn try_from_empty_field_is_error() { + let fields = vec![TupleField { value: vec![] }]; + let result: Result = fields.try_into(); + assert!(matches!( + result.unwrap_err(), + TrackNamespaceError::EmptyField + )); + } + + // ── TrackNamespacePrefix ────────────────────────────────────────────────── + + #[test] + fn prefix_allows_zero_fields() { + let prefix = TrackNamespacePrefix::new(); + let mut buf = BytesMut::new(); + prefix.encode(&mut buf).unwrap(); // must not error + assert_eq!(buf.to_vec(), vec![0x00]); + let decoded = TrackNamespacePrefix::decode(&mut buf).unwrap(); + assert_eq!(decoded.fields.len(), 0); + } + + #[test] + fn prefix_roundtrip() { + let prefix = TrackNamespacePrefix::from_utf8_path("example.com/meeting=123"); + let mut buf = BytesMut::new(); + prefix.encode(&mut buf).unwrap(); + let decoded = TrackNamespacePrefix::decode(&mut buf).unwrap(); + assert_eq!(decoded.fields.len(), prefix.fields.len()); + assert_eq!(decoded.to_utf8_path(), prefix.to_utf8_path()); + } + + #[test] + fn prefix_rejects_too_many_fields() { + let mut prefix = TrackNamespacePrefix::new(); + for i in 0..=TrackNamespacePrefix::MAX_FIELDS { + prefix + .fields + .push(TupleField::from_utf8(&format!("f{}", i))); + } + let mut buf = BytesMut::new(); + assert!(matches!( + prefix.encode(&mut buf).unwrap_err(), + EncodeError::FieldBoundsExceeded(_) + )); + } + + #[test] + fn prefix_rejects_empty_field() { + let mut prefix = TrackNamespacePrefix::new(); + prefix.fields.push(TupleField { value: vec![] }); + let mut buf = BytesMut::new(); + assert!(matches!( + prefix.encode(&mut buf).unwrap_err(), + EncodeError::EmptyNamespaceField + )); + } + + // ── full_track_name_len ─────────────────────────────────────────────────── + + #[test] + fn full_track_name_len_basic() { + let ns = TrackNamespace::from_utf8_path("a/b"); // fields: "a" (1 byte), "b" (1 byte) + let track_name = b"mytrack"; // 7 bytes + assert_eq!(full_track_name_len(&ns, track_name), 1 + 1 + 7); + } + + #[test] + fn full_track_name_len_at_limit() { + // Build a namespace with one field of 4088 bytes and an empty track name (0 bytes). + // Total = 4088 ≤ 4096. + let big_field = vec![b'x'; 4088]; + let ns = TrackNamespace { + fields: vec![TupleField { value: big_field }], + }; + assert!(full_track_name_len(&ns, b"") <= MAX_FULL_TRACK_NAME_LEN); + } + + #[test] + fn decode_rejects_namespace_over_full_track_name_limit() { + let mut buf = BytesMut::new(); + (2usize).encode(&mut buf).unwrap(); + let field = vec![b'x'; 2049]; + field.len().encode(&mut buf).unwrap(); + buf.extend_from_slice(&field); + field.len().encode(&mut buf).unwrap(); + buf.extend_from_slice(&field); + + let err = TrackNamespace::decode(&mut buf).unwrap_err(); + assert!(matches!(err, DecodeError::TrackNameTooLong)); + } + + #[test] + fn full_track_name_len_over_limit() { + // 4092 (namespace) + 5 (track name "hello") = 4097 > 4096 + let big_field = vec![b'x'; 4092]; + let ns = TrackNamespace { + fields: vec![TupleField { value: big_field }], + }; + assert!(full_track_name_len(&ns, b"hello") > MAX_FULL_TRACK_NAME_LEN); + } + + #[test] + fn validate_full_track_name_rejects_over_limit() { + let big_field = vec![b'a'; MAX_FULL_TRACK_NAME_LEN]; + let ns = TrackNamespace { + fields: vec![TupleField { value: big_field }], + }; + + let err = validate_full_track_name(&ns, b"x").unwrap_err(); + assert!(matches!(err, DecodeError::TrackNameTooLong)); + } } diff --git a/moq-transport/src/data/datagram.rs b/moq-transport/src/data/datagram.rs index 88c7af96..e3c1e3ba 100644 --- a/moq-transport/src/data/datagram.rs +++ b/moq-transport/src/data/datagram.rs @@ -19,6 +19,23 @@ pub enum DatagramType { ObjectIdStatusExt = 0x21, } +impl DatagramType { + fn has_extension_headers(self) -> bool { + matches!( + self, + Self::ObjectIdPayloadExt + | Self::ObjectIdPayloadExtEndOfGroup + | Self::PayloadExt + | Self::PayloadExtEndOfGroup + | Self::ObjectIdStatusExt + ) + } + + fn has_status(self) -> bool { + matches!(self, Self::ObjectIdStatus | Self::ObjectIdStatusExt) + } +} + impl Decode for DatagramType { fn decode(r: &mut B) -> Result { match u64::decode(r)? { @@ -92,23 +109,29 @@ impl Decode for Datagram { let publisher_priority = u8::decode(r)?; // Decode Extension Headers if required - let extension_headers = match datagram_type { - DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::PayloadExt - | DatagramType::PayloadExtEndOfGroup - | DatagramType::ObjectIdStatusExt => Some(ExtensionHeaders::decode(r)?), - _ => None, + let extension_headers = if datagram_type.has_extension_headers() { + let headers = ExtensionHeaders::decode(r)?; + if headers.is_empty() { + return Err(DecodeError::InvalidValue); + } + Some(headers) + } else { + None }; // Decode Status if required - let status = match datagram_type { - DatagramType::ObjectIdStatus | DatagramType::ObjectIdStatusExt => { - Some(ObjectStatus::decode(r)?) - } - _ => None, + let status = if datagram_type.has_status() { + Some(ObjectStatus::decode(r)?) + } else { + None }; + if status.is_some_and(|status| status != ObjectStatus::NormalObject) + && extension_headers.is_some() + { + return Err(DecodeError::InvalidValue); + } + // Decode Payload if required let payload = match datagram_type { DatagramType::ObjectIdPayload @@ -168,6 +191,9 @@ impl Encode for Datagram { | DatagramType::PayloadExtEndOfGroup | DatagramType::ObjectIdStatusExt => { if let Some(extension_headers) = &self.extension_headers { + if extension_headers.is_empty() { + return Err(EncodeError::InvalidValue); + } extension_headers.encode(w)?; } else { return Err(EncodeError::MissingField("ExtensionHeaders".to_string())); @@ -180,6 +206,9 @@ impl Encode for Datagram { match self.datagram_type { DatagramType::ObjectIdStatus | DatagramType::ObjectIdStatusExt => { if let Some(status) = &self.status { + if self.extension_headers.is_some() && *status != ObjectStatus::NormalObject { + return Err(EncodeError::InvalidValue); + } status.encode(w)?; } else { return Err(EncodeError::MissingField("Status".to_string())); @@ -367,7 +396,7 @@ mod tests { object_id: Some(1234), publisher_priority: 127, extension_headers: None, - status: Some(ObjectStatus::EndOfTrack), + status: Some(ObjectStatus::NormalObject), payload: None, }; msg.encode(&mut buf).unwrap(); @@ -384,7 +413,7 @@ mod tests { object_id: Some(1234), publisher_priority: 127, extension_headers: Some(ext_hdrs.clone()), - status: Some(ObjectStatus::EndOfTrack), + status: Some(ObjectStatus::NormalObject), payload: None, }; msg.encode(&mut buf).unwrap(); @@ -538,4 +567,85 @@ mod tests { // TODO SLG - add tests } + + #[test] + fn decode_rejects_extension_bit_with_zero_length() { + let data = vec![ + 0x01, // ObjectIdPayloadExt + 0x01, // track alias + 0x01, // group id + 0x01, // object id + 0x7f, // publisher priority + 0x00, // extension headers length + ]; + let mut buf: Bytes = data.into(); + + assert!(matches!( + Datagram::decode(&mut buf).unwrap_err(), + DecodeError::InvalidValue + )); + } + + #[test] + fn encode_rejects_extension_bit_with_empty_headers() { + let mut buf = BytesMut::new(); + let msg = Datagram { + datagram_type: DatagramType::ObjectIdPayloadExt, + track_alias: 1, + group_id: 1, + object_id: Some(1), + publisher_priority: 1, + extension_headers: Some(ExtensionHeaders::default()), + status: None, + payload: Some(Bytes::new()), + }; + + assert!(matches!( + msg.encode(&mut buf).unwrap_err(), + EncodeError::InvalidValue + )); + } + + #[test] + fn decode_rejects_non_normal_status_with_extension_headers() { + let data = vec![ + 0x21, // ObjectIdStatusExt + 0x01, // track alias + 0x01, // group id + 0x01, // object id + 0x7f, // publisher priority + 0x02, // extension headers byte length + 0x00, // extension delta type + 0x01, // extension value + 0x04, // EndOfTrack + ]; + let mut buf: Bytes = data.into(); + + assert!(matches!( + Datagram::decode(&mut buf).unwrap_err(), + DecodeError::InvalidValue + )); + } + + #[test] + fn encode_rejects_non_normal_status_with_extension_headers() { + let mut ext_hdrs = ExtensionHeaders::new(); + ext_hdrs.set_intvalue(0, 1); + let mut buf = BytesMut::new(); + let msg = Datagram { + datagram_type: DatagramType::ObjectIdStatusExt, + track_alias: 1, + group_id: 1, + object_id: Some(1), + publisher_priority: 1, + extension_headers: Some(ext_hdrs), + status: Some(ObjectStatus::EndOfTrack), + payload: None, + }; + + assert!(matches!( + msg.encode(&mut buf).unwrap_err(), + EncodeError::InvalidValue + )); + } } diff --git a/moq-transport/src/data/extension_headers.rs b/moq-transport/src/data/extension_headers.rs index e457ce77..c3f48488 100644 --- a/moq-transport/src/data/extension_headers.rs +++ b/moq-transport/src/data/extension_headers.rs @@ -1,23 +1,37 @@ // SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 +//! Extension headers for MoQT data-plane objects (§2.5, §10.2.1.2). +//! +//! On the wire, extension headers are encoded as a **byte-length prefix** +//! followed by a sequence of delta-encoded Key-Value-Pairs (same KVP format +//! as control-plane parameters). The byte-length prefix distinguishes this +//! from [`KeyValuePairs`] which uses a count prefix. +//! +//! Because the pairs share a single running `prev` counter across the whole +//! sequence, the encoder sorts them by ascending key before writing. + use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePair}; use bytes::Buf; use std::fmt; -/// A collection of KeyValuePair entries, where the length in bytes of key-value-pairs are encoded/decoded first. -/// This structure is appropriate for Data plane extension headers. -/// Since duplicate parameters are allowed for unknown extension headers, we don't do duplicate checking here. +/// Smallest possible encoded extension KVP: one-byte delta plus one-byte value. +const MIN_EXTENSION_KVP_WIRE_LEN: usize = 2; + +/// A length-prefixed sequence of delta-encoded Key-Value-Pairs used for +/// data-plane object extension headers. +/// +/// Keys are stored internally as absolute values; delta encoding is applied +/// only on the wire. #[derive(Default, Clone, Eq, PartialEq)] pub struct ExtensionHeaders(pub Vec); -// TODO: These set/get API's all assume no duplicate keys. We can add API's to support duplicates if needed. impl ExtensionHeaders { pub fn new() -> Self { Self::default() } - /// Insert or replace a KeyValuePair with the same key. + /// Insert or replace the entry with matching key. pub fn set(&mut self, kvp: KeyValuePair) { if let Some(existing) = self.0.iter_mut().find(|k| k.key == kvp.key) { *existing = kvp; @@ -49,28 +63,25 @@ impl ExtensionHeaders { impl Decode for ExtensionHeaders { fn decode(r: &mut R) -> Result { - // Read total byte length of the encoded kvps - // Note: this is the difference between KeyValuePairs and ExtensionHeaders. - // KeyValuePairs encodes the count of kvps, whereas ExtensionHeaders encodes the total byte length. + // Extension headers are byte-length prefixed (unlike KeyValuePairs which + // are count-prefixed). let length = usize::decode(r)?; - - // Ensure we have that many bytes available in the input Self::decode_remaining(r, length)?; - // If zero length, return empty map if length == 0 { return Ok(ExtensionHeaders::new()); } - // Copy the exact slice that contains the encoded kvps and decode from it - let mut buf = vec![0u8; length]; - r.copy_to_slice(&mut buf); - let mut kvps_bytes = bytes::Bytes::from(buf); + // Decode KVPs from the exact byte slice with a shared prev. + let mut kvps_bytes = r.copy_to_bytes(length); + + let mut kvps = Vec::with_capacity(length / MIN_EXTENSION_KVP_WIRE_LEN); + let mut prev = 0u64; - let mut kvps = Vec::new(); while kvps_bytes.has_remaining() { - let kvp = KeyValuePair::decode(&mut kvps_bytes)?; - kvps.push(kvp); + let (pair, new_prev) = KeyValuePair::decode_with_prev(&mut kvps_bytes, prev)?; + prev = new_prev; + kvps.push(pair); } Ok(ExtensionHeaders(kvps)) @@ -79,14 +90,29 @@ impl Decode for ExtensionHeaders { impl Encode for ExtensionHeaders { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - // Encode all KeyValuePair entries into a temporary buffer to compute total byte length + if self.0.is_empty() { + 0usize.encode(w)?; + return Ok(()); + } + + // Encode into a temporary buffer to measure the byte length before writing + // the length prefix. let mut tmp = bytes::BytesMut::new(); - for kvp in &self.0 { - kvp.encode(&mut tmp)?; + + if self.0.len() == 1 { + self.0[0].encode_with_prev(&mut tmp, 0)?; + } else { + // Sort by ascending key so deltas are always non-negative. + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|k| k.key); + let mut prev = 0u64; + for kvp in &sorted { + prev = kvp.encode_with_prev(&mut tmp, prev)?; + } } - // Write total byte length (u64) followed by the encoded bytes - (tmp.len() as u64).encode(w)?; + // Write the byte-length prefix followed by the encoded pairs. + tmp.len().encode(w)?; w.put_slice(&tmp); Ok(()) @@ -106,38 +132,108 @@ impl fmt::Debug for ExtensionHeaders { } } +// ─── Tests ──────────────────────────────────────────────────────────────────── + #[cfg(test)] mod tests { use super::*; use bytes::BytesMut; + // ── single pair ─────────────────────────────────────────────────────────── + #[test] - fn encode_decode_extension_headers() { + fn single_bytes_pair_roundtrip() { + // key=1 (odd, bytes), value=[0x01..0x05] + // Wire: length=7, delta=1 (prev=0→abs=1), length_of_value=5, bytes let mut buf = BytesMut::new(); - - let mut ext_hdrs = ExtensionHeaders::new(); - ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); - ext_hdrs.encode(&mut buf).unwrap(); + let mut ext = ExtensionHeaders::new(); + ext.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); + ext.encode(&mut buf).unwrap(); assert_eq!( buf.to_vec(), vec![ - 0x07, // 7 bytes total length - 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // Key=1, Value=[1,2,3,4,5] + 0x07, // 7 bytes of KVP data + 0x01, // delta=1 → abs_type=1 (odd, bytes) + 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // length=5, value ] ); let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); - assert_eq!(decoded, ext_hdrs); + assert_eq!(decoded, ext); + } + + // ── multiple pairs with correct delta encoding ──────────────────────────── - let mut ext_hdrs = ExtensionHeaders::new(); - ext_hdrs.set_intvalue(0, 0); // 2 bytes - ext_hdrs.set_intvalue(100, 100); // 4 bytes - ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); // 1 byte key, 1 byte length, 5 bytes data = 7 bytes - ext_hdrs.encode(&mut buf).unwrap(); + #[test] + fn multi_pair_delta_encoding() { + // Three pairs: key=0 (int=0), key=1 (bytes=[1..5]), key=100 (int=100) + // Sorted order on wire: key=0, key=1, key=100 + // Deltas: 0→0=0 (even, int), 0→1=1 (odd, bytes), 1→100=99 (even, int) + let mut ext = ExtensionHeaders::new(); + ext.set_intvalue(0, 0); + ext.set_intvalue(100, 100); + ext.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); + + let mut buf = BytesMut::new(); + ext.encode(&mut buf).unwrap(); + + // Manually compute expected bytes (QUIC varints ≥64 take 2 bytes): + // key=0 (even,int): delta=0 (1B), value=0 (1B) = 2B + // key=1 (odd,bytes): delta=1 (1B), length=5 (1B), 5 bytes = 7B + // key=100 (even,int): delta=99 (2B, since 99≥64), value=100 (2B, since 100≥64) = 4B + // total KVP bytes = 13; length prefix = 1B (13 < 64 so 1B varint) let buf_vec = buf.to_vec(); - // Validate the encoded length and the KeyValuePair's length. - assert_eq!(14, buf_vec.len()); // 14 bytes total (length + 3 kvps) - assert_eq!(13, buf_vec[0]); // 13 bytes for the 3 KeyValuePairs data + assert_eq!(buf_vec[0], 13); // 13 bytes of KVP data + assert_eq!(buf_vec.len(), 14); // 1 (length prefix) + 13 + + // Decode and verify all three pairs survive + let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); + assert_eq!(decoded.0.len(), 3); + assert_eq!( + decoded.get(0).unwrap().value, + crate::coding::Value::IntValue(0) + ); + assert_eq!( + decoded.get(100).unwrap().value, + crate::coding::Value::IntValue(100) + ); + assert_eq!( + decoded.get(1).unwrap().value, + crate::coding::Value::BytesValue(vec![0x01, 0x02, 0x03, 0x04, 0x05]) + ); + } + + // ── round-trip with out-of-order insertion ──────────────────────────────── + + #[test] + fn encode_sorts_before_delta() { + // Insert in reverse order; encode must produce correct ascending deltas. + let mut ext = ExtensionHeaders::new(); + ext.set_intvalue(100, 99); + ext.set_intvalue(0, 1); + + let mut buf = BytesMut::new(); + ext.encode(&mut buf).unwrap(); + let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); + + assert_eq!( + decoded.get(0).unwrap().value, + crate::coding::Value::IntValue(1) + ); + assert_eq!( + decoded.get(100).unwrap().value, + crate::coding::Value::IntValue(99) + ); + } + + // ── empty ───────────────────────────────────────────────────────────────── + + #[test] + fn empty_roundtrip() { + let ext = ExtensionHeaders::new(); + let mut buf = BytesMut::new(); + ext.encode(&mut buf).unwrap(); + assert_eq!(buf.to_vec(), vec![0x00]); // length=0 let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); - assert_eq!(decoded, ext_hdrs); + assert_eq!(decoded, ext); } } diff --git a/moq-transport/src/data/header.rs b/moq-transport/src/data/header.rs index 328b0bdc..b73ddc8c 100644 --- a/moq-transport/src/data/header.rs +++ b/moq-transport/src/data/header.rs @@ -57,6 +57,16 @@ impl StreamHeaderType { | StreamHeaderType::SubgroupIdExtEndOfGroup ) } + + pub fn uses_first_object_id_as_subgroup_id(&self) -> bool { + matches!( + *self, + StreamHeaderType::SubgroupFirstObjectId + | StreamHeaderType::SubgroupFirstObjectIdExt + | StreamHeaderType::SubgroupFirstObjectIdEndOfGroup + | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup + ) + } } impl Encode for StreamHeaderType { @@ -110,7 +120,7 @@ impl Decode for StreamHeaderType { }; if let Ok(header_type_inner) = &header_type { - tracing::debug!( + tracing::trace!( "[DECODE] StreamHeaderType: {}, has_subgroup_id={}, has_extension_headers={}", header_type_inner, header_type_inner.has_subgroup_id(), @@ -175,7 +185,7 @@ impl Decode for StreamHeader { } }; - tracing::debug!( + tracing::trace!( "[DECODE] StreamHeader complete: type={:?}, has_subgroup={}, has_fetch={}, buffer_remaining={} bytes", header_type, subgroup_header.is_some(), @@ -225,7 +235,7 @@ impl Encode for StreamHeader { return Err(EncodeError::MissingField("FetchHeader".to_string())); } - tracing::debug!("[ENCODE] StreamHeader complete"); + tracing::trace!("[ENCODE] StreamHeader complete"); Ok(()) } @@ -258,6 +268,12 @@ mod tests { assert!(ht.is_subgroup()); assert!(!ht.is_fetch()); assert!(!ht.has_subgroup_id()); + + let ht = StreamHeaderType::SubgroupFirstObjectId; + assert!(ht.uses_first_object_id_as_subgroup_id()); + + let ht = StreamHeaderType::SubgroupId; + assert!(!ht.uses_first_object_id_as_subgroup_id()); } #[test] diff --git a/moq-transport/src/data/object_status.rs b/moq-transport/src/data/object_status.rs index df2eb2bc..8c43b3f0 100644 --- a/moq-transport/src/data/object_status.rs +++ b/moq-transport/src/data/object_status.rs @@ -3,11 +3,20 @@ use crate::coding::{Decode, DecodeError, Encode, EncodeError}; +/// Object status values per draft-ietf-moq-transport-16 §10.2.1.1. +/// +/// Note: value 0x1 (`ObjectDoesNotExist`) was present in earlier drafts but +/// was removed in draft-16. Any received value other than 0x0, 0x3, or 0x4 +/// is treated as a protocol error. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum ObjectStatus { + /// 0x0 — Normal object with payload. For non-zero length objects this + /// status is implicit; a zero-length object must encode it explicitly. NormalObject = 0x0, - ObjectDoesNotExist = 0x1, + /// 0x3 — End of Group. No objects with this Group ID and an Object ID + /// greater than or equal to the one specified will exist. EndOfGroup = 0x3, + /// 0x4 — End of Track. No objects at or beyond this location exist. EndOfTrack = 0x4, } @@ -15,7 +24,6 @@ impl Decode for ObjectStatus { fn decode(r: &mut B) -> Result { match u64::decode(r)? { 0x0 => Ok(Self::NormalObject), - 0x1 => Ok(Self::ObjectDoesNotExist), 0x3 => Ok(Self::EndOfGroup), 0x4 => Ok(Self::EndOfTrack), _ => Err(DecodeError::InvalidObjectStatus), @@ -30,3 +38,61 @@ impl Encode for ObjectStatus { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use bytes::{Bytes, BytesMut}; + + #[test] + fn encode_decode_valid_statuses() { + let mut buf = BytesMut::new(); + + for status in [ + ObjectStatus::NormalObject, + ObjectStatus::EndOfGroup, + ObjectStatus::EndOfTrack, + ] { + status.encode(&mut buf).unwrap(); + let decoded = ObjectStatus::decode(&mut buf).unwrap(); + assert_eq!(decoded, status); + } + } + + #[test] + fn decode_rejects_removed_does_not_exist_value() { + // 0x1 was ObjectDoesNotExist in pre-draft-16 but is no longer valid. + let data = vec![0x01u8]; + let mut buf: Bytes = data.into(); + assert!(matches!( + ObjectStatus::decode(&mut buf).unwrap_err(), + crate::coding::DecodeError::InvalidObjectStatus + )); + } + + #[test] + fn decode_rejects_unknown_status() { + let data = vec![0x02u8]; + let mut buf: Bytes = data.into(); + assert!(matches!( + ObjectStatus::decode(&mut buf).unwrap_err(), + crate::coding::DecodeError::InvalidObjectStatus + )); + } + + #[test] + fn normal_object_wire_value_is_zero() { + // Objects with payload use NormalObject = 0x0. + assert_eq!(ObjectStatus::NormalObject as u64, 0x0); + } + + #[test] + fn end_of_group_wire_value_is_three() { + assert_eq!(ObjectStatus::EndOfGroup as u64, 0x3); + } + + #[test] + fn end_of_track_wire_value_is_four() { + assert_eq!(ObjectStatus::EndOfTrack as u64, 0x4); + } +} diff --git a/moq-transport/src/data/subgroup.rs b/moq-transport/src/data/subgroup.rs index 82d61495..bc7158e0 100644 --- a/moq-transport/src/data/subgroup.rs +++ b/moq-transport/src/data/subgroup.rs @@ -70,7 +70,7 @@ impl SubgroupHeader { publisher_priority, }; - tracing::debug!( + tracing::trace!( "[DECODE] SubgroupHeader complete: track_alias={}, group_id={}, subgroup_id={:?}, priority={}", result.track_alias, result.group_id, @@ -135,7 +135,7 @@ impl Encode for SubgroupHeader { ); let bytes_written = start_pos - w.remaining_mut(); - tracing::debug!( + tracing::trace!( "[ENCODE] SubgroupHeader complete: wrote {} bytes", bytes_written ); @@ -184,7 +184,7 @@ impl Decode for SubgroupObject { //Self::decode_remaining(r, payload_length); //let payload = r.copy_to_bytes(payload_length); - tracing::debug!( + tracing::trace!( "[DECODE] SubgroupObject complete: object_id_delta={}, payload_length={}, status={:?}, buffer_remaining={} bytes", object_id_delta, payload_length, @@ -234,7 +234,7 @@ impl Encode for SubgroupObject { //Self::encode_remaining(w, self.payload.len())?; //w.put_slice(&self.payload); - tracing::debug!("[ENCODE] SubgroupObject complete"); + tracing::trace!("[ENCODE] SubgroupObject complete"); Ok(()) } @@ -290,10 +290,16 @@ impl Decode for SubgroupObjectExt { } }; + if status.is_some_and(|status| status != ObjectStatus::NormalObject) + && !extension_headers.is_empty() + { + return Err(DecodeError::InvalidValue); + } + //Self::decode_remaining(r, payload_length); //let payload = r.copy_to_bytes(payload_length); - tracing::debug!( + tracing::trace!( "[DECODE] SubgroupObjectExt complete: object_id_delta={}, payload_length={}, status={:?}, buffer_remaining={} bytes", object_id_delta, payload_length, @@ -338,6 +344,9 @@ impl Encode for SubgroupObjectExt { if self.payload_length == 0 { if let Some(status) = self.status { + if status != ObjectStatus::NormalObject && !self.extension_headers.is_empty() { + return Err(EncodeError::InvalidValue); + } status.encode(w)?; tracing::trace!("[ENCODE] SubgroupObjectExt: encoded status={:?}", status); } else { @@ -348,7 +357,7 @@ impl Encode for SubgroupObjectExt { //Self::encode_remaining(w, self.payload.len())?; //w.put_slice(&self.payload); - tracing::debug!("[ENCODE] SubgroupObjectExt complete"); + tracing::trace!("[ENCODE] SubgroupObjectExt complete"); Ok(()) } @@ -358,6 +367,7 @@ impl Encode for SubgroupObjectExt { #[cfg(test)] mod tests { use super::*; + use bytes::Bytes; use bytes::BytesMut; #[test] @@ -392,4 +402,40 @@ mod tests { let decoded = SubgroupObjectExt::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } + + #[test] + fn decode_rejects_non_normal_status_with_extension_headers() { + let data = vec![ + 0x00, // object id delta + 0x02, // extension headers byte length + 0x00, // extension delta type + 0x01, // extension value + 0x00, // payload length + 0x04, // EndOfTrack + ]; + let mut buf: Bytes = data.into(); + + assert!(matches!( + SubgroupObjectExt::decode(&mut buf).unwrap_err(), + DecodeError::InvalidValue + )); + } + + #[test] + fn encode_rejects_non_normal_status_with_extension_headers() { + let mut ext_hdrs = ExtensionHeaders::new(); + ext_hdrs.set_intvalue(0, 1); + let msg = SubgroupObjectExt { + object_id_delta: 0, + extension_headers: ext_hdrs, + payload_length: 0, + status: Some(ObjectStatus::EndOfTrack), + }; + let mut buf = BytesMut::new(); + + assert!(matches!( + msg.encode(&mut buf).unwrap_err(), + EncodeError::InvalidValue + )); + } } diff --git a/moq-transport/src/message/fetch.rs b/moq-transport/src/message/fetch.rs index 89f16480..d8add288 100644 --- a/moq-transport/src/message/fetch.rs +++ b/moq-transport/src/message/fetch.rs @@ -2,14 +2,15 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use crate::coding::{ - Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, + validate_full_track_name, Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, + TrackName, TrackNamespace, }; -use crate::message::{FetchType, GroupOrder}; +use crate::message::FetchType; #[derive(Clone, Debug, Eq, PartialEq)] pub struct StandaloneFetch { pub track_namespace: TrackNamespace, - pub track_name: String, + pub track_name: TrackName, pub start_location: Location, pub end_location: Location, } @@ -17,7 +18,8 @@ pub struct StandaloneFetch { impl Decode for StandaloneFetch { fn decode(r: &mut R) -> Result { let track_namespace = TrackNamespace::decode(r)?; - let track_name = String::decode(r)?; + let track_name = TrackName::decode(r)?; + validate_full_track_name(&track_namespace, track_name.as_bytes())?; let start_location = Location::decode(r)?; let end_location = Location::decode(r)?; @@ -76,12 +78,6 @@ pub struct Fetch { /// The fetch request ID pub id: u64, - /// Subscriber Priority - pub subscriber_priority: u8, - - /// Object delivery order - pub group_order: GroupOrder, - /// Standalone fetch vs Relative Joining fetch vs Absolute Joining fetch pub fetch_type: FetchType, @@ -98,10 +94,6 @@ pub struct Fetch { impl Decode for Fetch { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; - - let subscriber_priority = u8::decode(r)?; - let group_order = GroupOrder::decode(r)?; - let fetch_type = FetchType::decode(r)?; let standalone_fetch: Option; @@ -121,8 +113,6 @@ impl Decode for Fetch { Ok(Self { id, - subscriber_priority, - group_order, fetch_type, standalone_fetch, joining_fetch, @@ -134,10 +124,6 @@ impl Decode for Fetch { impl Encode for Fetch { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; - - self.subscriber_priority.encode(w)?; - self.group_order.encode(w)?; - self.fetch_type.encode(w)?; match self.fetch_type { @@ -181,12 +167,10 @@ mod tests { // FetchType = Standlone let msg = Fetch { id: 12345, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, fetch_type: FetchType::Standalone, standalone_fetch: Some(StandaloneFetch { track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), + track_name: "audiotrack".into(), start_location: Location::new(34, 53), end_location: Location::new(34, 53), }), @@ -200,8 +184,6 @@ mod tests { // FetchType = RelativeJoining let msg = Fetch { id: 12345, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, fetch_type: FetchType::RelativeJoining, standalone_fetch: None, joining_fetch: Some(JoiningFetch { @@ -217,8 +199,6 @@ mod tests { // FetchType = AbsoluteJoining let msg = Fetch { id: 12345, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, fetch_type: FetchType::AbsoluteJoining, standalone_fetch: None, joining_fetch: Some(JoiningFetch { @@ -239,8 +219,6 @@ mod tests { // FetchType = Standlone - missing standalone_fetch let msg = Fetch { id: 12345, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, fetch_type: FetchType::Standalone, standalone_fetch: None, joining_fetch: None, @@ -252,8 +230,6 @@ mod tests { // FetchType = AbsoluteJoining - missing joining_fetch let msg = Fetch { id: 12345, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, fetch_type: FetchType::AbsoluteJoining, standalone_fetch: None, joining_fetch: None, diff --git a/moq-transport/src/message/fetch_error.rs b/moq-transport/src/message/fetch_error.rs deleted file mode 100644 index adfb42db..00000000 --- a/moq-transport/src/message/fetch_error.rs +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct FetchError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for FetchError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for FetchError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/fetch_ok.rs b/moq-transport/src/message/fetch_ok.rs index c9328bcb..0e34f854 100644 --- a/moq-transport/src/message/fetch_ok.rs +++ b/moq-transport/src/message/fetch_ok.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; -use crate::message::GroupOrder; +use crate::message::TrackExtensions; /// A publisher sends a FETCH_OK control message in response to successful fetches. #[derive(Clone, Debug, Eq, PartialEq)] @@ -10,9 +10,6 @@ pub struct FetchOk { /// The Fetch request ID of the Fetch this message is replying to. pub id: u64, - /// Order groups will be delivered in - pub group_order: GroupOrder, - /// True if all objects have been published on this track pub end_of_track: bool, @@ -21,28 +18,25 @@ pub struct FetchOk { /// Optional parameters pub params: KeyValuePairs, + + /// Track extension headers. + pub track_extensions: TrackExtensions, } impl Decode for FetchOk { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; - - let group_order = GroupOrder::decode(r)?; - // GroupOrder enum has Publisher in it, but it's not allowed to be used in this - // FetchOk message, so validate it now so we can return a protocol error. - if group_order == GroupOrder::Publisher { - return Err(DecodeError::InvalidGroupOrder); - } let end_of_track = bool::decode(r)?; let end_location = Location::decode(r)?; let params = KeyValuePairs::decode(r)?; + let track_extensions = TrackExtensions::decode(r)?; Ok(Self { id, - group_order, end_of_track, end_location, params, + track_extensions, }) } } @@ -50,16 +44,10 @@ impl Decode for FetchOk { impl Encode for FetchOk { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; - - // GroupOrder enum has Publisher in it, but it's not allowed to be used in this - // FetchOk message. - if self.group_order == GroupOrder::Publisher { - return Err(EncodeError::InvalidValue); - } - self.group_order.encode(w)?; self.end_of_track.encode(w)?; self.end_location.encode(w)?; self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -80,28 +68,13 @@ mod tests { let msg = FetchOk { id: 12345, - group_order: GroupOrder::Descending, end_of_track: true, end_location: Location::new(2, 3), params: kvps.clone(), + track_extensions: TrackExtensions::default(), }; msg.encode(&mut buf).unwrap(); let decoded = FetchOk::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } - - #[test] - fn encode_bad_group_order() { - let mut buf = BytesMut::new(); - - let msg = FetchOk { - id: 12345, - group_order: GroupOrder::Publisher, - end_of_track: true, - end_location: Location::new(2, 3), - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::InvalidValue)); - } } diff --git a/moq-transport/src/message/fetch_type.rs b/moq-transport/src/message/fetch_type.rs index a027cc6c..1a7a7869 100644 --- a/moq-transport/src/message/fetch_type.rs +++ b/moq-transport/src/message/fetch_type.rs @@ -13,7 +13,7 @@ pub enum FetchType { impl Encode for FetchType { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - let val = *self as u8; + let val = *self as u64; val.encode(w)?; Ok(()) } @@ -21,7 +21,7 @@ impl Encode for FetchType { impl Decode for FetchType { fn decode(r: &mut R) -> Result { - match u8::decode(r)? { + match u64::decode(r)? { 0x1 => Ok(Self::Standalone), 0x2 => Ok(Self::RelativeJoining), 0x3 => Ok(Self::AbsoluteJoining), diff --git a/moq-transport/src/message/mod.rs b/moq-transport/src/message/mod.rs index 8a790097..fdaa620f 100644 --- a/moq-transport/src/message/mod.rs +++ b/moq-transport/src/message/mod.rs @@ -2,227 +2,485 @@ // SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 -//! Low-level message sent over the wire, as defined in the specification. +//! Control messages sent over the bidirectional control stream. //! -//! All of these messages are sent over a bidirectional QUIC stream. -//! This introduces some head-of-line blocking but preserves ordering. -//! The only exception are OBJECT "messages", which are sent over dedicated QUIC streams. +//! Wire format per draft-ietf-moq-transport-16 §9: //! +//! ```text +//! MOQT Control Message { +//! Message Type (i), +//! Message Length (16), ← 16-bit unsigned big-endian +//! Message Payload (..), +//! } +//! ``` +//! +//! The receiver MUST close the session with PROTOCOL_VIOLATION if the +//! payload length does not match Message Length. Unknown message types +//! MUST also close the session. mod fetch; mod fetch_cancel; -mod fetch_error; mod fetch_ok; mod fetch_type; mod filter_type; mod go_away; mod group_order; mod max_request_id; +mod namespace; +mod params; mod pubilsh_namespace_done; mod publish; mod publish_done; -mod publish_error; mod publish_namespace; mod publish_namespace_cancel; -mod publish_namespace_error; -mod publish_namespace_ok; mod publish_ok; mod publisher; +mod request_error; +mod request_ok; +mod request_update; mod requests_blocked; mod subscribe; -mod subscribe_error; mod subscribe_namespace; -mod subscribe_namespace_error; -mod subscribe_namespace_ok; mod subscribe_ok; -mod subscribe_update; mod subscriber; mod track_status; -mod track_status_error; -mod track_status_ok; mod unsubscribe; -mod unsubscribe_namespace; pub use fetch::*; pub use fetch_cancel::*; -pub use fetch_error::*; pub use fetch_ok::*; pub use fetch_type::*; pub use filter_type::*; pub use go_away::*; pub use group_order::*; pub use max_request_id::*; +pub use namespace::*; +pub use params::*; pub use pubilsh_namespace_done::*; pub use publish::*; pub use publish_done::*; -pub use publish_error::*; pub use publish_namespace::*; pub use publish_namespace_cancel::*; -pub use publish_namespace_error::*; -pub use publish_namespace_ok::*; pub use publish_ok::*; pub use publisher::*; +pub use request_error::*; +pub use request_ok::*; +pub use request_update::*; pub use requests_blocked::*; pub use subscribe::*; -pub use subscribe_error::*; pub use subscribe_namespace::*; -pub use subscribe_namespace_error::*; -pub use subscribe_namespace_ok::*; pub use subscribe_ok::*; -pub use subscribe_update::*; pub use subscriber::*; pub use track_status::*; -pub use track_status_error::*; -pub use track_status_ok::*; pub use unsubscribe::*; -pub use unsubscribe_namespace::*; use crate::coding::{Decode, DecodeError, Encode, EncodeError}; +use bytes::Buf as _; use std::fmt; -// Use a macro to generate the message types rather than copy-paste. -// This implements a decode/encode method that uses the specified type. +// Use a macro to generate the Message enum and its encode/decode impls. macro_rules! message_types { {$($name:ident = $val:expr,)*} => { - /// All supported message types. - #[derive(Clone)] - pub enum Message { - $($name($name)),* - } - - impl Decode for Message { - fn decode(r: &mut R) -> Result { - let t = u64::decode(r)?; - let _len = u16::decode(r)?; - - // TODO: Check the length of the message. - - match t { - $($val => { - let msg = $name::decode(r)?; - Ok(Self::$name(msg)) - })* - _ => Err(DecodeError::InvalidMessage(t)), - } - } - } - - impl Encode for Message { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - match self { - $(Self::$name(ref m) => { - self.id().encode(w)?; - - // Find out the length of the message - // by encoding it into a buffer and then encoding the length. - // This is a bit wasteful, but it's the only way to know the length. - // TODO SLG - perhaps we can store the position of the Length field in the BufMut and - // write the length later, to avoid the copy of the message bytes? - let mut buf = Vec::new(); - m.encode(&mut buf).unwrap(); + /// All supported control message types. + #[derive(Clone)] + pub enum Message { + $($name($name)),* + } + + impl Decode for Message { + fn decode(r: &mut R) -> Result { + let t = u64::decode(r)?; + let len = u16::decode(r)? as usize; + + // Enforce the length field: read exactly `len` bytes as the + // payload and decode from that slice, so a truncated or + // overlong payload is detected immediately. + ::decode_remaining(r, len)?; + let mut payload = r.copy_to_bytes(len); + + let msg = match t { + $($val => { + let msg = $name::decode(&mut payload)?; + Ok(Self::$name(msg)) + })* + _ => Err(DecodeError::InvalidMessage(t)), + }?; + + // Any bytes left in the payload slice mean the message was + // shorter than declared — that is a PROTOCOL_VIOLATION. + if payload.has_remaining() { + return Err(DecodeError::InvalidMessage(t)); + } + + Ok(msg) + } + } + + impl Encode for Message { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + match self { + $(Self::$name(ref m) => { + self.id().encode(w)?; + + let mut buf = Vec::new(); + m.encode(&mut buf)?; if buf.len() > u16::MAX as usize { return Err(EncodeError::MsgBoundsExceeded); } (buf.len() as u16).encode(w)?; - // At least don't encode the message twice. - // Instead, write the buffer directly to the writer. Self::encode_remaining(w, buf.len())?; - w.put_slice(&buf); - Ok(()) - },)* - } - } - } - - impl Message { - pub fn id(&self) -> u64 { - match self { - $(Self::$name(_) => { - $val - },)* - } - } - - pub fn name(&self) -> &'static str { - match self { - $(Self::$name(_) => { - stringify!($name) - },)* - } - } - } - - $(impl From<$name> for Message { - fn from(m: $name) -> Self { - Message::$name(m) - } - })* - - impl fmt::Debug for Message { - // Delegate to the message formatter - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - $(Self::$name(ref m) => m.fmt(f),)* - } - } - } + w.put_slice(&buf); + Ok(()) + },)* + } + } + } + + impl Message { + pub fn id(&self) -> u64 { + match self { + $(Self::$name(_) => $val,)* + } + } + + pub fn name(&self) -> &'static str { + match self { + $(Self::$name(_) => stringify!($name),)* + } + } + + /// Return the request ID if this message participates in request ID sequencing. + /// + /// Responses and cancellation messages reference existing request IDs + /// and therefore return `None`. This is used only for request ID + /// sequencing validation on receive. + pub fn sequenced_request_id(&self) -> Option { + match self { + Self::Subscribe(m) => Some(m.id), + Self::RequestUpdate(m) => Some(m.id), + Self::Fetch(m) => Some(m.id), + Self::TrackStatus(m) => Some(m.id), + Self::SubscribeNamespace(m) => Some(m.id), + Self::Publish(m) => Some(m.id), + Self::PublishNamespace(m) => Some(m.id), + _ => None, + } + } + } + + $(impl From<$name> for Message { + fn from(m: $name) -> Self { + Message::$name(m) + } + })* + + impl fmt::Debug for Message { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + $(Self::$name(ref m) => m.fmt(f),)* + } + } + } } } -// Each message is prefixed with the given VarInt type. +// Wire IDs per draft-ietf-moq-transport-16 Table 1. message_types! { - // NOTE: Setup messages are in another module. - // SetupClient = 0x20 - // SetupServer = 0x21 - // SetupClient = 0x40 // legacy, used in draft versions <= 10 - // SetupServer = 0x41 // legacy, used in draft versions <= 10 - - // Misc - GoAway = 0x10, - MaxRequestId = 0x15, - RequestsBlocked = 0x1a, + // NOTE: Setup messages live in a separate module (setup::Client/Server). + + // ── Shared request responses (new in draft-16) ─────────────────────────── + RequestUpdate = 0x2, + RequestError = 0x5, // draft-16: REQUEST_ERROR + RequestOk = 0x7, // draft-16: REQUEST_OK + + // ── SUBSCRIBE family ───────────────────────────────────────────────────── + Subscribe = 0x3, + SubscribeOk = 0x4, + Unsubscribe = 0xa, - // SUBSCRIBE family, sent by subscriber - SubscribeUpdate = 0x2, - Subscribe = 0x3, - Unsubscribe = 0xa, - // SUBSCRIBE family, sent by publisher - SubscribeOk = 0x4, - SubscribeError = 0x5, - - // ANNOUNCE family, sent by publisher - PublishNamespace = 0x6, - PublishNamespaceDone = 0x9, - // ANNOUNCE family, sent by subscriber - PublishNamespaceOk = 0x7, - PublishNamespaceError = 0x8, - PublishNamespaceCancel = 0xc, - - // TRACK_STATUS family, sent by subscriber - TrackStatus = 0xd, - // TRACK_STATUS family, sent by publisher - TrackStatusOk = 0xe, - TrackStatusError = 0xf, - - // NAMESPACE family, sent by subscriber + // ── PUBLISH_NAMESPACE family ────────────────────────────────────────────── + PublishNamespace = 0x6, + Namespace = 0x8, + PublishNamespaceDone = 0x9, + NamespaceDone = 0xe, + PublishNamespaceCancel = 0xc, + + // ── TRACK_STATUS ────────────────────────────────────────────────────────── + TrackStatus = 0xd, + + // ── PUBLISH family ──────────────────────────────────────────────────────── + Publish = 0x1d, + PublishDone = 0xb, + PublishOk = 0x1e, + + // ── FETCH family ───────────────────────────────────────────────────────── + Fetch = 0x16, + FetchCancel = 0x17, + FetchOk = 0x18, + + // ── SUBSCRIBE_NAMESPACE (bidi stream; §9.25) ────────────────────────────── SubscribeNamespace = 0x11, - UnsubscribeNamespace = 0x14, - // NAMESPACE family, sent by publisher - SubscribeNamespaceOk = 0x12, - SubscribeNamespaceError = 0x13, - - // FETCH family, sent by subscriber - Fetch = 0x16, - FetchCancel = 0x17, - // FETCH family, sent by publisher - FetchOk = 0x18, - FetchError = 0x19, - - // PUBLISH family, sent by publisher - Publish = 0x1d, - PublishDone = 0xb, - // PUBLISH family, sent by subscriber - PublishOk = 0x1e, - PublishError = 0x1f, + + // ── Session management ──────────────────────────────────────────────────── + GoAway = 0x10, + MaxRequestId = 0x15, + RequestsBlocked = 0x1a, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::coding::{ + KeyValuePairs, Location, ReasonPhrase, TrackNamespace, TrackNamespacePrefix, + }; + + fn namespace() -> TrackNamespace { + TrackNamespace::from_utf8_path("test/ns") + } + + fn assert_sequenced(msg: Message, id: u64) { + assert_eq!(msg.sequenced_request_id(), Some(id)); + } + + fn assert_not_sequenced(msg: Message) { + assert_eq!(msg.sequenced_request_id(), None); + } + + #[test] + fn sequenced_request_id_covers_all_request_start_messages() { + assert_sequenced( + Message::Subscribe(Subscribe { + id: 0, + track_namespace: namespace(), + track_name: "track".into(), + params: KeyValuePairs::default(), + }), + 0, + ); + + assert_sequenced( + Message::RequestUpdate(RequestUpdate { + id: 2, + existing_request_id: 0, + params: KeyValuePairs::default(), + }), + 2, + ); + + assert_sequenced( + Message::Fetch(Fetch { + id: 4, + fetch_type: FetchType::Standalone, + standalone_fetch: Some(StandaloneFetch { + track_namespace: namespace(), + track_name: "track".into(), + start_location: Location::new(0, 0), + end_location: Location::new(0, 1), + }), + joining_fetch: None, + params: KeyValuePairs::default(), + }), + 4, + ); + + assert_sequenced( + Message::TrackStatus(TrackStatus { + id: 6, + track_namespace: namespace(), + track_name: "track".into(), + params: KeyValuePairs::default(), + }), + 6, + ); + + assert_sequenced( + Message::SubscribeNamespace(SubscribeNamespace { + id: 8, + track_namespace_prefix: TrackNamespacePrefix::from_utf8_path("test/ns"), + subscribe_options: SubscribeOptions::Both, + params: KeyValuePairs::default(), + }), + 8, + ); + + assert_sequenced( + Message::Publish(Publish { + id: 10, + track_namespace: namespace(), + track_name: "track".into(), + track_alias: 1, + params: KeyValuePairs::default(), + track_extensions: TrackExtensions::default(), + }), + 10, + ); + + assert_sequenced( + Message::PublishNamespace(PublishNamespace { + id: 12, + track_namespace: namespace(), + params: KeyValuePairs::default(), + }), + 12, + ); + } + + #[test] + fn sequenced_request_id_ignores_messages_that_reference_existing_requests() { + assert_not_sequenced(Message::RequestOk(RequestOk { + id: 0, + params: KeyValuePairs::default(), + })); + + assert_not_sequenced(Message::RequestError(RequestError { + id: 0, + error_code: 0, + retry_interval: 0, + reason: ReasonPhrase(String::new()), + })); + + assert_not_sequenced(Message::SubscribeOk(SubscribeOk { + id: 0, + track_alias: 1, + params: KeyValuePairs::default(), + track_extensions: TrackExtensions::default(), + })); + + assert_not_sequenced(Message::Unsubscribe(Unsubscribe { id: 0 })); + + assert_not_sequenced(Message::FetchCancel(FetchCancel { id: 0 })); + + assert_not_sequenced(Message::FetchOk(FetchOk { + id: 0, + end_of_track: false, + end_location: Location::new(0, 0), + params: KeyValuePairs::default(), + track_extensions: TrackExtensions::default(), + })); + + assert_not_sequenced(Message::PublishOk(PublishOk { + id: 0, + params: KeyValuePairs::default(), + })); + + assert_not_sequenced(Message::PublishDone(PublishDone { + id: 0, + status_code: 0, + stream_count: 0, + reason: ReasonPhrase(String::new()), + })); + } + + #[test] + fn decode_rejects_legacy_stub_message_type() { + let mut buf = bytes::BytesMut::new(); + 0x100u64.encode(&mut buf).unwrap(); + 0u16.encode(&mut buf).unwrap(); + + let err = Message::decode(&mut buf).unwrap_err(); + assert!(matches!(err, DecodeError::InvalidMessage(0x100))); + } + + #[test] + fn draft16_wire_layouts_for_changed_control_messages() { + fn encoded(msg: Message) -> Vec { + let mut buf = bytes::BytesMut::new(); + msg.encode(&mut buf).unwrap(); + buf.to_vec() + } + + let ns = TrackNamespace::from_utf8_path("ns"); + let prefix = TrackNamespacePrefix::new(); + + assert_eq!( + encoded(Message::Subscribe(Subscribe { + id: 0, + track_namespace: ns.clone(), + track_name: "t".into(), + params: KeyValuePairs::default(), + })), + vec![0x03, 0x00, 0x08, 0x00, 0x01, 0x02, b'n', b's', 0x01, b't', 0x00] + ); + + assert_eq!( + encoded(Message::SubscribeOk(SubscribeOk { + id: 0, + track_alias: 1, + params: KeyValuePairs::default(), + track_extensions: TrackExtensions::default(), + })), + vec![0x04, 0x00, 0x03, 0x00, 0x01, 0x00] + ); + + assert_eq!( + encoded(Message::TrackStatus(TrackStatus { + id: 0, + track_namespace: ns.clone(), + track_name: "t".into(), + params: KeyValuePairs::default(), + })), + vec![0x0d, 0x00, 0x08, 0x00, 0x01, 0x02, b'n', b's', 0x01, b't', 0x00] + ); + + assert_eq!( + encoded(Message::Publish(Publish { + id: 0, + track_namespace: ns.clone(), + track_name: "t".into(), + track_alias: 5, + params: KeyValuePairs::default(), + track_extensions: TrackExtensions::default(), + })), + vec![0x1d, 0x00, 0x09, 0x00, 0x01, 0x02, b'n', b's', 0x01, b't', 0x05, 0x00] + ); + + assert_eq!( + encoded(Message::PublishOk(PublishOk { + id: 0, + params: KeyValuePairs::default(), + })), + vec![0x1e, 0x00, 0x02, 0x00, 0x00] + ); + + assert_eq!( + encoded(Message::Fetch(Fetch { + id: 0, + fetch_type: FetchType::Standalone, + standalone_fetch: Some(StandaloneFetch { + track_namespace: ns, + track_name: "t".into(), + start_location: Location::new(0, 0), + end_location: Location::new(0, 1), + }), + joining_fetch: None, + params: KeyValuePairs::default(), + })), + vec![ + 0x16, 0x00, 0x0d, 0x00, 0x01, 0x01, 0x02, b'n', b's', 0x01, b't', 0x00, 0x00, 0x00, + 0x01, 0x00 + ] + ); + + assert_eq!( + encoded(Message::FetchOk(FetchOk { + id: 0, + end_of_track: false, + end_location: Location::new(0, 1), + params: KeyValuePairs::default(), + track_extensions: TrackExtensions::default(), + })), + vec![0x18, 0x00, 0x05, 0x00, 0x00, 0x00, 0x01, 0x00] + ); + + assert_eq!( + encoded(Message::SubscribeNamespace(SubscribeNamespace { + id: 0, + track_namespace_prefix: prefix, + subscribe_options: SubscribeOptions::Both, + params: KeyValuePairs::default(), + })), + vec![0x11, 0x00, 0x04, 0x00, 0x00, 0x02, 0x00] + ); + } } diff --git a/moq-transport/src/message/namespace.rs b/moq-transport/src/message/namespace.rs new file mode 100644 index 00000000..3d1676fa --- /dev/null +++ b/moq-transport/src/message/namespace.rs @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc. +// SPDX-License-Identifier: MIT OR Apache-2.0 + +use crate::coding::{Decode, DecodeError, Encode, EncodeError, TrackNamespacePrefix}; + +/// NAMESPACE message sent on a SUBSCRIBE_NAMESPACE response stream. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Namespace { + pub track_namespace_suffix: TrackNamespacePrefix, +} + +impl Decode for Namespace { + fn decode(r: &mut R) -> Result { + let track_namespace_suffix = TrackNamespacePrefix::decode(r)?; + Ok(Self { + track_namespace_suffix, + }) + } +} + +impl Encode for Namespace { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.track_namespace_suffix.encode(w) + } +} + +/// NAMESPACE_DONE message sent on a SUBSCRIBE_NAMESPACE response stream. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct NamespaceDone { + pub track_namespace_suffix: TrackNamespacePrefix, +} + +impl Decode for NamespaceDone { + fn decode(r: &mut R) -> Result { + let track_namespace_suffix = TrackNamespacePrefix::decode(r)?; + Ok(Self { + track_namespace_suffix, + }) + } +} + +impl Encode for NamespaceDone { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.track_namespace_suffix.encode(w) + } +} diff --git a/moq-transport/src/message/params.rs b/moq-transport/src/message/params.rs new file mode 100644 index 00000000..83c19fe5 --- /dev/null +++ b/moq-transport/src/message/params.rs @@ -0,0 +1,424 @@ +// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc. +// SPDX-License-Identifier: MIT OR Apache-2.0 + +use bytes::Buf as _; + +use crate::coding::{ + Decode, DecodeError, Encode, EncodeError, KeyValuePair, KeyValuePairs, Location, Value, +}; +use crate::message::{FilterType, GroupOrder}; + +/// Draft-16 message-parameter type IDs. +pub mod parameter_type { + pub const DELIVERY_TIMEOUT: u64 = 0x02; + pub const AUTHORIZATION_TOKEN: u64 = 0x03; + pub const EXPIRES: u64 = 0x08; + pub const LARGEST_OBJECT: u64 = 0x09; + pub const FORWARD: u64 = 0x10; + pub const SUBSCRIBER_PRIORITY: u64 = 0x20; + pub const SUBSCRIPTION_FILTER: u64 = 0x21; + pub const GROUP_ORDER: u64 = 0x22; + pub const NEW_GROUP_REQUEST: u64 = 0x32; +} + +/// Draft-16 extension-header type IDs. +pub mod extension_type { + pub const DELIVERY_TIMEOUT: u64 = 0x02; + pub const MAX_CACHE_DURATION: u64 = 0x04; + pub const IMMUTABLE_EXTENSIONS: u64 = 0x0B; + pub const DEFAULT_PUBLISHER_PRIORITY: u64 = 0x0E; + pub const DEFAULT_PUBLISHER_GROUP_ORDER: u64 = 0x22; + pub const DYNAMIC_GROUPS: u64 = 0x30; + pub const PRIOR_GROUP_ID_GAP: u64 = 0x3C; + pub const PRIOR_OBJECT_ID_GAP: u64 = 0x3E; +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct SubscriptionFilter { + pub filter_type: FilterType, + pub start_location: Option, + pub end_group_id: Option, +} + +/// Draft-16 Track Extensions are a trailing sequence of KVPs with no count or length prefix. +#[derive(Default, Clone, Debug, Eq, PartialEq)] +pub struct TrackExtensions(pub Vec); + +impl TrackExtensions { + pub fn new() -> Self { + Self::default() + } + + pub fn set_extension(&mut self, kvp: KeyValuePair) { + if let Some(existing) = self.0.iter_mut().find(|k| k.key == kvp.key) { + *existing = kvp; + } else { + self.0.push(kvp); + } + } + + pub fn set_int_extension(&mut self, key: u64, value: u64) { + self.set_extension(KeyValuePair::new_int(key, value)); + } + + pub fn set_bytes_extension(&mut self, key: u64, value: Vec) { + self.set_extension(KeyValuePair::new_bytes(key, value)); + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn delivery_timeout(&self) -> Result, DecodeError> { + get_kvp_int(&self.0, extension_type::DELIVERY_TIMEOUT) + } + + pub fn set_delivery_timeout(&mut self, timeout: u64) { + self.set_int_extension(extension_type::DELIVERY_TIMEOUT, timeout); + } + + pub fn max_cache_duration(&self) -> Result, DecodeError> { + get_kvp_int(&self.0, extension_type::MAX_CACHE_DURATION) + } + + pub fn set_max_cache_duration(&mut self, duration: u64) { + self.set_int_extension(extension_type::MAX_CACHE_DURATION, duration); + } + + pub fn default_publisher_priority(&self) -> Result, DecodeError> { + get_kvp_u8(&self.0, extension_type::DEFAULT_PUBLISHER_PRIORITY) + } + + pub fn set_default_publisher_priority(&mut self, priority: u8) { + self.set_int_extension(extension_type::DEFAULT_PUBLISHER_PRIORITY, priority.into()); + } + + pub fn default_publisher_group_order(&self) -> Result, DecodeError> { + match get_kvp_int(&self.0, extension_type::DEFAULT_PUBLISHER_GROUP_ORDER)? { + Some(1) => Ok(Some(GroupOrder::Ascending)), + Some(2) => Ok(Some(GroupOrder::Descending)), + Some(_) => Err(DecodeError::InvalidGroupOrder), + None => Ok(None), + } + } + + pub fn set_default_publisher_group_order( + &mut self, + group_order: GroupOrder, + ) -> Result<(), EncodeError> { + match group_order { + GroupOrder::Ascending | GroupOrder::Descending => { + self.set_int_extension( + extension_type::DEFAULT_PUBLISHER_GROUP_ORDER, + group_order as u64, + ); + Ok(()) + } + GroupOrder::Publisher => Err(EncodeError::InvalidValue), + } + } + + pub fn dynamic_groups(&self) -> Result, DecodeError> { + match get_kvp_int(&self.0, extension_type::DYNAMIC_GROUPS)? { + Some(0) => Ok(Some(false)), + Some(1) => Ok(Some(true)), + Some(_) => Err(DecodeError::InvalidParameter), + None => Ok(None), + } + } + + pub fn set_dynamic_groups(&mut self, enabled: bool) { + self.set_int_extension(extension_type::DYNAMIC_GROUPS, if enabled { 1 } else { 0 }); + } +} + +impl Decode for TrackExtensions { + fn decode(r: &mut R) -> Result { + let mut extensions = Vec::new(); + let mut prev = 0u64; + + while r.has_remaining() { + let (pair, new_prev) = KeyValuePair::decode_with_prev(r, prev)?; + prev = new_prev; + extensions.push(pair); + } + + Ok(Self(extensions)) + } +} + +impl Encode for TrackExtensions { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|k| k.key); + + let mut prev = 0u64; + for kvp in &sorted { + prev = kvp.encode_with_prev(w, prev)?; + } + + Ok(()) + } +} + +impl SubscriptionFilter { + pub fn largest_object() -> Self { + Self { + filter_type: FilterType::LargestObject, + start_location: None, + end_group_id: None, + } + } +} + +impl Decode for SubscriptionFilter { + fn decode(r: &mut R) -> Result { + let filter_type = FilterType::decode(r)?; + + let (start_location, end_group_id) = match filter_type { + FilterType::AbsoluteStart => (Some(Location::decode(r)?), None), + FilterType::AbsoluteRange => (Some(Location::decode(r)?), Some(u64::decode(r)?)), + FilterType::NextGroupStart | FilterType::LargestObject => (None, None), + }; + + Ok(Self { + filter_type, + start_location, + end_group_id, + }) + } +} + +impl Encode for SubscriptionFilter { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.filter_type.encode(w)?; + + match self.filter_type { + FilterType::AbsoluteStart => { + self.start_location + .ok_or_else(|| EncodeError::MissingField("StartLocation".to_string()))? + .encode(w)?; + } + FilterType::AbsoluteRange => { + self.start_location + .ok_or_else(|| EncodeError::MissingField("StartLocation".to_string()))? + .encode(w)?; + self.end_group_id + .ok_or_else(|| EncodeError::MissingField("EndGroupId".to_string()))? + .encode(w)?; + } + FilterType::NextGroupStart | FilterType::LargestObject => {} + } + + Ok(()) + } +} + +fn encode_bytes_value(value: &T) -> Result, EncodeError> { + let mut buf = Vec::new(); + value.encode(&mut buf)?; + Ok(buf) +} + +fn decode_bytes_value(value: &[u8]) -> Result { + let mut payload = bytes::Bytes::copy_from_slice(value); + let decoded = T::decode(&mut payload)?; + if payload.has_remaining() { + return Err(DecodeError::InvalidParameter); + } + Ok(decoded) +} + +fn get_kvp_int(pairs: &[KeyValuePair], key: u64) -> Result, DecodeError> { + match pairs + .iter() + .find(|kvp| kvp.key == key) + .map(|kvp| &kvp.value) + { + Some(Value::IntValue(value)) => Ok(Some(*value)), + Some(Value::BytesValue(_)) => Err(DecodeError::InvalidParameter), + None => Ok(None), + } +} + +fn get_kvp_u8(pairs: &[KeyValuePair], key: u64) -> Result, DecodeError> { + match get_kvp_int(pairs, key)? { + Some(value) => u8::try_from(value) + .map(Some) + .map_err(|_| DecodeError::InvalidParameter), + None => Ok(None), + } +} + +impl KeyValuePairs { + fn int_parameter(&self, key: u64) -> Result, DecodeError> { + get_kvp_int(&self.0, key) + } + + pub fn set_forward(&mut self, forward: bool) { + self.set_intvalue(parameter_type::FORWARD, if forward { 1 } else { 0 }); + } + + pub fn forward(&self) -> Result, DecodeError> { + match self.int_parameter(parameter_type::FORWARD)? { + Some(0) => Ok(Some(false)), + Some(1) => Ok(Some(true)), + Some(_) => Err(DecodeError::InvalidParameter), + None => Ok(None), + } + } + + pub fn set_subscriber_priority(&mut self, priority: u8) { + self.set_intvalue(parameter_type::SUBSCRIBER_PRIORITY, priority.into()); + } + + pub fn subscriber_priority(&self) -> Result, DecodeError> { + get_kvp_u8(&self.0, parameter_type::SUBSCRIBER_PRIORITY) + } + + pub fn set_group_order(&mut self, group_order: GroupOrder) { + if group_order != GroupOrder::Publisher { + self.set_intvalue(parameter_type::GROUP_ORDER, group_order as u64); + } + } + + pub fn group_order(&self) -> Result, DecodeError> { + match self.int_parameter(parameter_type::GROUP_ORDER)? { + Some(0) => Err(DecodeError::InvalidGroupOrder), + Some(1) => Ok(Some(GroupOrder::Ascending)), + Some(2) => Ok(Some(GroupOrder::Descending)), + Some(_) => Err(DecodeError::InvalidGroupOrder), + None => Ok(None), + } + } + + pub fn set_subscription_filter( + &mut self, + filter: &SubscriptionFilter, + ) -> Result<(), EncodeError> { + self.set_bytesvalue( + parameter_type::SUBSCRIPTION_FILTER, + encode_bytes_value(filter)?, + ); + Ok(()) + } + + pub fn subscription_filter(&self) -> Result, DecodeError> { + match self + .get(parameter_type::SUBSCRIPTION_FILTER) + .map(|kvp| &kvp.value) + { + Some(Value::BytesValue(value)) => decode_bytes_value(value).map(Some), + Some(Value::IntValue(_)) => Err(DecodeError::InvalidParameter), + None => Ok(None), + } + } + + pub fn set_largest_object(&mut self, location: Location) -> Result<(), EncodeError> { + self.set_bytesvalue( + parameter_type::LARGEST_OBJECT, + encode_bytes_value(&location)?, + ); + Ok(()) + } + + pub fn largest_object(&self) -> Result, DecodeError> { + match self + .get(parameter_type::LARGEST_OBJECT) + .map(|kvp| &kvp.value) + { + Some(Value::BytesValue(value)) => decode_bytes_value(value).map(Some), + Some(Value::IntValue(_)) => Err(DecodeError::InvalidParameter), + None => Ok(None), + } + } + + pub fn set_expires(&mut self, expires: u64) { + self.set_intvalue(parameter_type::EXPIRES, expires); + } + + pub fn expires(&self) -> Result, DecodeError> { + self.int_parameter(parameter_type::EXPIRES) + } + + pub fn set_delivery_timeout(&mut self, timeout: u64) { + self.set_intvalue(parameter_type::DELIVERY_TIMEOUT, timeout); + } + + pub fn delivery_timeout(&self) -> Result, DecodeError> { + self.int_parameter(parameter_type::DELIVERY_TIMEOUT) + } + + pub fn set_new_group_request(&mut self, group_id: u64) { + self.set_intvalue(parameter_type::NEW_GROUP_REQUEST, group_id); + } + + pub fn new_group_request(&self) -> Result, DecodeError> { + self.int_parameter(parameter_type::NEW_GROUP_REQUEST) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn key_value_pairs_message_parameter_methods_round_trip_typed_values() { + let mut params = KeyValuePairs::default(); + let filter = SubscriptionFilter { + filter_type: FilterType::AbsoluteRange, + start_location: Some(Location::new(1, 2)), + end_group_id: Some(3), + }; + + params.set_forward(false); + params.set_subscriber_priority(7); + params.set_group_order(GroupOrder::Descending); + params.set_subscription_filter(&filter).unwrap(); + params.set_largest_object(Location::new(4, 5)).unwrap(); + params.set_expires(6); + params.set_delivery_timeout(8); + params.set_new_group_request(9); + + assert_eq!(params.forward().unwrap(), Some(false)); + assert_eq!(params.subscriber_priority().unwrap(), Some(7)); + assert_eq!(params.group_order().unwrap(), Some(GroupOrder::Descending)); + assert_eq!(params.subscription_filter().unwrap(), Some(filter)); + assert_eq!(params.largest_object().unwrap(), Some(Location::new(4, 5))); + assert_eq!(params.expires().unwrap(), Some(6)); + assert_eq!(params.delivery_timeout().unwrap(), Some(8)); + assert_eq!(params.new_group_request().unwrap(), Some(9)); + } + + #[test] + fn track_extensions_methods_round_trip_typed_values() { + let mut extensions = TrackExtensions::default(); + + extensions.set_delivery_timeout(10); + extensions.set_max_cache_duration(20); + extensions.set_default_publisher_priority(30); + extensions + .set_default_publisher_group_order(GroupOrder::Ascending) + .unwrap(); + extensions.set_dynamic_groups(true); + + assert_eq!(extensions.delivery_timeout().unwrap(), Some(10)); + assert_eq!(extensions.max_cache_duration().unwrap(), Some(20)); + assert_eq!(extensions.default_publisher_priority().unwrap(), Some(30)); + assert_eq!( + extensions.default_publisher_group_order().unwrap(), + Some(GroupOrder::Ascending) + ); + assert_eq!(extensions.dynamic_groups().unwrap(), Some(true)); + } + + #[test] + fn track_extensions_rejects_publisher_group_order_sentinel() { + let mut extensions = TrackExtensions::default(); + + assert!(matches!( + extensions.set_default_publisher_group_order(GroupOrder::Publisher), + Err(EncodeError::InvalidValue) + )); + } +} diff --git a/moq-transport/src/message/pubilsh_namespace_done.rs b/moq-transport/src/message/pubilsh_namespace_done.rs index 9fa6799e..256fbe0f 100644 --- a/moq-transport/src/message/pubilsh_namespace_done.rs +++ b/moq-transport/src/message/pubilsh_namespace_done.rs @@ -1,27 +1,30 @@ // SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 -use crate::coding::{Decode, DecodeError, Encode, EncodeError, TrackNamespace}; +//! PUBLISH_NAMESPACE_DONE message (draft-ietf-moq-transport-16 §9.22). +//! +//! Sent by the publisher to stop serving new subscriptions for a namespace. +//! Carries the Request ID of the corresponding PUBLISH_NAMESPACE. + +use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Sent by the publisher to terminate a PUBLISH_NAMESPACE. #[derive(Clone, Debug, Eq, PartialEq)] pub struct PublishNamespaceDone { - pub track_namespace: TrackNamespace, + /// The Request ID of the PUBLISH_NAMESPACE being terminated. + pub id: u64, } impl Decode for PublishNamespaceDone { fn decode(r: &mut R) -> Result { - let track_namespace = TrackNamespace::decode(r)?; - - Ok(Self { track_namespace }) + let id = u64::decode(r)?; + Ok(Self { id }) } } impl Encode for PublishNamespaceDone { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.track_namespace.encode(w)?; - - Ok(()) + self.id.encode(w) } } @@ -33,12 +36,17 @@ mod tests { #[test] fn encode_decode() { let mut buf = BytesMut::new(); - - let msg = PublishNamespaceDone { - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - }; + let msg = PublishNamespaceDone { id: 12345 }; msg.encode(&mut buf).unwrap(); let decoded = PublishNamespaceDone::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } + + #[test] + fn round_trips_id_zero() { + let mut buf = BytesMut::new(); + let msg = PublishNamespaceDone { id: 0 }; + msg.encode(&mut buf).unwrap(); + assert_eq!(PublishNamespaceDone::decode(&mut buf).unwrap(), msg); + } } diff --git a/moq-transport/src/message/publish.rs b/moq-transport/src/message/publish.rs index 2246b68b..677652f7 100644 --- a/moq-transport/src/message/publish.rs +++ b/moq-transport/src/message/publish.rs @@ -2,9 +2,10 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use crate::coding::{ - Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, + validate_full_track_name, Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackName, + TrackNamespace, }; -use crate::message::GroupOrder; +use crate::message::TrackExtensions; /// Sent by publisher to initiate a subscription to a track. #[derive(Clone, Debug, Eq, PartialEq)] @@ -14,17 +15,14 @@ pub struct Publish { /// Track properties pub track_namespace: TrackNamespace, - pub track_name: String, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) + pub track_name: TrackName, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) pub track_alias: u64, - pub group_order: GroupOrder, - pub content_exists: bool, - // The largest object available for this track, if content exists. - pub largest_location: Option, - pub forward: bool, - /// Optional parameters pub params: KeyValuePairs, + + /// Track extension headers. + pub track_extensions: TrackExtensions, } impl Decode for Publish { @@ -32,34 +30,19 @@ impl Decode for Publish { let id = u64::decode(r)?; let track_namespace = TrackNamespace::decode(r)?; - let track_name = String::decode(r)?; + let track_name = TrackName::decode(r)?; + validate_full_track_name(&track_namespace, track_name.as_bytes())?; let track_alias = u64::decode(r)?; - - let group_order = GroupOrder::decode(r)?; - // GroupOrder enum has Publisher in it, but it's not allowed to be used in this - // publish message, so validate it now so we can return a protocol error. - if group_order == GroupOrder::Publisher { - return Err(DecodeError::InvalidGroupOrder); - } - let content_exists = bool::decode(r)?; - let largest_location = match content_exists { - true => Some(Location::decode(r)?), - false => None, - }; - let forward = bool::decode(r)?; - let params = KeyValuePairs::decode(r)?; + let track_extensions = TrackExtensions::decode(r)?; Ok(Self { id, track_namespace, track_name, track_alias, - group_order, - content_exists, - largest_location, - forward, params, + track_extensions, }) } } @@ -71,23 +54,8 @@ impl Encode for Publish { self.track_namespace.encode(w)?; self.track_name.encode(w)?; self.track_alias.encode(w)?; - - // GroupOrder enum has Publisher in it, but it's not allowed to be used in this - // publish message. - if self.group_order == GroupOrder::Publisher { - return Err(EncodeError::InvalidValue); - } - self.group_order.encode(w)?; - self.content_exists.encode(w)?; - if self.content_exists { - if let Some(largest) = &self.largest_location { - largest.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - } - self.forward.encode(w)?; self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -106,74 +74,16 @@ mod tests { let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - // Content exists = true - let msg = Publish { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: true, - largest_location: Some(Location::new(2, 3)), - forward: true, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = Publish::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // Content exists = false let msg = Publish { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), + track_name: "audiotrack".into(), track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: false, - largest_location: None, - forward: true, params: kvps.clone(), + track_extensions: TrackExtensions::default(), }; msg.encode(&mut buf).unwrap(); let decoded = Publish::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } - - #[test] - fn encode_missing_fields() { - let mut buf = BytesMut::new(); - - let msg = Publish { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: true, - largest_location: None, - forward: true, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - } - - #[test] - fn encode_bad_group_order() { - let mut buf = BytesMut::new(); - - let msg = Publish { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - track_alias: 212, - group_order: GroupOrder::Publisher, - content_exists: false, - largest_location: None, - forward: true, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::InvalidValue)); - } } diff --git a/moq-transport/src/message/publish_done.rs b/moq-transport/src/message/publish_done.rs index c198788c..c7aff51f 100644 --- a/moq-transport/src/message/publish_done.rs +++ b/moq-transport/src/message/publish_done.rs @@ -3,7 +3,26 @@ use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; -// TODO SLG - add an enum for status_codes +/// Draft-16 §13.4.3 PUBLISH_DONE codes. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u64)] +pub enum PublishDoneCode { + InternalError = 0x0, + Unauthorized = 0x1, + TrackEnded = 0x2, + SubscriptionEnded = 0x3, + GoingAway = 0x4, + Expired = 0x5, + TooFarBehind = 0x6, + UpdateFailed = 0x8, + MalformedTrack = 0x12, +} + +impl From for u64 { + fn from(value: PublishDoneCode) -> Self { + value as u64 + } +} /// Sent by the publisher to cleanly terminate a Subscription. #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/moq-transport/src/message/publish_error.rs b/moq-transport/src/message/publish_error.rs deleted file mode 100644 index c0e3f857..00000000 --- a/moq-transport/src/message/publish_error.rs +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct PublishError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for PublishError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for PublishError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/publish_namespace_cancel.rs b/moq-transport/src/message/publish_namespace_cancel.rs index 946d1ccb..fb6ba2f6 100644 --- a/moq-transport/src/message/publish_namespace_cancel.rs +++ b/moq-transport/src/message/publish_namespace_cancel.rs @@ -1,27 +1,38 @@ // SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase, TrackNamespace}; +//! PUBLISH_NAMESPACE_CANCEL message (draft-ietf-moq-transport-16 §9.24). +//! +//! Sent by the subscriber to revoke acceptance of a PUBLISH_NAMESPACE, for +//! example when authorization credentials expire. Carries the Request ID of +//! the corresponding PUBLISH_NAMESPACE rather than the namespace itself. -/// Sent by the subscriber to terminate an Announce after PUBLISH_NAMESPACE_OK +use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; + +/// Sent by the subscriber to revoke acceptance of a PUBLISH_NAMESPACE. +/// +/// The publisher may re-send PUBLISH_NAMESPACE with refreshed credentials or +/// discard the associated state. After receiving this, the publisher does NOT +/// send PUBLISH_NAMESPACE_DONE for the same request. #[derive(Clone, Debug, Eq, PartialEq)] pub struct PublishNamespaceCancel { - // Echo back the namespace that was reset - pub track_namespace: TrackNamespace, - // An error code. + /// The Request ID of the PUBLISH_NAMESPACE being cancelled. + pub id: u64, + + /// Error code explaining why the acceptance was revoked. pub error_code: u64, - // An optional, human-readable reason. + + /// Human-readable reason (max 1024 bytes, internal only — not shown to end users). pub reason_phrase: ReasonPhrase, } impl Decode for PublishNamespaceCancel { fn decode(r: &mut R) -> Result { - let track_namespace = TrackNamespace::decode(r)?; + let id = u64::decode(r)?; let error_code = u64::decode(r)?; let reason_phrase = ReasonPhrase::decode(r)?; - Ok(Self { - track_namespace, + id, error_code, reason_phrase, }) @@ -30,11 +41,9 @@ impl Decode for PublishNamespaceCancel { impl Encode for PublishNamespaceCancel { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.track_namespace.encode(w)?; + self.id.encode(w)?; self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) + self.reason_phrase.encode(w) } } @@ -46,14 +55,25 @@ mod tests { #[test] fn encode_decode() { let mut buf = BytesMut::new(); - let msg = PublishNamespaceCancel { - track_namespace: TrackNamespace::from_utf8_path("testpath/video"), - error_code: 0x2, - reason_phrase: ReasonPhrase("Timeout".to_string()), + id: 42, + error_code: 0x1, + reason_phrase: ReasonPhrase("credentials expired".to_string()), }; msg.encode(&mut buf).unwrap(); let decoded = PublishNamespaceCancel::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } + + #[test] + fn encode_decode_empty_reason() { + let mut buf = BytesMut::new(); + let msg = PublishNamespaceCancel { + id: 0, + error_code: 0, + reason_phrase: ReasonPhrase(String::new()), + }; + msg.encode(&mut buf).unwrap(); + assert_eq!(PublishNamespaceCancel::decode(&mut buf).unwrap(), msg); + } } diff --git a/moq-transport/src/message/publish_namespace_error.rs b/moq-transport/src/message/publish_namespace_error.rs deleted file mode 100644 index 6ba721d5..00000000 --- a/moq-transport/src/message/publish_namespace_error.rs +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an PUBLISH_NAMESPACE. -#[derive(Clone, Debug)] -pub struct PublishNamespaceError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for PublishNamespaceError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for PublishNamespaceError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/publish_namespace_ok.rs b/moq-transport/src/message/publish_namespace_ok.rs deleted file mode 100644 index 84dcc1b9..00000000 --- a/moq-transport/src/message/publish_namespace_ok.rs +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError}; - -/// Sent by the subscriber to accept a PUBLISH_NAMESPACE. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct PublishNamespaceOk { - /// The request ID of the PUBLISH_NAMESPACE this message is replying to. - pub id: u64, -} - -impl Decode for PublishNamespaceOk { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - Ok(Self { id }) - } -} - -impl Encode for PublishNamespaceOk { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = PublishNamespaceOk { id: 12345 }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishNamespaceOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/message/publish_ok.rs b/moq-transport/src/message/publish_ok.rs index ddd7c6b0..21ab9b9f 100644 --- a/moq-transport/src/message/publish_ok.rs +++ b/moq-transport/src/message/publish_ok.rs @@ -1,9 +1,7 @@ // SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; -use crate::message::FilterType; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; /// Sent by the subscriber to request all future objects for the given track. /// @@ -13,23 +11,6 @@ pub struct PublishOk { /// The request ID of the Publish this message is replying to. pub id: u64, - /// Forward Flag - pub forward: bool, - - /// Subscriber Priority - pub subscriber_priority: u8, - - /// The order the subscription will be delivered in - pub group_order: GroupOrder, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, - /// Optional parameters pub params: KeyValuePairs, } @@ -37,77 +18,15 @@ pub struct PublishOk { impl Decode for PublishOk { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; - - let forward = bool::decode(r)?; - let subscriber_priority = u8::decode(r)?; - let group_order = GroupOrder::decode(r)?; - - let filter_type = FilterType::decode(r)?; - let start_location: Option; - let end_group_id: Option; - match filter_type { - FilterType::AbsoluteStart => { - start_location = Some(Location::decode(r)?); - end_group_id = None; - } - FilterType::AbsoluteRange => { - start_location = Some(Location::decode(r)?); - end_group_id = Some(u64::decode(r)?); - } - _ => { - start_location = None; - end_group_id = None; - } - } - let params = KeyValuePairs::decode(r)?; - Ok(Self { - id, - forward, - subscriber_priority, - group_order, - filter_type, - start_location, - end_group_id, - params, - }) + Ok(Self { id, params }) } } impl Encode for PublishOk { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; - - self.forward.encode(w)?; - self.subscriber_priority.encode(w)?; - self.group_order.encode(w)?; - - self.filter_type.encode(w)?; - match self.filter_type { - FilterType::AbsoluteStart => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - // Just ignore end_group_id if it happens to be set - } - FilterType::AbsoluteRange => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - if let Some(end) = self.end_group_id { - end.encode(w)?; - } else { - return Err(EncodeError::MissingField("EndGroupId".to_string())); - } - } - _ => {} - } - self.params.encode(w)?; Ok(()) @@ -123,100 +42,15 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - // FilterType = NextGroupStart let msg = PublishOk { id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::NextGroupStart, - start_location: None, - end_group_id: None, params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); let decoded = PublishOk::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); - - // FilterType = AbsoluteStart - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteStart, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // FilterType = AbsoluteRange - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: Some(23456), - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } - - #[test] - fn encode_missing_fields() { - let mut buf = BytesMut::new(); - - // FilterType = AbsoluteStart - missing start_location - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteStart, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing start_location - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing end_group_id - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); } } diff --git a/moq-transport/src/message/publisher.rs b/moq-transport/src/message/publisher.rs index abe1c62d..a85e9b9b 100644 --- a/moq-transport/src/message/publisher.rs +++ b/moq-transport/src/message/publisher.rs @@ -48,18 +48,19 @@ macro_rules! publisher_msgs { } } -// Defines messages that a PUBLISHER would send, or that a SUBSCRIBER would handle +// Defines messages that a PUBLISHER would send, or that a SUBSCRIBER would handle. +// RequestOk and RequestError are shared responses (draft-16 §9.7 / §9.8). publisher_msgs! { + // Namespace advertisement and termination. PublishNamespace, PublishNamespaceDone, + // Publisher-initiated subscriptions. Publish, PublishDone, + // Responses to subscriber-initiated requests. SubscribeOk, - SubscribeError, - TrackStatusOk, - TrackStatusError, + RequestOk, + RequestError, + // FETCH response; FETCH itself is still unsupported by the session layer. FetchOk, - FetchError, - SubscribeNamespaceOk, - SubscribeNamespaceError, } diff --git a/moq-transport/src/message/request_error.rs b/moq-transport/src/message/request_error.rs new file mode 100644 index 00000000..2a7f6dd2 --- /dev/null +++ b/moq-transport/src/message/request_error.rs @@ -0,0 +1,185 @@ +// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc. +// SPDX-License-Identifier: MIT OR Apache-2.0 + +//! REQUEST_ERROR message (draft-ietf-moq-transport-16 §9.8). +//! +//! Sent in response to any request (SUBSCRIBE, FETCH, PUBLISH, +//! SUBSCRIBE_NAMESPACE, PUBLISH_NAMESPACE, TRACK_STATUS, REQUEST_UPDATE). +//! Replaces the per-request error messages from earlier drafts. + +use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; + +/// Draft-16 §13.4.2 REQUEST_ERROR codes. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u64)] +pub enum RequestErrorCode { + InternalError = 0x0, + Unauthorized = 0x1, + Timeout = 0x2, + NotSupported = 0x3, + MalformedAuthToken = 0x4, + ExpiredAuthToken = 0x5, + DoesNotExist = 0x10, + InvalidRange = 0x11, + MalformedTrack = 0x12, + DuplicateSubscription = 0x19, + Uninterested = 0x20, + PrefixOverlap = 0x30, + InvalidJoiningRequestId = 0x32, +} + +impl From for u64 { + fn from(c: RequestErrorCode) -> u64 { + c as u64 + } +} + +/// Sent to reject any request. +/// +/// `retry_interval`: minimum time (ms) before the request SHOULD be sent +/// again, plus one. A value of 0 means the request MUST NOT be retried; +/// a value of 1 means it can be retried immediately. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct RequestError { + /// The Request ID of the message this is rejecting. + pub id: u64, + + /// Error code identifying the reason for rejection. + pub error_code: u64, + + /// Minimum retry delay in milliseconds plus one, or 0 for no retry. + pub retry_interval: u64, + + /// Human-readable reason phrase (UTF-8, max 1024 bytes). + pub reason: ReasonPhrase, +} + +impl RequestError { + /// Convenience constructor from a [`RequestErrorCode`]. + pub fn new(id: u64, code: RequestErrorCode, retry_interval: u64, reason: &str) -> Self { + Self { + id, + error_code: code as u64, + retry_interval, + reason: ReasonPhrase(reason.to_string()), + } + } + + /// Return `true` if this error code indicates the request should not be retried. + pub fn is_fatal(&self) -> bool { + self.retry_interval == 0 + } +} + +impl Decode for RequestError { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let error_code = u64::decode(r)?; + let retry_interval = u64::decode(r)?; + let reason = ReasonPhrase::decode(r)?; + Ok(Self { + id, + error_code, + retry_interval, + reason, + }) + } +} + +impl Encode for RequestError { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.error_code.encode(w)?; + self.retry_interval.encode(w)?; + self.reason.encode(w)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode() { + let mut buf = BytesMut::new(); + let msg = RequestError { + id: 42, + error_code: RequestErrorCode::DoesNotExist as u64, + retry_interval: 0, + reason: ReasonPhrase("track not found".to_string()), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } + + #[test] + fn encode_decode_with_retry() { + let mut buf = BytesMut::new(); + let msg = RequestError::new(10, RequestErrorCode::Timeout, 5001, "upstream timeout"); + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded.id, 10); + assert_eq!(decoded.error_code, RequestErrorCode::Timeout as u64); + assert_eq!(decoded.retry_interval, 5001); + assert!(!decoded.is_fatal()); + } + + #[test] + fn is_fatal_when_retry_interval_zero() { + let msg = RequestError { + id: 1, + error_code: 0, + retry_interval: 0, + reason: ReasonPhrase(String::new()), + }; + assert!(msg.is_fatal()); + } + + #[test] + fn subscribe_rejection_uses_request_error() { + // Verify that subscription rejections can be expressed as REQUEST_ERROR + // with the correct error code. + let mut buf = bytes::BytesMut::new(); + let msg = RequestError::new(0, RequestErrorCode::DoesNotExist, 0, "track not found"); + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded.error_code, RequestErrorCode::DoesNotExist as u64); + assert!(decoded.is_fatal()); + } + + #[test] + fn duplicate_subscription_rejection() { + // Verify DUPLICATE_SUBSCRIPTION can be encoded and decoded. + let mut buf = bytes::BytesMut::new(); + let msg = RequestError::new( + 4, + RequestErrorCode::DuplicateSubscription, + 0, + "duplicate subscription", + ); + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded.id, 4); + assert_eq!( + decoded.error_code, + RequestErrorCode::DuplicateSubscription as u64 + ); + assert_eq!(decoded.retry_interval, 0); + assert!(decoded.is_fatal()); + } + + #[test] + fn not_supported_response_round_trips() { + // Verify that a NOT_SUPPORTED response encodes, decodes, and is fatal (retry_interval=0). + let mut buf = bytes::BytesMut::new(); + let msg = RequestError::new(10, RequestErrorCode::NotSupported, 0, "not supported"); + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded.id, 10); + assert_eq!(decoded.error_code, RequestErrorCode::NotSupported as u64); + assert!(decoded.is_fatal()); + } +} diff --git a/moq-transport/src/message/request_ok.rs b/moq-transport/src/message/request_ok.rs new file mode 100644 index 00000000..1223dd1f --- /dev/null +++ b/moq-transport/src/message/request_ok.rs @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc. +// SPDX-License-Identifier: MIT OR Apache-2.0 + +//! REQUEST_OK message (draft-ietf-moq-transport-16 §9.7). +//! +//! Sent in response to REQUEST_UPDATE, TRACK_STATUS, SUBSCRIBE_NAMESPACE, +//! and PUBLISH_NAMESPACE requests. The Request ID identifies which request +//! this acknowledgement is for. + +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; + +/// Sent to acknowledge a successful request update or status query. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct RequestOk { + /// The Request ID of the message this is replying to. + pub id: u64, + + /// Optional parameters (e.g. LARGEST_OBJECT for TRACK_STATUS responses). + pub params: KeyValuePairs, +} + +impl Decode for RequestOk { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let params = KeyValuePairs::decode(r)?; + Ok(Self { id, params }) + } +} + +impl Encode for RequestOk { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.params.encode(w)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode_no_params() { + let mut buf = BytesMut::new(); + let msg = RequestOk { + id: 42, + params: KeyValuePairs::default(), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } + + #[test] + fn encode_decode_with_params() { + let mut buf = BytesMut::new(); + let mut params = KeyValuePairs::new(); + params.set_intvalue(0x08, 3600); // EXPIRES example + let msg = RequestOk { id: 100, params }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } +} diff --git a/moq-transport/src/message/request_update.rs b/moq-transport/src/message/request_update.rs new file mode 100644 index 00000000..c463e424 --- /dev/null +++ b/moq-transport/src/message/request_update.rs @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc. +// SPDX-License-Identifier: MIT OR Apache-2.0 + +//! REQUEST_UPDATE message (draft-ietf-moq-transport-16 §9.11). +//! +//! The sender of a request (SUBSCRIBE, PUBLISH, FETCH, TRACK_STATUS, +//! PUBLISH_NAMESPACE, SUBSCRIBE_NAMESPACE) sends REQUEST_UPDATE to modify +//! it. The receiver responds with exactly one REQUEST_OK or REQUEST_ERROR. + +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; + +/// Sent to modify an existing request. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct RequestUpdate { + /// New Request ID for this update message. + pub id: u64, + + /// The Request ID of the request being modified. + pub existing_request_id: u64, + + /// Parameters to update. Absent parameters retain their current values. + pub params: KeyValuePairs, +} + +impl Decode for RequestUpdate { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let existing_request_id = u64::decode(r)?; + let params = KeyValuePairs::decode(r)?; + Ok(Self { + id, + existing_request_id, + params, + }) + } +} + +impl Encode for RequestUpdate { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.existing_request_id.encode(w)?; + self.params.encode(w)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode() { + let mut buf = BytesMut::new(); + let mut params = KeyValuePairs::new(); + params.set_intvalue(0x10, 1); // FORWARD=1 + let msg = RequestUpdate { + id: 4, + existing_request_id: 2, + params, + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestUpdate::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } + + #[test] + fn encode_decode_no_params() { + let mut buf = BytesMut::new(); + let msg = RequestUpdate { + id: 6, + existing_request_id: 4, + params: KeyValuePairs::default(), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestUpdate::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } +} diff --git a/moq-transport/src/message/subscribe.rs b/moq-transport/src/message/subscribe.rs index 6ab571c4..f0b823d9 100644 --- a/moq-transport/src/message/subscribe.rs +++ b/moq-transport/src/message/subscribe.rs @@ -3,10 +3,9 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use crate::coding::{ - Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, + validate_full_track_name, Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackName, + TrackNamespace, }; -use crate::message::FilterType; -use crate::message::GroupOrder; /// Sent by the subscriber to request all future objects for the given track. /// @@ -18,22 +17,7 @@ pub struct Subscribe { /// Track properties pub track_namespace: TrackNamespace, - pub track_name: String, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) - - /// Subscriber Priority - pub subscriber_priority: u8, - pub group_order: GroupOrder, - - /// Forward Flag - pub forward: bool, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, + pub track_name: TrackName, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) /// Optional parameters pub params: KeyValuePairs, @@ -44,30 +28,8 @@ impl Decode for Subscribe { let id = u64::decode(r)?; let track_namespace = TrackNamespace::decode(r)?; - let track_name = String::decode(r)?; - - let subscriber_priority = u8::decode(r)?; - let group_order = GroupOrder::decode(r)?; - - let forward = bool::decode(r)?; - - let filter_type = FilterType::decode(r)?; - let start_location: Option; - let end_group_id: Option; - match filter_type { - FilterType::AbsoluteStart => { - start_location = Some(Location::decode(r)?); - end_group_id = None; - } - FilterType::AbsoluteRange => { - start_location = Some(Location::decode(r)?); - end_group_id = Some(u64::decode(r)?); - } - _ => { - start_location = None; - end_group_id = None; - } - } + let track_name = TrackName::decode(r)?; + validate_full_track_name(&track_namespace, track_name.as_bytes())?; let params = KeyValuePairs::decode(r)?; @@ -75,12 +37,6 @@ impl Decode for Subscribe { id, track_namespace, track_name, - subscriber_priority, - group_order, - forward, - filter_type, - start_location, - end_group_id, params, }) } @@ -93,36 +49,6 @@ impl Encode for Subscribe { self.track_namespace.encode(w)?; self.track_name.encode(w)?; - self.subscriber_priority.encode(w)?; - self.group_order.encode(w)?; - - self.forward.encode(w)?; - - self.filter_type.encode(w)?; - match self.filter_type { - FilterType::AbsoluteStart => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - // Just ignore end_group_id if it happens to be set - } - FilterType::AbsoluteRange => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - if let Some(end) = self.end_group_id { - end.encode(w)?; - } else { - return Err(EncodeError::MissingField("EndGroupId".to_string())); - } - } - _ => {} - } - self.params.encode(w)?; Ok(()) @@ -138,56 +64,29 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - // FilterType = NextGroupStart - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::NextGroupStart, - start_location: None, - end_group_id: None, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = Subscribe::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // FilterType = AbsoluteStart let msg = Subscribe { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteStart, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, + track_name: "audiotrack".into(), params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); let decoded = Subscribe::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); + } - // FilterType = AbsoluteRange + #[test] + fn default_params_roundtrip() { + // Verify a minimal SUBSCRIBE with no params still round-trips cleanly. + let mut buf = BytesMut::new(); let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: Some(23456), - params: kvps.clone(), + id: 0, + track_namespace: TrackNamespace::from_utf8_path("a/b"), + track_name: "t".into(), + params: KeyValuePairs::default(), }; msg.encode(&mut buf).unwrap(); let decoded = Subscribe::decode(&mut buf).unwrap(); @@ -195,55 +94,35 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn decode_rejects_full_track_name_over_limit() { let mut buf = BytesMut::new(); - - // FilterType = AbsoluteStart - missing start_location let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteStart, - start_location: None, - end_group_id: None, - params: Default::default(), + id: 0, + track_namespace: TrackNamespace { + fields: vec![crate::coding::TupleField { + value: vec![b'a'; crate::coding::MAX_FULL_TRACK_NAME_LEN], + }], + }, + track_name: "x".into(), + params: KeyValuePairs::default(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - // FilterType = AbsoluteRange - missing start_location - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); + msg.encode(&mut buf).unwrap(); + let err = Subscribe::decode(&mut buf).unwrap_err(); + assert!(matches!(err, DecodeError::TrackNameTooLong)); + } - // FilterType = AbsoluteRange - missing end_group_id + #[test] + fn minimal_wire_format_has_no_fixed_subscription_fields() { + let mut buf = BytesMut::new(); let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: Default::default(), + id: 2, + track_namespace: TrackNamespace::from_utf8_path("ns/v"), + track_name: "track".into(), + params: KeyValuePairs::default(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); + msg.encode(&mut buf).unwrap(); + let decoded = Subscribe::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } } diff --git a/moq-transport/src/message/subscribe_error.rs b/moq-transport/src/message/subscribe_error.rs deleted file mode 100644 index a4761c64..00000000 --- a/moq-transport/src/message/subscribe_error.rs +++ /dev/null @@ -1,45 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct SubscribeError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for SubscribeError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for SubscribeError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/subscribe_namespace.rs b/moq-transport/src/message/subscribe_namespace.rs index 59a38311..6de5f0d0 100644 --- a/moq-transport/src/message/subscribe_namespace.rs +++ b/moq-transport/src/message/subscribe_namespace.rs @@ -1,7 +1,39 @@ // SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; +use crate::coding::{ + Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespacePrefix, +}; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u64)] +pub enum SubscribeOptions { + Publish = 0x00, + Namespace = 0x01, + Both = 0x02, +} + +impl Decode for SubscribeOptions { + fn decode(r: &mut R) -> Result { + let options = u64::decode(r)?; + match options { + 0x00 => Ok(SubscribeOptions::Publish), + 0x01 => Ok(SubscribeOptions::Namespace), + 0x02 => Ok(SubscribeOptions::Both), + _ => Err(DecodeError::InvalidSubscribeOptions(options)), + } + } +} + +impl Encode for SubscribeOptions { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + match self { + SubscribeOptions::Publish => 0x00u64.encode(w), + SubscribeOptions::Namespace => 0x01u64.encode(w), + SubscribeOptions::Both => 0x02u64.encode(w), + } + } +} /// Subscribe Namespace #[derive(Clone, Debug, Eq, PartialEq)] @@ -10,7 +42,9 @@ pub struct SubscribeNamespace { pub id: u64, /// The track namespace prefix - pub track_namespace_prefix: TrackNamespace, + pub track_namespace_prefix: TrackNamespacePrefix, + + pub subscribe_options: SubscribeOptions, /// Optional parameters pub params: KeyValuePairs, @@ -19,12 +53,14 @@ pub struct SubscribeNamespace { impl Decode for SubscribeNamespace { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; - let track_namespace_prefix = TrackNamespace::decode(r)?; + let track_namespace_prefix = TrackNamespacePrefix::decode(r)?; + let subscribe_options = SubscribeOptions::decode(r)?; let params = KeyValuePairs::decode(r)?; Ok(Self { id, track_namespace_prefix, + subscribe_options, params, }) } @@ -34,6 +70,7 @@ impl Encode for SubscribeNamespace { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; self.track_namespace_prefix.encode(w)?; + self.subscribe_options.encode(w)?; self.params.encode(w)?; Ok(()) @@ -55,7 +92,8 @@ mod tests { let msg = SubscribeNamespace { id: 12345, - track_namespace_prefix: TrackNamespace::from_utf8_path("path/prefix"), + track_namespace_prefix: TrackNamespacePrefix::from_utf8_path("path/prefix"), + subscribe_options: SubscribeOptions::Publish, params: kvps, }; msg.encode(&mut buf).unwrap(); diff --git a/moq-transport/src/message/subscribe_namespace_error.rs b/moq-transport/src/message/subscribe_namespace_error.rs deleted file mode 100644 index 10e47f18..00000000 --- a/moq-transport/src/message/subscribe_namespace_error.rs +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct SubscribeNamespaceError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for SubscribeNamespaceError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for SubscribeNamespaceError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/subscribe_namespace_ok.rs b/moq-transport/src/message/subscribe_namespace_ok.rs deleted file mode 100644 index 865f8c32..00000000 --- a/moq-transport/src/message/subscribe_namespace_ok.rs +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError}; - -/// Subscribe Namespace Ok -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct SubscribeNamespaceOk { - /// The SubscribeNamespace request ID this message is replying to. - pub id: u64, -} - -impl Decode for SubscribeNamespaceOk { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - Ok(Self { id }) - } -} - -impl Encode for SubscribeNamespaceOk { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = SubscribeNamespaceOk { id: 12345 }; - msg.encode(&mut buf).unwrap(); - let decoded = SubscribeNamespaceOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/message/subscribe_ok.rs b/moq-transport/src/message/subscribe_ok.rs index 2902a094..65913f57 100644 --- a/moq-transport/src/message/subscribe_ok.rs +++ b/moq-transport/src/message/subscribe_ok.rs @@ -2,8 +2,8 @@ // SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; +use crate::message::TrackExtensions; /// Sent by the publisher to accept a Subscribe. #[derive(Clone, Debug, Eq, PartialEq)] @@ -14,42 +14,25 @@ pub struct SubscribeOk { /// The identifier used for this track in Subgroups or Datagrams. pub track_alias: u64, - /// The time in milliseconds after which the subscription is not longer valid. - pub expires: u64, - - /// Order groups will be delivered in - pub group_order: GroupOrder, - - /// If content_exists, then largest_location is the location of the largest - /// object available for this track - pub content_exists: bool, - pub largest_location: Option, // Only provided if content_exists is 1/true - /// Subscribe Parameters pub params: KeyValuePairs, + + /// Track extension headers. + pub track_extensions: TrackExtensions, } impl Decode for SubscribeOk { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; let track_alias = u64::decode(r)?; - let expires = u64::decode(r)?; - let group_order = GroupOrder::decode(r)?; - let content_exists = bool::decode(r)?; - let largest_location = match content_exists { - true => Some(Location::decode(r)?), - false => None, - }; let params = KeyValuePairs::decode(r)?; + let track_extensions = TrackExtensions::decode(r)?; Ok(Self { id, track_alias, - expires, - group_order, - content_exists, - largest_location, params, + track_extensions, }) } } @@ -58,17 +41,8 @@ impl Encode for SubscribeOk { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; self.track_alias.encode(w)?; - self.expires.encode(w)?; - self.group_order.encode(w)?; - self.content_exists.encode(w)?; - if self.content_exists { - if let Some(largest) = &self.largest_location { - largest.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - } self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -90,11 +64,8 @@ mod tests { let msg = SubscribeOk { id: 12345, track_alias: 100, - expires: 3600, - group_order: GroupOrder::Publisher, - content_exists: true, - largest_location: Some(Location::new(2, 3)), params: kvps.clone(), + track_extensions: TrackExtensions::default(), }; msg.encode(&mut buf).unwrap(); let decoded = SubscribeOk::decode(&mut buf).unwrap(); @@ -102,19 +73,32 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn track_alias_independent_of_request_id() { + // track_alias can differ from the request id — it is chosen by the publisher. let mut buf = BytesMut::new(); + let msg = SubscribeOk { + id: 10, + track_alias: 42, + params: KeyValuePairs::default(), + track_extensions: TrackExtensions::default(), + }; + msg.encode(&mut buf).unwrap(); + let decoded = SubscribeOk::decode(&mut buf).unwrap(); + assert_eq!(decoded.id, 10); + assert_eq!(decoded.track_alias, 42); + } + #[test] + fn encode_decode_no_content() { + let mut buf = BytesMut::new(); let msg = SubscribeOk { - id: 12345, - track_alias: 100, - expires: 3600, - group_order: GroupOrder::Publisher, - content_exists: true, - largest_location: None, - params: Default::default(), + id: 0, + track_alias: 0, + params: KeyValuePairs::default(), + track_extensions: TrackExtensions::default(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); + msg.encode(&mut buf).unwrap(); + let decoded = SubscribeOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } } diff --git a/moq-transport/src/message/subscribe_update.rs b/moq-transport/src/message/subscribe_update.rs deleted file mode 100644 index 671e08f1..00000000 --- a/moq-transport/src/message/subscribe_update.rs +++ /dev/null @@ -1,104 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; - -/// Sent by the subscriber to request all future objects for the given track. -/// -/// Objects will use the provided ID instead of the full track name, to save bytes. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct SubscribeUpdate { - /// The request ID of this request - pub id: u64, - - /// The request ID of the SUBSCRIBE this message is updating. - pub subscription_request_id: u64, - - /// The starting location - pub start_location: Location, - /// The end Group ID, plus 1. A value of 0 means the subscription is open-ended. - pub end_group_id: u64, - - /// Subscriber Priority - pub subscriber_priority: u8, - - /// Forward Flag - pub forward: bool, - - /// Optional parameters - pub params: KeyValuePairs, -} - -impl Decode for SubscribeUpdate { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - - let subscription_request_id = u64::decode(r)?; - - let start_location = Location::decode(r)?; - let end_group_id = u64::decode(r)?; - - let subscriber_priority = u8::decode(r)?; - - let forward = bool::decode(r)?; - - let params = KeyValuePairs::decode(r)?; - - Ok(Self { - id, - subscription_request_id, - start_location, - end_group_id, - subscriber_priority, - forward, - params, - }) - } -} - -impl Encode for SubscribeUpdate { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - - self.subscription_request_id.encode(w)?; - - self.start_location.encode(w)?; - self.end_group_id.encode(w)?; - - self.subscriber_priority.encode(w)?; - - self.forward.encode(w)?; - - self.params.encode(w)?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - // One parameter for testing - let mut kvps = KeyValuePairs::new(); - kvps.set_intvalue(124, 456); - - let msg = SubscribeUpdate { - id: 1000, - subscription_request_id: 924, - start_location: Location::new(1, 1), - end_group_id: 100000, - subscriber_priority: 127, - forward: true, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = SubscribeUpdate::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/message/subscriber.rs b/moq-transport/src/message/subscriber.rs index a211c5e3..5c3eabfe 100644 --- a/moq-transport/src/message/subscriber.rs +++ b/moq-transport/src/message/subscriber.rs @@ -48,19 +48,21 @@ macro_rules! subscriber_msgs { } } -// Defines messages that a SUBSCRIBER would send, or that a PUBLISHER would handle +// Defines messages that a SUBSCRIBER would send, or that a PUBLISHER would handle. subscriber_msgs! { + // Subscriber-initiated requests. Subscribe, - SubscribeUpdate, + RequestUpdate, + // Shared responses used by subscriber-side request handlers. + RequestOk, + RequestError, + // Subscription and fetch control. Unsubscribe, Fetch, FetchCancel, TrackStatus, SubscribeNamespace, - UnsubscribeNamespace, + // Responses/control for publisher-initiated requests. PublishNamespaceCancel, - PublishNamespaceOk, - PublishNamespaceError, PublishOk, - PublishError, } diff --git a/moq-transport/src/message/track_status.rs b/moq-transport/src/message/track_status.rs index 29c2c426..68b88e51 100644 --- a/moq-transport/src/message/track_status.rs +++ b/moq-transport/src/message/track_status.rs @@ -2,10 +2,9 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use crate::coding::{ - Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, + validate_full_track_name, Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackName, + TrackNamespace, }; -use crate::message::FilterType; -use crate::message::GroupOrder; /// A potential subscriber sends a TrackStatus message to obtain information about /// the current status of a given track. @@ -16,22 +15,7 @@ pub struct TrackStatus { /// Track properties pub track_namespace: TrackNamespace, - pub track_name: String, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) - - /// Subscriber Priority - pub subscriber_priority: u8, - pub group_order: GroupOrder, - - /// Forward Flag - pub forward: bool, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, + pub track_name: TrackName, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) /// Optional parameters pub params: KeyValuePairs, @@ -42,30 +26,8 @@ impl Decode for TrackStatus { let id = u64::decode(r)?; let track_namespace = TrackNamespace::decode(r)?; - let track_name = String::decode(r)?; - - let subscriber_priority = u8::decode(r)?; - let group_order = GroupOrder::decode(r)?; - - let forward = bool::decode(r)?; - - let filter_type = FilterType::decode(r)?; - let start_location: Option; - let end_group_id: Option; - match filter_type { - FilterType::AbsoluteStart => { - start_location = Some(Location::decode(r)?); - end_group_id = None; - } - FilterType::AbsoluteRange => { - start_location = Some(Location::decode(r)?); - end_group_id = Some(u64::decode(r)?); - } - _ => { - start_location = None; - end_group_id = None; - } - } + let track_name = TrackName::decode(r)?; + validate_full_track_name(&track_namespace, track_name.as_bytes())?; let params = KeyValuePairs::decode(r)?; @@ -73,12 +35,6 @@ impl Decode for TrackStatus { id, track_namespace, track_name, - subscriber_priority, - group_order, - forward, - filter_type, - start_location, - end_group_id, params, }) } @@ -91,36 +47,6 @@ impl Encode for TrackStatus { self.track_namespace.encode(w)?; self.track_name.encode(w)?; - self.subscriber_priority.encode(w)?; - self.group_order.encode(w)?; - - self.forward.encode(w)?; - - self.filter_type.encode(w)?; - match self.filter_type { - FilterType::AbsoluteStart => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - // Just ignore end_group_id if it happens to be set - } - FilterType::AbsoluteRange => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - if let Some(end) = self.end_group_id { - end.encode(w)?; - } else { - return Err(EncodeError::MissingField("EndGroupId".to_string())); - } - } - _ => {} - } - self.params.encode(w)?; Ok(()) @@ -136,112 +62,17 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - // FilterType = NextGroupStart - let msg = TrackStatus { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::NextGroupStart, - start_location: None, - end_group_id: None, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = TrackStatus::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // FilterType = AbsoluteStart - let msg = TrackStatus { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteStart, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = TrackStatus::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // FilterType = AbsoluteRange let msg = TrackStatus { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: Some(23456), + track_name: "audiotrack".into(), params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); let decoded = TrackStatus::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } - - #[test] - fn encode_missing_fields() { - let mut buf = BytesMut::new(); - - // FilterType = AbsoluteStart - missing start_location - let msg = TrackStatus { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteStart, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing start_location - let msg = TrackStatus { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing end_group_id - let msg = TrackStatus { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - } } diff --git a/moq-transport/src/message/track_status_error.rs b/moq-transport/src/message/track_status_error.rs deleted file mode 100644 index 842c20da..00000000 --- a/moq-transport/src/message/track_status_error.rs +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct TrackStatusError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for TrackStatusError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for TrackStatusError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/track_status_ok.rs b/moq-transport/src/message/track_status_ok.rs deleted file mode 100644 index 35a2911d..00000000 --- a/moq-transport/src/message/track_status_ok.rs +++ /dev/null @@ -1,119 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; -use crate::message::GroupOrder; - -/// Sent by the publisher to accept a Subscribe. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct TrackStatusOk { - /// The request ID of the TRACK_STATUS this message is replying to - pub id: u64, - - /// The identifier used for this track in Subgroups or Datagrams. - pub track_alias: u64, - - /// The time in milliseconds after which the subscription is not longer valid. - pub expires: u64, - - /// Order groups will be delivered in - pub group_order: GroupOrder, - - /// If content_exists, then largest_location is the location of the largest - /// object available for this track - pub content_exists: bool, - pub largest_location: Option, // Only provided if content_exists is 1/true - - /// Subscribe Parameters - pub params: KeyValuePairs, -} - -impl Decode for TrackStatusOk { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let track_alias = u64::decode(r)?; - let expires = u64::decode(r)?; - let group_order = GroupOrder::decode(r)?; - let content_exists = bool::decode(r)?; - let largest_location = match content_exists { - true => Some(Location::decode(r)?), - false => None, - }; - let params = KeyValuePairs::decode(r)?; - - Ok(Self { - id, - track_alias, - expires, - group_order, - content_exists, - largest_location, - params, - }) - } -} - -impl Encode for TrackStatusOk { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.track_alias.encode(w)?; - self.expires.encode(w)?; - self.group_order.encode(w)?; - self.content_exists.encode(w)?; - if self.content_exists { - if let Some(largest) = &self.largest_location { - largest.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - } - self.params.encode(w)?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - // One parameter for testing - let mut kvps = KeyValuePairs::new(); - kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - - let msg = TrackStatusOk { - id: 12345, - track_alias: 100, - expires: 3600, - group_order: GroupOrder::Publisher, - content_exists: true, - largest_location: Some(Location::new(2, 3)), - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = TrackStatusOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } - - #[test] - fn encode_missing_fields() { - let mut buf = BytesMut::new(); - - let msg = TrackStatusOk { - id: 12345, - track_alias: 100, - expires: 3600, - group_order: GroupOrder::Publisher, - content_exists: true, - largest_location: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - } -} diff --git a/moq-transport/src/message/unsubscribe_namespace.rs b/moq-transport/src/message/unsubscribe_namespace.rs deleted file mode 100644 index 4ae21520..00000000 --- a/moq-transport/src/message/unsubscribe_namespace.rs +++ /dev/null @@ -1,45 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use crate::coding::{Decode, DecodeError, Encode, EncodeError, TrackNamespace}; - -/// Unsubscribe Namespace -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct UnsubscribeNamespace { - // Echo back the track namespace prefix from subscribe namespace - pub track_namespace_prefix: TrackNamespace, -} - -impl Decode for UnsubscribeNamespace { - fn decode(r: &mut R) -> Result { - let track_namespace_prefix = TrackNamespace::decode(r)?; - Ok(Self { - track_namespace_prefix, - }) - } -} - -impl Encode for UnsubscribeNamespace { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.track_namespace_prefix.encode(w)?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = UnsubscribeNamespace { - track_namespace_prefix: TrackNamespace::from_utf8_path("test/path/to/resource"), - }; - msg.encode(&mut buf).unwrap(); - let decoded = UnsubscribeNamespace::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/mlog/events.rs b/moq-transport/src/mlog/events.rs index a93820df..e78fa399 100644 --- a/moq-transport/src/mlog/events.rs +++ b/moq-transport/src/mlog/events.rs @@ -2,13 +2,13 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 // TODO: Unimplemented control message events (not yet needed for basic relay interop testing): -// - SubscribeUpdate (parsed/created) // - PublishNamespaceDone (parsed/created) // - PublishNamespaceCancel (parsed/created) -// - TrackStatus, TrackStatusOk, TrackStatusError (parsed/created) -// - SubscribeNamespace, SubscribeNamespaceOk, SubscribeNamespaceError, UnsubscribeNamespace (parsed/created) -// - Fetch, FetchOk, FetchError, FetchCancel (parsed/created) -// - Publish, PublishOk, PublishError, PublishDone (parsed/created) +// - TrackStatus (parsed/created) +// - SubscribeNamespace (parsed/created) +// - RequestUpdate (parsed/created) +// - Fetch, FetchOk, FetchCancel (parsed/created) +// - Publish, PublishOk, PublishDone (parsed/created) // - MaxRequestId (parsed/created) // - RequestsBlocked (parsed/created) // @@ -208,33 +208,29 @@ fn create_control_message_event( } } -/// Create a control_message_parsed event for CLIENT_SETUP +/// Create a control_message_parsed event for CLIENT_SETUP. +/// From draft-16 the setup payload carries only parameters; version is agreed via ALPN. pub fn client_setup_parsed(time: f64, stream_id: u64, msg: &setup::Client) -> Event { - let versions: Vec = msg.versions.0.iter().map(|v| format!("{:?}", v)).collect(); create_control_message_event( time, stream_id, true, "client_setup", - json!( - { - "number_of_supported_versions": msg.versions.0.len(), - "supported_versions": versions, + json!({ "parameters": key_value_pairs_to_vec(&msg.params.0), }), ) } -/// Create a control_message_created event for SERVER_SETUP +/// Create a control_message_created event for SERVER_SETUP. +/// From draft-16 the setup payload carries only parameters; version is agreed via ALPN. pub fn server_setup_created(time: f64, stream_id: u64, msg: &setup::Server) -> Event { create_control_message_event( time, stream_id, false, "server_setup", - json!( - { - "selected_version": format!("{:?}", msg.version), + json!({ "parameters": key_value_pairs_to_vec(&msg.params.0), }), ) @@ -242,26 +238,12 @@ pub fn server_setup_created(time: f64, stream_id: u64, msg: &setup::Server) -> E /// Helper to convert SUBSCRIBE message to JSON fn subscribe_to_json(msg: &message::Subscribe) -> JsonValue { - let mut json = json!({ + json!({ "subscribe_id": msg.id, "track_namespace": msg.track_namespace.to_string(), - "track_name": &msg.track_name, - "subscriber_priority": msg.subscriber_priority, - "group_order": format!("{:?}", msg.group_order), - "filter_type": format!("{:?}", msg.filter_type), + "track_name": msg.track_name.to_string(), "parameters": key_value_pairs_to_vec(&msg.params.0), - }); - - // Add optional fields based on filter type - if let Some(start_loc) = &msg.start_location { - json["start_group"] = json!(start_loc.group_id); - json["start_object"] = json!(start_loc.object_id); - } - if let Some(end_group) = msg.end_group_id { - json["end_group"] = json!(end_group); - } - - json + }) } /// Create a control_message_parsed event for SUBSCRIBE @@ -276,24 +258,12 @@ pub fn subscribe_created(time: f64, stream_id: u64, msg: &message::Subscribe) -> /// Helper to convert SUBSCRIBE_OK message to JSON fn subscribe_ok_to_json(msg: &message::SubscribeOk) -> JsonValue { - let mut json = json!({ + json!({ "subscribe_id": msg.id, "track_alias": msg.track_alias, - "expires": msg.expires, - "group_order": format!("{:?}", msg.group_order), - "content_exists": msg.content_exists, "parameters": key_value_pairs_to_vec(&msg.params.0), - }); - - // Add optional largest_location fields if content exists - if msg.content_exists { - if let Some(largest) = &msg.largest_location { - json["largest_group_id"] = json!(largest.group_id); - json["largest_object_id"] = json!(largest.object_id); - } - } - - json + "track_extensions": key_value_pairs_to_vec(&msg.track_extensions.0), + }) } /// Create a control_message_parsed event for SUBSCRIBE_OK @@ -318,37 +288,6 @@ pub fn subscribe_ok_created(time: f64, stream_id: u64, msg: &message::SubscribeO ) } -/// Helper to convert SUBSCRIBE_ERROR message to JSON -fn subscribe_error_to_json(msg: &message::SubscribeError) -> JsonValue { - json!({ - "subscribe_id": msg.id, - "error_code": msg.error_code, - "reason_phrase": &msg.reason_phrase.0, - }) -} - -/// Create a control_message_parsed event for SUBSCRIBE_ERROR -pub fn subscribe_error_parsed(time: f64, stream_id: u64, msg: &message::SubscribeError) -> Event { - create_control_message_event( - time, - stream_id, - true, - "subscribe_error", - subscribe_error_to_json(msg), - ) -} - -/// Create a control_message_created event for SUBSCRIBE_ERROR -pub fn subscribe_error_created(time: f64, stream_id: u64, msg: &message::SubscribeError) -> Event { - create_control_message_event( - time, - stream_id, - false, - "subscribe_error", - subscribe_error_to_json(msg), - ) -} - /// Helper to convert PUBLISH_NAMESPACE message to JSON fn publish_namespace_to_json(msg: &message::PublishNamespace) -> JsonValue { json!({ @@ -388,79 +327,85 @@ pub fn publish_namespace_created( ) } -/// Helper to convert PUBLISH_NAMESPACE_OK message to JSON -fn publish_namespace_ok_to_json(msg: &message::PublishNamespaceOk) -> JsonValue { +fn request_ok_to_json(request_kind: &str, msg: &message::RequestOk) -> JsonValue { json!({ "request_id": msg.id, + "request_kind": request_kind, + "parameters": key_value_pairs_to_vec(&msg.params.0), }) } -/// Create a control_message_parsed event for PUBLISH_NAMESPACE_OK (was ANNOUNCE_OK) -pub fn publish_namespace_ok_parsed( +/// Create a control_message_parsed event for REQUEST_OK. +pub fn request_ok_parsed( time: f64, stream_id: u64, - msg: &message::PublishNamespaceOk, + request_kind: &str, + msg: &message::RequestOk, ) -> Event { create_control_message_event( time, stream_id, true, - "publish_namespace_ok", - publish_namespace_ok_to_json(msg), + "request_ok", + request_ok_to_json(request_kind, msg), ) } -/// Create a control_message_created event for PUBLISH_NAMESPACE_OK -pub fn publish_namespace_ok_created( +/// Create a control_message_created event for REQUEST_OK. +pub fn request_ok_created( time: f64, stream_id: u64, - msg: &message::PublishNamespaceOk, + request_kind: &str, + msg: &message::RequestOk, ) -> Event { create_control_message_event( time, stream_id, false, - "publish_namespace_ok", - publish_namespace_ok_to_json(msg), + "request_ok", + request_ok_to_json(request_kind, msg), ) } -/// Helper to convert PUBLISH_NAMESPACE_ERROR message to JSON -fn publish_namespace_error_to_json(msg: &message::PublishNamespaceError) -> JsonValue { +fn request_error_to_json(request_kind: &str, msg: &message::RequestError) -> JsonValue { json!({ "request_id": msg.id, + "request_kind": request_kind, "error_code": msg.error_code, - "reason_phrase": &msg.reason_phrase.0, + "retry_interval": msg.retry_interval, + "reason_phrase": &msg.reason.0, }) } -/// Create a control_message_parsed event for PUBLISH_NAMESPACE_ERROR (was ANNOUNCE_ERROR) -pub fn publish_namespace_error_parsed( +/// Create a control_message_parsed event for REQUEST_ERROR. +pub fn request_error_parsed( time: f64, stream_id: u64, - msg: &message::PublishNamespaceError, + request_kind: &str, + msg: &message::RequestError, ) -> Event { create_control_message_event( time, stream_id, true, - "publish_namespace_error", - publish_namespace_error_to_json(msg), + "request_error", + request_error_to_json(request_kind, msg), ) } -/// Create a control_message_created event for PUBLISH_NAMESPACE_ERROR -pub fn publish_namespace_error_created( +/// Create a control_message_created event for REQUEST_ERROR. +pub fn request_error_created( time: f64, stream_id: u64, - msg: &message::PublishNamespaceError, + request_kind: &str, + msg: &message::RequestError, ) -> Event { create_control_message_event( time, stream_id, false, - "publish_namespace_error", - publish_namespace_error_to_json(msg), + "request_error", + request_error_to_json(request_kind, msg), ) } diff --git a/moq-transport/src/serve/subgroup.rs b/moq-transport/src/serve/subgroup.rs index daddb65d..fa2ce975 100644 --- a/moq-transport/src/serve/subgroup.rs +++ b/moq-transport/src/serve/subgroup.rs @@ -200,7 +200,7 @@ impl SubgroupsReader { state .latest_subgroup_reader .as_ref() - .map(|group| (group.group_id, group.latest())) + .and_then(|group| group.latest().map(|object_id| (group.group_id, object_id))) } /// Check if the subgroups writer has been closed or dropped. @@ -390,13 +390,9 @@ impl SubgroupReader { } } - pub fn latest(&self) -> u64 { + pub fn latest(&self) -> Option { let state = self.state.lock(); - state - .objects - .last() - .map(|o| o.object_id) - .unwrap_or_default() + state.objects.last().map(|o| o.object_id) } pub async fn read_next(&mut self) -> Result, ServeError> { diff --git a/moq-transport/src/serve/track.rs b/moq-transport/src/serve/track.rs index 3cb2edaf..3ad61392 100644 --- a/moq-transport/src/serve/track.rs +++ b/moq-transport/src/serve/track.rs @@ -22,7 +22,7 @@ use super::{ Datagrams, DatagramsReader, DatagramsWriter, ObjectsWriter, ServeError, Stream, StreamReader, StreamWriter, Subgroups, SubgroupsReader, SubgroupsWriter, }; -use crate::coding::{Location, TrackNamespace}; +use crate::coding::{Location, TrackName, TrackNamespace}; use paste::paste; use std::{ops::Deref, sync::Arc}; @@ -30,12 +30,15 @@ use std::{ops::Deref, sync::Arc}; #[derive(Debug, Clone, PartialEq)] pub struct Track { pub namespace: TrackNamespace, - pub name: String, + pub name: TrackName, } impl Track { - pub fn new(namespace: TrackNamespace, name: String) -> Self { - Self { namespace, name } + pub fn new(namespace: TrackNamespace, name: impl Into) -> Self { + Self { + namespace, + name: name.into(), + } } pub fn produce(self) -> (TrackWriter, TrackReader) { @@ -212,9 +215,10 @@ impl TrackReader { // Returns the largest group/sequence pub fn largest_location(&self) -> Option { - // We don't even know the mode yet. - // TODO populate from SUBSCRIBE_OK - None + let mode = self.state.lock().reader_mode.clone(); + mode.as_ref() + .and_then(|mode| mode.latest()) + .map(|(group_id, object_id)| Location::new(group_id, object_id)) } /// Wait until the track is closed, returning the closing error. diff --git a/moq-transport/src/serve/tracks.rs b/moq-transport/src/serve/tracks.rs index 674e2690..c8db1f81 100644 --- a/moq-transport/src/serve/tracks.rs +++ b/moq-transport/src/serve/tracks.rs @@ -17,14 +17,14 @@ use std::{collections::HashMap, ops::Deref, sync::Arc}; use super::{ServeError, Track, TrackReader, TrackWriter}; -use crate::coding::TrackNamespace; +use crate::coding::{TrackName, TrackNamespace}; use crate::watch::{Queue, State}; /// Full track identifier: namespace + track name #[derive(Hash, Eq, PartialEq, Clone, Debug)] pub struct FullTrackName { pub namespace: TrackNamespace, - pub name: String, + pub name: TrackName, } /// Static information about a broadcast. @@ -70,17 +70,18 @@ impl TracksWriter { /// Create a new track with the given name, inserting it into the broadcast. /// The track will use this writer's namespace. /// None is returned if all [TracksReader]s have been dropped. - pub fn create(&mut self, track: &str) -> Option { + pub fn create(&mut self, track: impl Into) -> Option { + let track = track.into(); let (writer, reader) = Track { namespace: self.namespace.clone(), - name: track.to_owned(), + name: track.clone(), } .produce(); // NOTE: We overwrite the track if it already exists. let full_name = FullTrackName { namespace: self.namespace.clone(), - name: track.to_owned(), + name: track, }; self.state.lock_mut()?.tracks.insert(full_name, reader); @@ -88,10 +89,14 @@ impl TracksWriter { } /// Remove a track from the broadcast by full name. - pub fn remove(&mut self, namespace: &TrackNamespace, track_name: &str) -> Option { + pub fn remove( + &mut self, + namespace: &TrackNamespace, + track_name: impl Into, + ) -> Option { let full_name = FullTrackName { namespace: namespace.clone(), - name: track_name.to_owned(), + name: track_name.into(), }; self.state.lock_mut()?.tracks.remove(&full_name) } @@ -176,12 +181,13 @@ impl TracksReader { pub fn get_track_reader( &mut self, namespace: &TrackNamespace, - track_name: &str, + track_name: impl Into, ) -> Option { + let track_name = track_name.into(); let state = self.state.lock(); let full_name = FullTrackName { namespace: namespace.clone(), - name: track_name.to_owned(), + name: track_name.clone(), }; if let Some(track_reader) = state.tracks.get(&full_name) { @@ -199,12 +205,13 @@ impl TracksReader { pub fn subscribe( &mut self, namespace: TrackNamespace, - track_name: &str, + track_name: impl Into, ) -> Option { + let track_name = track_name.into(); let state = self.state.lock(); let full_name = FullTrackName { namespace: namespace.clone(), - name: track_name.to_owned(), + name: track_name.clone(), }; // Check if we have a cached track that is still alive @@ -235,7 +242,7 @@ impl TracksReader { // Use the full requested namespace, not self.namespace let track_writer_reader = Track { namespace: namespace.clone(), - name: track_name.to_owned(), + name: track_name.clone(), } .produce(); @@ -307,7 +314,7 @@ mod tests { .await .expect("publisher should receive first track request"); - assert_eq!(track_writer_1.name, track_name); + assert_eq!(track_writer_1.name, TrackName::from(track_name)); // Publisher closes the track with an error (simulates connection failure) track_writer_1 @@ -346,7 +353,7 @@ mod tests { .unwrap() .expect("publisher should receive second track request"); - assert_eq!(track_writer_2.name, track_name); + assert_eq!(track_writer_2.name, TrackName::from(track_name)); // Verify that track_reader_2 is NOT already closed // (It should be a fresh, working track) diff --git a/moq-transport/src/session/announced.rs b/moq-transport/src/session/announced.rs deleted file mode 100644 index e87b96be..00000000 --- a/moq-transport/src/session/announced.rs +++ /dev/null @@ -1,123 +0,0 @@ -// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors -// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors -// SPDX-License-Identifier: MIT OR Apache-2.0 - -use std::ops; - -use crate::coding::{ReasonPhrase, TrackNamespace}; -use crate::watch::State; -use crate::{message, serve::ServeError}; - -use super::{AnnounceInfo, Subscriber}; - -// There's currently no feedback from the peer, so the shared state is empty. -// If Unannounce contained an error code then we'd be talking. -#[derive(Default)] -struct AnnouncedState {} - -pub struct Announced { - session: Subscriber, - state: State, - - pub info: AnnounceInfo, - - ok: bool, - error: Option, -} - -impl Announced { - pub(super) fn new( - session: Subscriber, - request_id: u64, - namespace: TrackNamespace, - ) -> (Announced, AnnouncedRecv) { - let info = AnnounceInfo { - request_id, - namespace, - }; - - let (send, recv) = State::default().split(); - let send = Self { - session, - info, - ok: false, - error: None, - state: send, - }; - let recv = AnnouncedRecv { _state: recv }; - - (send, recv) - } - - // Send an ANNOUNCE_OK - pub fn ok(&mut self) -> Result<(), ServeError> { - if self.ok { - return Err(ServeError::Duplicate); - } - - self.session.send_message(message::PublishNamespaceOk { - id: self.info.request_id, - }); - - self.ok = true; - - Ok(()) - } - - pub async fn closed(&self) -> Result<(), ServeError> { - loop { - // Wow this is dumb and yet pretty cool. - // Basically loop until the state changes and exit when Recv is dropped. - self.state - .lock() - .modified() - .ok_or(ServeError::Cancel)? - .await; - } - } - - pub fn close(mut self, err: ServeError) -> Result<(), ServeError> { - self.error = Some(err); - Ok(()) - } -} - -impl ops::Deref for Announced { - type Target = AnnounceInfo; - - fn deref(&self) -> &AnnounceInfo { - &self.info - } -} - -impl Drop for Announced { - fn drop(&mut self) { - let err = self.error.clone().unwrap_or(ServeError::Done); - - // TODO SLG - ServeError's do not align with draft-13 Announce error codes (section 8.25) - if self.ok { - self.session.send_message(message::PublishNamespaceCancel { - track_namespace: self.namespace.clone(), - error_code: err.code(), - reason_phrase: ReasonPhrase(err.to_string()), - }); - } else { - self.session.send_message(message::PublishNamespaceError { - id: self.info.request_id, - error_code: err.code(), - reason_phrase: ReasonPhrase(err.to_string()), - }); - } - } -} - -pub(super) struct AnnouncedRecv { - _state: State, -} - -impl AnnouncedRecv { - pub fn recv_unannounce(self) -> Result<(), ServeError> { - // Will cause the state to be dropped - Ok(()) - } -} diff --git a/moq-transport/src/session/error.rs b/moq-transport/src/session/error.rs index 86478ee6..fa7b5250 100644 --- a/moq-transport/src/session/error.rs +++ b/moq-transport/src/session/error.rs @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 -use crate::{coding, serve, setup}; +use crate::{coding, serve}; #[derive(thiserror::Error, Debug, Clone)] pub enum SessionError { @@ -15,10 +15,6 @@ pub enum SessionError { #[error("decode error: {0}")] Decode(#[from] coding::DecodeError), - // TODO move to a ConnectError - #[error("unsupported versions: client={0:?} server={1:?}")] - Version(setup::Versions, setup::Versions), - /// TODO SLG - eventually remove or morph into error for incorrect control message for publisher/subscriber /// The role negiotiated in the handshake was violated. For example, a publisher sent a SUBSCRIBE, or a subscriber sent an OBJECT. #[error("role violation")] @@ -43,6 +39,18 @@ pub enum SessionError { #[error("invalid connection path: {0}")] InvalidPath(String), + + /// Draft-16 §3.4 INVALID_REQUEST_ID (0x4): peer used an invalid request ID. + #[error("invalid request ID")] + InvalidRequestId, + + /// Draft-16 §3.4 TOO_MANY_REQUESTS (0x7): request ID meets or exceeds the maximum. + #[error("too many requests")] + TooManyRequests, + + /// Draft-16 §3.4 PROTOCOL_VIOLATION (0x3): peer violated a MUST rule. + #[error("protocol violation: {0}")] + ProtocolViolation(String), } // Session Termination Error Codes from draft-ietf-moq-transport-14 Section 13.1.1 @@ -58,14 +66,18 @@ impl SessionError { Self::Encode(_) => 0x1, Self::BoundsExceeded(_) => 0x1, Self::Internal => 0x1, - // VERSION_NEGOTIATION_FAILED (0x15) - Self::Version(..) => 0x15, // PROTOCOL_VIOLATION (0x3) - Malformed messages Self::Decode(_) => 0x3, Self::WrongSize => 0x3, Self::InvalidPath(_) => 0x3, // DUPLICATE_TRACK_ALIAS (0x5) Self::Duplicate => 0x5, + // INVALID_REQUEST_ID (0x4) + Self::InvalidRequestId => 0x4, + // TOO_MANY_REQUESTS (0x7) + Self::TooManyRequests => 0x7, + // PROTOCOL_VIOLATION (0x3) + Self::ProtocolViolation(_) => 0x3, // Delegate to ServeError for per-request error codes Self::Serve(err) => err.code(), } diff --git a/moq-transport/src/session/mod.rs b/moq-transport/src/session/mod.rs index a2313636..02a8acd1 100644 --- a/moq-transport/src/session/mod.rs +++ b/moq-transport/src/session/mod.rs @@ -2,21 +2,23 @@ // SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 -mod announce; -mod announced; mod error; +mod publish_namespace; +mod published_namespace; mod publisher; mod reader; +mod request_id; mod subscribe; mod subscribed; mod subscriber; mod track_status_requested; mod writer; -pub use announce::*; -pub use announced::*; pub use error::*; +pub use publish_namespace::*; +pub use published_namespace::*; pub use publisher::*; +pub use request_id::{RequestId, RequestIdAllocation}; pub use subscribe::*; pub use subscribed::*; pub use subscriber::*; @@ -26,7 +28,8 @@ use reader::*; use writer::*; use futures::{stream::FuturesUnordered, StreamExt}; -use std::sync::{atomic, Arc, Mutex}; +use request_id::max_request_id_from_params; +use std::sync::{Arc, Mutex}; use crate::coding::{KeyValuePairs, Value}; use crate::message::Message; @@ -50,7 +53,7 @@ pub enum Transport { /// ALPN: "h3". Path carried in HTTP/3 CONNECT :path pseudo-header. WebTransport, /// Raw QUIC with MoQT framing directly on QUIC streams. - /// ALPN: "moq-00". Path carried in CLIENT_SETUP PATH parameter. + /// ALPN: "moqt-16". Path carried in CLIENT_SETUP PATH parameter. RawQuic, } @@ -69,6 +72,10 @@ pub struct Session { /// Queue used by Publisher and Subscriber for sending Control Messages outgoing: Queue, + /// Session-level request ID manager. + /// Publisher and Subscriber share one outbound request ID sequence. + request_id: RequestId, + /// Optional mlog writer for MoQ Transport events /// Wrapped in Arc> to share across send/recv tasks when enabled mlog: Option>>, @@ -183,14 +190,6 @@ impl Session { self.connection_path.as_deref() } - // Helper for determining the largest supported version - fn largest_common(a: &[T], b: &[T]) -> Option { - a.iter() - .filter(|x| b.contains(x)) // keep only items also in b - .cloned() // clone because we return T, not &T - .max() // take the largest - } - /// Log a control message with structured fields for observability. /// Uses target "moq_transport::control" so it can be filtered independently. fn log_control_message(msg: &Message, direction: &str) { @@ -203,7 +202,6 @@ impl Session { subscribe_id = m.id, namespace = %m.track_namespace, track_name = %m.track_name, - filter_type = ?m.filter_type, "MoQT control message" ); } @@ -214,28 +212,6 @@ impl Session { msg_type = "SUBSCRIBE_OK", subscribe_id = m.id, track_alias = m.track_alias, - content_exists = m.content_exists, - "MoQT control message" - ); - } - Message::SubscribeError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE_ERROR", - subscribe_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } - Message::SubscribeUpdate(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE_UPDATE", - request_id = m.id, - subscription_request_id = m.subscription_request_id, "MoQT control message" ); } @@ -258,32 +234,30 @@ impl Session { "MoQT control message" ); } - Message::PublishNamespaceOk(m) => { + Message::PublishNamespaceDone(m) => { tracing::debug!( target: "moq_transport::control", direction, - msg_type = "PUBLISH_NAMESPACE_OK", + msg_type = "PUBLISH_NAMESPACE_DONE", request_id = m.id, "MoQT control message" ); } - Message::PublishNamespaceError(m) => { + Message::Namespace(m) => { tracing::debug!( target: "moq_transport::control", direction, - msg_type = "PUBLISH_NAMESPACE_ERROR", - request_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, + msg_type = "NAMESPACE", + namespace_suffix = %m.track_namespace_suffix, "MoQT control message" ); } - Message::PublishNamespaceDone(m) => { + Message::NamespaceDone(m) => { tracing::debug!( target: "moq_transport::control", direction, - msg_type = "PUBLISH_NAMESPACE_DONE", - namespace = %m.track_namespace, + msg_type = "NAMESPACE_DONE", + namespace_suffix = %m.track_namespace_suffix, "MoQT control message" ); } @@ -292,7 +266,7 @@ impl Session { target: "moq_transport::control", direction, msg_type = "PUBLISH_NAMESPACE_CANCEL", - namespace = %m.track_namespace, + request_id = m.id, error_code = m.error_code, reason = %m.reason_phrase.0, "MoQT control message" @@ -309,28 +283,6 @@ impl Session { "MoQT control message" ); } - Message::TrackStatusOk(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "TRACK_STATUS_OK", - request_id = m.id, - track_alias = m.track_alias, - content_exists = m.content_exists, - "MoQT control message" - ); - } - Message::TrackStatusError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "TRACK_STATUS_ERROR", - request_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } Message::SubscribeNamespace(m) => { tracing::debug!( target: "moq_transport::control", @@ -341,35 +293,6 @@ impl Session { "MoQT control message" ); } - Message::SubscribeNamespaceOk(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE_NAMESPACE_OK", - request_id = m.id, - "MoQT control message" - ); - } - Message::SubscribeNamespaceError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "SUBSCRIBE_NAMESPACE_ERROR", - request_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } - Message::UnsubscribeNamespace(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "UNSUBSCRIBE_NAMESPACE", - namespace_prefix = %m.track_namespace_prefix, - "MoQT control message" - ); - } Message::Fetch(m) => { tracing::debug!( target: "moq_transport::control", @@ -390,17 +313,6 @@ impl Session { "MoQT control message" ); } - Message::FetchError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "FETCH_ERROR", - request_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, - "MoQT control message" - ); - } Message::FetchCancel(m) => { tracing::debug!( target: "moq_transport::control", @@ -428,18 +340,6 @@ impl Session { direction, msg_type = "PUBLISH_OK", request_id = m.id, - filter_type = ?m.filter_type, - "MoQT control message" - ); - } - Message::PublishError(m) => { - tracing::debug!( - target: "moq_transport::control", - direction, - msg_type = "PUBLISH_ERROR", - request_id = m.id, - error_code = m.error_code, - reason = %m.reason_phrase.0, "MoQT control message" ); } @@ -481,6 +381,36 @@ impl Session { "MoQT control message" ); } + Message::RequestOk(m) => { + tracing::debug!( + target: "moq_transport::control", + direction, + msg_type = "REQUEST_OK", + request_id = m.id, + "MoQT control message" + ); + } + Message::RequestError(m) => { + tracing::debug!( + target: "moq_transport::control", + direction, + msg_type = "REQUEST_ERROR", + request_id = m.id, + error_code = m.error_code, + retry_interval = m.retry_interval, + "MoQT control message" + ); + } + Message::RequestUpdate(m) => { + tracing::debug!( + target: "moq_transport::control", + direction, + msg_type = "REQUEST_UPDATE", + request_id = m.id, + existing_request_id = m.existing_request_id, + "MoQT control message" + ); + } } } @@ -488,12 +418,11 @@ impl Session { webtransport: web_transport::Session, sender: Writer, recver: Reader, - first_requestid: u64, mlog: Option, transport: Transport, connection_path: Option, + request_id: RequestId, ) -> (Self, Option, Option) { - let next_requestid = Arc::new(atomic::AtomicU64::new(first_requestid)); let outgoing = Queue::default().split(); // Wrap mlog in Arc> for sharing across tasks @@ -502,13 +431,13 @@ impl Session { let publisher = Some(Publisher::new( outgoing.0.clone(), webtransport.clone(), - next_requestid.clone(), mlog_shared.clone(), + request_id.clone(), )); let subscriber = Some(Subscriber::new( outgoing.0, - next_requestid, mlog_shared.clone(), + request_id.clone(), )); let session = Self { @@ -518,6 +447,7 @@ impl Session { publisher: publisher.clone(), subscriber: subscriber.clone(), outgoing: outgoing.1, + request_id, mlog: mlog_shared, transport, connection_path, @@ -526,93 +456,113 @@ impl Session { (session, publisher, subscriber) } - /// Create an outbound/client QUIC connection, by opening a bi-directional QUIC stream for - /// MOQT control messaging. Performs SETUP messaging and version negotiation. + /// Create an outbound/client QUIC connection. /// - /// If the session URL contains a non-trivial path (not empty or "/"), the PATH - /// parameter (key 0x1) is automatically sent in CLIENT_SETUP. This propagates - /// the connection path (App ID / MoQT scope) to the remote peer, which is needed - /// for relay-to-relay connections. To connect without sending PATH, use a URL - /// with no path component. + /// Opens the bidirectional control stream, sends CLIENT_SETUP with + /// parameters only (version is agreed via ALPN), and waits for SERVER_SETUP. + /// + /// For native `moqt://` connections the PATH and AUTHORITY parameters are + /// sent automatically. For WebTransport the path is carried in the HTTP/3 + /// CONNECT URL so PATH is not sent. pub async fn connect( session: web_transport::Session, mlog_path: Option, transport: Transport, ) -> Result<(Session, Publisher, Subscriber), SessionError> { - // Auto-extract path from the session URL. - // This aligns with the unified moqt:// URI scheme direction (IETF PR #1486) - // where the path is always part of the URI regardless of transport. - let url_path = session.url().path(); + let url = session.url().clone(); + let url_path = url.path(); let path = Self::normalize_connection_path(url_path)?; - let mlog = mlog_path.and_then(|path| { - mlog::MlogWriter::new(path) + + let mlog = mlog_path.and_then(|p| { + mlog::MlogWriter::new(p) .map_err(|e| tracing::warn!("Failed to create mlog: {}", e)) .ok() }); + let control = session.open_bi().await?; let mut sender = Writer::new(control.0); let mut recver = Reader::new(control.1); - let versions: setup::Versions = [setup::Version::DRAFT_14].into(); - - // TODO SLG - make configurable? let mut params = KeyValuePairs::default(); - params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); - // Only send PATH in CLIENT_SETUP for raw QUIC connections. - // For WebTransport, the path is already carried in the HTTP/3 CONNECT URL. - if let Some(ref path) = path { - if transport == Transport::RawQuic { - params.set_bytesvalue(setup::ParameterType::Path.into(), path.as_bytes().to_vec()); + if transport == Transport::RawQuic { + // Draft-16 §9.3.1.1: send AUTHORITY for native QUIC. + if let Some(host) = url.host_str() { + let authority = if let Some(port) = url.port() { + format!("{}:{}", host, port) + } else { + host.to_string() + }; + params.set_bytesvalue( + setup::ParameterType::Authority.into(), + authority.into_bytes(), + ); + } + + // Draft-16 §9.3.1.2: send PATH (path + optional query) for native QUIC. + let path_and_query = match url.query() { + Some(q) => format!("{}?{}", url_path, q), + None => url_path.to_string(), + }; + if !path_and_query.is_empty() && path_and_query != "/" { + params.set_bytesvalue( + setup::ParameterType::Path.into(), + path_and_query.into_bytes(), + ); } } - let client = setup::Client { - versions: versions.clone(), - params, - }; + // The MAX_REQUEST_ID we advertise to the server. + // TODO(itzmanish): make configurable. + let our_max_request_id: u64 = 100; + params.set_intvalue( + setup::ParameterType::MaxRequestId.into(), + our_max_request_id, + ); + + let client = setup::Client { params }; tracing::debug!( target: "moq_transport::control", direction = "sent", msg_type = "CLIENT_SETUP", - versions = ?client.versions, ?transport, path = path.as_deref(), "MoQT control message" ); sender.encode(&client).await?; - // TODO: emit client_setup_created event when we add that - let server: setup::Server = recver.decode().await?; tracing::debug!( target: "moq_transport::control", direction = "recv", msg_type = "SERVER_SETUP", - version = ?server.version, "MoQT control message" ); - // TODO: emit server_setup_parsed event - - // We are the client, so the first request id is 0 - let session = Session::new(session, sender, recver, 0, mlog, transport, path); + let peer_max = max_request_id_from_params(&server.params); + // Client sends even IDs (0); peer server sends odd IDs (1). + let request_id = RequestId::new(0, peer_max, our_max_request_id, 1); + let session = Session::new(session, sender, recver, mlog, transport, path, request_id); Ok((session.0, session.1.unwrap(), session.2.unwrap())) } - /// Accepts an inbound/server QUIC connection, by accepting a bi-directional QUIC stream for - /// MOQT control messaging. Performs SETUP messaging and version negotiation. + /// Accept an inbound server connection. + /// + /// Waits for the bidirectional control stream, decodes CLIENT_SETUP, + /// sends SERVER_SETUP with parameters only. Version is already agreed + /// via ALPN before this is called. pub async fn accept( session: web_transport::Session, mlog_path: Option, transport: Transport, ) -> Result<(Session, Option, Option), SessionError> { - let mut mlog = mlog_path.and_then(|path| { - mlog::MlogWriter::new(path) + let mut mlog = mlog_path.and_then(|p| { + mlog::MlogWriter::new(p) .map_err(|e| tracing::warn!("Failed to create mlog: {}", e)) .ok() }); + let control = session.accept_bi().await?; let mut sender = Writer::new(control.0); let mut recver = Reader::new(control.1); @@ -622,27 +572,20 @@ impl Session { target: "moq_transport::control", direction = "recv", msg_type = "CLIENT_SETUP", - versions = ?client.versions, "MoQT control message" ); - // Extract WebTransport URL path from the underlying session. - // For WebTransport connections, this comes from the HTTP/3 CONNECT :path. - // For raw QUIC, this is the placeholder URL ("moqt://localhost") and has no meaningful path. + // For WebTransport the path arrives in the HTTP/3 CONNECT :path. + // For raw QUIC the PATH setup parameter carries it instead. let wt_url_path = session.url().path(); let wt_path = Self::normalize_connection_path(wt_url_path)?; - // Extract CLIENT_SETUP PATH parameter (key 0x1, BytesValue). - // Used for raw QUIC connections where there's no HTTP CONNECT URL. let client_setup_path = if wt_path.is_none() { Self::decode_client_setup_path(&client.params)? } else { None }; - // Combine: WebTransport URL path takes precedence over CLIENT_SETUP PATH. - // WebTransport connections always have the path in the CONNECT URL. - // Raw QUIC connections only have CLIENT_SETUP PATH. let connection_path = wt_path.or(client_setup_path); if connection_path.is_some() { @@ -652,55 +595,49 @@ impl Session { ); } - // Emit mlog event for CLIENT_SETUP parsed if let Some(ref mut mlog) = mlog { let event = mlog::events::client_setup_parsed(mlog.elapsed_ms(), 0, &client); let _ = mlog.add_event(event); } - let server_versions = setup::Versions(vec![setup::Version::DRAFT_14]); + let peer_max = max_request_id_from_params(&client.params); - if let Some(largest_common_version) = - Self::largest_common(&server_versions, &client.versions) - { - // TODO SLG - make configurable? - let mut params = KeyValuePairs::default(); - params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); + // The MAX_REQUEST_ID we advertise to the client. + // TODO(itzmanish): make configurable. + let our_max_request_id: u64 = 100; + let mut params = KeyValuePairs::default(); + params.set_intvalue( + setup::ParameterType::MaxRequestId.into(), + our_max_request_id, + ); - let server = setup::Server { - version: largest_common_version, - params, - }; + let server = setup::Server { params }; - tracing::debug!( - target: "moq_transport::control", - direction = "sent", - msg_type = "SERVER_SETUP", - version = ?server.version, - "MoQT control message" - ); + tracing::debug!( + target: "moq_transport::control", + direction = "sent", + msg_type = "SERVER_SETUP", + "MoQT control message" + ); - // Emit mlog event for SERVER_SETUP created - if let Some(ref mut mlog) = mlog { - let event = mlog::events::server_setup_created(mlog.elapsed_ms(), 0, &server); - let _ = mlog.add_event(event); - } + if let Some(ref mut mlog) = mlog { + let event = mlog::events::server_setup_created(mlog.elapsed_ms(), 0, &server); + let _ = mlog.add_event(event); + } - sender.encode(&server).await?; + sender.encode(&server).await?; - // We are the server, so the first request id is 1 - Ok(Session::new( - session, - sender, - recver, - 1, - mlog, - transport, - connection_path, - )) - } else { - Err(SessionError::Version(client.versions, server_versions)) - } + // Server sends odd IDs (1); peer client sends even IDs (0). + let request_id = RequestId::new(1, peer_max, our_max_request_id, 0); + Ok(Session::new( + session, + sender, + recver, + mlog, + transport, + connection_path, + request_id, + )) } /// Run Tasks for the session, including sending of control messages, receiving and processing @@ -708,7 +645,7 @@ impl Session { /// and receiving and processing QUIC datagrams received pub async fn run(self) -> Result<(), SessionError> { tokio::select! { - res = Self::run_recv(self.recver, self.publisher, self.subscriber.clone(), self.mlog.clone()) => res, + res = Self::run_recv(self.recver, self.publisher, self.subscriber.clone(), self.mlog.clone(), self.request_id.clone(), self.outgoing.clone()) => res, res = Self::run_send(self.sender, self.outgoing, self.mlog.clone()) => res, res = Self::run_streams(self.webtransport.clone(), self.subscriber.clone()) => res, res = Self::run_datagrams(self.webtransport, self.subscriber) => res, @@ -739,21 +676,12 @@ impl Session { Message::SubscribeOk(m) => { Some(mlog::events::subscribe_ok_created(time, stream_id, m)) } - Message::SubscribeError(m) => { - Some(mlog::events::subscribe_error_created(time, stream_id, m)) - } Message::Unsubscribe(m) => { Some(mlog::events::unsubscribe_created(time, stream_id, m)) } Message::PublishNamespace(m) => { Some(mlog::events::publish_namespace_created(time, stream_id, m)) } - Message::PublishNamespaceOk(m) => Some( - mlog::events::publish_namespace_ok_created(time, stream_id, m), - ), - Message::PublishNamespaceError(m) => Some( - mlog::events::publish_namespace_error_created(time, stream_id, m), - ), Message::GoAway(m) => { Some(mlog::events::go_away_created(time, stream_id, m)) } @@ -775,14 +703,19 @@ impl Session { /// Receives inbound messages from the control stream reader/receiver. Analyzes if the message /// is to be handled by Subscriber or Publisher logic and calls recv_message on either the /// Publisher or Subscriber. - /// Note: Should also be handling messages common to both roles, ie: GOAWAY, MAX_REQUEST_ID and - /// REQUESTS_BLOCKED + /// Receives and dispatches control messages. + /// Handles session-level messages (GOAWAY, MAX_REQUEST_ID, REQUESTS_BLOCKED) + /// directly and routes role-specific messages to Publisher or Subscriber. async fn run_recv( mut recver: Reader, mut publisher: Option, mut subscriber: Option, mlog: Option>>, + request_id: RequestId, + _outgoing: Queue, ) -> Result<(), SessionError> { + let mut goaway_received = false; + loop { let msg: message::Message = recver.decode().await?; @@ -803,21 +736,12 @@ impl Session { Message::SubscribeOk(m) => { Some(mlog::events::subscribe_ok_parsed(time, stream_id, m)) } - Message::SubscribeError(m) => { - Some(mlog::events::subscribe_error_parsed(time, stream_id, m)) - } Message::Unsubscribe(m) => { Some(mlog::events::unsubscribe_parsed(time, stream_id, m)) } Message::PublishNamespace(m) => { Some(mlog::events::publish_namespace_parsed(time, stream_id, m)) } - Message::PublishNamespaceOk(m) => Some( - mlog::events::publish_namespace_ok_parsed(time, stream_id, m), - ), - Message::PublishNamespaceError(m) => Some( - mlog::events::publish_namespace_error_parsed(time, stream_id, m), - ), Message::GoAway(m) => { Some(mlog::events::go_away_parsed(time, stream_id, m)) } @@ -830,6 +754,10 @@ impl Session { } } + if let Some(id) = msg.sequenced_request_id() { + request_id.validate_incoming(id)?; + } + let msg = match TryInto::::try_into(msg) { Ok(msg) => { subscriber @@ -852,12 +780,48 @@ impl Session { Err(msg) => msg, }; - // TODO GOAWAY, MAX_REQUEST_ID, REQUESTS_BLOCKED - tracing::warn!("Unimplemented message type received: {:?}", msg); - return Err(SessionError::unimplemented(&format!( - "message type {:?}", - msg - ))); + // Session-level messages handled here (not role-specific). + match msg { + Message::GoAway(ref m) => { + // Draft-16 §9.4: receiving a second GOAWAY is PROTOCOL_VIOLATION. + if goaway_received { + return Err(SessionError::ProtocolViolation( + "received multiple GOAWAY messages".to_string(), + )); + } + goaway_received = true; + tracing::info!( + target: "moq_transport::control", + new_uri = %m.uri.0, + "received GOAWAY" + ); + // TODO(itzmanish): trigger session migration. + } + Message::MaxRequestId(ref m) => { + request_id.apply_max_request_id(m)?; + tracing::debug!( + target: "moq_transport::control", + max_request_id = m.request_id, + "received MAX_REQUEST_ID" + ); + } + Message::RequestsBlocked(ref m) => { + tracing::debug!( + target: "moq_transport::control", + max_request_id = m.max_request_id, + "received REQUESTS_BLOCKED" + ); + // REQUESTS_BLOCKED tells us the peer's send budget is exhausted. + request_id.handle_requests_blocked(m)?; + } + other => { + tracing::warn!(msg_type = other.name(), "received unhandled message type"); + return Err(SessionError::unimplemented(&format!( + "message type {}", + other.name() + ))); + } + } } } diff --git a/moq-transport/src/session/announce.rs b/moq-transport/src/session/publish_namespace.rs similarity index 75% rename from moq-transport/src/session/announce.rs rename to moq-transport/src/session/publish_namespace.rs index 8e2a0a90..1b92921f 100644 --- a/moq-transport/src/session/announce.rs +++ b/moq-transport/src/session/publish_namespace.rs @@ -10,20 +10,21 @@ use crate::{message, serve::ServeError}; use super::{Publisher, Subscribed, TrackStatusRequested}; +/// Information about an outbound PUBLISH_NAMESPACE request. #[derive(Debug, Clone)] -pub struct AnnounceInfo { +pub struct PublishNamespaceInfo { pub request_id: u64, pub namespace: TrackNamespace, } -struct AnnounceState { +struct PublishNamespaceState { subscribers: VecDeque, track_statuses_requested: VecDeque, ok: bool, closed: Result<(), ServeError>, } -impl Default for AnnounceState { +impl Default for PublishNamespaceState { fn default() -> Self { Self { subscribers: Default::default(), @@ -34,33 +35,36 @@ impl Default for AnnounceState { } } -impl Drop for AnnounceState { +impl Drop for PublishNamespaceState { fn drop(&mut self) { for subscriber in self.subscribers.drain(..) { subscriber .close(ServeError::not_found_ctx( - "announce dropped before subscription handled", + "publish_namespace dropped before subscription handled", )) .ok(); } } } -#[must_use = "unannounce on drop"] -pub struct Announce { +/// Represents an outbound PUBLISH_NAMESPACE sent by a publisher. +/// +/// Dropped with PUBLISH_NAMESPACE_DONE unless already closed with an error. +#[must_use = "send PUBLISH_NAMESPACE_DONE on drop"] +pub struct PublishNamespace { publisher: Publisher, - state: State, + state: State, - pub info: AnnounceInfo, + pub info: PublishNamespaceInfo, } -impl Announce { +impl PublishNamespace { pub(super) fn new( mut publisher: Publisher, request_id: u64, namespace: TrackNamespace, - ) -> (Announce, AnnounceRecv) { - let info = AnnounceInfo { + ) -> (PublishNamespace, PublishNamespaceRecv) { + let info = PublishNamespaceInfo { request_id, namespace: namespace.clone(), }; @@ -78,7 +82,7 @@ impl Announce { info, state: send, }; - let recv = AnnounceRecv { + let recv = PublishNamespaceRecv { state: recv, request_id, }; @@ -86,7 +90,7 @@ impl Announce { (send, recv) } - // Run until we get an error + /// Wait until the namespace publish is closed (error or peer disconnect). pub async fn closed(&self) -> Result<(), ServeError> { loop { { @@ -102,7 +106,7 @@ impl Announce { } } - /// Wait until a subscriber is received + /// Wait until a subscriber arrives for this namespace. pub async fn subscribed(&self) -> Result, ServeError> { loop { { @@ -123,6 +127,7 @@ impl Announce { } } + /// Wait until a TRACK_STATUS request arrives for this namespace. pub async fn track_status_requested(&self) -> Result, ServeError> { loop { { @@ -143,7 +148,7 @@ impl Announce { } } - // Wait until an OK is received + /// Wait until the peer has sent REQUEST_OK for this namespace. pub async fn ok(&self) -> Result<(), ServeError> { loop { { @@ -163,32 +168,38 @@ impl Announce { } } -impl Drop for Announce { +impl Drop for PublishNamespace { fn drop(&mut self) { if self.state.lock().closed.is_err() { return; } + // Draft-16 §9.22: PUBLISH_NAMESPACE_DONE carries the Request ID, + // not the namespace. self.publisher.send_message(message::PublishNamespaceDone { - track_namespace: self.namespace.clone(), + id: self.info.request_id, }); } } -impl ops::Deref for Announce { - type Target = AnnounceInfo; +impl ops::Deref for PublishNamespace { + type Target = PublishNamespaceInfo; fn deref(&self) -> &Self::Target { &self.info } } -pub(super) struct AnnounceRecv { - state: State, - pub request_id: u64, // TODO SLG - Announcements need to be looked up by both request_id and namespace, consider 2 hashmaps in publisher instead of this +/// Peer-facing handle for tracking a PUBLISH_NAMESPACE request. +pub(super) struct PublishNamespaceRecv { + state: State, + /// Request ID of the outbound PUBLISH_NAMESPACE. + // Namespace lookup alone is insufficient: both request_id and namespace + // are needed, so Publisher holds a second index by request_id. + pub request_id: u64, } -impl AnnounceRecv { +impl PublishNamespaceRecv { pub fn recv_ok(&mut self) -> Result<(), ServeError> { if let Some(mut state) = self.state.lock_mut() { if state.ok { diff --git a/moq-transport/src/session/published_namespace.rs b/moq-transport/src/session/published_namespace.rs new file mode 100644 index 00000000..d836885e --- /dev/null +++ b/moq-transport/src/session/published_namespace.rs @@ -0,0 +1,177 @@ +// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc., Luke Curley, Mike English and contributors +// SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors +// SPDX-License-Identifier: MIT OR Apache-2.0 + +use std::ops; + +use crate::coding::{ReasonPhrase, TrackNamespace}; +use crate::message::RequestErrorCode; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::{PublishNamespaceInfo, Subscriber}; + +// Tracks whether the publisher has cleanly completed this namespace publish. +#[derive(Default)] +struct PublishedNamespaceState { + done: bool, +} + +/// Represents an inbound PUBLISH_NAMESPACE received by a subscriber. +/// +/// On drop, revokes an accepted namespace with PUBLISH_NAMESPACE_CANCEL, or +/// rejects an unaccepted namespace with REQUEST_ERROR. +pub struct PublishedNamespace { + session: Subscriber, + state: State, + + pub info: PublishNamespaceInfo, + + ok: bool, + error: Option, +} + +impl PublishedNamespace { + pub(super) fn new( + session: Subscriber, + request_id: u64, + namespace: TrackNamespace, + ) -> (PublishedNamespace, PublishedNamespaceRecv) { + let info = PublishNamespaceInfo { + request_id, + namespace, + }; + + let (send, recv) = State::default().split(); + let send = Self { + session, + info, + ok: false, + error: None, + state: send, + }; + let recv = PublishedNamespaceRecv { + state: recv, + request_id, + }; + + (send, recv) + } + + /// Accept the PUBLISH_NAMESPACE by sending REQUEST_OK (draft-16 §9.7). + pub fn ok(&mut self) -> Result<(), ServeError> { + if self.ok { + return Err(ServeError::Duplicate); + } + + // Draft-16 §6.2: acceptance is signalled with REQUEST_OK, not the + // legacy PUBLISH_NAMESPACE_OK. + self.session.send_request_ok( + "publish_namespace", + message::RequestOk { + id: self.info.request_id, + params: Default::default(), + }, + ); + + self.ok = true; + + Ok(()) + } + + /// Wait until the peer closes the namespace publish (PUBLISH_NAMESPACE_DONE). + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + let Some(modified) = self.state.lock().modified() else { + return Ok(()); + }; + + modified.await; + } + } + + /// Reject the PUBLISH_NAMESPACE; the error is sent on drop. + pub fn close(mut self, err: ServeError) -> Result<(), ServeError> { + self.error = Some(err); + Ok(()) + } +} + +impl ops::Deref for PublishedNamespace { + type Target = PublishNamespaceInfo; + + fn deref(&self) -> &PublishNamespaceInfo { + &self.info + } +} + +impl Drop for PublishedNamespace { + fn drop(&mut self) { + let err = self.error.clone().unwrap_or(ServeError::Done); + + if self.state.lock().done { + return; + } + + if self.ok { + // Accepted: send PUBLISH_NAMESPACE_CANCEL to revoke acceptance + // (draft-16 §9.24). Carries Request ID, not the namespace. + self.session.send_message(message::PublishNamespaceCancel { + id: self.info.request_id, + error_code: err.code(), + reason_phrase: ReasonPhrase(err.to_string()), + }); + } else { + // Never accepted: send REQUEST_ERROR (draft-16 §9.8). + self.session.send_request_error( + "publish_namespace", + message::RequestError { + id: self.info.request_id, + error_code: RequestErrorCode::Uninterested as u64, + retry_interval: 0, + reason: ReasonPhrase(err.to_string()), + }, + ); + } + } +} + +pub(super) struct PublishedNamespaceRecv { + state: State, + /// Request ID of the corresponding PUBLISH_NAMESPACE, used for O(1) lookup + /// when PUBLISH_NAMESPACE_DONE or PUBLISH_NAMESPACE_CANCEL arrives. + pub request_id: u64, +} + +impl PublishedNamespaceRecv { + pub fn recv_done(self) -> Result<(), ServeError> { + if let Some(mut state) = self.state.lock_mut() { + state.done = true; + } + + // Dropping the state signals the PublishedNamespace that the peer is done. + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn recv_done_marks_namespace_done_before_drop() { + let state = State::::default(); + let (send_state, recv_state) = state.split(); + let recv = PublishedNamespaceRecv { + state: recv_state, + request_id: 0, + }; + + assert!(!send_state.lock().done); + + recv.recv_done().unwrap(); + + assert!(send_state.lock().done); + assert!(send_state.lock().modified().is_none()); + } +} diff --git a/moq-transport/src/session/publisher.rs b/moq-transport/src/session/publisher.rs index 8b8c5085..c643541e 100644 --- a/moq-transport/src/session/publisher.rs +++ b/moq-transport/src/session/publisher.rs @@ -4,7 +4,7 @@ use std::{ collections::{hash_map, HashMap}, - sync::{atomic, Arc, Mutex}, + sync::{Arc, Mutex}, }; use futures::{stream::FuturesUnordered, StreamExt}; @@ -13,45 +13,48 @@ use crate::{ coding::TrackNamespace, message::{self, Message}, mlog, - serve::{ServeError, TracksReader}, + serve::{FullTrackName, ServeError, TracksReader}, }; use crate::watch::Queue; use super::{ - Announce, AnnounceRecv, Session, SessionError, Subscribed, SubscribedRecv, TrackStatusRequested, + PublishNamespace, PublishNamespaceRecv, RequestId, RequestIdAllocation, Session, SessionError, + Subscribed, SubscribedRecv, TrackStatusRequested, }; +use crate::message::RequestErrorCode; // TODO remove Clone. #[derive(Clone)] pub struct Publisher { webtransport: web_transport::Session, - /// When the announce method is used, a new entry is added to this HashMap to track outbound announcement - announces: Arc>>, + /// Active outbound PUBLISH_NAMESPACE requests, keyed by namespace. + publish_namespaces: Arc>>, - /// When a Subscribe is received and we have a previous announce for the namespace, then a new entry is - /// added to this HashMap to track the inbound subscription + /// When a Subscribe is received and we have a matching publish_namespace entry, the + /// subscription is routed to that PublishNamespaceRecv. Otherwise it goes here. subscribeds: Arc>>, - /// When a Subscribe is received and we DO NOT have a previous announce for the namespace, then a new entry is - /// added to this Queue to track the inbound subscription + /// Active inbound SUBSCRIBEs keyed by Full Track Name. + subscribed_names: Arc>>, + + /// Subscriptions for namespaces that have no matching PUBLISH_NAMESPACE. unknown_subscribed: Queue, - /// When a TrackStatus is received and we DO NOT have a previous announce for the namespace, then a new entry is - /// added to this Queue to track the inbound track status request + /// TRACK_STATUS requests for namespaces that have no matching PUBLISH_NAMESPACE. unknown_track_status_requested: Queue, - /// The queue we will write any outbound control messages we want to sent, the session run_send task - /// will process the queue and send the message on the control stream. + /// Queue for outbound control messages; processed by the session run_send task. outgoing: Queue, - /// When we need a new Request Id for sending a request, we can get it from here. Note: The instance - /// of AtomicU64 is shared with the Subscriber, so the session uses unique request ids for all requests - /// generated. Note: If we initiated the QUIC connection then request id's start at 0 and increment by 2 - /// for each request (even numbers). If we accepted an inbound QUIC connection then request id's start at 1 and - /// increment by 2 for each request (odd numbers). - next_requestid: Arc, + /// Shared with Subscriber so all requests within a session use unique IDs. + /// When we need a new Request Id for sending a request, we can get it from here. + /// The manager is shared with the Subscriber, so the session uses unique request ids + /// for all requests generated. If we initiated the QUIC connection then request + /// IDs start at 0 and increment by 2 (even numbers). If we accepted an inbound + /// QUIC connection then request IDs start at 1 and increment by 2 (odd numbers). + request_id: RequestId, /// Optional mlog writer for logging transport events mlog: Option>>, @@ -61,17 +64,18 @@ impl Publisher { pub(crate) fn new( outgoing: Queue, webtransport: web_transport::Session, - next_requestid: Arc, mlog: Option>>, + request_id: RequestId, ) -> Self { Self { webtransport, - announces: Default::default(), + publish_namespaces: Default::default(), subscribeds: Default::default(), + subscribed_names: Default::default(), unknown_subscribed: Default::default(), unknown_track_status_requested: Default::default(), outgoing, - next_requestid, + request_id, mlog, } } @@ -92,26 +96,31 @@ impl Publisher { Ok((session, publisher)) } - /// Announce a namespace and serve tracks using the provided [serve::TracksReader]. - /// The caller uses [serve::TracksWriter] for static tracks and [serve::TracksRequest] for dynamic tracks. - pub async fn announce(&mut self, tracks: TracksReader) -> Result<(), SessionError> { - // Check if annouce for this namespace already exists or not, and if not, then create a new Announce - let announce = match self - .announces + /// Send a PUBLISH_NAMESPACE for a namespace and serve tracks using the provided + /// [serve::TracksReader]. Blocks until the namespace is unannounced or an error occurs. + pub async fn publish_namespace(&mut self, tracks: TracksReader) -> Result<(), SessionError> { + let publish_ns = match self + .publish_namespaces .lock() - .unwrap() + .map_err(|_| SessionError::Internal)? .entry(tracks.namespace.clone()) { - // Namespace already exists in HashMap (has already been announced) - return Duplicate error + // Duplicate PUBLISH_NAMESPACE for the same namespace is a protocol error. hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), - // This is a new announce, send announce message to peer. hash_map::Entry::Vacant(entry) => { - // Get the current next request id to use and increment the value for by 2 for the next request - let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); - + // Allocate a request ID, enforcing the peer-advertised maximum. + let request_id = match self.request_id.allocate()? { + RequestIdAllocation::Allocated(id) => id, + blocked @ RequestIdAllocation::Blocked { .. } => { + if let Some(msg) = blocked.requests_blocked() { + let _ = self.outgoing.push(msg.into()); + } + return Err(SessionError::TooManyRequests); + } + }; let (send, recv) = - Announce::new(self.clone(), request_id, tracks.namespace.clone()); + PublishNamespace::new(self.clone(), request_id, tracks.namespace.clone()); entry.insert(recv); send } @@ -122,49 +131,47 @@ impl Publisher { let mut subscribe_done = false; let mut status_done = false; - // The code enters an infinite loop and waits for one of several events: - // - A new subscription arrives. - // - A new track status request arrives. - // - One of the spawned subscription-handling tasks completes. - // - One of the spawned status-handling tasks completes. - // Exit the loop when all input streams are done (None), and all tasks have completed loop { tokio::select! { - // Get next subscription to this announce - res = announce.subscribed(), if !subscribe_done => { + res = publish_ns.subscribed(), if !subscribe_done => { match res? { Some(subscribed) => { let tracks = tracks.clone(); - subscribe_tasks.push(async move { let info = subscribed.info.clone(); if let Err(err) = Self::serve_subscribe(subscribed, tracks).await { - tracing::warn!("failed serving subscribe: {:?}, error: {}", info, err) + tracing::warn!( + subscribe_info = ?info, + error = %err, + "failed serving subscribe" + ); } }); - }, + } None => subscribe_done = true, } - }, - res = announce.track_status_requested(), if !status_done => { + res = publish_ns.track_status_requested(), if !status_done => { match res? { Some(status) => { let tracks = tracks.clone(); - status_tasks.push(async move { let request_msg = status.request_msg.clone(); if let Err(err) = Self::serve_track_status(status, tracks).await { - tracing::warn!("failed serving track status request: {:?}, error: {}", request_msg, err) + tracing::warn!( + request = ?request_msg, + error = %err, + "failed serving track status request" + ); } }); - }, + } None => status_done = true, } }, Some(res) = subscribe_tasks.next() => res, Some(res) = status_tasks.next() => res, - else => return Ok(()) + else => return Ok(()), } } } @@ -212,92 +219,142 @@ impl Publisher { Ok(()) } - // Returns subscriptions that do not map to an active announce. + /// Returns the next subscription that did not match any active PUBLISH_NAMESPACE. pub async fn subscribed(&mut self) -> Option { self.unknown_subscribed.pop().await } - // Returns track_status requests that do not map to an active announce. + /// Returns the next TRACK_STATUS request that did not match any active PUBLISH_NAMESPACE. pub async fn track_status_requested(&mut self) -> Option { self.unknown_track_status_requested.pop().await } + fn add_mlog_event(&self, make_event: F) + where + F: FnOnce(f64) -> mlog::Event, + { + if let Some(ref mlog) = self.mlog { + if let Ok(mut mlog) = mlog.lock() { + let event = make_event(mlog.elapsed_ms()); + let _ = mlog.add_event(event); + } + } + } + + fn log_request_ok_parsed(&self, request_kind: &str, msg: &message::RequestOk) { + self.add_mlog_event(|time| mlog::events::request_ok_parsed(time, 0, request_kind, msg)); + } + + fn log_request_error_parsed(&self, request_kind: &str, msg: &message::RequestError) { + self.add_mlog_event(|time| mlog::events::request_error_parsed(time, 0, request_kind, msg)); + } + + fn log_request_error_created(&self, request_kind: &str, msg: &message::RequestError) { + self.add_mlog_event(|time| mlog::events::request_error_created(time, 0, request_kind, msg)); + } + + pub(super) fn send_request_ok(&mut self, request_kind: &str, msg: message::RequestOk) { + self.add_mlog_event(|time| mlog::events::request_ok_created(time, 0, request_kind, &msg)); + self.send_message(msg); + } + + pub(super) fn send_request_error(&mut self, request_kind: &str, msg: message::RequestError) { + self.log_request_error_created(request_kind, &msg); + self.send_message(msg); + } + pub(crate) fn recv_message(&mut self, msg: message::Subscriber) -> Result<(), SessionError> { - let res = match msg { - message::Subscriber::Subscribe(msg) => self.recv_subscribe(msg), - message::Subscriber::SubscribeUpdate(msg) => self.recv_subscribe_update(msg), - message::Subscriber::Unsubscribe(msg) => self.recv_unsubscribe(msg), - message::Subscriber::Fetch(_msg) => Err(SessionError::unimplemented("FETCH")), - message::Subscriber::FetchCancel(_msg) => { - Err(SessionError::unimplemented("FETCH_CANCEL")) + match msg { + message::Subscriber::Subscribe(msg) => self.recv_subscribe(msg)?, + // REQUEST_UPDATE: not yet implemented — send REQUEST_ERROR NOT_SUPPORTED (§4). + message::Subscriber::RequestUpdate(msg) => { + self.send_not_supported(msg.id, "request_update"); } - message::Subscriber::TrackStatus(msg) => self.recv_track_status(msg), - message::Subscriber::SubscribeNamespace(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE")) + // Draft-16: REQUEST_OK from subscriber is acceptance of PUBLISH_NAMESPACE. + message::Subscriber::RequestOk(msg) => self.recv_publish_namespace_ok(msg)?, + // Draft-16: REQUEST_ERROR from subscriber is rejection of PUBLISH_NAMESPACE. + message::Subscriber::RequestError(msg) => self.recv_publish_namespace_error(msg)?, + message::Subscriber::Unsubscribe(msg) => self.recv_unsubscribe(msg)?, + // FETCH not yet implemented — send REQUEST_ERROR NOT_SUPPORTED (§4). + message::Subscriber::Fetch(msg) => { + self.send_not_supported(msg.id, "fetch"); } - message::Subscriber::UnsubscribeNamespace(_msg) => { - Err(SessionError::unimplemented("UNSUBSCRIBE_NAMESPACE")) + // FETCH_CANCEL references an existing request; log and ignore. + message::Subscriber::FetchCancel(msg) => { + tracing::debug!( + target: "moq_transport::control", + request_id = msg.id, + "received FETCH_CANCEL for unsupported FETCH — ignoring" + ); } - message::Subscriber::PublishNamespaceCancel(msg) => { - self.recv_publish_namespace_cancel(msg) + message::Subscriber::TrackStatus(msg) => self.recv_track_status(msg)?, + // SUBSCRIBE_NAMESPACE not yet implemented — send REQUEST_ERROR NOT_SUPPORTED (§4). + message::Subscriber::SubscribeNamespace(msg) => { + self.send_not_supported(msg.id, "subscribe_namespace"); } - message::Subscriber::PublishNamespaceOk(msg) => self.recv_publish_namespace_ok(msg), - message::Subscriber::PublishNamespaceError(msg) => { - self.recv_publish_namespace_error(msg) + message::Subscriber::PublishNamespaceCancel(msg) => { + self.recv_publish_namespace_cancel(msg)?; } - message::Subscriber::PublishOk(_msg) => Err(SessionError::unimplemented("PUBLISH_OK")), - message::Subscriber::PublishError(_msg) => { - Err(SessionError::unimplemented("PUBLISH_ERROR")) + // PUBLISH_OK is for publisher-initiated subscriptions, which are not + // yet implemented — log and ignore. + message::Subscriber::PublishOk(msg) => { + tracing::debug!( + target: "moq_transport::control", + request_id = msg.id, + "received PUBLISH_OK for unsupported PUBLISH — ignoring" + ); } - }; - - if let Err(err) = res { - tracing::warn!("failed to process message: {}", err); } Ok(()) } - fn recv_publish_namespace_ok( - &mut self, - msg: message::PublishNamespaceOk, - ) -> Result<(), SessionError> { - // We need to find the announce request using the request id, however the self.announces data structure - // is a HashMap indexed by Namespace (which is needed for handling PUBLISH_NAMESPACE_CANCEL). TODO - make more efficient. - // For now iterate through all self.annouces until we find the matching id. - let mut announces = self.announces.lock().unwrap(); - let announce = announces.iter_mut().find(|(_k, v)| v.request_id == msg.id); - - if let Some(announce) = announce { - announce.1.recv_ok()?; + /// Send REQUEST_ERROR NOT_SUPPORTED for an incoming request we do not implement. + /// + /// Draft-16 §4: limited endpoints SHOULD respond with NOT_SUPPORTED rather + /// than ignoring unsupported request types. + fn send_not_supported(&mut self, request_id: u64, request_kind: &str) { + tracing::debug!( + target: "moq_transport::control", + request_id, + "sending REQUEST_ERROR NOT_SUPPORTED for unimplemented request" + ); + self.send_request_error( + request_kind, + message::RequestError { + id: request_id, + error_code: RequestErrorCode::NotSupported as u64, + retry_interval: 0, + reason: crate::coding::ReasonPhrase("not supported".to_string()), + }, + ); + } + + /// Handle REQUEST_OK from subscriber — acceptance of our PUBLISH_NAMESPACE (draft-16 §9.7). + fn recv_publish_namespace_ok(&mut self, msg: message::RequestOk) -> Result<(), SessionError> { + self.log_request_ok_parsed("publish_namespace", &msg); + // The publish_namespaces map is keyed by namespace; we must search by request_id. + // TODO(itzmanish): maintain a second index keyed by request_id to make this O(1). + let mut namespaces = self + .publish_namespaces + .lock() + .map_err(|_| SessionError::Internal)?; + if let Some(entry) = namespaces.iter_mut().find(|(_k, v)| v.request_id == msg.id) { + entry.1.recv_ok()?; } Ok(()) } + /// Handle REQUEST_ERROR from subscriber — rejection of our PUBLISH_NAMESPACE (draft-16 §9.8). fn recv_publish_namespace_error( &mut self, - msg: message::PublishNamespaceError, + msg: message::RequestError, ) -> Result<(), SessionError> { - // We need to find the announce request using the request id, however the self.announces data structure - // is a HashMap indexed by Namespace (which is needed for handling PUBLISH_NAMESPACE_CANCEL). TODO - make more efficient. - // For now iterate through all self.annouces until we find the matching id. - let mut announces = self.announces.lock().unwrap(); - - // Find the key first (immutable borrow only) - let key_opt = announces - .iter() - .find(|(_k, v)| v.request_id == msg.id) - .map(|(k, _)| k.clone()); - - // Remove from HashMap and take ownership - if let Some(key) = key_opt { - if let Some((_ns, v)) = announces.remove_entry(&key) { - // Step 3: call recv_error, consuming v - v.recv_error(ServeError::Closed(msg.error_code))?; - } + self.log_request_error_parsed("publish_namespace", &msg); + if let Some(recv) = self.drop_publish_namespace(msg.id) { + recv.recv_error(ServeError::Closed(msg.error_code))?; } - Ok(()) } @@ -305,44 +362,82 @@ impl Publisher { &mut self, msg: message::PublishNamespaceCancel, ) -> Result<(), SessionError> { - // TODO: If a publisher receives new subscriptions for that namespace after receiving an ANNOUNCE_CANCEL, - // it SHOULD close the session as a 'Protocol Violation'. - if let Some(announce) = self.announces.lock().unwrap().remove(&msg.track_namespace) { - announce.recv_error(ServeError::Cancel)?; + // Draft-16 §9.24: PUBLISH_NAMESPACE_CANCEL now carries Request ID. + if let Some(recv) = self.drop_publish_namespace(msg.id) { + recv.recv_error(ServeError::Cancel)?; } - Ok(()) } fn recv_subscribe(&mut self, msg: message::Subscribe) -> Result<(), SessionError> { let namespace = msg.track_namespace.clone(); + let full_name = FullTrackName { + namespace: msg.track_namespace.clone(), + name: msg.track_name.clone(), + }; let subscribed = { - let mut subscribeds = self.subscribeds.lock().unwrap(); + let mut subscribeds = self + .subscribeds + .lock() + .map_err(|_| SessionError::Internal)?; + + if subscribeds.contains_key(&msg.id) { + let id = msg.id; + drop(subscribeds); + // Draft-16 §5.1: duplicate SUBSCRIBE for the same request ID + // MUST be rejected with DUPLICATE_SUBSCRIPTION, not a session close. + self.send_request_error( + "subscribe", + message::RequestError { + id, + error_code: RequestErrorCode::DuplicateSubscription as u64, + retry_interval: 0, + reason: crate::coding::ReasonPhrase("duplicate subscription".to_string()), + }, + ); + return Ok(()); + } - // See if entry exists for this request id already, if so error out - let entry = match subscribeds.entry(msg.id) { - hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), - hash_map::Entry::Vacant(entry) => entry, - }; + let mut subscribed_names = self + .subscribed_names + .lock() + .map_err(|_| SessionError::Internal)?; + if subscribed_names.contains_key(&full_name) { + let id = msg.id; + drop(subscribed_names); + drop(subscribeds); + self.send_request_error( + "subscribe", + message::RequestError { + id, + error_code: RequestErrorCode::DuplicateSubscription as u64, + retry_interval: 0, + reason: crate::coding::ReasonPhrase("duplicate subscription".to_string()), + }, + ); + return Ok(()); + } - // Create new Subscribed entry and add to HashMap - let (send, recv) = Subscribed::new(self.clone(), msg, self.mlog.clone()); - entry.insert(recv); + let (send, recv) = Subscribed::new(self.clone(), msg, self.mlog.clone())?; + subscribed_names.insert(full_name, send.info.id); + subscribeds.insert(send.info.id, recv); send }; - // If we have an announce, route the subscribe to it. - if let Some(announce) = self.announces.lock().unwrap().get_mut(&namespace) { - return announce.recv_subscribe(subscribed).map_err(Into::into); + // Route to an active PUBLISH_NAMESPACE if present. + if let Some(ns) = self + .publish_namespaces + .lock() + .map_err(|_| SessionError::Internal)? + .get_mut(&namespace) + { + return ns.recv_subscribe(subscribed).map_err(Into::into); } - // Otherwise, put it in the unknown queue. - // TODO Have some way to detect if the application is not reading from the unknown queue, - // then send SubscribeError. + // Otherwise, surface it to the application via the unknown queue. if let Err(err) = self.unknown_subscribed.push(subscribed) { - // Default to closing with a not found error I guess. err.close(ServeError::not_found_ctx(format!( "unknown_subscribed queue full for namespace {:?}", namespace @@ -352,50 +447,54 @@ impl Publisher { Ok(()) } - fn recv_subscribe_update( - &mut self, - _msg: message::SubscribeUpdate, - ) -> Result<(), SessionError> { - // TODO: Implement updating subscriptions. - Err(SessionError::unimplemented("SUBSCRIBE_UPDATE")) - } - fn recv_track_status(&mut self, msg: message::TrackStatus) -> Result<(), SessionError> { let namespace = msg.track_namespace.clone(); - // Create TrackStatusRequested to track this request let track_status_requested = TrackStatusRequested::new(self.clone(), msg); - // If we have an announce, route the track_status to it. - if let Some(announce) = self.announces.lock().unwrap().get_mut(&namespace) { - return announce + if let Some(ns) = self + .publish_namespaces + .lock() + .map_err(|_| SessionError::Internal)? + .get_mut(&namespace) + { + return ns .recv_track_status_requested(track_status_requested) .map_err(Into::into); } - // Otherwise, put it in the unknown_track_status queue. - // TODO Have some way to detect if the application is not reading from the unknown_track_status queue, - // then send TrackStatusError. if let Err(mut err) = self .unknown_track_status_requested .push(track_status_requested) { - // push only fails if the queue is dropped, send TrackStatusError, Internal error - err.respond_error(0, "Internal error")?; + err.respond_error(RequestErrorCode::InternalError as u64, "internal error")?; } Ok(()) } fn recv_unsubscribe(&mut self, msg: message::Unsubscribe) -> Result<(), SessionError> { - if let Some(subscribed) = self.subscribeds.lock().unwrap().get_mut(&msg.id) { + { + let mut subscribeds = self + .subscribeds + .lock() + .map_err(|_| SessionError::Internal)?; + let subscribed = subscribeds.get_mut(&msg.id).ok_or_else(|| { + SessionError::ProtocolViolation(format!( + "UNSUBSCRIBE for unknown subscribe ID {}", + msg.id + )) + })?; + subscribed.recv_unsubscribe()?; } + self.remove_subscribe(msg.id)?; + Ok(()) } - /// Process a message before sending it, performing any necessary internal actions. + /// Pre-send hook: clean up internal state when terminal publisher messages are enqueued. fn act_on_message_to_send>( &mut self, msg: T, @@ -403,22 +502,23 @@ impl Publisher { let msg = msg.into(); match &msg { message::Publisher::PublishDone(m) => self.drop_subscribe(m.id), - message::Publisher::SubscribeError(m) => self.drop_subscribe(m.id), + // Draft-16: PUBLISH_NAMESPACE_DONE carries Request ID, not namespace. + // Dropping the recv state signals that the namespace is done. message::Publisher::PublishNamespaceDone(m) => { - self.drop_publish_namespace(&m.track_namespace); + let _ = self.drop_publish_namespace(m.id); } _ => {} } msg } - /// Send a message without waiting for it to be sent. + /// Enqueue a control message for sending (fire-and-forget). pub(super) fn send_message + Into>(&mut self, msg: T) { let msg = self.act_on_message_to_send(msg); self.outgoing.push(msg.into()).ok(); } - /// Send a message and wait until it is sent (or at least popped off the outgoing control message queue) + /// Enqueue a control message and wait until it has been dequeued for sending. pub(super) async fn send_message_and_wait + Into>( &mut self, msg: T, @@ -430,12 +530,41 @@ impl Publisher { .ok(); } - fn drop_subscribe(&mut self, id: u64) { - self.subscribeds.lock().unwrap().remove(&id); + pub(super) fn drop_subscribe(&mut self, id: u64) { + let _ = self.remove_subscribe(id); + } + + fn remove_subscribe(&mut self, id: u64) -> Result<(), SessionError> { + self.subscribeds + .lock() + .map_err(|_| SessionError::Internal)? + .remove(&id); + Self::drop_subscribed_name(&self.subscribed_names, id) + } + + fn drop_subscribed_name( + subscribed_names: &Arc>>, + id: u64, + ) -> Result<(), SessionError> { + subscribed_names + .lock() + .map_err(|_| SessionError::Internal)? + .retain(|_, request_id| *request_id != id); + + Ok(()) } - fn drop_publish_namespace(&mut self, namespace: &TrackNamespace) { - self.announces.lock().unwrap().remove(namespace); + fn drop_publish_namespace(&mut self, id: u64) -> Option { + if let Ok(mut ns) = self.publish_namespaces.lock() { + let key = ns + .iter() + .find(|(_k, v)| v.request_id == id) + .map(|(k, _)| k.clone()); + if let Some(key) = key { + return ns.remove(&key); + } + } + None } pub(super) async fn open_uni(&mut self) -> Result { @@ -446,3 +575,44 @@ impl Publisher { Ok(self.webtransport.send_datagram(data).await?) } } + +#[cfg(test)] +mod tests { + use std::{ + collections::HashMap, + sync::{Arc, Mutex}, + }; + + use crate::{ + coding::{TrackName, TrackNamespace}, + serve::FullTrackName, + }; + + use super::Publisher; + + fn full_track_name(namespace: &str, name: &str) -> FullTrackName { + FullTrackName { + namespace: TrackNamespace::from_utf8_path(namespace), + name: TrackName::from(name), + } + } + + #[test] + fn drop_subscribed_name_removes_only_matching_request_id() { + let subscribed_names = Arc::new(Mutex::new(HashMap::new())); + let unsubscribed_track = full_track_name("bb1", "video.m4s"); + let active_track = full_track_name("bb1", "audio.m4s"); + + { + let mut names = subscribed_names.lock().unwrap(); + names.insert(unsubscribed_track.clone(), 6); + names.insert(active_track.clone(), 8); + } + + Publisher::drop_subscribed_name(&subscribed_names, 6).unwrap(); + + let names = subscribed_names.lock().unwrap(); + assert!(!names.contains_key(&unsubscribed_track)); + assert_eq!(names.get(&active_track), Some(&8)); + } +} diff --git a/moq-transport/src/session/reader.rs b/moq-transport/src/session/reader.rs index b3326c3f..f92dcb76 100644 --- a/moq-transport/src/session/reader.rs +++ b/moq-transport/src/session/reader.rs @@ -38,7 +38,7 @@ impl Reader { Ok(msg) => { let consumed = cursor.position() as usize; self.buffer.advance(consumed); - tracing::debug!( + tracing::trace!( "[READER] decode: successfully decoded {} (consumed={} bytes, buffer_remaining={})", std::any::type_name::(), consumed, diff --git a/moq-transport/src/session/request_id.rs b/moq-transport/src/session/request_id.rs new file mode 100644 index 00000000..8ad0d0e4 --- /dev/null +++ b/moq-transport/src/session/request_id.rs @@ -0,0 +1,361 @@ +// SPDX-FileCopyrightText: 2024-2026 Cloudflare Inc. +// SPDX-License-Identifier: MIT OR Apache-2.0 + +//! Request ID flow control per draft-ietf-moq-transport-16 §9.1. +//! +//! This module intentionally exposes one cloneable session-level handle, +//! [`RequestId`]. Internally it owns two independent states under one shared +//! allocation: +//! +//! - `send`: outbound request ID allocation and peer-advertised max. +//! - `recv`: inbound request ID validation and our advertised max. +//! +//! The two states use separate mutexes. That keeps outbound allocation from +//! blocking inbound validation while still keeping all request-ID state tied to +//! one session handle. + +use std::sync::{Arc, Mutex}; + +use crate::coding::KeyValuePairs; +use crate::message::{MaxRequestId, RequestsBlocked}; +use crate::session::SessionError; + +#[derive(Clone, Debug)] +pub struct RequestId { + inner: Arc, +} + +#[derive(Debug)] +struct RequestIdInner { + send: Mutex, + recv: Mutex, +} + +#[derive(Debug)] +struct SendState { + next: u64, + peer_max: u64, + blocked_sent_for: Option, +} + +#[derive(Debug)] +struct RecvState { + next_expected: u64, + our_max: u64, +} + +#[derive(Debug, Eq, PartialEq)] +pub enum RequestIdAllocation { + Allocated(u64), + Blocked { + max_request_id: u64, + should_send_requests_blocked: bool, + }, +} + +impl RequestIdAllocation { + /// Return a REQUESTS_BLOCKED message if this allocation should emit one. + pub fn requests_blocked(&self) -> Option { + match self { + Self::Blocked { + max_request_id, + should_send_requests_blocked: true, + } => Some(RequestsBlocked { + max_request_id: *max_request_id, + }), + _ => None, + } + } +} + +impl RequestId { + /// Create a request-ID manager for one endpoint in a session. + /// + /// `local_first_id` is 0 for clients and 1 for servers. + /// `peer_max` is the peer-advertised MAX_REQUEST_ID from setup. + /// `our_max` is the MAX_REQUEST_ID we advertised in setup. + /// `peer_first_id` is 0 if the peer is client, 1 if the peer is server. + pub fn new(local_first_id: u64, peer_max: u64, our_max: u64, peer_first_id: u64) -> Self { + Self { + inner: Arc::new(RequestIdInner { + send: Mutex::new(SendState { + next: local_first_id, + peer_max, + blocked_sent_for: None, + }), + recv: Mutex::new(RecvState { + next_expected: peer_first_id, + our_max, + }), + }), + } + } + + /// Allocate the next outbound request ID. + /// + /// If the peer-advertised budget is exhausted, returns `Blocked` with + /// `should_send_requests_blocked=true` once per max value. + pub fn allocate(&self) -> Result { + let mut send = self.inner.send.lock().map_err(|_| SessionError::Internal)?; + + if send.next >= send.peer_max { + let should_send_requests_blocked = if send.blocked_sent_for == Some(send.peer_max) { + false + } else { + send.blocked_sent_for = Some(send.peer_max); + true + }; + + return Ok(RequestIdAllocation::Blocked { + max_request_id: send.peer_max, + should_send_requests_blocked, + }); + } + + let id = send.next; + send.next = send + .next + .checked_add(2) + .ok_or(SessionError::TooManyRequests)?; + send.blocked_sent_for = None; + Ok(RequestIdAllocation::Allocated(id)) + } + + /// Apply a peer MAX_REQUEST_ID update to the outbound allocator. + pub fn apply_max_request_id(&self, msg: &MaxRequestId) -> Result<(), SessionError> { + let mut send = self.inner.send.lock().map_err(|_| SessionError::Internal)?; + + if msg.request_id <= send.peer_max { + return Err(SessionError::ProtocolViolation( + "MAX_REQUEST_ID must be strictly increasing".to_string(), + )); + } + + send.peer_max = msg.request_id; + send.blocked_sent_for = None; + Ok(()) + } + + /// Validate an incoming new request ID from the peer. + pub fn validate_incoming(&self, id: u64) -> Result<(), SessionError> { + let mut recv = self.inner.recv.lock().map_err(|_| SessionError::Internal)?; + + if id != recv.next_expected { + return Err(SessionError::InvalidRequestId); + } + + if id >= recv.our_max { + return Err(SessionError::TooManyRequests); + } + + recv.next_expected = recv + .next_expected + .checked_add(2) + .ok_or(SessionError::InvalidRequestId)?; + Ok(()) + } + + /// Handle REQUESTS_BLOCKED from the peer. + /// + /// If the peer has consumed our current advertised maximum and reports that + /// same maximum as blocked, we currently ignore this. In the future, we may + /// advertise new incremented MAX_REQUEST_ID. + pub fn handle_requests_blocked(&self, msg: &RequestsBlocked) -> Result<(), SessionError> { + let recv = self.inner.recv.lock().map_err(|_| SessionError::Internal)?; + tracing::warn!( + "got requests blocked, peer max: {}, configured limit: {}, limit hit: {}, ignoring it", + msg.max_request_id, + recv.our_max, + msg.max_request_id == recv.our_max + ); + + Ok(()) + } +} + +/// Extract the MAX_REQUEST_ID value from setup parameters (0 if absent). +pub fn max_request_id_from_params(params: &KeyValuePairs) -> u64 { + use crate::coding::Value; + use crate::setup::ParameterType; + + params + .get(ParameterType::MaxRequestId.into()) + .and_then(|kvp| match &kvp.value { + Value::IntValue(v) => Some(*v), + _ => None, + }) + .unwrap_or(0) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn client_ids(peer_max: u64, our_max: u64) -> RequestId { + // local client sends even, peer server sends odd + RequestId::new(0, peer_max, our_max, 1) + } + + fn server_ids(peer_max: u64, our_max: u64) -> RequestId { + // local server sends odd, peer client sends even + RequestId::new(1, peer_max, our_max, 0) + } + + #[test] + fn client_allocates_even_ids() { + let ids = client_ids(10, 10); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(0)); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(2)); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(4)); + } + + #[test] + fn server_allocates_odd_ids() { + let ids = server_ids(10, 10); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(1)); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(3)); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(5)); + } + + #[test] + fn allocation_blocks_at_peer_max() { + let ids = client_ids(4, 10); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(0)); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(2)); + assert_eq!( + ids.allocate().unwrap(), + RequestIdAllocation::Blocked { + max_request_id: 4, + should_send_requests_blocked: true, + } + ); + } + + #[test] + fn requests_blocked_is_stable_for_same_limit() { + let ids = client_ids(2, 10); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(0)); + let first_block = ids.allocate().unwrap(); + assert_eq!( + first_block, + RequestIdAllocation::Blocked { + max_request_id: 2, + should_send_requests_blocked: true, + } + ); + assert_eq!(first_block.requests_blocked().unwrap().max_request_id, 2); + + let second_block = ids.allocate().unwrap(); + assert_eq!( + second_block, + RequestIdAllocation::Blocked { + max_request_id: 2, + should_send_requests_blocked: false, + } + ); + assert!(second_block.requests_blocked().is_none()); + } + + #[test] + fn max_request_id_must_increase() { + let ids = client_ids(10, 10); + assert!(matches!( + ids.apply_max_request_id(&MaxRequestId { request_id: 10 }) + .unwrap_err(), + SessionError::ProtocolViolation(_) + )); + assert!(matches!( + ids.apply_max_request_id(&MaxRequestId { request_id: 9 }) + .unwrap_err(), + SessionError::ProtocolViolation(_) + )); + } + + #[test] + fn max_request_id_increases_allocation_budget() { + let ids = client_ids(2, 10); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(0)); + assert!(matches!( + ids.allocate().unwrap(), + RequestIdAllocation::Blocked { .. } + )); + ids.apply_max_request_id(&MaxRequestId { request_id: 10 }) + .unwrap(); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(2)); + assert_eq!(ids.allocate().unwrap(), RequestIdAllocation::Allocated(4)); + } + + #[test] + fn validates_client_peer_sequence() { + let ids = server_ids(10, 10); + ids.validate_incoming(0).unwrap(); + ids.validate_incoming(2).unwrap(); + ids.validate_incoming(4).unwrap(); + } + + #[test] + fn validates_server_peer_sequence() { + let ids = client_ids(10, 10); + ids.validate_incoming(1).unwrap(); + ids.validate_incoming(3).unwrap(); + ids.validate_incoming(5).unwrap(); + } + + #[test] + fn rejects_wrong_first_id() { + let ids = server_ids(10, 10); + assert!(matches!( + ids.validate_incoming(2).unwrap_err(), + SessionError::InvalidRequestId + )); + } + + #[test] + fn rejects_skipped_id() { + let ids = server_ids(10, 10); + ids.validate_incoming(0).unwrap(); + assert!(matches!( + ids.validate_incoming(4).unwrap_err(), + SessionError::InvalidRequestId + )); + } + + #[test] + fn rejects_repeated_id() { + let ids = server_ids(10, 10); + ids.validate_incoming(0).unwrap(); + assert!(matches!( + ids.validate_incoming(0).unwrap_err(), + SessionError::InvalidRequestId + )); + } + + #[test] + fn rejects_id_at_our_max() { + let ids = server_ids(10, 4); + ids.validate_incoming(0).unwrap(); + ids.validate_incoming(2).unwrap(); + assert!(matches!( + ids.validate_incoming(4).unwrap_err(), + SessionError::TooManyRequests + )); + } + + #[test] + fn send_and_receive_state_do_not_block_each_other() { + let ids = client_ids(10, 10); + let send_ids = ids.clone(); + let recv_ids = ids.clone(); + + assert_eq!( + send_ids.allocate().unwrap(), + RequestIdAllocation::Allocated(0) + ); + recv_ids.validate_incoming(1).unwrap(); + assert_eq!( + send_ids.allocate().unwrap(), + RequestIdAllocation::Allocated(2) + ); + recv_ids.validate_incoming(3).unwrap(); + } +} diff --git a/moq-transport/src/session/subscribe.rs b/moq-transport/src/session/subscribe.rs index ce63e6c3..eca9ada0 100644 --- a/moq-transport/src/session/subscribe.rs +++ b/moq-transport/src/session/subscribe.rs @@ -5,22 +5,53 @@ use std::ops; use crate::{ - coding::{KeyValuePairs, Location, TrackNamespace}, + coding::{KeyValuePairs, Location, TrackName, TrackNamespace}, data, - message::{self, FilterType, GroupOrder}, + message::{self, FilterType, GroupOrder, SubscriptionFilter}, serve::{self, ServeError, TrackWriter, TrackWriterMode}, }; use crate::watch::State; +use super::SessionError; use super::Subscriber; +#[derive(Debug, Clone, Copy)] +pub struct DeliveryFilter { + pub forward: bool, + pub start_location: Option, + pub end_group_id: Option, +} + +impl DeliveryFilter { + pub fn allows(&self, group_id: u64, object_id: u64) -> bool { + if !self.forward { + return false; + } + + let location = Location::new(group_id, object_id); + if let Some(start) = self.start_location { + if location < start { + return false; + } + } + + if let Some(end_group_id) = self.end_group_id { + if group_id > end_group_id { + return false; + } + } + + true + } +} + // TODO rename to SubscriptionInfo when used for Publishes as well? #[derive(Debug, Clone)] pub struct SubscribeInfo { pub id: u64, pub track_namespace: TrackNamespace, - pub track_name: String, + pub track_name: TrackName, /// Subscriber Priority pub subscriber_priority: u8, @@ -37,6 +68,10 @@ pub struct SubscribeInfo { /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. pub end_group_id: Option, + /// None means the SUBSCRIPTION_FILTER parameter was omitted and the + /// subscription is unfiltered per draft-16 §9.2.2.5. + pub filter: Option, + /// Optional parameters pub params: KeyValuePairs, @@ -45,23 +80,74 @@ pub struct SubscribeInfo { } impl SubscribeInfo { - pub fn new_from_subscribe(msg: &message::Subscribe) -> Self { - Self { + pub fn new_from_subscribe(msg: &message::Subscribe) -> Result { + let filter = msg.params.subscription_filter()?; + let filter_type = filter + .as_ref() + .map(|filter| filter.filter_type) + .unwrap_or(FilterType::AbsoluteStart); + let start_location = filter.as_ref().and_then(|filter| filter.start_location); + let end_group_id = filter.as_ref().and_then(|filter| filter.end_group_id); + + Ok(Self { id: msg.id, track_namespace: msg.track_namespace.clone(), track_name: msg.track_name.clone(), - subscriber_priority: msg.subscriber_priority, - group_order: msg.group_order, - forward: msg.forward, - filter_type: msg.filter_type, - start_location: msg.start_location, - end_group_id: msg.end_group_id, + subscriber_priority: msg.params.subscriber_priority()?.unwrap_or(128), + group_order: msg.params.group_order()?.unwrap_or(GroupOrder::Publisher), + forward: msg.params.forward()?.unwrap_or(true), + filter_type, + start_location, + end_group_id, + filter, params: msg.params.clone(), track_status: false, + }) + } + + pub fn delivery_filter(&self, largest_location: Option) -> DeliveryFilter { + let Some(filter) = &self.filter else { + return DeliveryFilter { + forward: self.forward, + start_location: None, + end_group_id: None, + }; + }; + + let start_location = match filter.filter_type { + FilterType::LargestObject => Some(next_object_location(largest_location)), + FilterType::NextGroupStart => Some(next_group_location(largest_location)), + FilterType::AbsoluteStart | FilterType::AbsoluteRange => filter.start_location, + }; + + DeliveryFilter { + forward: self.forward, + start_location, + end_group_id: filter.end_group_id, } } } +fn next_object_location(largest_location: Option) -> Location { + let Some(location) = largest_location else { + return Location::new(0, 0); + }; + + if let Some(object_id) = location.object_id.checked_add(1) { + Location::new(location.group_id, object_id) + } else { + next_group_location(Some(location)) + } +} + +fn next_group_location(largest_location: Option) -> Location { + let Some(location) = largest_location else { + return Location::new(0, 0); + }; + + Location::new(location.group_id.saturating_add(1), 0) +} + struct SubscribeState { ok: bool, track_alias: Option, @@ -97,16 +183,25 @@ impl Subscribe { id: request_id, track_namespace: track.namespace.clone(), track_name: track.name.clone(), - // TODO add prioritization logic on the publisher side - subscriber_priority: 127, // default to mid value, see: https://github.com/moq-wg/moq-transport/issues/504 - group_order: GroupOrder::Publisher, // defer to publisher send order - forward: true, // default to forwarding objects - filter_type: FilterType::LargestObject, - start_location: None, - end_group_id: None, - params: Default::default(), + params: KeyValuePairs::default(), }; - let info = SubscribeInfo::new_from_subscribe(&subscribe_message); + let info = SubscribeInfo::new_from_subscribe(&subscribe_message).unwrap_or_else(|err| { + tracing::warn!(error = %err, "failed to decode outbound subscribe parameters"); + SubscribeInfo { + id: request_id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + subscriber_priority: 128, + group_order: GroupOrder::Publisher, + forward: true, + filter_type: FilterType::AbsoluteStart, + start_location: None, + end_group_id: None, + filter: None, + params: Default::default(), + track_status: false, + } + }); subscriber.send_message(subscribe_message); @@ -277,3 +372,73 @@ impl SubscribeRecv { } } } + +#[cfg(test)] +mod tests { + use super::*; + + fn subscribe_info_with(params: KeyValuePairs) -> SubscribeInfo { + SubscribeInfo::new_from_subscribe(&message::Subscribe { + id: 0, + track_namespace: TrackNamespace::from_utf8_path("test"), + track_name: "track".into(), + params, + }) + .unwrap() + } + + #[test] + fn omitted_subscription_filter_is_unfiltered() { + let info = subscribe_info_with(KeyValuePairs::default()); + let filter = info.delivery_filter(Some(Location::new(10, 20))); + + assert!(info.filter.is_none()); + assert!(filter.allows(0, 0)); + assert!(filter.allows(10, 20)); + assert!(filter.allows(100, 0)); + } + + #[test] + fn largest_object_filter_starts_after_largest_object() { + let mut params = KeyValuePairs::default(); + params + .set_subscription_filter(&SubscriptionFilter::largest_object()) + .unwrap(); + let info = subscribe_info_with(params); + let filter = info.delivery_filter(Some(Location::new(2, 3))); + + assert!(!filter.allows(2, 3)); + assert!(filter.allows(2, 4)); + assert!(filter.allows(3, 0)); + } + + #[test] + fn absolute_range_filter_limits_start_and_end_group() { + let mut params = KeyValuePairs::default(); + params + .set_subscription_filter(&SubscriptionFilter { + filter_type: FilterType::AbsoluteRange, + start_location: Some(Location::new(2, 3)), + end_group_id: Some(4), + }) + .unwrap(); + let info = subscribe_info_with(params); + let filter = info.delivery_filter(None); + + assert!(!filter.allows(2, 2)); + assert!(filter.allows(2, 3)); + assert!(filter.allows(4, 10)); + assert!(!filter.allows(5, 0)); + } + + #[test] + fn forward_false_blocks_delivery() { + let mut params = KeyValuePairs::default(); + params.set_forward(false); + let info = subscribe_info_with(params); + let filter = info.delivery_filter(None); + + assert!(!filter.allows(0, 0)); + assert!(!filter.allows(100, 100)); + } +} diff --git a/moq-transport/src/session/subscribed.rs b/moq-transport/src/session/subscribed.rs index 5847e6e4..867fe0e9 100644 --- a/moq-transport/src/session/subscribed.rs +++ b/moq-transport/src/session/subscribed.rs @@ -8,23 +8,32 @@ use std::sync::{Arc, Mutex}; use futures::stream::FuturesUnordered; use futures::StreamExt; -use crate::coding::{Encode, Location, ReasonPhrase}; +use crate::coding::{Encode, KeyValuePairs, Location, ReasonPhrase}; +use crate::message::RequestErrorCode; use crate::mlog; use crate::serve::{ServeError, TrackReaderMode}; use crate::watch::State; use crate::{data, message, serve}; -use super::{Publisher, SessionError, SubscribeInfo, Writer}; +use super::{DeliveryFilter, Publisher, SessionError, SubscribeInfo, Writer}; // This file defines Publisher handling of inbound Subscriptions #[derive(Debug)] struct SubscribedState { largest_location: Option, + stream_count: u64, + /// Set to true when UNSUBSCRIBE is received. When true, Drop skips sending + /// PUBLISH_DONE or REQUEST_ERROR because the subscriber already terminated. + unsubscribed: bool, closed: Result<(), ServeError>, } impl SubscribedState { + fn record_stream_opened(&mut self) { + self.stream_count = self.stream_count.saturating_add(1); + } + fn update_largest_location(&mut self, group_id: u64, object_id: u64) -> Result<(), ServeError> { if let Some(current_largest_location) = self.largest_location { let update_largest_location = Location::new(group_id, object_id); @@ -41,6 +50,8 @@ impl Default for SubscribedState { fn default() -> Self { Self { largest_location: None, + stream_count: 0, + unsubscribed: false, closed: Ok(()), } } @@ -57,7 +68,7 @@ pub struct Subscribed { state: State, /// Tracks if SubscribeOk has been sent yet or not. Used to send - /// SubscribeDone vs SubscribeError on drop. + /// PUBLISH_DONE vs REQUEST_ERROR on drop. ok: bool, /// Optional mlog writer for logging transport events @@ -69,9 +80,9 @@ impl Subscribed { publisher: Publisher, msg: message::Subscribe, mlog: Option>>, - ) -> (Self, SubscribedRecv) { + ) -> Result<(Self, SubscribedRecv), SessionError> { let (send, recv) = State::default().split(); - let info = SubscribeInfo::new_from_subscribe(&msg); + let info = SubscribeInfo::new_from_subscribe(&msg)?; let send = Self { publisher, state: send, @@ -83,7 +94,7 @@ impl Subscribed { // Prevents updates after being closed let recv = SubscribedRecv { state: recv }; - (send, recv) + Ok((send, recv)) } pub async fn serve(mut self, track: serve::TrackReader) -> Result<(), SessionError> { @@ -106,26 +117,36 @@ impl Subscribed { // Send SubscribeOk using send_message_and_wait to ensure it is sent at least to the QUIC stack before // we start serving the track. If a subscriber gets the stream before SubscribeOk // then they won't recognize the track_alias in the stream header. + let mut params = KeyValuePairs::default(); + if let Some(largest) = largest_location { + params + .set_largest_object(largest) + .map_err(|_| SessionError::Internal)?; + } + self.publisher .send_message_and_wait(message::SubscribeOk { id: self.info.id, track_alias: self.info.id, // use subscription id as track alias - expires: 0, // TODO SLG - group_order: message::GroupOrder::Descending, // TODO: resolve correct value from publisher / subscriber prefs - content_exists: largest_location.is_some(), - largest_location, - params: Default::default(), + params, + track_extensions: Default::default(), }) .await; self.ok = true; // So we send SubscribeDone on drop + let delivery_filter = self.info.delivery_filter(largest_location); + // Serve based on track mode match track.mode().await? { // TODO cancel track/datagrams on closed TrackReaderMode::Stream(_stream) => panic!("deprecated"), - TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups).await, - TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + TrackReaderMode::Subgroups(subgroups) => { + self.serve_subgroups(subgroups, delivery_filter).await + } + TrackReaderMode::Datagrams(datagrams) => { + self.serve_datagrams(datagrams, delivery_filter).await + } } } @@ -172,29 +193,77 @@ impl Drop for Subscribed { .err() .cloned() .unwrap_or(ServeError::Done); + let stream_count = state.stream_count; + let unsubscribed = state.unsubscribed; drop(state); // Important to avoid a deadlock + // Subscriber already sent UNSUBSCRIBE — no terminal message needed. + if unsubscribed { + return; + } + if self.ok { self.publisher.send_message(message::PublishDone { id: self.info.id, - status_code: err.code(), - stream_count: 0, // TODO SLG + status_code: Self::publish_done_code(&err), + stream_count, reason: ReasonPhrase(err.to_string()), }); } else { - self.publisher.send_message(message::SubscribeError { - id: self.info.id, - error_code: err.code(), - reason_phrase: ReasonPhrase(err.to_string()), - }); + // Draft-16 §9.8: subscription rejection uses REQUEST_ERROR, not the + // legacy SUBSCRIBE_ERROR. + self.publisher.send_request_error( + "subscribe", + message::RequestError { + id: self.info.id, + error_code: Self::request_error_code(&err), + retry_interval: 0, + reason: ReasonPhrase(err.to_string()), + }, + ); + self.publisher.drop_subscribe(self.info.id); }; } } impl Subscribed { + fn publish_done_code(err: &ServeError) -> u64 { + match err { + ServeError::Done => message::PublishDoneCode::TrackEnded as u64, + ServeError::Closed(code) => *code, + _ => message::PublishDoneCode::InternalError as u64, + } + } + + fn request_error_code(err: &ServeError) -> u64 { + match err { + ServeError::Closed(code) => *code, + ServeError::NotFound | ServeError::NotFoundWithId(_, _) => { + RequestErrorCode::DoesNotExist as u64 + } + ServeError::Duplicate => RequestErrorCode::DuplicateSubscription as u64, + ServeError::Cancel | ServeError::Done => RequestErrorCode::Uninterested as u64, + ServeError::Mode + | ServeError::Size + | ServeError::NotImplemented(_) + | ServeError::NotImplementedWithId(_, _) => RequestErrorCode::NotSupported as u64, + ServeError::Internal(_) | ServeError::InternalWithId(_, _) => { + RequestErrorCode::InternalError as u64 + } + } + } + + fn is_expected_serve_shutdown(err: &SessionError) -> bool { + matches!( + err, + SessionError::Serve(ServeError::Cancel | ServeError::Done) + ) + } + async fn serve_subgroups( &mut self, mut subgroups: serve::SubgroupsReader, + delivery_filter: DeliveryFilter, ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); let mut done: Option> = None; @@ -217,8 +286,12 @@ impl Subscribed { let mlog = self.mlog.clone(); tasks.push(async move { - if let Err(err) = Self::serve_subgroup(header, subgroup, publisher, state, mlog).await { - tracing::warn!("failed to serve subgroup: {:?}, error: {}", info, err); + if let Err(err) = Self::serve_subgroup(header, subgroup, publisher, state, mlog, delivery_filter).await { + if Self::is_expected_serve_shutdown(&err) { + tracing::debug!(subgroup_info = ?info, error = %err, "stopped serving subgroup"); + } else { + tracing::warn!(subgroup_info = ?info, error = %err, "failed to serve subgroup"); + } } }); }, @@ -238,47 +311,71 @@ impl Subscribed { mut publisher: Publisher, state: State, mlog: Option>>, + delivery_filter: DeliveryFilter, ) -> Result<(), SessionError> { - tracing::debug!( + tracing::trace!( "[PUBLISHER] serve_subgroup: starting - group_id={}, subgroup_id={:?}, priority={}", subgroup_reader.group_id, subgroup_reader.subgroup_id, subgroup_reader.priority ); - let mut send_stream = publisher.open_uni().await?; - tracing::trace!("[PUBLISHER] serve_subgroup: opened unidirectional stream"); + let mut writer: Option = None; + let mut object_count = 0; + while let Some(mut subgroup_object_reader) = subgroup_reader.next().await? { + if !delivery_filter.allows(subgroup_reader.group_id, subgroup_object_reader.object_id) { + tracing::trace!( + "[PUBLISHER] serve_subgroup: filtered object group_id={}, object_id={}", + subgroup_reader.group_id, + subgroup_object_reader.object_id + ); + continue; + } + + if writer.is_none() { + let mut send_stream = publisher.open_uni().await?; + tracing::trace!("[PUBLISHER] serve_subgroup: opened unidirectional stream"); - // TODO figure out u32 vs u64 priority - send_stream.set_priority(subgroup_reader.priority as i32); + state + .lock_mut() + .ok_or(ServeError::Done)? + .record_stream_opened(); - let mut writer = Writer::new(send_stream); + // TODO figure out u32 vs u64 priority + send_stream.set_priority(subgroup_reader.priority as i32); - tracing::debug!( - "[PUBLISHER] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={}, header_type={:?}", - header.track_alias, - header.group_id, - header.subgroup_id, - header.publisher_priority, - header.header_type - ); + let mut new_writer = Writer::new(send_stream); - writer.encode(&header).await?; + tracing::trace!( + "[PUBLISHER] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={}, header_type={:?}", + header.track_alias, + header.group_id, + header.subgroup_id, + header.publisher_priority, + header.header_type + ); - // Log subgroup header created/sent - if let Some(ref mlog) = mlog { - if let Ok(mut mlog_guard) = mlog.lock() { - let time = mlog_guard.elapsed_ms(); - let stream_id = 0; // TODO: Placeholder, need actual QUIC stream ID - let event = mlog::subgroup_header_created(time, stream_id, &header); - let _ = mlog_guard.add_event(event); + new_writer.encode(&header).await?; + + // Log subgroup header created/sent + if let Some(ref mlog) = mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; // TODO: Placeholder, need actual QUIC stream ID + let event = mlog::subgroup_header_created(time, stream_id, &header); + let _ = mlog_guard.add_event(event); + } + } + + writer = Some(new_writer); } - } - let mut object_count = 0; - while let Some(mut subgroup_object_reader) = subgroup_reader.next().await? { + let writer = writer.as_mut().ok_or(SessionError::Internal)?; let subgroup_object = data::SubgroupObjectExt { - object_id_delta: 0, // before delta logic, used to be subgroup_object_reader.object_id, + // TODO(itzmanish): compute real delta when the receive side uses object IDs + // for ordering. Both sender and receiver must agree on the same prev tracking + // semantics before this is meaningful. + object_id_delta: 0, extension_headers: subgroup_object_reader.extension_headers.clone(), // Pass through extension headers payload_length: subgroup_object_reader.size, status: if subgroup_object_reader.size == 0 { @@ -289,7 +386,7 @@ impl Subscribed { }, }; - tracing::debug!( + tracing::trace!( "[PUBLISHER] serve_subgroup: sending object #{} - object_id={}, object_id_delta={}, payload_length={}, status={:?}, extension_headers={:?}", object_count + 1, subgroup_object_reader.object_id, @@ -349,7 +446,7 @@ impl Subscribed { object_count += 1; } - tracing::info!( + tracing::trace!( "[PUBLISHER] serve_subgroup: completed subgroup (group_id={}, subgroup_id={:?}, {} objects sent)", subgroup_reader.group_id, subgroup_reader.subgroup_id, @@ -362,11 +459,21 @@ impl Subscribed { async fn serve_datagrams( &mut self, mut datagrams: serve::DatagramsReader, + delivery_filter: DeliveryFilter, ) -> Result<(), SessionError> { tracing::debug!("[PUBLISHER] serve_datagrams: starting"); let mut datagram_count = 0; while let Some(datagram) = datagrams.read().await? { + if !delivery_filter.allows(datagram.group_id, datagram.object_id) { + tracing::trace!( + "[PUBLISHER] serve_datagrams: filtered datagram group_id={}, object_id={}", + datagram.group_id, + datagram.object_id + ); + continue; + } + // Determine datagram type based on extension headers presence let has_extension_headers = !datagram.extension_headers.is_empty(); let datagram_type = if has_extension_headers { @@ -398,7 +505,7 @@ impl Subscribed { let mut buffer = bytes::BytesMut::with_capacity(payload_len + 100); encoded_datagram.encode(&mut buffer)?; - tracing::debug!( + tracing::trace!( "[PUBLISHER] serve_datagrams: sending datagram #{} - track_alias={}, group_id={}, object_id={}, priority={}, payload_len={}, extension_headers={:?}, total_encoded_len={}", datagram_count + 1, encoded_datagram.track_alias, @@ -436,7 +543,7 @@ impl Subscribed { datagram_count += 1; } - tracing::info!( + tracing::trace!( "[PUBLISHER] serve_datagrams: completed ({} datagrams sent)", datagram_count ); @@ -455,9 +562,106 @@ impl SubscribedRecv { state.closed.clone()?; if let Some(mut state) = state.into_mut() { + state.unsubscribed = true; state.closed = Err(ServeError::Cancel); } Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn subscribed_state_counts_opened_streams() { + let mut state = SubscribedState::default(); + assert_eq!(state.stream_count, 0); + + state.record_stream_opened(); + assert_eq!(state.stream_count, 1); + + state.record_stream_opened(); + assert_eq!(state.stream_count, 2); + } + + #[test] + fn recv_unsubscribe_marks_unsubscribed_and_closes() { + let state = State::::default(); + let (_send, recv_state) = state.split(); + let mut recv = SubscribedRecv { state: recv_state }; + + assert!(!recv.state.lock().unsubscribed); + + recv.recv_unsubscribe().unwrap(); + + let locked = recv.state.lock(); + assert!(locked.unsubscribed); + assert!(matches!(locked.closed, Err(ServeError::Cancel))); + } + + #[test] + fn publish_done_code_maps_done_to_track_ended() { + assert_eq!( + Subscribed::publish_done_code(&ServeError::Done), + message::PublishDoneCode::TrackEnded as u64 + ); + } + + #[test] + fn publish_done_code_passes_through_closed_code() { + assert_eq!( + Subscribed::publish_done_code(&ServeError::Closed(0x12)), + 0x12 + ); + } + + #[test] + fn publish_done_code_maps_other_errors_to_internal() { + assert_eq!( + Subscribed::publish_done_code(&ServeError::internal_ctx("test")), + message::PublishDoneCode::InternalError as u64 + ); + } + + #[test] + fn request_error_code_maps_rejection_reasons() { + assert_eq!( + Subscribed::request_error_code(&ServeError::NotFound), + RequestErrorCode::DoesNotExist as u64 + ); + assert_eq!( + Subscribed::request_error_code(&ServeError::Duplicate), + RequestErrorCode::DuplicateSubscription as u64 + ); + assert_eq!( + Subscribed::request_error_code(&ServeError::NotImplemented("fetch".to_string())), + RequestErrorCode::NotSupported as u64 + ); + assert_eq!( + Subscribed::request_error_code(&ServeError::Cancel), + RequestErrorCode::Uninterested as u64 + ); + assert_eq!( + Subscribed::request_error_code(&ServeError::Closed(0x42)), + 0x42 + ); + } + + #[test] + fn expected_serve_shutdown_is_only_cancel_or_done() { + assert!(Subscribed::is_expected_serve_shutdown( + &SessionError::Serve(ServeError::Cancel) + )); + assert!(Subscribed::is_expected_serve_shutdown( + &SessionError::Serve(ServeError::Done) + )); + assert!(!Subscribed::is_expected_serve_shutdown( + &SessionError::Serve(ServeError::NotFound) + )); + assert!(!Subscribed::is_expected_serve_shutdown( + &SessionError::Internal + )); + } +} diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index 2653abee..15d02d32 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -5,23 +5,26 @@ use std::{ collections::{hash_map, HashMap}, io, - sync::{atomic, Arc, Mutex}, + sync::{Arc, Mutex}, time::Duration, }; use tokio::sync::Notify; use crate::{ - coding::{Decode, TrackNamespace}, + coding::{Decode, TrackName, TrackNamespace}, data, - message::{self, FilterType, GroupOrder, Message}, + message::{self, Message}, mlog, serve::{self, ServeError}, }; use crate::watch::Queue; -use super::{Announced, AnnouncedRecv, Reader, Session, SessionError, Subscribe, SubscribeRecv}; +use super::{ + PublishedNamespace, PublishedNamespaceRecv, Reader, RequestId, RequestIdAllocation, Session, + SessionError, Subscribe, SubscribeRecv, +}; // Default timeout for waiting for subscribe aliases to become available via SUBSCRIBE_OK (1 second) const DEFAULT_ALIAS_WAIT_TIME_MS: u64 = 1000; @@ -29,11 +32,11 @@ const DEFAULT_ALIAS_WAIT_TIME_MS: u64 = 1000; // TODO remove Clone. #[derive(Clone)] pub struct Subscriber { - /// The currently active inbound announces, keyed by namespace. - announced: Arc>>, + /// Active inbound PUBLISH_NAMESPACE messages, keyed by namespace. + published_namespaces: Arc>>, - /// Queue of announced namespaces we have received from the Publisher, waiting to be processed. - announced_queue: Queue, + /// Queue of inbound PUBLISH_NAMESPACE events waiting to be consumed by the application. + published_namespace_queue: Queue, /// The currently active outbound subscribes, keyed by request id. subscribes: Arc>>, @@ -48,12 +51,13 @@ pub struct Subscriber { /// will process the queue and send the message on the control stream. outgoing: Queue, - /// When we need a new Request Id for sending a request, we can get it from here. Note: The instance - /// of AtomicU64 is shared with the Subscriber, so the session uses unique request ids for all requests - /// generated. Note: If we initiated the QUIC connection then request id's start at 0 and increment by 2 - /// for each request (even numbers). If we accepted an inbound QUIC connection then request id's start at 1 and - /// increment by 2 for each request (odd numbers). - next_requestid: Arc, + /// Shared with Publisher so all requests within a session use unique IDs. + /// When we need a new Request Id for sending a request, we can get it from here. + /// The manager is shared with the Publisher, so the session uses unique request ids + /// for all requests generated. If we initiated the QUIC connection then request + /// IDs start at 0 and increment by 2 (even numbers). If we accepted an inbound + /// QUIC connection then request IDs start at 1 and increment by 2 (odd numbers). + request_id: RequestId, /// Optional mlog writer for logging transport events mlog: Option>>, @@ -62,16 +66,16 @@ pub struct Subscriber { impl Subscriber { pub(super) fn new( outgoing: Queue, - next_requestid: Arc, mlog: Option>>, + request_id: RequestId, ) -> Self { Self { - announced: Default::default(), - announced_queue: Default::default(), + published_namespaces: Default::default(), + published_namespace_queue: Default::default(), subscribes: Default::default(), subscribe_alias_map: Default::default(), outgoing, - next_requestid, + request_id, mlog, subscribe_alias_notify: Arc::new(Notify::new()), } @@ -95,30 +99,80 @@ impl Subscriber { Ok((session, subscriber)) } - /// Wait for the next announced namespace from the publisher, if any. - pub async fn announced(&mut self) -> Option { - self.announced_queue.pop().await + /// Wait for the next inbound PUBLISH_NAMESPACE from the peer, if any. + pub async fn published_namespace(&mut self) -> Option { + self.published_namespace_queue.pop().await + } + + fn add_mlog_event(&self, make_event: F) + where + F: FnOnce(f64) -> mlog::Event, + { + if let Some(ref mlog) = self.mlog { + if let Ok(mut mlog) = mlog.lock() { + let event = make_event(mlog.elapsed_ms()); + let _ = mlog.add_event(event); + } + } + } + + fn log_request_ok_parsed(&self, request_kind: &str, msg: &message::RequestOk) { + self.add_mlog_event(|time| mlog::events::request_ok_parsed(time, 0, request_kind, msg)); } - /// Get the current next request id to use and increment the value for by 2 for the next request - fn get_next_request_id(&self) -> u64 { - self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed) + fn log_request_error_parsed(&self, request_kind: &str, msg: &message::RequestError) { + self.add_mlog_event(|time| mlog::events::request_error_parsed(time, 0, request_kind, msg)); } - pub fn track_status(&mut self, track_namespace: &TrackNamespace, track_name: &str) { + fn log_request_error_created(&self, request_kind: &str, msg: &message::RequestError) { + self.add_mlog_event(|time| mlog::events::request_error_created(time, 0, request_kind, msg)); + } + + pub(super) fn send_request_ok(&mut self, request_kind: &str, msg: message::RequestOk) { + self.add_mlog_event(|time| mlog::events::request_ok_created(time, 0, request_kind, &msg)); + self.send_message(msg); + } + + pub(super) fn send_request_error(&mut self, request_kind: &str, msg: message::RequestError) { + self.log_request_error_created(request_kind, &msg); + self.send_message(msg); + } + + /// Allocate the next outbound request ID, enforcing the peer-advertised maximum. + /// + /// Returns `Err(TooManyRequests)` if no budget remains and also sends + /// REQUESTS_BLOCKED if not already sent for this limit. + fn get_next_request_id(&mut self) -> Result { + match self.request_id.allocate()? { + RequestIdAllocation::Allocated(id) => Ok(id), + blocked @ RequestIdAllocation::Blocked { .. } => { + if let Some(msg) = blocked.requests_blocked() { + let _ = self.outgoing.push(msg.into()); + } + Err(SessionError::TooManyRequests) + } + } + } + + pub fn track_status( + &mut self, + track_namespace: &TrackNamespace, + track_name: impl Into, + ) { + let id = match self.get_next_request_id() { + Ok(id) => id, + Err(e) => { + tracing::warn!(error = %e, "could not send TRACK_STATUS: request ID limit reached"); + return; + } + }; self.send_message(message::TrackStatus { - id: self.get_next_request_id(), + id, track_namespace: track_namespace.clone(), - track_name: track_name.to_string(), - subscriber_priority: 127, // default to mid value, see: https://github.com/moq-wg/moq-transport/issues/504 - group_order: GroupOrder::Publisher, // defer to publisher send order - forward: true, // default to forwarding objects - filter_type: FilterType::LargestObject, - start_location: None, - end_group_id: None, + track_name: track_name.into(), params: Default::default(), }); - // TODO make async and wait for response? + // TODO(itzmanish): make async and wait for response? } /// Subscribe to a track by creating a new subscribe request to the publisher. Block until subscription is closed. @@ -132,9 +186,14 @@ impl Subscriber { &mut self, track: serve::TrackWriter, ) -> Result { - let request_id = self.get_next_request_id(); + let request_id = self + .get_next_request_id() + .map_err(|e| ServeError::internal_ctx(format!("request ID limit: {}", e)))?; let (send, recv) = Subscribe::new(self.clone(), request_id, track); - self.subscribes.lock().unwrap().insert(request_id, recv); + self.subscribes + .lock() + .map_err(|_| ServeError::internal_ctx("subscribe lock poisoned"))? + .insert(request_id, recv); send.ok().await?; Ok(send) } @@ -144,13 +203,10 @@ impl Subscriber { let msg = msg.into(); // Remove our entry on terminal state. - match &msg { - message::Subscriber::PublishNamespaceCancel(msg) => { - self.drop_publish_namespace(&msg.track_namespace) - } - // TODO SLG - there is no longer a namespace in the error, need to map via request id - message::Subscriber::PublishNamespaceError(_msg) => {} // Not implemented yet - need request id mapping - _ => {} + // Draft-16: PUBLISH_NAMESPACE_CANCEL carries Request ID, so look up + // the namespace by iterating the map. + if let message::Subscriber::PublishNamespaceCancel(msg) = &msg { + let _ = self.drop_publish_namespace(msg.id); } // TODO report dropped messages? @@ -159,52 +215,75 @@ impl Subscriber { /// Receive a message from the publisher via the control stream. pub(super) fn recv_message(&mut self, msg: message::Publisher) -> Result<(), SessionError> { - let res = match &msg { - message::Publisher::PublishNamespace(msg) => self.recv_publish_namespace(msg), - message::Publisher::PublishNamespaceDone(msg) => self.recv_publish_namespace_done(msg), - message::Publisher::Publish(_msg) => Err(SessionError::unimplemented("PUBLISH")), - message::Publisher::PublishDone(msg) => self.recv_publish_done(msg), - message::Publisher::SubscribeOk(msg) => self.recv_subscribe_ok(msg), - message::Publisher::SubscribeError(msg) => self.recv_subscribe_error(msg), - message::Publisher::TrackStatusOk(msg) => self.recv_track_status_ok(msg), - message::Publisher::TrackStatusError(_msg) => { - Err(SessionError::unimplemented("TRACK_STATUS_ERROR")) + match &msg { + message::Publisher::PublishNamespace(msg) => self.recv_publish_namespace(msg)?, + message::Publisher::PublishNamespaceDone(msg) => { + self.recv_publish_namespace_done(msg)?; } - message::Publisher::FetchOk(_msg) => Err(SessionError::unimplemented("FETCH_OK")), - message::Publisher::FetchError(_msg) => Err(SessionError::unimplemented("FETCH_ERROR")), - message::Publisher::SubscribeNamespaceOk(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE_OK")) + // PUBLISH (publisher-initiated subscription) not yet implemented. + // Send REQUEST_ERROR NOT_SUPPORTED so the publisher knows we cannot accept it. + message::Publisher::Publish(msg) => { + self.send_not_supported(msg.id, "publish"); } - message::Publisher::SubscribeNamespaceError(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE_ERROR")) + message::Publisher::PublishDone(msg) => self.recv_publish_done(msg)?, + message::Publisher::SubscribeOk(msg) => self.recv_subscribe_ok(msg)?, + // Draft-16 shared responses (REQUEST_OK / REQUEST_ERROR). + message::Publisher::RequestOk(msg) => self.recv_request_ok(msg)?, + message::Publisher::RequestError(msg) => self.recv_request_error(msg)?, + // FETCH_OK is part of draft-16, but FETCH is not implemented here yet. + message::Publisher::FetchOk(msg) => { + tracing::debug!( + target: "moq_transport::control", + request_id = msg.id, + "received FETCH_OK for unsupported FETCH — ignoring" + ); } - }; - - if let Err(SessionError::Serve(err)) = res { - tracing::debug!("failed to process message: {:?} {}", msg, err); - return Ok(()); } - res + Ok(()) + } + + /// Send REQUEST_ERROR NOT_SUPPORTED for an incoming request we do not implement. + /// + /// Draft-16 §4: limited endpoints SHOULD respond with NOT_SUPPORTED rather + /// than ignoring unsupported request types. + fn send_not_supported(&mut self, request_id: u64, request_kind: &str) { + tracing::debug!( + target: "moq_transport::control", + request_id, + "sending REQUEST_ERROR NOT_SUPPORTED for unimplemented request" + ); + self.send_request_error( + request_kind, + message::RequestError { + id: request_id, + error_code: crate::message::RequestErrorCode::NotSupported as u64, + retry_interval: 0, + reason: crate::coding::ReasonPhrase("not supported".to_string()), + }, + ); } - /// Handle the reception of a PublishNamespace message from the publisher. + /// Handle reception of an inbound PUBLISH_NAMESPACE from the publisher. fn recv_publish_namespace( &mut self, msg: &message::PublishNamespace, ) -> Result<(), SessionError> { - let mut announces = self.announced.lock().unwrap(); + let mut published_namespaces = self + .published_namespaces + .lock() + .map_err(|_| SessionError::Internal)?; - // Check for duplicate namespace announcement - let entry = match announces.entry(msg.track_namespace.clone()) { + // Duplicate PUBLISH_NAMESPACE for the same namespace within a session is invalid. + let entry = match published_namespaces.entry(msg.track_namespace.clone()) { hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), hash_map::Entry::Vacant(entry) => entry, }; - // Create the announced namespace and insert it into our map of active announces, and the announced queue. - let (announced, recv) = Announced::new(self.clone(), msg.id, msg.track_namespace.clone()); - if let Err(announced) = self.announced_queue.push(announced) { - announced.close(ServeError::Cancel)?; + let (published_ns, recv) = + PublishedNamespace::new(self.clone(), msg.id, msg.track_namespace.clone()); + if let Err(published_ns) = self.published_namespace_queue.push(published_ns) { + published_ns.close(ServeError::Cancel)?; return Ok(()); } entry.insert(recv); @@ -212,25 +291,30 @@ impl Subscriber { Ok(()) } - /// Handle the reception of a PublishNamespaceDone message from the publisher. + /// Handle reception of PUBLISH_NAMESPACE_DONE from the publisher. fn recv_publish_namespace_done( &mut self, msg: &message::PublishNamespaceDone, ) -> Result<(), SessionError> { - if let Some(announce) = self.announced.lock().unwrap().remove(&msg.track_namespace) { - announce.recv_unannounce()?; + // Draft-16 §9.22: PUBLISH_NAMESPACE_DONE carries Request ID, not namespace. + if let Some(recv) = self.drop_publish_namespace(msg.id) { + recv.recv_done()?; } - Ok(()) } /// Handle the reception of a SubscribeOk message from the publisher. fn recv_subscribe_ok(&mut self, msg: &message::SubscribeOk) -> Result<(), SessionError> { - if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&msg.id) { + if let Some(subscribe) = self + .subscribes + .lock() + .map_err(|_| SessionError::Internal)? + .get_mut(&msg.id) + { // Map track alias to subscription id for quick lookup when receiving streams/datagrams self.subscribe_alias_map .lock() - .unwrap() + .map_err(|_| SessionError::Internal)? .insert(msg.track_alias, msg.id); // Notify waiting tasks that the alias map has been updated @@ -245,27 +329,15 @@ impl Subscriber { /// Remove a subscribe from our map of active subscribes, and the alias map if present. pub(super) fn remove_subscribe(&mut self, id: u64) -> Option { - if let Some(subscribe) = self.subscribes.lock().unwrap().remove(&id) { - // Remove from alias map if present - if let Some(track_alias) = subscribe.track_alias() { - self.subscribe_alias_map - .lock() - .unwrap() - .remove(&track_alias); - }; - Some(subscribe) - } else { - None - } - } - - /// Handle the reception of a SubscribeError message from the publisher. - fn recv_subscribe_error(&mut self, msg: &message::SubscribeError) -> Result<(), SessionError> { - if let Some(subscribe) = self.remove_subscribe(msg.id) { - subscribe.error(ServeError::Closed(msg.error_code))?; + let subscribe = self.subscribes.lock().ok().and_then(|mut s| s.remove(&id)); + if let Some(ref sub) = subscribe { + if let Some(track_alias) = sub.track_alias() { + if let Ok(mut alias_map) = self.subscribe_alias_map.lock() { + alias_map.remove(&track_alias); + } + } } - - Ok(()) + subscribe } /// Handle the reception of a PublishDone message from the publisher. @@ -277,17 +349,59 @@ impl Subscriber { Ok(()) } - /// Handle the reception of a TrackStatusOk message from the publisher. - fn recv_track_status_ok(&mut self, _msg: &message::TrackStatusOk) -> Result<(), SessionError> { - // TODO: Expose this somehow? - // TODO: Also add a way to send a Track Status Request in the first place + /// Handle REQUEST_OK from the publisher. + /// + /// REQUEST_OK is the shared positive response for REQUEST_UPDATE, TRACK_STATUS, + /// SUBSCRIBE_NAMESPACE, and PUBLISH_NAMESPACE. SUBSCRIBE uses its own dedicated + /// SUBSCRIBE_OK message (§9.10) and does not come through this handler. + /// Full routing for the other request types is wired up (TODO itzmanish). + fn recv_request_ok(&mut self, msg: &message::RequestOk) -> Result<(), SessionError> { + self.log_request_ok_parsed("unknown", msg); + tracing::debug!( + target: "moq_transport::control", + request_id = msg.id, + "received REQUEST_OK" + ); + // TODO(itzmanish): route to the correct pending request type by ID. + Ok(()) + } + /// Handle REQUEST_ERROR from the publisher. + /// + /// Routes to the matching active subscribe (via request ID) if one + /// exists, otherwise logs and ignores. Full per-flow routing is + /// wired up (TODO itzmanish). + fn recv_request_error(&mut self, msg: &message::RequestError) -> Result<(), SessionError> { + // Route to a matching subscribe if present. + if let Some(subscribe) = self.remove_subscribe(msg.id) { + self.log_request_error_parsed("subscribe", msg); + subscribe.error(ServeError::Closed(msg.error_code))?; + } else { + self.log_request_error_parsed("unknown", msg); + } + + tracing::debug!( + target: "moq_transport::control", + request_id = msg.id, + error_code = msg.error_code, + retry_interval = msg.retry_interval, + reason = %msg.reason.0, + "received REQUEST_ERROR" + ); Ok(()) } - /// Remove an announced namespace from our map of active announces. - fn drop_publish_namespace(&mut self, namespace: &TrackNamespace) { - self.announced.lock().unwrap().remove(namespace); + fn drop_publish_namespace(&mut self, id: u64) -> Option { + if let Ok(mut ns) = self.published_namespaces.lock() { + let key = ns + .iter() + .find(|(_k, v)| v.request_id == id) + .map(|(k, _)| k.clone()); + if let Some(key) = key { + return ns.remove(&key); + } + } + None } /// Get a subscribe id by track alias, waiting up to the specified timeout if not present. @@ -296,18 +410,23 @@ impl Subscriber { &self, track_alias: u64, timeout_ms: Option, - ) -> Option { + ) -> Result, SessionError> { // If no timeout specified, don't wait let timeout_ms = match timeout_ms { Some(ms) => ms, None => { // Just check once - return self - .subscribe_alias_map - .lock() - .unwrap() - .get(&track_alias) - .cloned(); + return match self.subscribe_alias_map.lock() { + Ok(aliases) => Ok(aliases.get(&track_alias).cloned()), + Err(_) => { + tracing::error!( + target: "moq_transport::control", + track_alias, + "subscribe alias map lock poisoned" + ); + Err(SessionError::Internal) + } + }; } }; @@ -319,14 +438,20 @@ impl Subscriber { let notified = self.subscribe_alias_notify.notified(); // Check Map for alias - if let Some(id) = self - .subscribe_alias_map - .lock() - .unwrap() - .get(&track_alias) - .cloned() - { - return id; + let id = match self.subscribe_alias_map.lock() { + Ok(aliases) => aliases.get(&track_alias).cloned(), + Err(_) => { + tracing::error!( + target: "moq_transport::control", + track_alias, + "subscribe alias map lock poisoned" + ); + return Err(SessionError::Internal); + } + }; + + if let Some(id) = id { + return Ok(Some(id)); } // Alias not present yet, wait for notification @@ -334,7 +459,7 @@ impl Subscriber { } }) .await - .ok() + .unwrap_or(Ok(None)) } /// Handle reception of a new stream from the QUIC session. @@ -347,7 +472,7 @@ impl Subscriber { // Decode the stream header let stream_header: data::StreamHeader = reader.decode().await?; - tracing::debug!( + tracing::trace!( "[SUBSCRIBER] recv_stream: decoded stream header type={:?}", stream_header.header_type ); @@ -385,7 +510,7 @@ impl Subscriber { ); // The writer is closed, so we should terminate. // TODO it would be nice to do this immediately when the Writer is closed. - if let Some(subscribe_id) = self.get_subscribe_id_by_alias(track_alias, None).await { + if let Some(subscribe_id) = self.get_subscribe_id_by_alias(track_alias, None).await? { if let Some(subscribe) = self.remove_subscribe(subscribe_id) { subscribe.error(err.clone())?; } @@ -408,56 +533,27 @@ impl Subscriber { track_alias ); - // This is super silly, but I couldn't figure out a way to avoid the mutex guard across awaits. - enum Writer { - //Fetch(serve::FetchWriter), - Subgroup(serve::SubgroupWriter), - } - - let writer = { - // Look up the subscribe id for this track alias - if let Some(subscribe_id) = self - .get_subscribe_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) - .await - { - // Look up the subscribe by id - let mut subscribes = self.subscribes.lock().unwrap(); - let subscribe = subscribes.get_mut(&subscribe_id).ok_or_else(|| { - ServeError::not_found_ctx(format!( - "subscribe_id={} not found for track_alias={}", - subscribe_id, track_alias - )) - })?; - - // Create the appropriate writer based on the stream header type - if stream_header.header_type.is_subgroup() { - tracing::trace!("[SUBSCRIBER] recv_stream_inner: creating subgroup writer"); - Writer::Subgroup(subscribe.subgroup(stream_header.subgroup_header.unwrap())?) - } else { - return Err(SessionError::Serve(ServeError::internal_ctx(format!( - "unsupported stream header type={}", - stream_header.header_type - )))); - } - } else { - return Err(SessionError::Serve(ServeError::not_found_ctx(format!( - "subscription track_alias={} not found", - track_alias - )))); - } + let Some(subscribe_id) = self + .get_subscribe_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) + .await? + else { + return Err(SessionError::Serve(ServeError::not_found_ctx(format!( + "subscription track_alias={} not found", + track_alias + )))); }; - // Handle the stream based on the writer type - match writer { - //Writer::Fetch(fetch) => Self::recv_fetch(fetch, reader).await?, - Writer::Subgroup(subgroup_writer) => { - tracing::trace!("[SUBSCRIBER] recv_stream_inner: receiving subgroup data"); - Self::recv_subgroup(stream_header.header_type, subgroup_writer, reader, mlog) - .await? - } - }; + tracing::trace!("[SUBSCRIBER] recv_stream_inner: receiving subgroup data"); + self.recv_subgroup( + stream_header.header_type, + stream_header.subgroup_header.unwrap(), + subscribe_id, + reader, + mlog, + ) + .await?; - tracing::debug!( + tracing::trace!( "[SUBSCRIBER] recv_stream_inner: completed processing stream for track_alias={}", track_alias ); @@ -466,20 +562,23 @@ impl Subscriber { /// If new stream is a Subgroup stream, handle reception of subgroup objects and payloads. async fn recv_subgroup( + &mut self, stream_header_type: data::StreamHeaderType, - mut subgroup_writer: serve::SubgroupWriter, + mut subgroup_header: data::SubgroupHeader, + subscribe_id: u64, mut reader: Reader, mlog: Option>>, ) -> Result<(), SessionError> { - tracing::debug!( - "[SUBSCRIBER] recv_subgroup: starting - group_id={}, subgroup_id={}, priority={}", - subgroup_writer.info.group_id, - subgroup_writer.info.subgroup_id, - subgroup_writer.info.priority + tracing::trace!( + "[SUBSCRIBER] recv_subgroup: starting - group_id={}, subgroup_id={:?}, priority={}", + subgroup_header.group_id, + subgroup_header.subgroup_id, + subgroup_header.publisher_priority ); let mut object_count = 0; - let mut current_object_id = 0u64; + let mut previous_object_id: Option = None; + let mut subgroup_writer: Option = None; while !reader.done().await? { tracing::trace!( "[SUBSCRIBER] recv_subgroup: reading object #{} (has_ext_headers={})", @@ -493,7 +592,7 @@ impl Subscriber { match stream_header_type.has_extension_headers() { true => { let object = reader.decode::().await?; - tracing::debug!( + tracing::trace!( "[SUBSCRIBER] recv_subgroup: object #{} with extension headers - object_id_delta={}, payload_length={}, status={:?}, extension_headers={:?}", object_count + 1, object.object_id_delta, @@ -506,12 +605,12 @@ impl Subscriber { // Check for Immutable Extensions (type 0xB = 11) if object.extension_headers.has(0xB) { - tracing::info!( + tracing::trace!( "[SUBSCRIBER] recv_subgroup: object #{} contains IMMUTABLE EXTENSIONS (type 0xB) - will be forwarded", object_count + 1 ); if let Some(immutable_ext) = object.extension_headers.get(0xB) { - tracing::debug!( + tracing::trace!( "[SUBSCRIBER] recv_subgroup: immutable extension details: {:?}", immutable_ext ); @@ -520,12 +619,12 @@ impl Subscriber { // Check for Prior Group ID Gap (type 0x3C = 60) if object.extension_headers.has(0x3C) { - tracing::info!( + tracing::trace!( "[SUBSCRIBER] recv_subgroup: object #{} contains PRIOR GROUP ID GAP (type 0x3C)", object_count + 1 ); if let Some(gap_ext) = object.extension_headers.get(0x3C) { - tracing::debug!( + tracing::trace!( "[SUBSCRIBER] recv_subgroup: prior group id gap details: {:?}", gap_ext ); @@ -542,7 +641,7 @@ impl Subscriber { } false => { let object = reader.decode::().await?; - tracing::debug!( + tracing::trace!( "[SUBSCRIBER] recv_subgroup: object #{} - object_id_delta={}, payload_length={}, status={:?}", object_count + 1, object.object_id_delta, @@ -558,14 +657,48 @@ impl Subscriber { } }; - // Calculate absolute object_id from delta - current_object_id += object_id_delta; + let current_object_id = match previous_object_id { + Some(previous) => previous + .checked_add(object_id_delta) + .and_then(|value| value.checked_add(1)) + .ok_or_else(|| { + SessionError::ProtocolViolation("subgroup object id overflow".to_string()) + })?, + None => object_id_delta, + }; + previous_object_id = Some(current_object_id); // Extract extension headers if present let extension_headers = decoded_object .as_ref() .map(|obj| obj.extension_headers.clone()); + if status.is_some_and(|status| status != data::ObjectStatus::NormalObject) + && extension_headers + .as_ref() + .is_some_and(|headers| !headers.is_empty()) + { + return Err(SessionError::ProtocolViolation( + "non-normal object status with extension headers".to_string(), + )); + } + + if subgroup_writer.is_none() { + if stream_header_type.uses_first_object_id_as_subgroup_id() { + subgroup_header.subgroup_id = Some(current_object_id); + } + + let mut subscribes = self.subscribes.lock().map_err(|_| SessionError::Internal)?; + let subscribe = subscribes.get_mut(&subscribe_id).ok_or_else(|| { + ServeError::not_found_ctx(format!( + "subscribe_id={} not found for track_alias={}", + subscribe_id, subgroup_header.track_alias + )) + })?; + + subgroup_writer = Some(subscribe.subgroup(subgroup_header.clone())?); + } + // Log subgroup object parsed/received if let Some(ref mlog) = mlog { if let Ok(mut mlog_guard) = mlog.lock() { @@ -575,8 +708,8 @@ impl Subscriber { mlog::subgroup_object_ext_parsed( time, stream_id, - subgroup_writer.info.group_id, - subgroup_writer.info.subgroup_id, + subgroup_header.group_id, + subgroup_header.subgroup_id.unwrap_or(0), current_object_id, &obj_ext, ) @@ -590,8 +723,8 @@ impl Subscriber { mlog::subgroup_object_parsed( time, stream_id, - subgroup_writer.info.group_id, - subgroup_writer.info.subgroup_id, + subgroup_header.group_id, + subgroup_header.subgroup_id.unwrap_or(0), current_object_id, &temp_obj, ) @@ -603,6 +736,7 @@ impl Subscriber { // Pass extension headers through to the serve layer // TODO SLG - object_id_delta and object status are still being ignored + let subgroup_writer = subgroup_writer.as_mut().ok_or(SessionError::Internal)?; let mut object_writer = subgroup_writer.create(remaining_bytes, extension_headers)?; tracing::trace!( "[SUBSCRIBER] recv_subgroup: reading payload for object #{} ({} bytes)", @@ -643,10 +777,10 @@ impl Subscriber { object_count += 1; } - tracing::info!( + tracing::trace!( "[SUBSCRIBER] recv_subgroup: completed subgroup (group_id={}, subgroup_id={}, {} objects received)", - subgroup_writer.info.group_id, - subgroup_writer.info.subgroup_id, + subgroup_header.group_id, + subgroup_header.subgroup_id.unwrap_or(0), object_count ); @@ -669,7 +803,7 @@ impl Subscriber { // Check for extension headers in the datagram if let Some(ref ext_headers) = datagram.extension_headers { - tracing::debug!( + tracing::trace!( "[SUBSCRIBER] recv_datagram: datagram contains extension headers: {:?}", ext_headers ); @@ -678,11 +812,11 @@ impl Subscriber { // Check for Immutable Extensions (type 0xB = 11) if ext_headers.has(0xB) { - tracing::info!( + tracing::trace!( "[SUBSCRIBER] recv_datagram: datagram contains IMMUTABLE EXTENSIONS (type 0xB)" ); if let Some(immutable_ext) = ext_headers.get(0xB) { - tracing::debug!( + tracing::trace!( "[SUBSCRIBER] recv_datagram: immutable extension details: {:?}", immutable_ext ); @@ -691,11 +825,11 @@ impl Subscriber { // Check for Prior Group ID Gap (type 0x3C = 60) if ext_headers.has(0x3C) { - tracing::info!( + tracing::trace!( "[SUBSCRIBER] recv_datagram: datagram contains PRIOR GROUP ID GAP (type 0x3C)" ); if let Some(gap_ext) = ext_headers.get(0x3C) { - tracing::debug!( + tracing::trace!( "[SUBSCRIBER] recv_datagram: prior group id gap details: {:?}", gap_ext ); @@ -706,10 +840,16 @@ impl Subscriber { // Look up the subscribe id for this track alias if let Some(subscribe_id) = self .get_subscribe_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) - .await + .await? { // Look up the subscribe by id - if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&subscribe_id) { + if let Some(subscribe) = self + .subscribes + .lock() + .ok() + .as_mut() + .and_then(|s| s.get_mut(&subscribe_id)) + { tracing::trace!( "[SUBSCRIBER] recv_datagram: track_alias={}, group_id={}, object_id={}, publisher_priority={}, status={}, payload_length={}", datagram.track_alias, @@ -737,16 +877,14 @@ impl Subscriber { #[cfg(test)] mod tests { - use std::{sync::atomic, task::Poll}; + use std::task::Poll; use super::*; - use crate::{ - message::{self, GroupOrder}, - serve::Track, - }; + use crate::{message, serve::Track}; fn subscriber() -> Subscriber { - Subscriber::new(Queue::default(), Arc::new(atomic::AtomicU64::new(0)), None) + let request_id = RequestId::new(0, 100, 100, 0); + Subscriber::new(Queue::default(), None, request_id) } #[tokio::test] @@ -754,7 +892,7 @@ mod tests { let mut subscriber = subscriber(); let observer = subscriber.clone(); let (writer, _reader) = - Track::new(TrackNamespace::from_utf8_path("test"), "0.mp4".into()).produce(); + Track::new(TrackNamespace::from_utf8_path("test"), "0.mp4").produce(); { let subscribe = subscriber.subscribe_open(writer); @@ -773,7 +911,7 @@ mod tests { let mut subscriber = subscriber(); let observer = subscriber.clone(); let (writer, _reader) = - Track::new(TrackNamespace::from_utf8_path("test"), "0.mp4".into()).produce(); + Track::new(TrackNamespace::from_utf8_path("test"), "0.mp4").produce(); let subscribe = subscriber.subscribe_open(writer); futures::pin_mut!(subscribe); @@ -786,11 +924,8 @@ mod tests { .recv_subscribe_ok(&message::SubscribeOk { id: 0, track_alias: 10, - expires: 0, - group_order: GroupOrder::Publisher, - content_exists: false, - largest_location: None, params: Default::default(), + track_extensions: Default::default(), }) .unwrap(); diff --git a/moq-transport/src/session/track_status_requested.rs b/moq-transport/src/session/track_status_requested.rs index 8a5a5e5c..447c6205 100644 --- a/moq-transport/src/session/track_status_requested.rs +++ b/moq-transport/src/session/track_status_requested.rs @@ -2,8 +2,9 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use super::{Publisher, SessionError}; -use crate::coding::ReasonPhrase; +use crate::coding::{KeyValuePairs, ReasonPhrase}; use crate::message; +use crate::message::RequestOk; use crate::serve; pub struct TrackStatusRequested { @@ -19,31 +20,45 @@ impl TrackStatusRequested { } } + /// Reject the TRACK_STATUS request with REQUEST_ERROR (draft-16 §9.8). pub fn respond_error( &mut self, error_code: u64, error_message: &str, ) -> Result<(), SessionError> { - let status_error = message::TrackStatusError { - id: self.request_msg.id, - error_code, - reason_phrase: ReasonPhrase(error_message.to_string()), - }; - self.publisher.send_message(status_error); + self.publisher.send_request_error( + "track_status", + message::RequestError { + id: self.request_msg.id, + error_code, + retry_interval: 0, + reason: ReasonPhrase(error_message.to_string()), + }, + ); Ok(()) } + /// Accept the TRACK_STATUS request with REQUEST_OK (draft-16 §9.7). + /// + /// The response includes LARGEST_OBJECT when objects have been published. + /// No Track Alias is included — draft-16 §9.19 does not use one for + /// TRACK_STATUS responses. pub fn respond_ok(mut self, track: &serve::TrackReader) -> Result<(), SessionError> { - // Send TrackStatusOk - self.publisher.send_message(message::TrackStatusOk { - id: self.request_msg.id, - track_alias: self.request_msg.id, // TODO SLG does a track alias make sense in track_status response? Using track_status request id for now - expires: 0, // TODO SLG - group_order: message::GroupOrder::Ascending, // TODO: resolve correct value from publisher / subscriber prefs - content_exists: track.largest_location().is_some(), - largest_location: track.largest_location(), - params: Default::default(), - }); + let mut params = KeyValuePairs::default(); + + if let Some(largest) = track.largest_location() { + params + .set_largest_object(largest) + .map_err(|_| SessionError::Internal)?; + } + + self.publisher.send_request_ok( + "track_status", + RequestOk { + id: self.request_msg.id, + params, + }, + ); Ok(()) } diff --git a/moq-transport/src/session/writer.rs b/moq-transport/src/session/writer.rs index 47d9299f..2d3b1445 100644 --- a/moq-transport/src/session/writer.rs +++ b/moq-transport/src/session/writer.rs @@ -31,7 +31,7 @@ impl Writer { msg.encode(&mut self.buffer)?; let encoded_len = self.buffer.len(); - tracing::debug!( + tracing::trace!( "[WRITER] encode: encoded {} ({} bytes), sending to stream", std::any::type_name::(), encoded_len @@ -50,7 +50,7 @@ impl Writer { ); } - tracing::debug!( + tracing::trace!( "[WRITER] encode: finished sending {} ({} bytes total)", std::any::type_name::(), total_written @@ -85,7 +85,7 @@ impl Writer { ); } - tracing::debug!("[WRITER] write: finished writing {} bytes", total_written); + tracing::trace!("[WRITER] write: finished writing {} bytes", total_written); Ok(()) } diff --git a/moq-transport/src/setup/client.rs b/moq-transport/src/setup/client.rs index fcbe8870..7910ff9b 100644 --- a/moq-transport/src/setup/client.rs +++ b/moq-transport/src/setup/client.rs @@ -2,64 +2,57 @@ // SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 -use super::Versions; +//! CLIENT_SETUP message (draft-ietf-moq-transport-16 §9.3). +//! +//! From draft-16, version negotiation is performed via ALPN only. The +//! CLIENT_SETUP and SERVER_SETUP payloads carry setup parameters only; +//! they no longer contain a version list. + use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; +use bytes::Buf as _; -/// Sent by the client to setup the session. -/// This CLIENT_SETUP message is used by moq-transport draft versions 11 and later. -/// Id = 0x20 vs 0x40 for versions <= 10. +/// Sent by the client to set up the session. +/// +/// Message Type = 0x20 (unchanged from draft-11+). +/// The payload contains only setup parameters; version is agreed via ALPN. #[derive(Debug)] pub struct Client { - /// The list of supported versions in preferred order. - pub versions: Versions, - - /// Setup Parameters, ie: PATH, MAX_REQUEST_ID, - /// MAX_AUTH_TOKEN_CACHE_SIZE, AUTHORIZATION_TOKEN, etc. + /// Setup parameters (PATH, AUTHORITY, MAX_REQUEST_ID, + /// MAX_AUTH_TOKEN_CACHE_SIZE, AUTHORIZATION_TOKEN, MOQT_IMPLEMENTATION, …). pub params: KeyValuePairs, } impl Decode for Client { - /// Decode a client setup message. fn decode(r: &mut R) -> Result { let typ = u64::decode(r)?; if typ != 0x20 { - // CLIENT_SETUP message ID for draft versions 11 and later return Err(DecodeError::InvalidMessage(typ)); } - let _len = u16::decode(r)?; - // TODO: Check the length of the message. + let len = u16::decode(r)? as usize; + ::decode_remaining(r, len)?; + let mut payload = r.copy_to_bytes(len); - let versions = Versions::decode(r)?; - let params = KeyValuePairs::decode(r)?; + let params = KeyValuePairs::decode(&mut payload)?; + if payload.has_remaining() { + return Err(DecodeError::InvalidMessage(typ)); + } - Ok(Self { versions, params }) + Ok(Self { params }) } } impl Encode for Client { - /// Encode a server setup message. fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - (0x20_u64).encode(w)?; // CLIENT_SETUP message ID for draft versions 11 and later + (0x20_u64).encode(w)?; - // Find out the length of the message - // by encoding it into a buffer and then encoding the length. - // This is a bit wasteful, but it's the only way to know the length. - // TODO SLG - perhaps we can store the position of the Length field in the BufMut and - // write the length later, to avoid the copy of the message bytes? let mut buf = Vec::new(); + self.params.encode(&mut buf)?; - self.versions.encode(&mut buf).unwrap(); - self.params.encode(&mut buf).unwrap(); - - // Make sure buf.len() <= u16::MAX if buf.len() > u16::MAX as usize { return Err(EncodeError::MsgBoundsExceeded); } (buf.len() as u16).encode(w)?; - - // At least don't encode the message twice. - // Instead, write the buffer directly to the writer. Self::encode_remaining(w, buf.len())?; w.put_slice(&buf); @@ -74,32 +67,73 @@ mod tests { use bytes::BytesMut; #[test] - fn encode_decode() { + fn encode_decode_params_only() { let mut buf = BytesMut::new(); let mut params = KeyValuePairs::default(); - params.set_bytesvalue(ParameterType::Path.into(), "testpath".as_bytes().to_vec()); + // PATH is odd key (0x01) → bytes value; delta from 0 = 1 + params.set_bytesvalue(ParameterType::Path.into(), b"testpath".to_vec()); + // MAX_REQUEST_ID is even key (0x02) → int value; delta from 1 = 1 + params.set_intvalue(ParameterType::MaxRequestId.into(), 100); + + let client = Client { params }; + client.encode(&mut buf).unwrap(); + + // Wire layout: + // 0x20 type + // len (2B) 16-bit length of payload + // payload: + // 0x02 count = 2 params (varint) + // 0x01 delta=1 → abs_type=1 (PATH, odd→bytes) + // 0x08 length=8 + // "testpath" + // 0x01 delta=1 → abs_type=2 (MAX_REQUEST_ID, even→int) + // 0x40 0x64 value=100 (2-byte varint, 100 ≥ 64) + let bytes = buf.to_vec(); + assert_eq!(bytes[0], 0x20); // type + let payload_len = u16::from_be_bytes([bytes[1], bytes[2]]) as usize; + assert_eq!(bytes.len(), 3 + payload_len); + let decoded = Client::decode(&mut buf).unwrap(); + assert_eq!(decoded.params, client.params); + } + + #[test] + fn encode_decode_no_params() { + let mut buf = BytesMut::new(); let client = Client { - versions: [Version::DRAFT_13].into(), - params, + params: KeyValuePairs::default(), }; client.encode(&mut buf).unwrap(); - #[rustfmt::skip] - assert_eq!( - buf.to_vec(), - vec![ - 0x20, // Type - 0x00, 0x14, // Length - 0x01, // 1 Version - 0xC0, 0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x0D, // Version DRAFT_13 (0xff00000D) - 0x01, // 1 Param - 0x01, 0x08, 0x74, 0x65, 0x73, 0x74, 0x70, 0x61, 0x74, 0x68, // Key=1 (Path), Value="testpath" - ] - ); + // Wire: 0x20, 0x00 0x01 (length=1), 0x00 (count=0) + assert_eq!(buf[0], 0x20); + let payload_len = u16::from_be_bytes([buf[1], buf[2]]) as usize; + assert_eq!(payload_len, 1); // just the count varint (0x00) + let decoded = Client::decode(&mut buf).unwrap(); - assert_eq!(decoded.versions, client.versions); - assert_eq!(decoded.params, client.params); + assert!(decoded.params.0.is_empty()); + } + + #[test] + fn decode_rejects_overlong_payload() { + let mut buf = BytesMut::new(); + let client = Client { + params: KeyValuePairs::default(), + }; + client.encode(&mut buf).unwrap(); + buf[2] += 1; + buf.extend_from_slice(&[0x00]); + + assert!(matches!( + Client::decode(&mut buf).unwrap_err(), + DecodeError::InvalidMessage(0x20) + )); + } + + /// Confirm DRAFT_16 version constant is defined and has the right value. + #[test] + fn draft_16_version_constant() { + assert_eq!(Version::DRAFT_16.0, 0xff000010); } } diff --git a/moq-transport/src/setup/mod.rs b/moq-transport/src/setup/mod.rs index 3a16ead5..6f673df7 100644 --- a/moq-transport/src/setup/mod.rs +++ b/moq-transport/src/setup/mod.rs @@ -18,4 +18,6 @@ pub use param_types::*; pub use server::*; pub use version::*; -pub const ALPN: &[u8] = b"moq-00"; +/// ALPN identifier for draft-ietf-moq-transport-16. +/// Used for native QUIC connections; WebTransport uses the WT-Available-Protocols header instead. +pub const ALPN: &[u8] = b"moqt-16"; diff --git a/moq-transport/src/setup/server.rs b/moq-transport/src/setup/server.rs index 5a4b952f..12c3b860 100644 --- a/moq-transport/src/setup/server.rs +++ b/moq-transport/src/setup/server.rs @@ -2,63 +2,57 @@ // SPDX-FileCopyrightText: 2023-2024 Luke Curley and contributors // SPDX-License-Identifier: MIT OR Apache-2.0 -use super::Version; +//! SERVER_SETUP message (draft-ietf-moq-transport-16 §9.3). +//! +//! From draft-16, version negotiation is performed via ALPN only. The +//! SERVER_SETUP payload carries setup parameters only; it no longer echoes +//! the selected version. + use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; +use bytes::Buf as _; -/// Sent by the server in response to a client setup. -/// This SERVER_SETUP message is used by moq-transport draft versions 11 and later. -/// Id = 0x21 vs 0x41 for versions <= 10. +/// Sent by the server in response to CLIENT_SETUP. +/// +/// Message Type = 0x21 (unchanged from draft-11+). +/// The payload contains only setup parameters; version is agreed via ALPN. #[derive(Debug)] pub struct Server { - /// The list of supported versions in preferred order. - pub version: Version, - - /// Setup Parameters, ie: MAX_REQUEST_ID, MAX_AUTH_TOKEN_CACHE_SIZE, - /// AUTHORIZATION_TOKEN, etc. + /// Setup parameters (MAX_REQUEST_ID, MAX_AUTH_TOKEN_CACHE_SIZE, + /// AUTHORIZATION_TOKEN, MOQT_IMPLEMENTATION, …). pub params: KeyValuePairs, } impl Decode for Server { - /// Decode the server setup. fn decode(r: &mut R) -> Result { let typ = u64::decode(r)?; if typ != 0x21 { - // SERVER_SETUP message ID for draft versions 11 and later return Err(DecodeError::InvalidMessage(typ)); } - let _len = u16::decode(r)?; - // TODO: Check the length of the message. + let len = u16::decode(r)? as usize; + ::decode_remaining(r, len)?; + let mut payload = r.copy_to_bytes(len); - let version = Version::decode(r)?; - let params = KeyValuePairs::decode(r)?; + let params = KeyValuePairs::decode(&mut payload)?; + if payload.has_remaining() { + return Err(DecodeError::InvalidMessage(typ)); + } - Ok(Self { version, params }) + Ok(Self { params }) } } impl Encode for Server { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - (0x21_u64).encode(w)?; // SERVER_SETUP message ID for draft versions 11 and later + (0x21_u64).encode(w)?; - // Find out the length of the message - // by encoding it into a buffer and then encoding the length. - // This is a bit wasteful, but it's the only way to know the length. - // TODO SLG - perhaps we can store the position of the Length field in the BufMut and - // write the length later, to avoid the copy of the message bytes? let mut buf = Vec::new(); + self.params.encode(&mut buf)?; - self.version.encode(&mut buf).unwrap(); - self.params.encode(&mut buf).unwrap(); - - // Make sure buf.len() <= u16::MAX if buf.len() > u16::MAX as usize { return Err(EncodeError::MsgBoundsExceeded); } (buf.len() as u16).encode(w)?; - - // At least don't encode the message twice. - // Instead, write the buffer directly to the writer. Self::encode_remaining(w, buf.len())?; w.put_slice(&buf); @@ -73,33 +67,61 @@ mod tests { use bytes::BytesMut; #[test] - fn encode_decode() { + fn encode_decode_params_only() { let mut buf = BytesMut::new(); let mut params = KeyValuePairs::default(); + // MAX_REQUEST_ID is even key (0x02) → int value; delta from 0 = 2 params.set_intvalue(ParameterType::MaxRequestId.into(), 1000); + let server = Server { params }; + server.encode(&mut buf).unwrap(); + + // Wire layout: + // 0x21 type + // len (2B) 16-bit length of payload + // payload: + // 0x01 count = 1 param + // 0x02 delta=2 → abs_type=2 (MAX_REQUEST_ID, even→int) + // 0x43 0xe8 value=1000 (2-byte varint) + let bytes = buf.to_vec(); + assert_eq!(bytes[0], 0x21); + let payload_len = u16::from_be_bytes([bytes[1], bytes[2]]) as usize; + assert_eq!(bytes.len(), 3 + payload_len); + + let decoded = Server::decode(&mut buf).unwrap(); + assert_eq!(decoded.params, server.params); + } + + #[test] + fn encode_decode_no_params() { + let mut buf = BytesMut::new(); let server = Server { - version: Version::DRAFT_14, - params, + params: KeyValuePairs::default(), }; - server.encode(&mut buf).unwrap(); - #[rustfmt::skip] - assert_eq!( - buf.to_vec(), - vec![ - 0x21, // Type - 0x00, 0x0c, // Length - 0xC0, 0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x0E, // Version DRAFT_14 (0xff00000E) - 0x01, // 1 Param - 0x02, 0x43, 0xe8, // Key=2 (MaxRequestId), Value=1000 - ] - ); + assert_eq!(buf[0], 0x21); + let payload_len = u16::from_be_bytes([buf[1], buf[2]]) as usize; + assert_eq!(payload_len, 1); // just count=0 let decoded = Server::decode(&mut buf).unwrap(); - assert_eq!(decoded.version, server.version); - assert_eq!(decoded.params, server.params); + assert!(decoded.params.0.is_empty()); + } + + #[test] + fn decode_rejects_overlong_payload() { + let mut buf = BytesMut::new(); + let server = Server { + params: KeyValuePairs::default(), + }; + server.encode(&mut buf).unwrap(); + buf[2] += 1; + buf.extend_from_slice(&[0x00]); + + assert!(matches!( + Server::decode(&mut buf).unwrap_err(), + DecodeError::InvalidMessage(0x21) + )); } } diff --git a/moq-transport/src/setup/version.rs b/moq-transport/src/setup/version.rs index 20704d97..b62201cf 100644 --- a/moq-transport/src/setup/version.rs +++ b/moq-transport/src/setup/version.rs @@ -27,6 +27,9 @@ impl Version { /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-14.html pub const DRAFT_14: Version = Version(0xff00000e); + + /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-16.html + pub const DRAFT_16: Version = Version(0xff000010); } impl From for Version {