diff --git a/Cargo.lock b/Cargo.lock index 490e2e56..cc6ddb79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1346,6 +1346,7 @@ dependencies = [ "serde_json", "thiserror 2.0.17", "tokio", + "tokio-util", "tower-http", "tracing", "tracing-subscriber", @@ -1463,9 +1464,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.1.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +checksum = "521739c6d2bac4aa25192232afe6841231376b2b26d4d9fae5ecf8ca5772e441" [[package]] name = "num-integer" @@ -2500,30 +2501,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.44" +version = "0.3.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde", + "serde_core", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "time-macros" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" dependencies = [ "num-conv", "time-core", diff --git a/dev/sub b/dev/sub index 8e6c114a..1b04863b 100755 --- a/dev/sub +++ b/dev/sub @@ -20,6 +20,6 @@ ADDR="${ADDR:-$HOST:$PORT}" NAME="${NAME:-bbb}" # Combine the host and name into a URL. -URL="${URL:-"https://$ADDR/$NAME"}" +URL="${URL:-"https://$ADDR"}" cargo run --bin moq-sub -- --name "$NAME" "$URL" "$@" | ffplay - diff --git a/moq-relay-ietf/Cargo.toml b/moq-relay-ietf/Cargo.toml index 780c176c..4fe345f4 100644 --- a/moq-relay-ietf/Cargo.toml +++ b/moq-relay-ietf/Cargo.toml @@ -33,7 +33,7 @@ url = "2" # Async stuff tokio = { version = "1", features = ["full"] } -# tokio-util = "0.7" +tokio-util = "0.7" futures = "0.3" async-trait = "0.1" @@ -70,9 +70,6 @@ thiserror = "2.0.17" metrics = "0.24" metrics-exporter-prometheus = { version = "0.16", optional = true } -# misc -#once_cell = "1.21.3" - [features] default = [] metrics-prometheus = ["dep:metrics-exporter-prometheus"] diff --git a/moq-relay-ietf/src/lib.rs b/moq-relay-ietf/src/lib.rs index 098fa047..c469a730 100644 --- a/moq-relay-ietf/src/lib.rs +++ b/moq-relay-ietf/src/lib.rs @@ -48,6 +48,6 @@ pub use coordinator::*; pub use local::*; pub use producer::*; pub use relay::*; -pub use remote::*; +pub use remote::RemoteManager; pub use session::*; pub use web::*; diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 15890a3c..9387b6a1 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -9,7 +9,7 @@ use moq_transport::{ use crate::{ metrics::{GaugeGuard, TimingGuard}, - Locals, RemotesConsumer, + Locals, RemoteManager, }; /// Producer of tracks to a remote Subscriber @@ -17,7 +17,7 @@ use crate::{ pub struct Producer { publisher: Publisher, locals: Locals, - remotes: Option, + remotes: RemoteManager, /// The resolved scope identity for this session, if any. /// Produced by `Coordinator::resolve_scope()` from the connection path. /// Passed to locals/remotes to isolate namespace lookups. @@ -28,7 +28,7 @@ impl Producer { pub fn new( publisher: Publisher, locals: Locals, - remotes: Option, + remotes: RemoteManager, scope: Option, ) -> Self { Self { @@ -46,7 +46,6 @@ impl Producer { /// Run the producer to serve subscribe requests. pub async fn run(self) -> Result<(), SessionError> { - //let mut tasks = FuturesUnordered::new(); let mut tasks: FuturesUnordered> = FuturesUnordered::new(); @@ -122,40 +121,40 @@ impl Producer { } } - if let Some(remotes) = self.remotes { - // Check remote tracks second, and serve from remote if possible - match remotes.route(self.scope.as_deref(), &namespace).await { - Ok(remote) => { - if let Some(remote) = remote { - if let Some(track) = remote.subscribe(&namespace, &track_name)? { - let ns = namespace.to_utf8_path(); - tracing::info!(namespace = %ns, track = %track_name, source = "remote", "serving subscribe from remote: {:?}", track.info); - // Update label to indicate remote source, timing recorded on drop - timing_guard.set_label("source", "remote"); - // Track active tracks - decrements when serve completes - let _track_guard = GaugeGuard::new("moq_relay_active_tracks"); - return Ok(subscribed.serve(track.reader).await?); - } - } - } - Err(e) => { - // Route error = infrastructure failure (couldn't reach coordinator/upstream) - // This is different from "not found" - we don't know if the track exists + // Check remote tracks second, and serve from remote if possible + match self + .remotes + .subscribe(self.scope.as_deref(), &namespace, &track_name) + .await + { + Ok(track) => { + if let Some(track) = track { let ns = namespace.to_utf8_path(); - tracing::error!(namespace = %ns, track = %track_name, error = %e, "failed to route to remote: {}", e); - timing_guard.set_label("source", "route_error"); - metrics::counter!("moq_relay_subscribe_route_errors_total").increment(1); - - // Return an internal error rather than "not found" since we couldn't check - // TODO: Consider returning a more specific error to the subscriber - let err = ServeError::internal_ctx(format!( - "route error for namespace '{}': {}", - namespace, e - )); - subscribed.close(err.clone())?; - return Err(err.into()); + tracing::info!(namespace = %ns, track = %track_name, source = "remote", "serving subscribe from remote: {:?}", track.info); + // Update label to indicate remote source, timing recorded on drop + timing_guard.set_label("source", "remote"); + // Track active tracks - decrements when serve completes + let _track_guard = GaugeGuard::new("moq_relay_active_tracks"); + return Ok(subscribed.serve(track).await?); } } + Err(e) => { + // Route error = infrastructure failure (couldn't reach coordinator/upstream) + // This is different from "not found" - we don't know if the track exists + let ns = namespace.to_utf8_path(); + tracing::error!(namespace = %ns, track = %track_name, error = %e, "failed to route to remote: {}", e); + timing_guard.set_label("source", "route_error"); + metrics::counter!("moq_relay_subscribe_route_errors_total").increment(1); + + // Return an internal error rather than "not found" since we couldn't check + // TODO: Consider returning a more specific error to the subscriber + let err = ServeError::internal_ctx(format!( + "route error for namespace '{}': {}", + namespace, e + )); + subscribed.close(err.clone())?; + return Err(err.into()); + } } // Track not found - we checked all sources and the track doesn't exist diff --git a/moq-relay-ietf/src/relay.rs b/moq-relay-ietf/src/relay.rs index 5beea1b8..06a43c9e 100644 --- a/moq-relay-ietf/src/relay.rs +++ b/moq-relay-ietf/src/relay.rs @@ -9,10 +9,7 @@ use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_native_ietf::quic::{self, Endpoint}; use url::Url; -use crate::{ - metrics::GaugeGuard, Consumer, Coordinator, Locals, Producer, Remotes, RemotesConsumer, - RemotesProducer, Session, -}; +use crate::{metrics::GaugeGuard, Consumer, Coordinator, Locals, Producer, RemoteManager, Session}; // A type alias for boxed future type ServerFuture = Pin< @@ -64,7 +61,7 @@ pub struct Relay { announce_url: Option, mlog_dir: Option, locals: Locals, - remotes: Option<(RemotesProducer, RemotesConsumer)>, + remotes: RemoteManager, coordinator: Arc, } @@ -109,260 +106,263 @@ impl Relay { .collect::>(); // Create remote manager - uses coordinator for namespace lookups - let remotes = Remotes { - coordinator: config.coordinator.clone(), - quic: remote_clients[0].clone(), - } - .produce(); + let remotes = RemoteManager::new(config.coordinator.clone(), remote_clients); Ok(Self { quic_endpoints: endpoints, announce_url: config.announce, mlog_dir: config.mlog_dir, locals, - remotes: Some(remotes), + remotes, coordinator: config.coordinator, }) } /// Run the relay server. pub async fn run(self) -> anyhow::Result<()> { - let mut tasks = FuturesUnordered::new(); - - // Split remotes producer/consumer and spawn producer task - let remotes = self.remotes.map(|(producer, consumer)| { - tasks.push(producer.run().boxed()); - consumer - }); - - // Start the forwarder, if any - let forward_producer = if let Some(url) = &self.announce_url { - tracing::info!("forwarding announces to {}", url); - - // Establish a QUIC connection to the forward URL - let (session, _quic_client_initial_cid, transport) = self.quic_endpoints[0] - .client - .connect(url, None) - .await - .context("failed to establish forward connection")?; - - // Create the MoQ session over the connection - let (session, publisher, subscriber) = - moq_transport::session::Session::connect(session, None, transport) - .await - .context("failed to establish forward session")?; - - // Use the connection path already validated and stored by Session::connect(). - // The forward session is scoped to whatever path the announce URL specifies. - // - // Note: the forward connection intentionally does not call - // coordinator.resolve_scope(). The announce URL is operator-configured - // (via --announce), not client-supplied, so it doesn't need the same - // auth/permission checks that incoming client connections get. The - // forward session always gets both Producer and Consumer (full - // read-write) since it's acting as a relay peer, not a client. - // - // Limitation: all incoming scopes are forwarded to this single upstream scope. - // Multi-scope forwarding (routing different incoming scopes to different - // upstream paths) would require per-scope forward connections. - let forward_scope = session.connection_path().map(|s| s.to_string()); - - let coordinator = self.coordinator.clone(); - let session = Session { - session, - producer: Some(Producer::new( - publisher, - self.locals.clone(), - remotes.clone(), - forward_scope.clone(), - )), - consumer: Some(Consumer::new( - subscriber, - self.locals.clone(), - coordinator, - None, - forward_scope, - )), - // Forward connections are always full read-write relay peers, - // so no reject loops needed. - reject_publishes: None, - reject_subscribes: None, - }; + let Self { + quic_endpoints, + announce_url, + mlog_dir, + locals, + remotes, + coordinator, + } = self; - let forward_producer = session.producer.clone(); + let run_result = async { + let mut tasks = FuturesUnordered::new(); - tasks.push(async move { session.run().await.context("forwarding failed") }.boxed()); + // Use the remote manager for routing to remote relays. + let remote_manager = remotes.clone(); - forward_producer - } else { - None - }; + // Start the forwarder, if any + let forward_producer = if let Some(url) = &announce_url { + tracing::info!("forwarding announces to {}", url); - let servers: Vec = self - .quic_endpoints - .into_iter() - .map(|endpoint| { - endpoint - .server - .context("missing TLS certificate for server") - }) - .collect::>()?; - - // This will hold the futures for all our listening servers. - let mut accepts: FuturesUnordered = FuturesUnordered::new(); - for mut server in servers { - tracing::info!("listening on {}", server.local_addr()?); - - // Create a future, box it, and push it to the collection. - accepts.push( - async move { - let conn = server.accept().await.context("accept failed"); - (conn, server) - } - .boxed(), - ); - } + // Establish a QUIC connection to the forward URL + let (session, _quic_client_initial_cid, transport) = quic_endpoints[0] + .client + .connect(url, None) + .await + .context("failed to establish forward connection")?; + + // Create the MoQ session over the connection + let (session, publisher, subscriber) = + moq_transport::session::Session::connect(session, None, transport) + .await + .context("failed to establish forward session")?; + + // Use the connection path already validated and stored by Session::connect(). + // The forward session is scoped to whatever path the announce URL specifies. + // + // Note: the forward connection intentionally does not call + // coordinator.resolve_scope(). The announce URL is operator-configured + // (via --announce), not client-supplied, so it doesn't need the same + // auth/permission checks that incoming client connections get. The + // forward session always gets both Producer and Consumer (full + // read-write) since it's acting as a relay peer, not a client. + // + // Limitation: all incoming scopes are forwarded to this single upstream scope. + // Multi-scope forwarding (routing different incoming scopes to different + // upstream paths) would require per-scope forward connections. + let forward_scope = session.connection_path().map(|s| s.to_string()); + + let forward_coordinator = coordinator.clone(); + let session = Session { + session, + producer: Some(Producer::new( + publisher, + locals.clone(), + remote_manager.clone(), + forward_scope.clone(), + )), + consumer: Some(Consumer::new( + subscriber, + locals.clone(), + forward_coordinator, + None, + forward_scope, + )), + // Forward connections are always full read-write relay peers, + // so no reject loops needed. + reject_publishes: None, + reject_subscribes: None, + }; + + let forward_producer = session.producer.clone(); + + tasks.push(async move { session.run().await.context("forwarding failed") }.boxed()); + + forward_producer + } else { + None + }; + + let servers: Vec = quic_endpoints + .into_iter() + .map(|endpoint| endpoint.server.context("missing TLS certificate for server")) + .collect::>()?; + + // This will hold the futures for all our listening servers. + let mut accepts: FuturesUnordered = FuturesUnordered::new(); + for mut server in servers { + tracing::info!("listening on {}", server.local_addr()?); + + // Create a future, box it, and push it to the collection. + accepts.push( + async move { + let conn = server.accept().await.context("accept failed"); + (conn, server) + } + .boxed(), + ); + } - loop { - tokio::select! { - // This branch polls all the `accept` futures concurrently. - Some((conn_result, mut server)) = accepts.next() => { - // An accept operation has completed. - // First, immediately queue up the next accept() call for this server. - accepts.push( - async move { - let conn = server.accept().await.context("accept failed"); - (conn, server) - } - .boxed(), - ); - - let (conn, connection_id, transport) = conn_result.context("failed to accept QUIC connection")?; - - metrics::counter!("moq_relay_connections_total").increment(1); - - // Construct mlog path from connection ID if mlog directory is configured - let mlog_path = self.mlog_dir.as_ref() - .map(|dir| dir.join(format!("{}_server.mlog", connection_id))); - - let locals = self.locals.clone(); - let remotes = remotes.clone(); - let forward = forward_producer.clone(); - let coordinator = self.coordinator.clone(); - - // Spawn a new task to handle the connection - tasks.push(async move { - // Track active connections - decrements when task completes - let _conn_guard = GaugeGuard::new("moq_relay_active_connections"); - - // Clone the raw connection so we can close it with a proper - // error code if scope resolution fails after the MoQ handshake. - let raw_conn = conn.clone(); - - // Create the MoQ session over the connection (setup handshake etc) - let (session, publisher, subscriber) = match moq_transport::session::Session::accept(conn, mlog_path, transport).await { - Ok(session) => session, - Err(err) => { - tracing::warn!(error = %err, "failed to accept MoQ session: {}", err); - metrics::counter!("moq_relay_connection_errors_total", "stage" => "session_accept").increment(1); - // Maintain invariant: connections_total - connections_closed_total == active_connections - metrics::counter!("moq_relay_connections_closed_total").increment(1); - return Ok(()); + loop { + tokio::select! { + // This branch polls all the `accept` futures concurrently. + Some((conn_result, mut server)) = accepts.next() => { + // An accept operation has completed. + // First, immediately queue up the next accept() call for this server. + accepts.push( + async move { + let conn = server.accept().await.context("accept failed"); + (conn, server) } - }; - - // Create our MoQ relay session - let moq_session = session; - - // Resolve the connection path to a scope (identity + permissions). - // This translates the raw transport-level path into an application-level - // scope_id and determines what the connection is allowed to do. - let scope_info = match coordinator.resolve_scope(moq_session.connection_path()).await { - Ok(info) => info, - Err(err) => { - tracing::warn!( + .boxed(), + ); + + let (conn, connection_id, transport) = conn_result.context("failed to accept QUIC connection")?; + + metrics::counter!("moq_relay_connections_total").increment(1); + + // Construct mlog path from connection ID if mlog directory is configured + let mlog_path = mlog_dir.as_ref() + .map(|dir| dir.join(format!("{}_server.mlog", connection_id))); + + let locals = locals.clone(); + let remotes = remote_manager.clone(); + let forward = forward_producer.clone(); + let coordinator = coordinator.clone(); + + // Spawn a new task to handle the connection + tasks.push(async move { + // Track active connections - decrements when task completes + let _conn_guard = GaugeGuard::new("moq_relay_active_connections"); + + // Clone the raw connection so we can close it with a proper + // error code if scope resolution fails after the MoQ handshake. + let raw_conn = conn.clone(); + + // Create the MoQ session over the connection (setup handshake etc) + let (session, publisher, subscriber) = match moq_transport::session::Session::accept(conn, mlog_path, transport).await { + Ok(session) => session, + Err(err) => { + tracing::warn!(error = %err, "failed to accept MoQ session: {}", err); + metrics::counter!("moq_relay_connection_errors_total", "stage" => "session_accept").increment(1); + // Maintain invariant: connections_total - connections_closed_total == active_connections + metrics::counter!("moq_relay_connections_closed_total").increment(1); + return Ok(()); + } + }; + + // Create our MoQ relay session + let moq_session = session; + + // Resolve the connection path to a scope (identity + permissions). + // This translates the raw transport-level path into an application-level + // scope_id and determines what the connection is allowed to do. + let scope_info = match coordinator.resolve_scope(moq_session.connection_path()).await { + Ok(info) => info, + Err(err) => { + tracing::warn!( + connection_path = moq_session.connection_path(), + error = %err, + "scope resolution failed, rejecting session" + ); + // Close with PROTOCOL_VIOLATION (0x3) so the client + // gets a meaningful error instead of an abrupt reset. + // This is a QUIC APPLICATION_CLOSE, not a MoQT SESSION_CLOSE + // control message. Sending a proper SESSION_CLOSE would require + // running the MoQ session's send loop, which is not warranted + // for a pre-session rejection. The QUIC close code and reason + // string are visible to the client's transport layer. + raw_conn.close(0x3, "scope resolution failed"); + metrics::counter!("moq_relay_connection_errors_total", "stage" => "scope_resolve").increment(1); + metrics::counter!("moq_relay_connections_closed_total").increment(1); + return Ok(()); + } + }; + + let scope_id = scope_info.as_ref().map(|s| s.scope_id.clone()); + let can_publish = scope_info.as_ref().is_none_or(|s| s.permissions.can_publish()); + let can_subscribe = scope_info.as_ref().is_none_or(|s| s.permissions.can_subscribe()); + + if let Some(ref info) = scope_info { + tracing::debug!( connection_path = moq_session.connection_path(), - error = %err, - "scope resolution failed, rejecting session" + scope_id = %info.scope_id, + permissions = ?info.permissions, + "scope resolved" ); - // Close with PROTOCOL_VIOLATION (0x3) so the client - // gets a meaningful error instead of an abrupt reset. - // This is a QUIC APPLICATION_CLOSE, not a MoQT SESSION_CLOSE - // control message. Sending a proper SESSION_CLOSE would require - // running the MoQ session's send loop, which is not warranted - // for a pre-session rejection. The QUIC close code and reason - // string are visible to the client's transport layer. - raw_conn.close(0x3, "scope resolution failed"); - metrics::counter!("moq_relay_connection_errors_total", "stage" => "scope_resolve").increment(1); - metrics::counter!("moq_relay_connections_closed_total").increment(1); - return Ok(()); - } - }; - - let scope_id = scope_info.as_ref().map(|s| s.scope_id.clone()); - let can_publish = scope_info.as_ref().is_none_or(|s| s.permissions.can_publish()); - let can_subscribe = scope_info.as_ref().is_none_or(|s| s.permissions.can_subscribe()); - - if let Some(ref info) = scope_info { - tracing::debug!( - connection_path = moq_session.connection_path(), - scope_id = %info.scope_id, - permissions = ?info.permissions, - "scope resolved" - ); - } - - // Gate Producer/Consumer creation on permissions. - // Note the intentional inversion: - // - Producer serves SUBSCRIBEs → gated on can_subscribe - // - Consumer handles PUBLISH_NAMESPACEs → gated on can_publish - // - // When a half is disabled, we pass its transport counterpart - // to the Session's reject fields so unauthorized messages get - // an explicit error response instead of being silently ignored. - let (producer, reject_subscribes) = if can_subscribe { - (publisher.map(|publisher| Producer::new(publisher, locals.clone(), remotes, scope_id.clone())), None) - } else { - (None, publisher) - }; - - let (consumer, reject_publishes) = if can_publish { - (subscriber.map(|subscriber| Consumer::new(subscriber, locals, coordinator, forward, scope_id)), None) - } else { - (None, subscriber) - }; - - let session = Session { - session: moq_session, - producer, - consumer, - reject_publishes, - reject_subscribes, - }; - - match session.run().await { - Ok(()) => { - // Session ended cleanly (uncommon - usually ends via close) - metrics::counter!("moq_relay_connections_closed_total").increment(1); } - Err(err) if err.is_graceful_close() => { - // Graceful close - peer sent APPLICATION_CLOSE with code 0 - tracing::debug!("MoQ session closed gracefully"); - metrics::counter!("moq_relay_connections_closed_total").increment(1); - } - Err(err) => { - // Actual error - protocol violation, timeout, etc. - tracing::warn!(error = %err, "MoQ session error: {}", err); - metrics::counter!("moq_relay_connection_errors_total", "stage" => "session_run").increment(1); - metrics::counter!("moq_relay_connections_closed_total").increment(1); + + // Gate Producer/Consumer creation on permissions. + // Note the intentional inversion: + // - Producer serves SUBSCRIBEs → gated on can_subscribe + // - Consumer handles PUBLISH_NAMESPACEs → gated on can_publish + // + // When a half is disabled, we pass its transport counterpart + // to the Session's reject fields so unauthorized messages get + // an explicit error response instead of being silently ignored. + let (producer, reject_subscribes) = if can_subscribe { + (publisher.map(|publisher| Producer::new(publisher, locals.clone(), remotes, scope_id.clone())), None) + } else { + (None, publisher) + }; + + let (consumer, reject_publishes) = if can_publish { + (subscriber.map(|subscriber| Consumer::new(subscriber, locals, coordinator, forward, scope_id)), None) + } else { + (None, subscriber) + }; + + let session = Session { + session: moq_session, + producer, + consumer, + reject_publishes, + reject_subscribes, + }; + + match session.run().await { + Ok(()) => { + // Session ended cleanly (uncommon - usually ends via close) + metrics::counter!("moq_relay_connections_closed_total").increment(1); + } + Err(err) if err.is_graceful_close() => { + // Graceful close - peer sent APPLICATION_CLOSE with code 0 + tracing::debug!("MoQ session closed gracefully"); + metrics::counter!("moq_relay_connections_closed_total").increment(1); + } + Err(err) => { + // Actual error - protocol violation, timeout, etc. + tracing::warn!(error = %err, "MoQ session error: {}", err); + metrics::counter!("moq_relay_connection_errors_total", "stage" => "session_run").increment(1); + metrics::counter!("moq_relay_connections_closed_total").increment(1); + } } - } - Ok(()) - }.boxed()); - }, - res = tasks.next(), if !tasks.is_empty() => res.unwrap()?, + Ok(()) + }.boxed()); + }, + res = tasks.next(), if !tasks.is_empty() => res.unwrap()?, + } } } + .await; + + remotes.shutdown().await; + run_result } } diff --git a/moq-relay-ietf/src/remote.rs b/moq-relay-ietf/src/remote.rs index 2adc8582..f86ba374 100644 --- a/moq-relay-ietf/src/remote.rs +++ b/moq-relay-ietf/src/remote.rs @@ -2,269 +2,271 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use std::collections::HashMap; - -use std::collections::VecDeque; -use std::fmt; use std::net::SocketAddr; -use std::ops; -use std::sync::Arc; -use std::sync::Weak; - -/// Cache key for upstream relay-to-relay connections. -/// -/// Keyed by both URL and destination address so that connections are -/// reused only when both match. This matters when a [`Coordinator`] -/// returns the same URL for different namespaces (e.g. a shared relay -/// hostname) but distinguishes destinations via [`NamespaceOrigin::addr`]. -/// Without the address in the key, all namespaces that share a URL -/// would be routed through a single cached connection to whichever -/// upstream host was contacted first. -type RemoteCacheKey = (Url, Option); +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Weak}; -use futures::stream::FuturesUnordered; -use futures::FutureExt; -use futures::StreamExt; use moq_native_ietf::quic; use moq_transport::coding::TrackNamespace; -use moq_transport::serve::{Track, TrackReader, TrackWriter}; -use moq_transport::watch::State; +use moq_transport::serve::{Track, TrackReader}; +use tokio::sync::Mutex; +use tokio_util::sync::CancellationToken; use url::Url; -use crate::{metrics::GaugeGuard, Coordinator}; +use crate::{metrics::GaugeGuard, Coordinator, CoordinatorError}; -/// Information about remote origins. -pub struct Remotes { - /// The client we use to fetch/store origin information. - pub coordinator: Arc, +/// Cache key for upstream relay-to-relay connections. +/// +/// Keyed by both URL and destination address so that connections are reused +/// only when both match. +type RemoteCacheKey = (Url, Option); +type RemoteSlot = Arc>>; +type TrackCacheKey = (TrackNamespace, String); +type TrackSlot = Arc>>; - // A QUIC endpoint we'll use to fetch from other origins. - pub quic: quic::Client, +/// Manages connections to remote relays. +/// +/// When a subscription request comes in for a namespace that isn't local, +/// RemoteManager uses the coordinator to find which remote relay serves it, +/// establishes a connection if needed, and subscribes to the track. +#[derive(Clone)] +pub struct RemoteManager { + coordinator: Arc, + clients: Vec, + remotes: Arc>>, } -impl Remotes { - pub fn produce(self) -> (RemotesProducer, RemotesConsumer) { - let (send, recv) = State::default().split(); - let info = Arc::new(self); - - let producer = RemotesProducer::new(info.clone(), send); - let consumer = RemotesConsumer::new(info, recv); - - (producer, consumer) +impl RemoteManager { + /// Create a new RemoteManager. + pub fn new(coordinator: Arc, clients: Vec) -> Self { + Self { + coordinator, + clients, + remotes: Arc::new(Mutex::new(HashMap::new())), + } } -} - -#[derive(Default)] -struct RemotesState { - lookup: HashMap, - requested: VecDeque, -} -// Clone for convenience, but there should only be one instance of this -#[derive(Clone)] -pub struct RemotesProducer { - info: Arc, - state: State, -} + /// Subscribe to a track from a remote relay. + /// + /// `scope` is the resolved scope identity from `Coordinator::resolve_scope()`, + /// passed through to the coordinator's `lookup()` to scope the search. + /// + /// Returns None if the namespace isn't found in any remote relay. + pub async fn subscribe( + &self, + scope: Option<&str>, + namespace: &TrackNamespace, + track_name: &str, + ) -> anyhow::Result> { + let (origin, client) = match self.coordinator.lookup(scope, namespace).await { + Ok(result) => result, + Err(CoordinatorError::NamespaceNotFound) => return Ok(None), + Err(err) => return Err(err.into()), + }; -impl RemotesProducer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } - } + let url = origin.url(); + let cache_key = (url.clone(), origin.addr()); + + let remote = match self + .get_or_connect(cache_key.clone(), client.as_ref()) + .await + { + Ok(remote) => remote, + Err(err) => { + tracing::error!(remote_url = %url, error = %err, "failed to connect to remote relay: {}", err); + return Err(err); + } + }; - /// Block until the next remote requested by a consumer. - async fn next(&mut self) -> Option { - loop { - { - let state = self.state.lock(); - if !state.requested.is_empty() { - return state.into_mut()?.requested.pop_front(); - } + match remote + .subscribe(namespace.clone(), track_name.to_string()) + .await + { + Ok(reader) => Ok(reader), + Err(err) => { + tracing::warn!(remote_url = %url, error = %err, "remote subscribe failed, removing from cache"); + self.remove_if_same_remote(&cache_key, &remote).await; - state.modified()? + Err(err) } - .await; } } - /// Run the remotes producer to serve remote requests. - pub async fn run(mut self) -> anyhow::Result<()> { - let mut tasks = FuturesUnordered::new(); + /// Get an existing remote connection or create a new one. + async fn get_or_connect( + &self, + cache_key: RemoteCacheKey, + client: Option<&quic::Client>, + ) -> anyhow::Result { + let client = match client { + Some(client) => client, + None => self.clients.first().ok_or_else(|| { + anyhow::anyhow!("no QUIC clients configured for remote connections") + })?, + }; loop { - tokio::select! { - Some(mut remote) = self.next() => { - let url = remote.url.clone(); - let cache_key = (url.clone(), remote.addr); + // The manager lock only protects the map. The per-key slot lock protects + // that key's connection state, so unrelated remotes can connect in parallel. + let slot = { + let mut remotes = self.remotes.lock().await; + remotes + .entry(cache_key.clone()) + .or_insert_with(|| Arc::new(Mutex::new(None))) + .clone() + }; - // Spawn a task to serve the remote - tasks.push(async move { - let info = remote.info.clone(); - let remote_url = url.to_string(); + let mut cached = slot.lock().await; - tracing::warn!(remote_url = %remote_url, "serving remote: {:?}", info); + let is_current_slot = { + let remotes = self.remotes.lock().await; + matches!(remotes.get(&cache_key), Some(current) if Arc::ptr_eq(current, &slot)) + }; - // Run the remote producer - if let Err(err) = remote.run().await { - tracing::warn!(remote_url = %remote_url, error = %err, "failed serving remote: {:?}, error: {}", info, err); - } + if !is_current_slot { + continue; + } - cache_key - }); + if let Some(remote) = cached.as_ref() { + if remote.is_connected() { + return Ok(remote.clone()); } - // Handle finished remote producers - res = tasks.next(), if !tasks.is_empty() => { - let cache_key = res.unwrap(); + tracing::info!(remote_url = %cache_key.0, "removing dead connection to remote relay"); + }; - if let Some(mut state) = self.state.lock_mut() { - state.lookup.remove(&cache_key); - } - }, - else => return Ok(()), + if let Some(remote) = cached.take() { + remote.shutdown().await; } - } - } -} - -impl ops::Deref for RemotesProducer { - type Target = Remotes; - - fn deref(&self) -> &Self::Target { - &self.info - } -} -#[derive(Clone)] -pub struct RemotesConsumer { - pub info: Arc, - state: State, -} + tracing::info!(remote_url = %cache_key.0, "connecting to remote relay"); + let remote = match Remote::connect( + cache_key.0.clone(), + cache_key.1, + client, + Arc::downgrade(&self.remotes), + cache_key.clone(), + Arc::downgrade(&slot), + ) + .await + { + Ok(remote) => remote, + Err(err) => { + drop(cached); + remove_empty_remote_slot(&self.remotes, &cache_key, &slot).await; + return Err(err); + } + }; -impl RemotesConsumer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } + *cached = Some(remote.clone()); + return Ok(remote); + } } - /// Route to a remote origin based on the namespace. - /// - /// `scope` is the resolved scope identity (from `Coordinator::resolve_scope()`), - /// passed through to the coordinator's `lookup()` to scope the search. - pub async fn route( - &self, - scope: Option<&str>, - namespace: &TrackNamespace, - ) -> anyhow::Result> { - // Always fetch the origin instead of using the (potentially invalid) cache. - let (origin, client) = self.coordinator.lookup(scope, namespace).await?; + async fn remove_if_same_remote(&self, cache_key: &RemoteCacheKey, remote: &Remote) { + let slot = { + let remotes = self.remotes.lock().await; + remotes.get(cache_key).cloned() + }; - let cache_key = (origin.url(), origin.addr()); + if let Some(slot) = slot { + let removed = { + let mut cached = slot.lock().await; + match cached.as_ref() { + Some(current) if current.is_same_connection(remote) => cached.take(), + _ => None, + } + }; - // Check if we already have a remote for this origin - let state = self.state.lock(); - if let Some(remote) = state.lookup.get(&cache_key).cloned() { - return Ok(Some(remote)); + if let Some(remote) = removed { + remote.shutdown().await; + remove_empty_remote_slot(&self.remotes, cache_key, &slot).await; + } } + } - // Create a new remote for this origin - let mut state = match state.into_mut() { - Some(state) => state, - None => return Ok(None), - }; - - let remote = Remote { - url: origin.url(), - remotes: self.info.clone(), - addr: origin.addr(), - client, + /// Shutdown all remote connections. + pub(crate) async fn shutdown(&self) { + let remotes = { + let mut remotes = self.remotes.lock().await; + remotes.drain().collect::>() }; - // Produce the remote - let (writer, reader) = remote.produce(); - state.requested.push_back(writer); - - // Insert the remote into our Map, keyed by both URL and destination address - state.lookup.insert(cache_key, reader.clone()); - - Ok(Some(reader)) + for (cache_key, slot) in remotes { + tracing::info!(remote_url = %cache_key.0, "shutting down remote connection"); + let mut remote = slot.lock().await; + if let Some(remote) = remote.take() { + remote.shutdown().await; + } + } } } -impl ops::Deref for RemotesConsumer { - type Target = Remotes; - - fn deref(&self) -> &Self::Target { - &self.info +async fn remove_empty_remote_slot( + remotes: &Arc>>, + cache_key: &RemoteCacheKey, + slot: &RemoteSlot, +) { + let cached = slot.lock().await; + if cached.is_some() { + return; } -} -pub struct Remote { - pub remotes: Arc, - pub url: Url, - pub addr: Option, - pub client: Option, -} - -impl fmt::Debug for Remote { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Remote") - .field("url", &self.url.to_string()) - .finish() + let mut remotes = remotes.lock().await; + if matches!(remotes.get(cache_key), Some(current) if Arc::ptr_eq(current, slot)) { + remotes.remove(cache_key); } } -impl ops::Deref for Remote { - type Target = Remotes; - - fn deref(&self) -> &Self::Target { - &self.remotes +async fn remove_empty_track_slot( + tracks: &Arc>>, + key: &TrackCacheKey, + slot: &TrackSlot, +) { + let cached = slot.lock().await; + if cached.is_some() { + return; } -} - -impl Remote { - /// Create a new broadcast. - pub fn produce(self) -> (RemoteProducer, RemoteConsumer) { - let (send, recv) = State::default().split(); - let info = Arc::new(self); - let consumer = RemoteConsumer::new(info.clone(), recv); - let producer = RemoteProducer::new(info, send); - - (producer, consumer) + let mut tracks = tracks.lock().await; + if matches!(tracks.get(key), Some(current) if Arc::ptr_eq(current, slot)) { + tracks.remove(key); } } -#[derive(Default)] -struct RemoteState { - tracks: HashMap<(TrackNamespace, String), RemoteTrackWeak>, - requested: VecDeque, -} - -pub struct RemoteProducer { - pub info: Arc, - state: State, +/// A connection to a single remote relay with its own QUIC client. +#[derive(Clone)] +struct Remote { + url: Url, + subscriber: moq_transport::session::Subscriber, + /// Track subscriptions keyed by full track name. + tracks: Arc>>, + /// Flag indicating if the connection is still alive. + connected: Arc, + /// Cancellation token for the session task. + cancel: CancellationToken, } -impl RemoteProducer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } - } - - pub async fn run(&mut self) -> anyhow::Result<()> { - let client = if let Some(client) = &self.info.client { - client - } else { - &self.quic +impl Remote { + /// Connect to a remote relay with a dedicated QUIC client. + async fn connect( + url: Url, + addr: Option, + client: &quic::Client, + remotes: Weak>>, + cache_key: RemoteCacheKey, + cache_slot: Weak>>, + ) -> anyhow::Result { + let (session, _quic_client_initial_cid, transport) = match client.connect(&url, addr).await + { + Ok(session) => session, + Err(err) => { + metrics::counter!("moq_relay_upstream_errors_total", "stage" => "connect") + .increment(1); + return Err(err); + } }; - // TODO reuse QUIC and MoQ sessions - let (session, _quic_client_initial_cid, transport) = - match client.connect(&self.url, self.addr).await { - Ok(session) => session, - Err(err) => { - metrics::counter!("moq_relay_upstream_errors_total", "stage" => "connect") - .increment(1); - return Err(err); - } - }; + let (session, subscriber) = match moq_transport::session::Subscriber::connect(session, transport).await { Ok(session) => session, @@ -275,190 +277,194 @@ impl RemoteProducer { } }; - // Track established upstream connections - decrements when this function returns. - // Placed after successful connect + session setup so the gauge only reflects - // connections that are actually serving, not in-flight attempts. - let _upstream_guard = GaugeGuard::new("moq_relay_upstream_connections"); - - // Run the session - let mut session = session.run().boxed(); - let mut tasks = FuturesUnordered::new(); + let connected = Arc::new(AtomicBool::new(true)); + let cancel = CancellationToken::new(); + let upstream_guard = GaugeGuard::new("moq_relay_upstream_connections"); - let mut done = None; + let session_url = url.clone(); + let session_connected = connected.clone(); + let session_cancel = cancel.clone(); - // Serve requested tracks - loop { + tokio::spawn(async move { + let _upstream_guard = upstream_guard; tokio::select! { - track = self.next(), if done.is_none() => { - let track = match track { - Ok(Some(track)) => track, - Ok(None) => { done = Some(Ok(())); continue }, - Err(err) => { done = Some(Err(err)); continue }, - }; - - let info = track.info.clone(); - let mut subscriber = subscriber.clone(); - - tasks.push(async move { - if let Err(err) = subscriber.subscribe(track).await { - let namespace = info.namespace.to_utf8_path(); - let track_name = &info.name; - tracing::warn!(namespace = %namespace, track = %track_name, error = %err, "failed serving track: {:?}, error: {}", info, err); - } - }); + result = session.run() => { + if let Err(err) = result { + tracing::warn!(remote_url = %session_url, error = %err, "remote session closed: {}", err); + } else { + tracing::info!(remote_url = %session_url, "remote session closed normally"); + } + } + _ = session_cancel.cancelled() => { + tracing::info!(remote_url = %session_url, "remote session cancelled"); } - _ = tasks.next(), if !tasks.is_empty() => {}, - - // Keep running the session - res = &mut session, if !tasks.is_empty() || done.is_none() => return Ok(res?), - - else => return done.unwrap(), } - } - } - /// Block until the next track requested by a consumer. - async fn next(&self) -> anyhow::Result> { - loop { - let notify = { - let state = self.state.lock(); - - // Check if we have any requested tracks - if !state.requested.is_empty() { - return Ok(state - .into_mut() - .and_then(|mut state| state.requested.pop_front())); + session_connected.store(false, Ordering::Release); + + if let Some(cache_slot) = cache_slot.upgrade() { + let mut cleared = false; + let mut cached = cache_slot.lock().await; + if matches!(cached.as_ref(), Some(remote) if Arc::ptr_eq(&remote.connected, &session_connected)) + { + cached.take(); + cleared = true; + tracing::info!(remote_url = %session_url, "cleared closed remote connection from cache"); } + drop(cached); - match state.modified() { - Some(notified) => notified, - None => return Ok(None), + if cleared { + if let Some(remotes) = remotes.upgrade() { + remove_empty_remote_slot(&remotes, &cache_key, &cache_slot).await; + } } - }; + } + }); - notify.await - } + Ok(Self { + url, + subscriber, + tracks: Arc::new(Mutex::new(HashMap::new())), + connected, + cancel, + }) } -} - -impl ops::Deref for RemoteProducer { - type Target = Remote; - fn deref(&self) -> &Self::Target { - &self.info + /// Check if the connection is still alive. + fn is_connected(&self) -> bool { + self.connected.load(Ordering::Acquire) } -} -#[derive(Clone)] -pub struct RemoteConsumer { - pub info: Arc, - state: State, -} + fn is_same_connection(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.connected, &other.connected) + } -impl RemoteConsumer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } + /// Shutdown the remote connection. + async fn shutdown(&self) { + self.cancel.cancel(); + self.connected.store(false, Ordering::Release); + self.tracks.lock().await.clear(); } - /// Request a track from the broadcast. - pub fn subscribe( + /// Subscribe to a track on this remote relay. + async fn subscribe( &self, - namespace: &TrackNamespace, - name: &str, - ) -> anyhow::Result> { - let key = (namespace.clone(), name.to_string()); - let state = self.state.lock(); - if let Some(track) = state.tracks.get(&key) { - if let Some(track) = track.upgrade() { - return Ok(Some(track)); + namespace: TrackNamespace, + track_name: String, + ) -> anyhow::Result> { + let key = (namespace.clone(), track_name.clone()); + + loop { + if !self.is_connected() { + anyhow::bail!("remote connection to {} is closed", self.url); } - } - let mut state = match state.into_mut() { - Some(state) => state, - None => return Ok(None), - }; + let slot = { + let mut tracks = self.tracks.lock().await; + tracks + .entry(key.clone()) + .or_insert_with(|| Arc::new(Mutex::new(None))) + .clone() + }; - let (writer, reader) = Track::new(namespace.clone(), name.to_string()).produce(); - let reader = RemoteTrackReader::new(reader, self.state.clone()); + let mut cached = slot.lock().await; - // Insert the track into our Map so we deduplicate future requests. - state.tracks.insert(key, reader.downgrade()); - state.requested.push_back(writer); + let is_current_slot = { + let tracks = self.tracks.lock().await; + matches!(tracks.get(&key), Some(current) if Arc::ptr_eq(current, &slot)) + }; - Ok(Some(reader)) - } -} + if !is_current_slot { + continue; + } -impl ops::Deref for RemoteConsumer { - type Target = Remote; + if let Some(reader) = cached.as_ref() { + if !reader.is_closed() { + return Ok(Some(reader.clone())); + } - fn deref(&self) -> &Self::Target { - &self.info - } -} + tracing::debug!(remote_url = %self.url, namespace = %key.0, track = %key.1, "removing closed remote track from cache"); + } -#[derive(Clone)] -pub struct RemoteTrackReader { - pub reader: TrackReader, - drop: Arc, -} + cached.take(); -impl RemoteTrackReader { - fn new(reader: TrackReader, parent: State) -> Self { - let drop = Arc::new(RemoteTrackDrop { - parent, - key: (reader.namespace.clone(), reader.name.clone()), - }); + let mut subscriber = self.subscriber.clone(); + let url = self.url.clone(); + let tracks = Arc::downgrade(&self.tracks); + let cancel = self.cancel.clone(); - Self { reader, drop } - } + tracing::info!(remote_url = %url, namespace = %key.0, track = %key.1, "subscribing to remote track"); - fn downgrade(&self) -> RemoteTrackWeak { - RemoteTrackWeak { - reader: self.reader.clone(), - drop: Arc::downgrade(&self.drop), - } - } -} + let (writer, reader) = Track::new(namespace.clone(), track_name.clone()).produce(); + let subscribe_result = tokio::select! { + result = subscriber.subscribe_open(writer) => result, + _ = cancel.cancelled() => { + drop(cached); + remove_empty_track_slot(&self.tracks, &key, &slot).await; + anyhow::bail!("subscribe cancelled, remote connection to {} is closed", self.url); + } + }; + + let subscribe = match subscribe_result { + Ok(subscribe) => subscribe, + Err(err) => { + drop(cached); + remove_empty_track_slot(&self.tracks, &key, &slot).await; + return Err(err.into()); + } + }; -impl ops::Deref for RemoteTrackReader { - type Target = TrackReader; + if !self.is_connected() { + drop(cached); + remove_empty_track_slot(&self.tracks, &key, &slot).await; + anyhow::bail!("remote connection to {} is closed", self.url); + } - fn deref(&self) -> &Self::Target { - &self.reader - } -} + *cached = Some(reader.clone()); + drop(cached); + + let cleanup_key = key.clone(); + let cleanup_reader = reader.clone(); + let cleanup_slot = slot.clone(); + tokio::spawn(async move { + tokio::select! { + result = subscribe.closed() => { + match result { + Ok(()) => { + tracing::debug!(remote_url = %url, namespace = %cleanup_key.0, track = %cleanup_key.1, "remote track subscription ended"); + } + Err(err) => { + tracing::warn!(remote_url = %url, namespace = %cleanup_key.0, track = %cleanup_key.1, error = %err, "remote track subscription ended with error: {}", err); + } + } + } + _ = cancel.cancelled() => { + tracing::debug!(remote_url = %url, namespace = %cleanup_key.0, track = %cleanup_key.1, "remote track subscription cancelled"); + } + } -impl ops::DerefMut for RemoteTrackReader { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.reader - } -} + if let Some(tracks) = tracks.upgrade() { + let mut cached = cleanup_slot.lock().await; + if matches!(cached.as_ref(), Some(current) if Arc::ptr_eq(¤t.info, &cleanup_reader.info)) + { + cached.take(); + } + drop(cached); -struct RemoteTrackWeak { - reader: TrackReader, - drop: Weak, -} + remove_empty_track_slot(&tracks, &cleanup_key, &cleanup_slot).await; + } + }); -impl RemoteTrackWeak { - fn upgrade(&self) -> Option { - Some(RemoteTrackReader { - reader: self.reader.clone(), - drop: self.drop.upgrade()?, - }) + return Ok(Some(reader)); + } } } -struct RemoteTrackDrop { - parent: State, - key: (TrackNamespace, String), -} - -impl Drop for RemoteTrackDrop { - fn drop(&mut self) { - if let Some(mut parent) = self.parent.lock_mut() { - parent.tracks.remove(&self.key); - } +impl std::fmt::Debug for Remote { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Remote") + .field("url", &self.url.to_string()) + .field("connected", &self.is_connected()) + .finish() } } diff --git a/moq-sub/README.md b/moq-sub/README.md index c1340d57..695cc68d 100644 --- a/moq-sub/README.md +++ b/moq-sub/README.md @@ -2,9 +2,9 @@ A command line tool for subscribing to media via Media over QUIC (MoQ). -Takes an URL to MoQ relay with a broadcast name in the path part of the URL. It will connect to the relay, subscribe to -the broadcast, and dump the media segments of the first video and first audio track to STDOUT. +Takes a URL to a MoQ relay and a broadcast name via `--name`. It will connect to the relay, subscribe to the broadcast, +and dump the media segments of the first video and first audio track to STDOUT. ``` -moq-sub https://localhost:4443/dev | ffplay - +moq-sub --name dev https://localhost:4443 | ffplay - ``` diff --git a/moq-transport/src/session/subscribe.rs b/moq-transport/src/session/subscribe.rs index c98e7370..ce63e6c3 100644 --- a/moq-transport/src/session/subscribe.rs +++ b/moq-transport/src/session/subscribe.rs @@ -140,12 +140,32 @@ impl Subscribe { .await; } } + + pub async fn ok(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + if state.ok { + return Ok(()); + } + + match state.modified() { + Some(notify) => notify, + None => return Err(ServeError::Done), + } + } + .await; + } + } } impl Drop for Subscribe { fn drop(&mut self) { self.subscriber .send_message(message::Unsubscribe { id: self.info.id }); + self.subscriber.remove_subscribe(self.info.id); } } diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index af3b51a4..2653abee 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -123,11 +123,20 @@ impl Subscriber { /// Subscribe to a track by creating a new subscribe request to the publisher. Block until subscription is closed. pub async fn subscribe(&mut self, track: serve::TrackWriter) -> Result<(), ServeError> { + let subscribe = self.subscribe_open(track).await?; + subscribe.closed().await + } + + /// Subscribe to a track and wait until the publisher acknowledges it. + pub async fn subscribe_open( + &mut self, + track: serve::TrackWriter, + ) -> Result { let request_id = self.get_next_request_id(); let (send, recv) = Subscribe::new(self.clone(), request_id, track); self.subscribes.lock().unwrap().insert(request_id, recv); - - send.closed().await + send.ok().await?; + Ok(send) } /// Send a message to the publisher via the control stream. @@ -235,7 +244,7 @@ impl Subscriber { } /// Remove a subscribe from our map of active subscribes, and the alias map if present. - fn remove_subscribe(&mut self, id: u64) -> Option { + pub(super) fn remove_subscribe(&mut self, id: u64) -> Option { if let Some(subscribe) = self.subscribes.lock().unwrap().remove(&id) { // Remove from alias map if present if let Some(track_alias) = subscribe.track_alias() { @@ -725,3 +734,86 @@ impl Subscriber { Ok(()) } } + +#[cfg(test)] +mod tests { + use std::{sync::atomic, task::Poll}; + + use super::*; + use crate::{ + message::{self, GroupOrder}, + serve::Track, + }; + + fn subscriber() -> Subscriber { + Subscriber::new(Queue::default(), Arc::new(atomic::AtomicU64::new(0)), None) + } + + #[tokio::test] + async fn subscribe_open_cleans_up_when_cancelled_before_ok() { + let mut subscriber = subscriber(); + let observer = subscriber.clone(); + let (writer, _reader) = + Track::new(TrackNamespace::from_utf8_path("test"), "0.mp4".into()).produce(); + + { + let subscribe = subscriber.subscribe_open(writer); + futures::pin_mut!(subscribe); + + assert!(matches!(futures::poll!(&mut subscribe), Poll::Pending)); + assert_eq!(observer.subscribes.lock().unwrap().len(), 1); + } + + assert!(observer.subscribes.lock().unwrap().is_empty()); + assert!(observer.subscribe_alias_map.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn dropping_open_subscribe_removes_recv_state() { + let mut subscriber = subscriber(); + let observer = subscriber.clone(); + let (writer, _reader) = + Track::new(TrackNamespace::from_utf8_path("test"), "0.mp4".into()).produce(); + + let subscribe = subscriber.subscribe_open(writer); + futures::pin_mut!(subscribe); + + assert!(matches!(futures::poll!(&mut subscribe), Poll::Pending)); + assert_eq!(observer.subscribes.lock().unwrap().len(), 1); + + let mut receiver = observer.clone(); + receiver + .recv_subscribe_ok(&message::SubscribeOk { + id: 0, + track_alias: 10, + expires: 0, + group_order: GroupOrder::Publisher, + content_exists: false, + largest_location: None, + params: Default::default(), + }) + .unwrap(); + + let subscribe = match futures::poll!(&mut subscribe) { + Poll::Ready(Ok(subscribe)) => subscribe, + Poll::Ready(Err(err)) => panic!("subscribe failed: {err}"), + Poll::Pending => panic!("subscribe remained pending after SubscribeOk"), + }; + + assert_eq!(observer.subscribes.lock().unwrap().len(), 1); + assert_eq!( + observer + .subscribe_alias_map + .lock() + .unwrap() + .get(&10) + .copied(), + Some(0) + ); + + drop(subscribe); + + assert!(observer.subscribes.lock().unwrap().is_empty()); + assert!(observer.subscribe_alias_map.lock().unwrap().is_empty()); + } +}