From de76583956d1c379ad3ade9ef0b3bc3949a9bc2e Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Thu, 2 Apr 2026 18:48:42 +0200 Subject: [PATCH 01/42] this is ok except for websocket server hella ugly --- bin/router/src/lib.rs | 2 +- bin/router/src/pipeline/http_callback.rs | 96 +++---- bin/router/src/pipeline/mod.rs | 214 +++++++++++---- bin/router/src/pipeline/websocket_server.rs | 156 ++++++++++- bin/router/src/schema_state.rs | 55 ++-- .../src/executors/active_subscriptions.rs | 258 ++++++++++++++++++ lib/executor/src/executors/http_callback.rs | 121 +++----- lib/executor/src/executors/map.rs | 21 +- lib/executor/src/executors/mod.rs | 1 + 9 files changed, 695 insertions(+), 229 deletions(-) create mode 100644 lib/executor/src/executors/active_subscriptions.rs diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index b0ab292b9..88bb4faa7 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -212,7 +212,7 @@ pub async fn router_entrypoint(plugin_registry: PluginRegistry) -> Result<(), Ro .await?; let shared_state_clone = shared_state.clone(); - let active_subs = schema_state.active_callback_subscriptions.clone(); + let active_subs = schema_state.active_subscriptions.clone(); // when `listen` is set, the callback route lives on a dedicated server bound to that address // otherwise, the callback route is mounted on the main server on the `callback_path` diff --git a/bin/router/src/pipeline/http_callback.rs b/bin/router/src/pipeline/http_callback.rs index c9e2de216..c3c214348 100644 --- a/bin/router/src/pipeline/http_callback.rs +++ b/bin/router/src/pipeline/http_callback.rs @@ -1,8 +1,11 @@ +use std::sync::Arc; + use bytes::Bytes as BytesLib; -use dashmap::mapref::one::Ref; +use hive_router_plan_executor::executors::active_subscriptions::{ + ActiveSubscriptionsRegistry, BroadcastItem, +}; use hive_router_plan_executor::executors::http_callback::{ - ActiveSubscription, ActiveSubscriptionsMap, CallbackMessage, CALLBACK_PROTOCOL_VERSION, - SUBSCRIPTION_PROTOCOL_HEADER, + CALLBACK_PROTOCOL_VERSION, SUBSCRIPTION_PROTOCOL_HEADER, }; use hive_router_plan_executor::response::graphql_error::GraphQLError; use http::StatusCode; @@ -142,16 +145,15 @@ fn validate_payload( Ok(()) } -fn handle_check(subscription_id: &str, subscription: &Ref<'_, String, ActiveSubscription>) { +fn handle_check(subscription_id: &str, registry: &ActiveSubscriptionsRegistry) { trace!(subscription_id = %subscription_id, "Received check message"); - subscription.record_heartbeat(); + registry.record_heartbeat(subscription_id); } fn handle_next( subscription_id: &str, payload: &CallbackPayload<'_>, - subscription: Ref<'_, String, ActiveSubscription>, - active_subscriptions: &ActiveSubscriptionsMap, + registry: &ActiveSubscriptionsRegistry, ) -> Result<(), CallbackError> { trace!(subscription_id = %subscription_id, "Received next message"); @@ -164,53 +166,38 @@ fn handle_next( } }; - match subscription - .sender - .try_send(CallbackMessage::Next { payload: data }) - { - Ok(()) => Ok(()), - Err(mpsc::error::TrySendError::Full(_)) => { - // if the channel is full it means the consuming client is too slow and unable to keep - // up. we terminate the subscription without an error message because it anyways cant go through - warn!(subscription_id = %subscription_id, "Subscription client is too slow"); - drop(subscription); - active_subscriptions.remove(subscription_id); - Err(CallbackError::ClientTooSlow { - subscription_id: subscription_id.to_string(), - }) - } - Err(mpsc::error::TrySendError::Closed(_)) => { - debug!(subscription_id = %subscription_id, "Subscription receiver dropped"); - drop(subscription); - active_subscriptions.remove(subscription_id); - Err(CallbackError::SubscriptionDropped { - subscription_id: subscription_id.to_string(), - }) - } + if !registry.send_event(subscription_id, BroadcastItem::Event(data)) { + debug!(subscription_id = %subscription_id, "Subscription receiver dropped"); + registry.remove(subscription_id); + return Err(CallbackError::SubscriptionDropped { + subscription_id: subscription_id.to_string(), + }); } + + // TODO: ClientTooSlow + + Ok(()) } fn handle_complete( subscription_id: &str, payload: &CallbackPayload<'_>, - subscription: Ref<'_, String, ActiveSubscription>, - active_subscriptions: &ActiveSubscriptionsMap, + registry: &ActiveSubscriptionsRegistry, ) { trace!(subscription_id = %subscription_id, "Received complete message"); - // if the buffer is full or closed we ignore and remove the subscription, we dont send - // the final error message because the client is already unable to consume - let _ = subscription.sender.try_send(CallbackMessage::Complete { - errors: payload.errors.clone(), - }); - drop(subscription); - active_subscriptions.remove(subscription_id); + if let Some(errors) = &payload.errors { + if !errors.is_empty() { + registry.send_event(subscription_id, BroadcastItem::Error(errors.clone())); + } + } + registry.remove(subscription_id); } pub async fn handler( req: HttpRequest, path: Path, body: Bytes, - active_subscriptions: web::types::State, + active_subscriptions: web::types::State>, ) -> Result { let subscription_id_from_path = path.into_inner(); @@ -220,29 +207,30 @@ pub async fn handler( validate_payload(&payload, &subscription_id_from_path)?; - let subscription = match active_subscriptions.get(&payload.id) { - Some(sub) => sub, - None => { - return Err(CallbackError::SubscriptionNotFound { - subscription_id: payload.id.clone(), - }); - } - }; + if !active_subscriptions.contains(&payload.id) { + return Err(CallbackError::SubscriptionNotFound { + subscription_id: payload.id.clone(), + }); + } - if subscription.verifier != payload.verifier { + let verifier = active_subscriptions + .get_callback_verifier(&payload.id) + .ok_or_else(|| CallbackError::SubscriptionNotFound { + subscription_id: payload.id.clone(), + })?; + + if verifier != payload.verifier { return Err(CallbackError::InvalidVerifier { subscription_id: payload.id.clone(), }); } match payload.action { - CallbackAction::Check => handle_check(&payload.id, &subscription), + CallbackAction::Check => handle_check(&payload.id, &active_subscriptions), CallbackAction::Next => { - handle_next(&payload.id, &payload, subscription, &active_subscriptions)?; - } - CallbackAction::Complete => { - handle_complete(&payload.id, &payload, subscription, &active_subscriptions) + handle_next(&payload.id, &payload, &active_subscriptions)?; } + CallbackAction::Complete => handle_complete(&payload.id, &payload, &active_subscriptions), }; Ok(HttpResponse::NoContent() diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 04d515a4d..a8b1773ae 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,13 +1,19 @@ use futures::Stream; +use futures::StreamExt; use std::{ collections::HashMap, hash::{Hash, Hasher}, sync::Arc, time::Instant, }; -use tracing::{error, Instrument}; +use tracing::{error, trace, Instrument}; use xxhash_rust::xxh3::Xxh3; +use hive_router_plan_executor::execution::plan::FailedExecutionResult; +use hive_router_plan_executor::executors::active_subscriptions::BroadcastItem; +use hive_router_plan_executor::executors::active_subscriptions::ListenerGuard; +use hive_router_plan_executor::headers::plan::ResponseHeaderAggregator; + use hive_router_internal::telemetry::traces::spans::{ graphql::GraphQLOperationSpan, http_request::HttpServerRequestSpan, }; @@ -250,10 +256,11 @@ pub async fn graphql_request_handler( let request_dedupe_enabled = shared_state.router_config.traffic_shaping.router.dedupe.enabled; - let planned_response = if request_dedupe_enabled + let fingerprint = if request_dedupe_enabled && matches!( normalize_payload.operation_for_plan.operation_kind, - Some(OperationKind::Query) | None + // same deduplication applies for queries and subscriptions + Some(OperationKind::Query) | Some(OperationKind::Subscription) | None ) { let variables_hash = hash_graphql_variables(&graphql_params.variables); let extensions_hash = graphql_params @@ -262,17 +269,50 @@ pub async fn graphql_request_handler( .map_or(0, hash_graphql_extensions); let schema_checksum = supergraph.schema_checksum(); - let fingerprint = inbound_request_fingerprint( - req, + Some(inbound_request_fingerprint( + req.method(), + req.path(), + req.headers(), &shared_state.in_flight_requests_header_policy, schema_checksum, normalize_payload.normalized_operation_hash, variables_hash, extensions_hash, - ); + )) + } else { + None + }; + + // subscription dedup, try to join an existing subscription + if is_subscription { + if let Some(fp) = fingerprint { + let registry = &schema_state.active_subscriptions; + if let Some((_sub_id, receiver, guard)) = + registry.try_join_by_fingerprint(fp) + { + let body_stream = broadcast_receiver_to_body_stream(receiver, guard); + let response = build_streaming_response( + body_stream, + response_mode, + None, + )?; + return Ok(response); + } + } + } + + // subscription fingerprint for leader registration (None disables broadcasting) + let subscription_fingerprint = if is_subscription { fingerprint } else { None }; + + let planned_response = if fingerprint.is_some() + && matches!( + normalize_payload.operation_for_plan.operation_kind, + Some(OperationKind::Query) | None + ) { + let fp = fingerprint.unwrap(); let (shared_response, _role) = shared_state .in_flight_requests - .claim(fingerprint) + .claim(fp) .get_or_try_init(|| async { match execute_planned_request( req, @@ -284,6 +324,7 @@ pub async fn graphql_request_handler( operation_span, plugin_req_state, response_mode, + subscription_fingerprint, ) .await? { @@ -306,6 +347,7 @@ pub async fn graphql_request_handler( operation_span, plugin_req_state, response_mode, + subscription_fingerprint, ) .await? }; @@ -378,6 +420,7 @@ async fn execute_planned_request<'exec>( operation_span: GraphQLOperationSpan, plugin_req_state: Option>, response_mode: &'exec ResponseMode, + subscription_fingerprint: Option, ) -> Result { let jwt_request_details = match &shared_state.jwt_auth_runtime { Some(jwt_auth_runtime) => match jwt_auth_runtime @@ -429,50 +472,31 @@ async fn execute_planned_request<'exec>( .await? { QueryPlanExecutionResult::Stream(result) => { - let stream_content_type = response_mode - .stream_content_type() - .ok_or(PipelineError::SubscriptionsTransportNotSupported)?; - - let content_type_header = match stream_content_type { - StreamContentType::IncrementalDelivery => { - http::HeaderValue::from_static(INCREMENTAL_DELIVERY_CONTENT_TYPE) - } - StreamContentType::SSE => http::HeaderValue::from_static("text/event-stream"), - StreamContentType::ApolloMultipartHTTP => { - http::HeaderValue::from_static(APOLLO_MULTIPART_HTTP_CONTENT_TYPE) - } - }; + let body_stream = if let Some(fingerprint) = subscription_fingerprint { + let registry = &schema_state.active_subscriptions; + let (handle, receiver, guard) = registry.register(Some(fingerprint), None); + + // spawn a task that reads from the upstream and broadcasts to all listeners. + // dropping the handle when the upstream ends removes the registry entry + let mut upstream = result.body; + tokio::spawn(async move { + while let Some(event) = upstream.next().await { + if !handle.send(BroadcastItem::Event(event.into())) { + break; + } + } + }); - // TODO: why exactly do we need a type cast here? - let body: std::pin::Pin< - Box> + Send>, - > = match stream_content_type { - StreamContentType::IncrementalDelivery => Box::pin( - multipart_subscribe::create_incremental_delivery_stream(result.body), - ), - StreamContentType::SSE => Box::pin(sse::create_stream( - result.body, - std::time::Duration::from_secs(10), - )), - StreamContentType::ApolloMultipartHTTP => { - Box::pin(multipart_subscribe::create_apollo_multipart_http_stream( - result.body, - std::time::Duration::from_secs(10), - )) - } + broadcast_receiver_to_body_stream(receiver, guard) + } else { + result.body }; - let mut response_builder = web::HttpResponse::Ok(); - - if let Some(response_headers_aggregator) = result.response_headers_aggregator { - response_headers_aggregator - .modify_client_response_headers(&mut response_builder)?; - } - - let response = response_builder - // .status(result.status) status codes in streaming responses should always be ok - .header(http::header::CONTENT_TYPE, content_type_header) - .streaming(body); + let response = build_streaming_response( + body_stream, + response_mode, + result.response_headers_aggregator, + )?; Ok(PlannedResponse::Direct { response }) } @@ -569,8 +593,87 @@ pub async fn execute_pipeline<'exec>( execute_plan(supergraph, shared_state, planned_request, operation_span).await } -fn inbound_request_fingerprint( - req: &HttpRequest, +/// converts a broadcast receiver into a BoxStream of serialized event bodies. +/// the listener guard is held for the lifetime of the stream to track listener count +fn broadcast_receiver_to_body_stream( + mut receiver: tokio::sync::broadcast::Receiver, + guard: ListenerGuard, +) -> futures::stream::BoxStream<'static, Vec> { + Box::pin(async_stream::stream! { + let _guard = guard; + loop { + match receiver.recv().await { + Ok(BroadcastItem::Event(data)) => { + yield data.to_vec(); + } + Ok(BroadcastItem::Error(errors)) => { + yield FailedExecutionResult { errors }.serialize(); + break; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + trace!(lagged = n, "broadcast receiver lagged, skipping missed messages"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + }) +} + +fn build_streaming_response( + body_stream: futures::stream::BoxStream<'static, Vec>, + response_mode: &ResponseMode, + response_headers_aggregator: Option, +) -> Result { + let stream_content_type = response_mode + .stream_content_type() + .ok_or(PipelineError::SubscriptionsTransportNotSupported)?; + + let content_type_header = match stream_content_type { + StreamContentType::IncrementalDelivery => { + http::HeaderValue::from_static(INCREMENTAL_DELIVERY_CONTENT_TYPE) + } + StreamContentType::SSE => http::HeaderValue::from_static("text/event-stream"), + StreamContentType::ApolloMultipartHTTP => { + http::HeaderValue::from_static(APOLLO_MULTIPART_HTTP_CONTENT_TYPE) + } + }; + + let body: std::pin::Pin< + Box> + Send>, + > = match stream_content_type { + StreamContentType::IncrementalDelivery => Box::pin( + multipart_subscribe::create_incremental_delivery_stream(body_stream), + ), + StreamContentType::SSE => Box::pin(sse::create_stream( + body_stream, + std::time::Duration::from_secs(10), + )), + StreamContentType::ApolloMultipartHTTP => { + Box::pin(multipart_subscribe::create_apollo_multipart_http_stream( + body_stream, + std::time::Duration::from_secs(10), + )) + } + }; + + let mut response_builder = web::HttpResponse::Ok(); + + if let Some(response_headers_aggregator) = response_headers_aggregator { + response_headers_aggregator.modify_client_response_headers(&mut response_builder)?; + } + + Ok(response_builder + .header(http::header::CONTENT_TYPE, content_type_header) + .streaming(body)) +} + +pub fn inbound_request_fingerprint( + method: &http::Method, + path: &str, + request_headers: &ntex::http::HeaderMap, dedupe_header_policy: &RouterRequestDedupeHeaderPolicy, schema_checksum: u64, normalized_operation_hash: u64, @@ -579,8 +682,7 @@ fn inbound_request_fingerprint( ) -> u64 { let mut hasher = Xxh3::new(); - let mut headers: Vec<(&str, &str)> = req - .headers() + let mut headers: Vec<(&str, &str)> = request_headers .iter() .filter(|(name, _)| dedupe_header_policy.should_include(name.as_str())) .filter_map(|(name, value)| value.to_str().ok().map(|v_str| (name.as_str(), v_str))) @@ -591,8 +693,8 @@ fn inbound_request_fingerprint( .then_with(|| left_value.cmp(right_value)) }); - req.method().hash(&mut hasher); - req.path().hash(&mut hasher); + method.hash(&mut hasher); + path.hash(&mut hasher); headers.hash(&mut hasher); schema_checksum.hash(&mut hasher); normalized_operation_hash.hash(&mut hasher); @@ -602,7 +704,7 @@ fn inbound_request_fingerprint( hasher.finish() } -fn hash_graphql_variables(variables: &HashMap) -> u64 { +pub fn hash_graphql_variables(variables: &HashMap) -> u64 { let mut hasher = Xxh3::new(); let mut keys: Vec<&str> = variables.keys().map(String::as_str).collect(); @@ -619,7 +721,7 @@ fn hash_graphql_variables(variables: &HashMap) -> u64 { hasher.finish() } -fn hash_graphql_extensions(extensions: &HashMap) -> u64 { +pub fn hash_graphql_extensions(extensions: &HashMap) -> u64 { // reused as hash_graphql_variables has the same function signature hash_graphql_variables(extensions) } diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 8b9cdd911..659a0715d 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -39,13 +39,17 @@ use crate::pipeline::coerce_variables::coerce_request_variables; use crate::pipeline::error::PipelineError; use crate::pipeline::execute_pipeline; use crate::pipeline::execution_request::GetQueryStr; -use crate::pipeline::normalize::normalize_request_with_cache; -use crate::pipeline::parser::parse_operation_with_cache; -use crate::pipeline::usage_reporting; -use crate::pipeline::validation::validate_operation_with_cache; +use crate::pipeline::{ + hash_graphql_extensions, hash_graphql_variables, inbound_request_fingerprint, + normalize::normalize_request_with_cache, parser::parse_operation_with_cache, usage_reporting, + validation::validate_operation_with_cache, +}; use crate::schema_state::SchemaState; use crate::shared_state::RouterSharedState; +use hive_router_plan_executor::execution::plan::FailedExecutionResult; +use hive_router_plan_executor::executors::active_subscriptions::BroadcastItem; + type WsStateRef = Rc>>>; pub async fn ws_index( @@ -480,6 +484,88 @@ async fn handle_text_frame( jwt: jwt_request_details, }.into(); + let is_subscription = matches!( + normalize_payload.operation_for_plan.operation_kind, + Some(OperationKind::Subscription) + ); + let request_dedupe_enabled = + shared_state.router_config.traffic_shaping.router.dedupe.enabled; + + // subscription dedup: try to join an existing subscription + if is_subscription && request_dedupe_enabled { + let variables_hash = hash_graphql_variables(&payload.variables); + let extensions_hash = payload + .extensions + .as_ref() + .map_or(0, hash_graphql_extensions); + let schema_checksum = supergraph.schema_checksum(); + let fingerprint = inbound_request_fingerprint( + &Method::POST, + ws_uri.path(), + &headers, + &shared_state.in_flight_requests_header_policy, + schema_checksum, + normalize_payload.normalized_operation_hash, + variables_hash, + extensions_hash, + ); + + let registry = &schema_state.active_subscriptions; + if let Some((_sub_id, mut receiver, guard)) = + registry.try_join_by_fingerprint(fingerprint) + { + let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); + state.borrow_mut().subscriptions.insert(id.clone(), cancel_tx); + + let _ws_guard = SubscriptionGuard { + state: state.clone(), + id: id.clone(), + }; + + trace!(id = %id, "Subscription joined via dedup"); + + let id_for_loop = id.clone(); + let mut cancelled = false; + loop { + tokio::select! { + recv_result = receiver.recv() => { + match recv_result { + Ok(BroadcastItem::Event(data)) => { + let _ = sink.send(ServerMessage::next(&id_for_loop, &data)).await; + } + Ok(BroadcastItem::Error(errors)) => { + let body = FailedExecutionResult { errors }.serialize(); + let _ = sink.send(ServerMessage::next(&id_for_loop, &body)).await; + break; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + trace!(id = %id_for_loop, lagged = n, "broadcast receiver lagged, skipping missed messages"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + _ = cancel_rx.recv() => { + cancelled = true; + break; + } + } + } + + drop(guard); + + return if cancelled { + trace!(id = %id, "Deduped subscription cancelled"); + None + } else { + trace!(id = %id, "Deduped subscription completed"); + Some(ServerMessage::complete(&id)) + }; + } + } + match execute_pipeline( &client_request_details, &normalize_payload, @@ -523,6 +609,67 @@ async fn handle_text_frame( Some(ServerMessage::complete(&id)) } Ok(QueryPlanExecutionResult::Stream(response)) => { + // register with the active subscriptions registry for + // dedup and close-with-error support + let subscription_fingerprint = if is_subscription && request_dedupe_enabled { + let variables_hash = hash_graphql_variables(&payload.variables); + let extensions_hash = payload + .extensions + .as_ref() + .map_or(0, hash_graphql_extensions); + let schema_checksum = supergraph.schema_checksum(); + Some(inbound_request_fingerprint( + &Method::POST, + ws_uri.path(), + &headers, + &shared_state.in_flight_requests_header_policy, + schema_checksum, + normalize_payload.normalized_operation_hash, + variables_hash, + extensions_hash, + )) + } else { + None + }; + + let (mut stream, _listener_guard) = if let Some(fp) = subscription_fingerprint { + let registry = &schema_state.active_subscriptions; + let (handle, receiver, guard) = registry.register(Some(fp), None); + + // spawn a task that reads from upstream and broadcasts + let mut upstream = response.body; + tokio::spawn(async move { + while let Some(event) = upstream.next().await { + if !handle.send(BroadcastItem::Event(event.into())) { + break; + } + } + }); + + let body_stream: futures::stream::BoxStream<'static, Vec> = + Box::pin(async_stream::stream! { + let mut receiver = receiver; + loop { + match receiver.recv().await { + Ok(BroadcastItem::Event(data)) => yield data.to_vec(), + Ok(BroadcastItem::Error(errors)) => { + yield FailedExecutionResult { errors }.serialize(); + break; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + trace!(lagged = n, "broadcast receiver lagged, skipping missed messages"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + }); + + (body_stream, Some(guard)) + } else { + (response.body, None) + }; + // we use mpsc::channel(1) instead of oneshot because oneshot::Receiver // is consumed on first await, which doesn't work in tokio::select! loops that // need to poll the receiver multiple times across iterations @@ -539,7 +686,6 @@ async fn handle_text_frame( id: id.clone(), }; - let mut stream = response.body; let mut cancelled = false; trace!(id = %id, "Subscription started"); diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index 7071fa42b..9905d83bb 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -1,7 +1,6 @@ use crate::pipeline::authorization::metadata::AuthorizationMetadataExt; use arc_swap::{ArcSwap, Guard}; use async_trait::async_trait; -use dashmap::DashMap; use graphql_tools::static_graphql::schema::Document; use graphql_tools::validation::utils::ValidationError; use hive_router_config::{supergraph::SupergraphSource, HiveRouterConfig}; @@ -10,10 +9,9 @@ use hive_router_internal::{ authorization::metadata::AuthorizationMetadata, background_tasks::{BackgroundTask, BackgroundTasksManager}, }; -use hive_router_plan_executor::response::graphql_error::GraphQLErrorExtensions; use hive_router_plan_executor::{ + executors::active_subscriptions::{ActiveSubscriptionsRegistry, BroadcastItem}, executors::error::SubgraphExecutorError, - executors::http_callback::{ActiveSubscriptionsMap, CallbackMessage}, hooks::on_supergraph_load::{ OnSupergraphLoadEndHookPayload, OnSupergraphLoadStartHookPayload, SupergraphData, }, @@ -49,7 +47,7 @@ pub struct SchemaState { pub validate_cache: Cache>>, pub normalize_cache: Cache>, pub telemetry_context: Arc, - pub active_callback_subscriptions: ActiveSubscriptionsMap, + pub active_subscriptions: Arc, } #[derive(Debug, thiserror::Error)] @@ -99,16 +97,16 @@ impl SchemaState { let plan_cache = cache_state.plan_cache.clone(); let validate_cache = cache_state.validate_cache.clone(); let normalize_cache = cache_state.normalize_cache.clone(); - let active_callback_subscriptions: ActiveSubscriptionsMap = Arc::new(DashMap::new()); + let active_subscriptions = Arc::new(ActiveSubscriptionsRegistry::new()); // This is cheap clone, as Cache is thread-safe and can be cloned without any performance penalty. let cache_state_for_invalidation = cache_state.clone(); - let active_callback_subscriptions_for_build_data = active_callback_subscriptions.clone(); + let active_subscriptions_for_build_data = active_subscriptions.clone(); // kick off subscriptions/subgraphs that are idling/timed out due to missed heartbeats if let Some(ref callback_config) = router_config.subscriptions.callback { if !callback_config.heartbeat_interval.is_zero() { - let enforcer_subs = active_callback_subscriptions.clone(); + let enforcer_subs = active_subscriptions.clone(); let heartbeat_interval = callback_config.heartbeat_interval; bg_tasks_manager.register_task(HeartbeatEnforcerTask { active_subscriptions: enforcer_subs, @@ -119,6 +117,7 @@ impl SchemaState { let metrics = telemetry_context.metrics.clone(); let task_telemetry = telemetry_context.clone(); + let active_subscriptions_for_reload = active_subscriptions.clone(); bg_tasks_manager.register_handle(async move { let supergraph_metrics = &metrics.supergraph; while let Some(new_sdl) = rx.recv().await { @@ -165,7 +164,7 @@ impl SchemaState { router_config.clone(), task_telemetry.clone(), new_ast, - active_callback_subscriptions_for_build_data.clone(), + active_subscriptions_for_build_data.clone(), ) }) { Ok(mut new_supergraph_data) => { @@ -205,6 +204,14 @@ impl SchemaState { new_supergraph_data = end_payload.new_supergraph_data; } + // close all active subscriptions before swapping supergraph data + active_subscriptions_for_reload.close_all_with_error(vec![ + GraphQLError::from_message_and_code( + "subscription has been closed due to a schema reload", + "SUBSCRIPTION_SCHEMA_RELOAD", + ), + ]); + swappable_data_spawn_clone.store(Arc::new(Some(new_supergraph_data))); debug!("Supergraph updated successfully"); @@ -226,7 +233,7 @@ impl SchemaState { validate_cache, normalize_cache, telemetry_context: telemetry_context.clone(), - active_callback_subscriptions, + active_subscriptions, }) } @@ -234,7 +241,7 @@ impl SchemaState { router_config: Arc, telemetry_context: Arc, parsed_supergraph_sdl: Document, - active_callback_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: Arc, ) -> Result { let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; let metadata = Arc::new(planner.consumer_schema.schema_metadata()); @@ -243,7 +250,7 @@ impl SchemaState { &planner.supergraph.subgraph_endpoint_map, router_config, telemetry_context, - active_callback_subscriptions, + active_subscriptions, )?); Ok(SupergraphData { @@ -335,7 +342,7 @@ impl BackgroundTask for SupergraphBackgroundLoaderTask { } struct HeartbeatEnforcerTask { - active_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: Arc, heartbeat_interval: Duration, } @@ -358,14 +365,14 @@ impl BackgroundTask for HeartbeatEnforcerTask { } let mut timed_out = Vec::new(); - for entry in self.active_subscriptions.iter() { - let last = *entry.value().last_heartbeat.lock().unwrap(); + for (id, last_heartbeat) in self.active_subscriptions.iter_callback_subscriptions() { + let last = *last_heartbeat.lock().unwrap(); if Instant::now().duration_since(last) > self.heartbeat_interval + // add a grace period if latency increases due to usage std::time::Duration::from_millis(500) { - timed_out.push(entry.key().clone()); + timed_out.push(id); } } @@ -375,16 +382,14 @@ impl BackgroundTask for HeartbeatEnforcerTask { subscription_id = %id, "terminating subscription due to missed heartbeat" ); - if let Some((_, sub)) = self.active_subscriptions.remove(&id) { - // we dont care about the result of this send, if it fails it means the client - // is already gone or too slow, either way we just terminate the subscription - let _ = sub.sender.try_send(CallbackMessage::Complete { - errors: Some(vec![GraphQLError::from_message_and_extensions( - "Subgraph gone due to heartbeat timeout".to_string(), - GraphQLErrorExtensions::new_from_code("SUBGRAPH_GONE"), - )]), - }); - } + self.active_subscriptions.send_event( + &id, + BroadcastItem::Error(vec![GraphQLError::from_message_and_code( + "Subgraph gone due to heartbeat timeout".to_string(), + "SUBGRAPH_GONE", + )]), + ); + self.active_subscriptions.remove(&id); } } } diff --git a/lib/executor/src/executors/active_subscriptions.rs b/lib/executor/src/executors/active_subscriptions.rs new file mode 100644 index 000000000..87db1a3c9 --- /dev/null +++ b/lib/executor/src/executors/active_subscriptions.rs @@ -0,0 +1,258 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::Instant; + +use bytes::Bytes; +use dashmap::DashMap; +use tracing::trace; +use uuid::Uuid; + +use crate::response::graphql_error::GraphQLError; + +pub type SubscriptionId = String; +pub type Fingerprint = u64; + +#[derive(Clone, Debug)] +pub enum BroadcastItem { + /// a normal subscription event from the upstream, already serialized. + /// uses Bytes for zero-copy cloning across broadcast receivers + Event(Bytes), + /// a terminal error pushed externally (e.g. supergraph reload, shutdown). + /// consumers should yield this as the final event and then stop + Error(Vec), +} + +/// state specific to http callback subscriptions +pub struct CallbackState { + pub verifier: String, + pub last_heartbeat: Arc>, +} + +impl CallbackState { + pub fn record_heartbeat(&self) { + *self.last_heartbeat.lock().unwrap() = Instant::now(); + } +} + +struct ActiveSubscriptionEntry { + sender: tokio::sync::broadcast::Sender, + listener_count: Arc, + fingerprint: Option, + callback_state: Option, +} + +pub struct ActiveSubscriptionsRegistry { + subscriptions: DashMap, + fingerprints: DashMap, +} + +// tokio::sync::broadcast capacity. subscription events are typically low-frequency, +// so a small buffer is fine. if backpressure becomes an issue, we should rethink +// this (e.g. per-consumer mpsc channels with a fan-out task) +const BROADCAST_CAPACITY: usize = 32; + +impl ActiveSubscriptionsRegistry { + pub fn new() -> Self { + Self { + subscriptions: DashMap::new(), + fingerprints: DashMap::new(), + } + } + + /// try to join an existing subscription by fingerprint. + /// returns the subscription id, a broadcast receiver, and a listener guard if a match exists + pub fn try_join_by_fingerprint( + self: &Arc, + fingerprint: Fingerprint, + ) -> Option<( + SubscriptionId, + tokio::sync::broadcast::Receiver, + ListenerGuard, + )> { + let sub_id = self.fingerprints.get(&fingerprint)?.value().clone(); + + let entry = self.subscriptions.get(&sub_id)?; + let receiver = entry.sender.subscribe(); + entry.listener_count.fetch_add(1, Ordering::AcqRel); + + let guard = ListenerGuard { + id: sub_id.clone(), + registry: Arc::clone(self), + listener_count: entry.listener_count.clone(), + fingerprint: entry.fingerprint, + }; + + trace!(subscription_id = %sub_id, fingerprint = fingerprint, "joined existing subscription via dedup"); + + Some((sub_id, receiver, guard)) + } + + /// register a brand new subscription. returns: + /// - a SubscriptionHandle for the upstream producer (dropping it removes the entry and closes the channel) + /// - a broadcast receiver for the first consumer + /// - a ListenerGuard that tracks this consumer's lifetime + pub fn register( + self: &Arc, + fingerprint: Option, + callback_state: Option, + ) -> ( + SubscriptionHandle, + tokio::sync::broadcast::Receiver, + ListenerGuard, + ) { + let id = Uuid::new_v4().to_string(); + let (sender, receiver) = tokio::sync::broadcast::channel(BROADCAST_CAPACITY); + let listener_count = Arc::new(AtomicUsize::new(1)); + + let entry = ActiveSubscriptionEntry { + sender, + listener_count: listener_count.clone(), + fingerprint, + callback_state, + }; + + self.subscriptions.insert(id.clone(), entry); + + if let Some(fp) = fingerprint { + self.fingerprints.insert(fp, id.clone()); + } + + let handle = SubscriptionHandle { + id: id.clone(), + registry: Arc::clone(self), + }; + + let guard = ListenerGuard { + id: id.clone(), + registry: Arc::clone(self), + listener_count, + fingerprint, + }; + + trace!(subscription_id = %id, "registered new subscription"); + + (handle, receiver, guard) + } + + /// check if a subscription exists + pub fn contains(&self, id: &str) -> bool { + self.subscriptions.contains_key(id) + } + + /// get the verifier for a callback subscription + pub fn get_callback_verifier(&self, id: &str) -> Option { + self.subscriptions + .get(id) + .and_then(|entry| entry.callback_state.as_ref().map(|cs| cs.verifier.clone())) + } + + /// record a heartbeat for a callback subscription + pub fn record_heartbeat(&self, id: &str) -> bool { + if let Some(entry) = self.subscriptions.get(id) { + if let Some(ref cs) = entry.callback_state { + cs.record_heartbeat(); + return true; + } + } + false + } + + /// send an event to a specific subscription's broadcast channel + pub fn send_event(&self, id: &str, item: BroadcastItem) -> bool { + if let Some(entry) = self.subscriptions.get(id) { + entry.sender.send(item).is_ok() + } else { + false + } + } + + /// remove a subscription entry and clean up its fingerprint mapping + pub fn remove(&self, id: &str) { + if let Some((_, entry)) = self.subscriptions.remove(id) { + if let Some(fp) = entry.fingerprint { + self.fingerprints.remove(&fp); + } + } + } + + /// close all active subscriptions with an error message + pub fn close_all_with_error(&self, errors: Vec) { + let item = BroadcastItem::Error(errors); + for entry in self.subscriptions.iter() { + let _ = entry.sender.send(item.clone()); + } + self.fingerprints.clear(); + self.subscriptions.clear(); + } + + /// iterate over all subscription ids and their callback state for heartbeat enforcement + pub fn iter_callback_subscriptions( + &self, + ) -> impl Iterator>)> + '_ { + self.subscriptions.iter().filter_map(|entry| { + entry + .callback_state + .as_ref() + .map(|cs| (entry.key().clone(), cs.last_heartbeat.clone())) + }) + } +} + +/// held by the upstream producer (the task that reads from the subgraph). +/// dropping this removes the subscription entry from the registry, which drops +/// the broadcast sender and closes the channel. all receivers will see Closed +/// and their streams will end naturally +pub struct SubscriptionHandle { + id: SubscriptionId, + registry: Arc, +} + +impl SubscriptionHandle { + pub fn id(&self) -> &str { + &self.id + } + + /// send an event to all listeners of this subscription + pub fn send(&self, item: BroadcastItem) -> bool { + self.registry.send_event(&self.id, item) + } +} + +impl Drop for SubscriptionHandle { + fn drop(&mut self) { + // removing the entry drops the broadcast sender inside it, closing the channel. + // all receivers will see Closed and their ListenerGuards will drop. + // we remove here (rather than in ListenerGuard) because the upstream is the + // authoritative source - when it's gone, the subscription is done + self.registry.remove(&self.id); + trace!(subscription_id = %self.id, "subscription handle dropped, upstream closed"); + } +} + +/// held by each consumer. on drop, decrements the listener count. +/// when the last listener drops and the subscription entry still exists +/// (upstream hasn't dropped yet), removes it - causing the upstream +/// task to see send() fail and exit +pub struct ListenerGuard { + id: SubscriptionId, + registry: Arc, + listener_count: Arc, + fingerprint: Option, +} + +impl Drop for ListenerGuard { + fn drop(&mut self) { + let prev = self.listener_count.fetch_sub(1, Ordering::AcqRel); + if prev == 1 { + // last listener gone, clean up. this also drops the sender, + // causing the upstream producer's send() to return false + self.registry.subscriptions.remove(&self.id); + if let Some(fp) = self.fingerprint { + self.registry.fingerprints.remove(&fp); + } + trace!(subscription_id = %self.id, "last listener dropped, subscription removed"); + } else { + trace!(subscription_id = %self.id, remaining = prev - 1, "listener dropped"); + } + } +} diff --git a/lib/executor/src/executors/http_callback.rs b/lib/executor/src/executors/http_callback.rs index 14a6fcb40..2b33d0cab 100644 --- a/lib/executor/src/executors/http_callback.rs +++ b/lib/executor/src/executors/http_callback.rs @@ -3,63 +3,26 @@ use std::time::{Duration, Instant}; use async_trait::async_trait; use bytes::Bytes; -use dashmap::DashMap; use futures::stream::BoxStream; use http::{HeaderMap, HeaderValue}; use http_body_util::BodyExt; use http_body_util::Full; use hyper::Version; -use tokio::sync::mpsc; use tracing::{debug, error, trace}; use uuid::Uuid; -use crate::executors::common::{ - SubgraphExecutionRequest, SubgraphExecutor, SUBSCRIPTION_EVENT_BUFFER_CAPACITY, +use crate::executors::active_subscriptions::{ + ActiveSubscriptionsRegistry, BroadcastItem, CallbackState, }; +use crate::executors::common::{SubgraphExecutionRequest, SubgraphExecutor}; use crate::executors::error::SubgraphExecutorError; use crate::executors::http::{build_request_body, HttpClient}; use crate::plugin_context::PluginRequestState; -use crate::response::graphql_error::GraphQLError; use crate::response::subgraph_response::SubgraphResponse; pub const CALLBACK_PROTOCOL_VERSION: &str = "callback/1.0"; pub const SUBSCRIPTION_PROTOCOL_HEADER: &str = "subscription-protocol"; -type SubscriptionId = String; - -#[derive(Clone)] -pub struct ActiveSubscription { - pub verifier: String, - pub sender: mpsc::Sender, - pub last_heartbeat: Arc>, -} - -impl ActiveSubscription { - pub fn record_heartbeat(&self) { - *self.last_heartbeat.lock().unwrap() = Instant::now(); - } -} - -#[derive(Debug)] -pub enum CallbackMessage { - Next { payload: Bytes }, - Complete { errors: Option> }, -} - -pub type ActiveSubscriptionsMap = Arc>; - -struct SubscriptionGuard { - subscription_id: SubscriptionId, - active_subscriptions: ActiveSubscriptionsMap, -} - -impl Drop for SubscriptionGuard { - fn drop(&mut self) { - self.active_subscriptions.remove(&self.subscription_id); - trace!(subscription_id = %self.subscription_id, "HTTP callback subscription entry removed from active subscriptions"); - } -} - pub struct HttpCallbackSubgraphExecutor { pub subgraph_name: String, pub endpoint: http::Uri, @@ -67,7 +30,7 @@ pub struct HttpCallbackSubgraphExecutor { pub header_map: HeaderMap, pub callback_base_url: String, pub heartbeat_interval_ms: u64, - pub active_subscriptions: ActiveSubscriptionsMap, + pub active_subscriptions: Arc, } impl HttpCallbackSubgraphExecutor { @@ -77,7 +40,7 @@ impl HttpCallbackSubgraphExecutor { http_client: Arc, callback_base_url: String, heartbeat_interval_ms: u64, - active_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: Arc, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -154,33 +117,27 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { BoxStream<'static, Result, SubgraphExecutorError>>, SubgraphExecutorError, > { - let subscription_id = Uuid::new_v4().to_string(); let verifier = Uuid::new_v4().to_string(); - let body = self.build_request_body(&mut execution_request, &subscription_id, &verifier)?; + let callback_state = CallbackState { + verifier: verifier.clone(), + // initialize last_heartbeat to now + heartbeat_interval so the enforcer + // won't evict the subscription before the subgraph's initial check arrives. + // the initial check from the subgraph can take up to heartbeat_interval to + // arrive (due to network latency), and without this head start the enforcer + // would evict the subscription before the first heartbeat is recorded. + last_heartbeat: Arc::new(Mutex::new( + Instant::now() + Duration::from_millis(self.heartbeat_interval_ms), + )), + }; - let (tx, mut rx) = mpsc::channel::(SUBSCRIPTION_EVENT_BUFFER_CAPACITY); - self.active_subscriptions.insert( - subscription_id.clone(), - ActiveSubscription { - verifier, - sender: tx, - // initialize last_heartbeat to now + heartbeat_interval so the enforcer - // won't evict the subscription before the subgraph's initial check arrives. - // the initial check from the subgraph can take up to heartbeat_interval to - // arrive (due to network latency), and without this head start the enforcer - // would evict the subscription before the first heartbeat is recorded. - last_heartbeat: Arc::new(Mutex::new( - Instant::now() + Duration::from_millis(self.heartbeat_interval_ms), - )), - }, - ); + let (handle, mut receiver, guard) = self + .active_subscriptions + .register(None, Some(callback_state)); - // guard removes the entry from `active_subscriptions` when dropped - let guard = SubscriptionGuard { - subscription_id: subscription_id.clone(), - active_subscriptions: self.active_subscriptions.clone(), - }; + let subscription_id = handle.id().to_string(); + + let body = self.build_request_body(&mut execution_request, &subscription_id, &verifier)?; let mut req = hyper::Request::builder() .method(http::Method::POST) @@ -244,14 +201,15 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { } Ok(Box::pin(async_stream::stream! { - // `guard` is held here; dropping the stream drops `guard`, removing the map entry. + // hold the handle and guard so the subscription entry is removed when the stream ends + let _handle = handle; let _guard = guard; trace!(subscription_id = %subscription_id, "HTTP callback subscription stream started"); - while let Some(msg) = rx.recv().await { - match msg { - CallbackMessage::Next { payload } => { + loop { + match receiver.recv().await { + Ok(BroadcastItem::Event(payload)) => { trace!(subscription_id = %subscription_id, "received next payload"); match SubgraphResponse::deserialize_from_bytes(payload) { Ok(response) => yield Ok(response), @@ -266,18 +224,25 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { } } } - CallbackMessage::Complete { errors } => { - trace!(subscription_id = %subscription_id, "received complete"); - if let Some(errors) = errors { - if !errors.is_empty() { - yield Ok(SubgraphResponse { - errors: Some(errors), - ..Default::default() - }); - } + Ok(BroadcastItem::Error(errors)) => { + trace!(subscription_id = %subscription_id, "received close with error"); + if !errors.is_empty() { + yield Ok(SubgraphResponse { + errors: Some(errors), + ..Default::default() + }); } break; } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + // slow consumer, skip missed messages and continue + trace!(subscription_id = %subscription_id, lagged = n, "broadcast receiver lagged, skipping missed messages"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + trace!(subscription_id = %subscription_id, "broadcast channel closed"); + break; + } } } diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 6e04d575c..c5e267749 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -27,10 +27,11 @@ use tokio::sync::Semaphore; use crate::{ execution::client_request_details::ClientRequestDetails, executors::{ + active_subscriptions::ActiveSubscriptionsRegistry, common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient, SubgraphHttpResponse}, - http_callback::{ActiveSubscriptionsMap, HttpCallbackSubgraphExecutor}, + http_callback::HttpCallbackSubgraphExecutor, websocket::WsSubgraphExecutor, }, hooks::on_subgraph_execute::{ @@ -74,8 +75,8 @@ pub struct SubgraphExecutorMap { max_connections_per_host: usize, in_flight_requests: InflightRequestsMap, telemetry_context: Arc, - /// Shared map of active HTTP callback subscriptions - active_callback_subscriptions: ActiveSubscriptionsMap, + /// Shared registry of all active subscriptions (http streaming, websocket, http callback) + active_subscriptions: Arc, } fn build_https_executor() -> Result, SubgraphExecutorError> { @@ -112,7 +113,7 @@ impl SubgraphExecutorMap { timeouts_by_subgraph: Default::default(), global_timeout, telemetry_context, - active_callback_subscriptions: Arc::new(DashMap::new()), + active_subscriptions: Arc::new(ActiveSubscriptionsRegistry::new()), }) } @@ -120,7 +121,7 @@ impl SubgraphExecutorMap { subgraph_endpoint_map: &HashMap, config: Arc, telemetry_context: Arc, - active_callback_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: Arc, ) -> Result { let global_timeout = DurationOrProgram::compile( &config.traffic_shaping.all.request_timeout, @@ -131,7 +132,7 @@ impl SubgraphExecutorMap { })?; let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), global_timeout, telemetry_context)?; - subgraph_executor_map.active_callback_subscriptions = active_callback_subscriptions; + subgraph_executor_map.active_subscriptions = active_subscriptions; for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.iter() { let endpoint_config = config @@ -156,9 +157,9 @@ impl SubgraphExecutorMap { Ok(subgraph_executor_map) } - /// Returns the shared active callback subscriptions map for use by callback handlers. - pub fn active_callback_subscriptions(&self) -> ActiveSubscriptionsMap { - self.active_callback_subscriptions.clone() + /// Returns the shared active subscriptions registry. + pub fn active_subscriptions(&self) -> Arc { + self.active_subscriptions.clone() } pub async fn execute<'exec>( @@ -505,7 +506,7 @@ impl SubgraphExecutorMap { self.client.clone(), callback_config.public_url.to_string(), heartbeat_interval_ms, - self.active_callback_subscriptions.clone(), + self.active_subscriptions.clone(), ) .to_boxed_arc(); diff --git a/lib/executor/src/executors/mod.rs b/lib/executor/src/executors/mod.rs index 283e331e3..efc39c23d 100644 --- a/lib/executor/src/executors/mod.rs +++ b/lib/executor/src/executors/mod.rs @@ -1,3 +1,4 @@ +pub mod active_subscriptions; pub mod common; pub mod dedupe; pub mod error; From 9b787a6302686da75351757acd02592eaf95e43b Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Thu, 2 Apr 2026 19:04:04 +0200 Subject: [PATCH 02/42] of course I didnt run tests --- e2e/src/testkit/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/e2e/src/testkit/mod.rs b/e2e/src/testkit/mod.rs index d38cab9ca..ad5632610 100644 --- a/e2e/src/testkit/mod.rs +++ b/e2e/src/testkit/mod.rs @@ -707,7 +707,7 @@ impl TestRouter { let serv_shared_state = shared_state.clone(); let serv_schema_state = schema_state.clone(); - let serv_active_subs = schema_state.active_callback_subscriptions.clone(); + let serv_active_subs = schema_state.active_subscriptions.clone(); let serv_graphql_path = self.graphql_path.clone(); let serv_websocket_path = self.websocket_path.clone(); @@ -721,7 +721,7 @@ impl TestRouter { }) => { let cb_path = path.to_string(); let cb_addr = listen.to_string(); - let cb_active_subs = schema_state.active_callback_subscriptions.clone(); + let cb_active_subs = schema_state.active_subscriptions.clone(); let server = web::HttpServer::new(async move || { let active_subs = cb_active_subs.clone(); From 6d64f8e1caa8672f3d274ae1bc65fa27e3c8f83f Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Thu, 2 Apr 2026 19:27:52 +0200 Subject: [PATCH 03/42] dedupe what can be deduped --- bin/router/src/pipeline/mod.rs | 53 +++++---- bin/router/src/pipeline/websocket_server.rs | 117 ++++++-------------- 2 files changed, 64 insertions(+), 106 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index a8b1773ae..384a1abb4 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -472,25 +472,11 @@ async fn execute_planned_request<'exec>( .await? { QueryPlanExecutionResult::Stream(result) => { - let body_stream = if let Some(fingerprint) = subscription_fingerprint { - let registry = &schema_state.active_subscriptions; - let (handle, receiver, guard) = registry.register(Some(fingerprint), None); - - // spawn a task that reads from the upstream and broadcasts to all listeners. - // dropping the handle when the upstream ends removes the registry entry - let mut upstream = result.body; - tokio::spawn(async move { - while let Some(event) = upstream.next().await { - if !handle.send(BroadcastItem::Event(event.into())) { - break; - } - } - }); - - broadcast_receiver_to_body_stream(receiver, guard) - } else { - result.body - }; + let body_stream = register_subscription_leader( + result.body, + subscription_fingerprint, + schema_state, + ); let response = build_streaming_response( body_stream, @@ -593,9 +579,36 @@ pub async fn execute_pipeline<'exec>( execute_plan(supergraph, shared_state, planned_request, operation_span).await } +/// registers upstream as a subscription leader, spawning a broadcast task and +/// returning a stream backed by the broadcast receiver. if fingerprint is None, +/// the upstream stream is returned as-is with no registration. +pub(crate) fn register_subscription_leader( + upstream: futures::stream::BoxStream<'static, Vec>, + fingerprint: Option, + schema_state: &SchemaState, +) -> futures::stream::BoxStream<'static, Vec> { + let Some(fingerprint) = fingerprint else { + return upstream; + }; + + let registry = &schema_state.active_subscriptions; + let (handle, receiver, guard) = registry.register(Some(fingerprint), None); + + let mut upstream = upstream; + tokio::spawn(async move { + while let Some(event) = upstream.next().await { + if !handle.send(BroadcastItem::Event(event.into())) { + break; + } + } + }); + + broadcast_receiver_to_body_stream(receiver, guard) +} + /// converts a broadcast receiver into a BoxStream of serialized event bodies. /// the listener guard is held for the lifetime of the stream to track listener count -fn broadcast_receiver_to_body_stream( +pub(crate) fn broadcast_receiver_to_body_stream( mut receiver: tokio::sync::broadcast::Receiver, guard: ListenerGuard, ) -> futures::stream::BoxStream<'static, Vec> { diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 659a0715d..87ecb6995 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -422,6 +422,29 @@ async fn handle_text_frame( return Some(PipelineError::SubscriptionsNotSupported.into_server_message(&id)); } + let request_dedupe_enabled = + shared_state.router_config.traffic_shaping.router.dedupe.enabled; + + let fingerprint = if is_subscription && request_dedupe_enabled { + let variables_hash = hash_graphql_variables(&payload.variables); + let extensions_hash = payload + .extensions + .as_ref() + .map_or(0, hash_graphql_extensions); + Some(inbound_request_fingerprint( + &Method::POST, + ws_uri.path(), + &headers, + &shared_state.in_flight_requests_header_policy, + supergraph.schema_checksum(), + normalize_payload.normalized_operation_hash, + variables_hash, + extensions_hash, + )) + } else { + None + }; + let jwt_request_details = match &shared_state.jwt_auth_runtime { Some(jwt_auth_runtime) => match jwt_auth_runtime .validate_headers(&headers, &shared_state.jwt_claims_cache) @@ -484,35 +507,11 @@ async fn handle_text_frame( jwt: jwt_request_details, }.into(); - let is_subscription = matches!( - normalize_payload.operation_for_plan.operation_kind, - Some(OperationKind::Subscription) - ); - let request_dedupe_enabled = - shared_state.router_config.traffic_shaping.router.dedupe.enabled; - // subscription dedup: try to join an existing subscription - if is_subscription && request_dedupe_enabled { - let variables_hash = hash_graphql_variables(&payload.variables); - let extensions_hash = payload - .extensions - .as_ref() - .map_or(0, hash_graphql_extensions); - let schema_checksum = supergraph.schema_checksum(); - let fingerprint = inbound_request_fingerprint( - &Method::POST, - ws_uri.path(), - &headers, - &shared_state.in_flight_requests_header_policy, - schema_checksum, - normalize_payload.normalized_operation_hash, - variables_hash, - extensions_hash, - ); - + if let Some(fp) = fingerprint { let registry = &schema_state.active_subscriptions; if let Some((_sub_id, mut receiver, guard)) = - registry.try_join_by_fingerprint(fingerprint) + registry.try_join_by_fingerprint(fp) { let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); state.borrow_mut().subscriptions.insert(id.clone(), cancel_tx); @@ -609,66 +608,12 @@ async fn handle_text_frame( Some(ServerMessage::complete(&id)) } Ok(QueryPlanExecutionResult::Stream(response)) => { - // register with the active subscriptions registry for - // dedup and close-with-error support - let subscription_fingerprint = if is_subscription && request_dedupe_enabled { - let variables_hash = hash_graphql_variables(&payload.variables); - let extensions_hash = payload - .extensions - .as_ref() - .map_or(0, hash_graphql_extensions); - let schema_checksum = supergraph.schema_checksum(); - Some(inbound_request_fingerprint( - &Method::POST, - ws_uri.path(), - &headers, - &shared_state.in_flight_requests_header_policy, - schema_checksum, - normalize_payload.normalized_operation_hash, - variables_hash, - extensions_hash, - )) - } else { - None - }; - - let (mut stream, _listener_guard) = if let Some(fp) = subscription_fingerprint { - let registry = &schema_state.active_subscriptions; - let (handle, receiver, guard) = registry.register(Some(fp), None); - - // spawn a task that reads from upstream and broadcasts - let mut upstream = response.body; - tokio::spawn(async move { - while let Some(event) = upstream.next().await { - if !handle.send(BroadcastItem::Event(event.into())) { - break; - } - } - }); - - let body_stream: futures::stream::BoxStream<'static, Vec> = - Box::pin(async_stream::stream! { - let mut receiver = receiver; - loop { - match receiver.recv().await { - Ok(BroadcastItem::Event(data)) => yield data.to_vec(), - Ok(BroadcastItem::Error(errors)) => { - yield FailedExecutionResult { errors }.serialize(); - break; - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { - trace!(lagged = n, "broadcast receiver lagged, skipping missed messages"); - continue; - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => break, - } - } - }); - - (body_stream, Some(guard)) - } else { - (response.body, None) - }; + let subscription_fingerprint = if is_subscription { fingerprint } else { None }; + let mut stream = crate::pipeline::register_subscription_leader( + response.body, + subscription_fingerprint, + schema_state, + ); // we use mpsc::channel(1) instead of oneshot because oneshot::Receiver // is consumed on first await, which doesn't work in tokio::select! loops that From 00dbae2d151bea5b464b795adb242dc3d2e564ea Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Thu, 2 Apr 2026 19:38:12 +0200 Subject: [PATCH 04/42] haha --- bin/router/src/schema_state.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index 9905d83bb..ef3c7b674 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -206,6 +206,8 @@ impl SchemaState { // close all active subscriptions before swapping supergraph data active_subscriptions_for_reload.close_all_with_error(vec![ + // this is litearaly the same message apollo sends - reasoning is + // drop-in-replacement - is that oke? should we have our own? GraphQLError::from_message_and_code( "subscription has been closed due to a schema reload", "SUBSCRIPTION_SCHEMA_RELOAD", From 8377ac6418d5a60f264fe9e28e02a65c7c526a0f Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Thu, 2 Apr 2026 19:58:21 +0200 Subject: [PATCH 05/42] broadcast capacity --- bin/router/src/schema_state.rs | 2 +- .../src/executors/active_subscriptions.rs | 12 +++++------- lib/executor/src/executors/map.rs | 3 ++- lib/router-config/src/subscriptions.rs | 18 ++++++++++++++++++ 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index ef3c7b674..0e1b851e8 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -97,7 +97,7 @@ impl SchemaState { let plan_cache = cache_state.plan_cache.clone(); let validate_cache = cache_state.validate_cache.clone(); let normalize_cache = cache_state.normalize_cache.clone(); - let active_subscriptions = Arc::new(ActiveSubscriptionsRegistry::new()); + let active_subscriptions = Arc::new(ActiveSubscriptionsRegistry::new(router_config.subscriptions.broadcast_capacity)); // This is cheap clone, as Cache is thread-safe and can be cloned without any performance penalty. let cache_state_for_invalidation = cache_state.clone(); diff --git a/lib/executor/src/executors/active_subscriptions.rs b/lib/executor/src/executors/active_subscriptions.rs index 87db1a3c9..d525b3fb0 100644 --- a/lib/executor/src/executors/active_subscriptions.rs +++ b/lib/executor/src/executors/active_subscriptions.rs @@ -44,18 +44,16 @@ struct ActiveSubscriptionEntry { pub struct ActiveSubscriptionsRegistry { subscriptions: DashMap, fingerprints: DashMap, + // capacity of the broadcast channel per subscription, see router config `subscriptions.broadcast_capacity` + broadcast_capacity: usize, } -// tokio::sync::broadcast capacity. subscription events are typically low-frequency, -// so a small buffer is fine. if backpressure becomes an issue, we should rethink -// this (e.g. per-consumer mpsc channels with a fan-out task) -const BROADCAST_CAPACITY: usize = 32; - impl ActiveSubscriptionsRegistry { - pub fn new() -> Self { + pub fn new(broadcast_capacity: usize) -> Self { Self { subscriptions: DashMap::new(), fingerprints: DashMap::new(), + broadcast_capacity, } } @@ -101,7 +99,7 @@ impl ActiveSubscriptionsRegistry { ListenerGuard, ) { let id = Uuid::new_v4().to_string(); - let (sender, receiver) = tokio::sync::broadcast::channel(BROADCAST_CAPACITY); + let (sender, receiver) = tokio::sync::broadcast::channel(self.broadcast_capacity); let listener_count = Arc::new(AtomicUsize::new(1)); let entry = ActiveSubscriptionEntry { diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index c5e267749..ab8594647 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -99,6 +99,7 @@ impl SubgraphExecutorMap { .build(build_https_executor()?); let max_connections_per_host = config.traffic_shaping.max_connections_per_host; + let broadcast_capacity = config.subscriptions.broadcast_capacity; Ok(SubgraphExecutorMap { http_executors_by_subgraph: Default::default(), @@ -113,7 +114,7 @@ impl SubgraphExecutorMap { timeouts_by_subgraph: Default::default(), global_timeout, telemetry_context, - active_subscriptions: Arc::new(ActiveSubscriptionsRegistry::new()), + active_subscriptions: Arc::new(ActiveSubscriptionsRegistry::new(broadcast_capacity)), }) } diff --git a/lib/router-config/src/subscriptions.rs b/lib/router-config/src/subscriptions.rs index f3c727362..c922780fd 100644 --- a/lib/router-config/src/subscriptions.rs +++ b/lib/router-config/src/subscriptions.rs @@ -15,6 +15,20 @@ pub struct SubscriptionsConfig { /// You can override this setting by setting the `SUBSCRIPTIONS_ENABLED` environment variable to `true` or `false`. #[serde(default)] pub enabled: bool, + /// The capacity of the broadcast channel used to fan out subscription events to all active listeners. + /// + /// Each active subscription has its own broadcast channel. This value controls how many events + /// can be buffered in that channel before slow consumers start lagging. If a consumer falls too + /// far behind and the buffer is full, it will skip the missed messages and continue from the + /// latest available event. + /// + /// Subscription events are typically low-frequency, so the default of 32 is sufficient for most + /// use cases. Increase this value if you expect bursts of events or have slow consumers that + /// need more headroom to catch up. + /// + /// Defaults to 32. + #[serde(default = "default_broadcast_capacity")] + pub broadcast_capacity: usize, /// Configuration for subgraphs using the HTTP Callback protocol. #[serde(default, skip_serializing_if = "Option::is_none")] pub callback: Option, @@ -61,6 +75,10 @@ pub struct CallbackConfig { pub subgraphs: HashSet, } +fn default_broadcast_capacity() -> usize { + 32 +} + fn default_callback_path() -> AbsolutePath { AbsolutePath::try_from("/callback").expect("default callback path is valid") } From d1e31607eb0e389a42163b7d8a182ec88f18af47 Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Thu, 2 Apr 2026 20:24:52 +0200 Subject: [PATCH 06/42] haha --- bin/router/src/pipeline/websocket_server.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 87ecb6995..4cecd63af 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -507,6 +507,8 @@ async fn handle_text_frame( jwt: jwt_request_details, }.into(); + // TODO: query dedupe + // subscription dedup: try to join an existing subscription if let Some(fp) = fingerprint { let registry = &schema_state.active_subscriptions; From 85e528d6d6cda897e667d71a48371f8dc77139f8 Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Thu, 2 Apr 2026 20:32:35 +0200 Subject: [PATCH 07/42] he he --- lib/router-config/src/traffic_shaping.rs | 29 +++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/lib/router-config/src/traffic_shaping.rs b/lib/router-config/src/traffic_shaping.rs index 8b362b6c6..5aa230ee4 100644 --- a/lib/router-config/src/traffic_shaping.rs +++ b/lib/router-config/src/traffic_shaping.rs @@ -185,10 +185,33 @@ pub struct TrafficShapingRouterConfig { #[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] #[serde(deny_unknown_fields)] pub struct TrafficShapingRouterDedupeConfig { - /// Enables/disables in-flight request deduplication at the router endpoint level. + /// Enables/disables in-flight request and active subscriptions deduplication at the router level. /// - /// When enabled, identical incoming GraphQL query requests that are processed at the same time - /// share the same in-flight execution result. + /// When enabled, the router deduplicates both queries and subscriptions using the same + /// fingerprint key (method, path, selected headers, schema checksum, normalized operation + /// hash, variables, and extensions). The `headers` configuration below controls which + /// headers participate in that key for all operation types. + /// + /// For queries, concurrent HTTP requests that produce the same fingerprint share a single + /// in-flight execution - only the first one runs, and the rest wait for and receive the + /// same result. + /// + /// For subscriptions, the mechanism is broadcast-based rather than request-sharing. The + /// first client with a given fingerprint becomes the leader: it runs the upstream subscription + /// and its events are fanned out through a broadcast channel backed by an active subscriptions + /// registry. Any subsequent client that arrives with an identical fingerprint while that subscription + /// is still active joins as a listener on the same broadcast channel instead of starting a new upstream + /// connection. When all listeners have dropped and the leader finishes, the entry is removed from the + /// registry. + /// + /// WebSocket connections participate in the same deduplication space as HTTP. Each + /// subscribe message is processed with a synthetic request assembled from the WebSocket + /// path and the headers derived from the `websocket.headers` config. The fingerprint is computed + /// from those synthetic headers using the same header policy, so a subscription started over HTTP + /// and an identical one started over WebSocket will deduplicate against each other. + /// + /// The deduplication is transport agnostic. A query over WebSocket would get deduplicated with an + /// identical query over HTTP if they arrive at the same time and have the same fingerprint. #[serde(default = "default_router_dedupe_enabled")] pub enabled: bool, From ced21955f7b840b80794d57d4dcb4d7f3c63eea4 Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Thu, 2 Apr 2026 20:46:10 +0200 Subject: [PATCH 08/42] format nad long lived client limit --- bin/router/src/lib.rs | 6 + .../src/pipeline/long_lived_client_limit.rs | 165 ++++++++++++++++++ bin/router/src/pipeline/mod.rs | 8 +- bin/router/src/schema_state.rs | 4 +- bin/router/src/shared_state.rs | 4 + lib/router-config/src/traffic_shaping.rs | 12 ++ 6 files changed, 193 insertions(+), 6 deletions(-) create mode 100644 bin/router/src/pipeline/long_lived_client_limit.rs diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 88bb4faa7..44baeab28 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -26,6 +26,7 @@ use crate::{ graphql_request_handler, header::ResponseMode, http_callback::handler, + long_lived_client_limit::LongLivedClientLimitService, request_extensions::{ read_graphql_operation_metric_identity, read_graphql_response_metric_status, read_request_body_size, write_graphql_response_metric_status, @@ -248,10 +249,15 @@ pub async fn router_entrypoint(plugin_registry: PluginRegistry) -> Result<(), Ro let paths = RouterPaths::new(graphql_path.clone(), websocket_path, callback_path); paths.detect_conflicts(&prometheus)?; + let long_lived_client_limit_service = + LongLivedClientLimitService::new(&shared_state.router_config); + let maybe_error = web::HttpServer::new(async move || { let landing_page_path = graphql_path.clone(); let prometheus = prometheus.clone(); + let long_lived_client_limit_service = long_lived_client_limit_service.clone(); web::App::new() + .middleware(long_lived_client_limit_service) .middleware(PluginService) .state(shared_state.clone()) .state(schema_state.clone()) diff --git a/bin/router/src/pipeline/long_lived_client_limit.rs b/bin/router/src/pipeline/long_lived_client_limit.rs new file mode 100644 index 000000000..8c2a1d221 --- /dev/null +++ b/bin/router/src/pipeline/long_lived_client_limit.rs @@ -0,0 +1,165 @@ +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +use http::{header, StatusCode}; +use ntex::{ + service::{Service, ServiceCtx}, + web::{self, DefaultError}, + Middleware, SharedCfg, +}; + +use crate::RouterSharedState; + +/// pre-resolved at app construction time so the per-request path is branch-free +#[derive(Clone)] +pub struct LongLivedClientLimitService { + /// false means the middleware is entirely bypassed on every request + enabled: bool, +} + +impl LongLivedClientLimitService { + pub fn new(router_config: &hive_router_config::HiveRouterConfig) -> Self { + let limit = router_config.traffic_shaping.router.max_long_lived_clients; + let has_long_lived = router_config.subscriptions.enabled || router_config.websocket.enabled; + Self { + enabled: limit > 0 && has_long_lived, + } + } +} + +impl Middleware for LongLivedClientLimitService { + type Service = LongLivedClientLimitMiddleware; + + fn create(&self, service: S, _cfg: SharedCfg) -> Self::Service { + LongLivedClientLimitMiddleware { + service, + enabled: self.enabled, + } + } +} + +pub struct LongLivedClientLimitMiddleware { + service: S, + enabled: bool, +} + +impl Service> for LongLivedClientLimitMiddleware +where + S: Service, Response = web::WebResponse, Error = web::Error>, +{ + type Response = web::WebResponse; + type Error = S::Error; + + ntex::forward_ready!(service); + + async fn call( + &self, + req: web::WebRequest, + ctx: ServiceCtx<'_, Self>, + ) -> Result { + if !self.enabled { + return ctx.call(&self.service, req).await; + } + + if !is_long_lived_request(req.headers()) { + return ctx.call(&self.service, req).await; + } + + let shared_state = match req.app_state::>() { + Some(s) => s, + None => return ctx.call(&self.service, req).await, + }; + + let limit = shared_state + .router_config + .traffic_shaping + .router + .max_long_lived_clients; + let counter = shared_state.long_lived_client_count.clone(); + + // try to reserve a slot; back off if we're at the limit + let prev = counter.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| { + if current < limit { + Some(current + 1) + } else { + None + } + }); + + if prev.is_err() { + let error_response = web::HttpResponse::build(StatusCode::SERVICE_UNAVAILABLE) + .header(header::RETRY_AFTER, "5") + .body("Too many long-lived clients"); + return Ok(req.into_response(error_response)); + } + + let guard = LongLivedClientGuard(counter); + let response = ctx.call(&self.service, req).await?; + drop(guard); + + Ok(response) + } +} + +/// decrements the counter when dropped +struct LongLivedClientGuard(Arc); + +impl Drop for LongLivedClientGuard { + fn drop(&mut self) { + self.0.fetch_sub(1, Ordering::AcqRel); + } +} + +/// returns true if the request is a websocket upgrade or an http streaming request. +/// +/// deliberately ordered cheapest check first: +/// 1. upgrade: websocket - two header lookups, no parsing +/// 2. accept streaming - one header lookup + fast substring pre-filter, full parse only if needed +#[inline] +fn is_long_lived_request(headers: &ntex::http::HeaderMap) -> bool { + // websocket: Connection: Upgrade + Upgrade: websocket + // both headers must be present and contain the expected values (case-insensitive) + if headers + .get(header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .is_some_and(|v| v.eq_ignore_ascii_case("websocket")) + && headers + .get(header::CONNECTION) + .and_then(|v| v.to_str().ok()) + .is_some_and(|v| v.to_ascii_lowercase().contains("upgrade")) + { + return true; + } + + // http streaming: Accept header contains a known streaming content type. + // we do a fast substring scan before handing off to the full Accept parser + // to avoid the parse cost on the hot path for regular requests. + let accept = match headers.get(header::ACCEPT).and_then(|v| v.to_str().ok()) { + Some(v) if !v.is_empty() => v, + _ => return false, + }; + + if !looks_like_streaming_accept(accept) { + return false; + } + + use crate::pipeline::header::StreamContentType; + use headers_accept::Accept; + use std::str::FromStr; + + Accept::from_str(accept) + .ok() + .and_then(|a| a.negotiate(StreamContentType::media_types().iter())) + .is_some() +} + +/// fast pre-filter: returns true if the raw Accept string contains any substring +/// that could match a known streaming content type, avoiding the full parse on +/// the vast majority of regular (application/json) requests. +#[inline] +fn looks_like_streaming_accept(accept: &str) -> bool { + // covers: multipart/mixed, text/event-stream + accept.contains("multipart") || accept.contains("event-stream") +} diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 384a1abb4..25117bbfe 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -74,6 +74,7 @@ pub mod execution_request; pub mod header; pub mod http_callback; pub mod introspection_policy; +pub mod long_lived_client_limit; pub mod multipart_subscribe; pub mod normalize; pub mod parser; @@ -472,11 +473,8 @@ async fn execute_planned_request<'exec>( .await? { QueryPlanExecutionResult::Stream(result) => { - let body_stream = register_subscription_leader( - result.body, - subscription_fingerprint, - schema_state, - ); + let body_stream = + register_subscription_leader(result.body, subscription_fingerprint, schema_state); let response = build_streaming_response( body_stream, diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index 0e1b851e8..a8789e6f5 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -97,7 +97,9 @@ impl SchemaState { let plan_cache = cache_state.plan_cache.clone(); let validate_cache = cache_state.validate_cache.clone(); let normalize_cache = cache_state.normalize_cache.clone(); - let active_subscriptions = Arc::new(ActiveSubscriptionsRegistry::new(router_config.subscriptions.broadcast_capacity)); + let active_subscriptions = Arc::new(ActiveSubscriptionsRegistry::new( + router_config.subscriptions.broadcast_capacity, + )); // This is cheap clone, as Cache is thread-safe and can be cloned without any performance penalty. let cache_state_for_invalidation = cache_state.clone(); diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 76aef6673..8a3aa3050 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -17,6 +17,7 @@ use moka::future::Cache; use moka::Expiry; use ntex::web; use ntex::{http::HeaderMap, util::Bytes}; +use std::sync::atomic::AtomicUsize; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{collections::HashSet, sync::Arc}; @@ -153,6 +154,8 @@ pub struct RouterSharedState { pub plugins: Option>>, pub in_flight_requests: RouterInflightRequestsMap, pub in_flight_requests_header_policy: RouterRequestDedupeHeaderPolicy, + /// tracks the number of active long-lived clients (websockets + http streams) + pub long_lived_client_count: Arc, } impl RouterSharedState { @@ -195,6 +198,7 @@ impl RouterSharedState { .dedupe .headers) .into(), + long_lived_client_count: Arc::new(AtomicUsize::new(0)), }) } } diff --git a/lib/router-config/src/traffic_shaping.rs b/lib/router-config/src/traffic_shaping.rs index 5aa230ee4..25d1cfc17 100644 --- a/lib/router-config/src/traffic_shaping.rs +++ b/lib/router-config/src/traffic_shaping.rs @@ -180,6 +180,13 @@ pub struct TrafficShapingRouterConfig { )] #[schemars(with = "String")] pub request_timeout: Duration, + + /// Maximum number of concurrent long-lived clients (WebSocket connections and HTTP streaming responses). + /// Regular non-streaming requests are not counted toward this limit. + /// When the limit is reached, new WebSocket and streaming HTTP requests are rejected with 503. + /// If both WebSockets and Subscriptions are disabled, this setting has no effect. + #[serde(default = "default_max_long_lived_clients")] + pub max_long_lived_clients: usize, } #[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] @@ -261,11 +268,16 @@ fn default_router_request_timeout() -> Duration { Duration::from_secs(60) } +fn default_max_long_lived_clients() -> usize { + 128 +} + impl Default for TrafficShapingRouterConfig { fn default() -> Self { Self { dedupe: Default::default(), request_timeout: default_router_request_timeout(), + max_long_lived_clients: default_max_long_lived_clients(), } } } From 4e732d3fa38cbda0d062b9da59ea8e08f18f4f51 Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Fri, 3 Apr 2026 11:48:41 +0200 Subject: [PATCH 09/42] of course race condition --- e2e/src/subscriptions.rs | 58 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/e2e/src/subscriptions.rs b/e2e/src/subscriptions.rs index 7951dcf67..bb0e213df 100644 --- a/e2e/src/subscriptions.rs +++ b/e2e/src/subscriptions.rs @@ -1394,4 +1394,62 @@ mod subscriptions_e2e_tests { event: complete "#); } + + #[ntex::test] + async fn active_subscriptions_deduplication() { + let subgraphs = TestSubgraphs::builder().build().start().await; + let router = TestRouter::builder() + .with_subgraphs(&subgraphs) + .inline_config( + r#" + supergraph: + source: file + path: supergraph.graphql + subscriptions: + enabled: true + traffic_shaping: + router: + dedupe: + enabled: true + "#, + ) + .build() + .start() + .await; + + let query = r#" + subscription { + reviewAdded(intervalInMs: 100) { + id + product { + name + } + } + } + "#; + let headers = some_header_map! { + http::header::ACCEPT => "text/event-stream" + }; + + let (sub1, sub2, sub3) = tokio::join!( + router.send_graphql_request(query, None, headers.clone()), + router.send_graphql_request(query, None, headers.clone()), + router.send_graphql_request(query, None, headers.clone()), + ); + + for sub in [&sub1, &sub2, &sub3] { + let body = sub.string_body().await; + assert!( + body.contains("event: next") && body.contains("event: complete"), + "Expected subscription to receive events and complete" + ); + } + + let reviews_requests = subgraphs.get_requests_log("reviews").unwrap_or_default(); + assert_eq!( + reviews_requests.len(), + 1, + "Expected requests to reviews subgraph to be deduplicated" + ); + } } From 6f8e441a063727c1e8e5cfdee4da87200cd95d44 Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Fri, 3 Apr 2026 17:02:16 +0200 Subject: [PATCH 10/42] ok attempt to clean this up but lets rebase first --- bin/router/src/pipeline/mod.rs | 89 ++++++------- bin/router/src/pipeline/websocket_server.rs | 25 ++-- .../src/executors/active_subscriptions.rs | 118 ++++++++++++------ lib/executor/src/executors/http_callback.rs | 2 +- 4 files changed, 141 insertions(+), 93 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 25117bbfe..cfd46d12c 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -10,8 +10,9 @@ use tracing::{error, trace, Instrument}; use xxhash_rust::xxh3::Xxh3; use hive_router_plan_executor::execution::plan::FailedExecutionResult; -use hive_router_plan_executor::executors::active_subscriptions::BroadcastItem; -use hive_router_plan_executor::executors::active_subscriptions::ListenerGuard; +use hive_router_plan_executor::executors::active_subscriptions::{ + BroadcastItem, FingerprintGuard, SubscriptionHandle, +}; use hive_router_plan_executor::headers::plan::ResponseHeaderAggregator; use hive_router_internal::telemetry::traces::spans::{ @@ -284,26 +285,20 @@ pub async fn graphql_request_handler( None }; - // subscription dedup, try to join an existing subscription - if is_subscription { - if let Some(fp) = fingerprint { - let registry = &schema_state.active_subscriptions; - if let Some((_sub_id, receiver, guard)) = - registry.try_join_by_fingerprint(fp) - { - let body_stream = broadcast_receiver_to_body_stream(receiver, guard); - let response = build_streaming_response( - body_stream, - response_mode, - None, - )?; - return Ok(response); - } + // subscription dedup: join or create a fingerprinted subscription. + // joiners get a receiver from the existing broadcast channel. + // the leader gets a handle to feed upstream events into later + let sub_dedup = if let Some(fp) = fingerprint.filter(|_| is_subscription) { + let (handle, receiver, guard) = + schema_state.active_subscriptions.dedupe_by_fingerprint(fp); + if handle.is_none() { + let body_stream = broadcast_receiver_to_body_stream(receiver, guard); + return build_streaming_response(body_stream, response_mode, None); } - } - - // subscription fingerprint for leader registration (None disables broadcasting) - let subscription_fingerprint = if is_subscription { fingerprint } else { None }; + Some((handle, receiver, guard)) + } else { + None + }; let planned_response = if fingerprint.is_some() && matches!( @@ -325,7 +320,7 @@ pub async fn graphql_request_handler( operation_span, plugin_req_state, response_mode, - subscription_fingerprint, + None, ) .await? { @@ -348,7 +343,7 @@ pub async fn graphql_request_handler( operation_span, plugin_req_state, response_mode, - subscription_fingerprint, + sub_dedup, ) .await? }; @@ -421,7 +416,11 @@ async fn execute_planned_request<'exec>( operation_span: GraphQLOperationSpan, plugin_req_state: Option>, response_mode: &'exec ResponseMode, - subscription_fingerprint: Option, + sub_dedup: Option<( + Option, + tokio::sync::broadcast::Receiver, + FingerprintGuard, + )>, ) -> Result { let jwt_request_details = match &shared_state.jwt_auth_runtime { Some(jwt_auth_runtime) => match jwt_auth_runtime @@ -473,8 +472,7 @@ async fn execute_planned_request<'exec>( .await? { QueryPlanExecutionResult::Stream(result) => { - let body_stream = - register_subscription_leader(result.body, subscription_fingerprint, schema_state); + let body_stream = start_subscription_leader(result.body, sub_dedup); let response = build_streaming_response( body_stream, @@ -577,29 +575,32 @@ pub async fn execute_pipeline<'exec>( execute_plan(supergraph, shared_state, planned_request, operation_span).await } -/// registers upstream as a subscription leader, spawning a broadcast task and -/// returning a stream backed by the broadcast receiver. if fingerprint is None, -/// the upstream stream is returned as-is with no registration. -pub(crate) fn register_subscription_leader( +/// if sub_dedup is provided, spawns a task that feeds upstream events into +/// the broadcast channel and returns a stream backed by the broadcast receiver. +/// otherwise returns the upstream stream directly +pub(crate) fn start_subscription_leader( upstream: futures::stream::BoxStream<'static, Vec>, - fingerprint: Option, - schema_state: &SchemaState, + sub_dedup: Option<( + Option, + tokio::sync::broadcast::Receiver, + FingerprintGuard, + )>, ) -> futures::stream::BoxStream<'static, Vec> { - let Some(fingerprint) = fingerprint else { + let Some((handle, receiver, guard)) = sub_dedup else { return upstream; }; - let registry = &schema_state.active_subscriptions; - let (handle, receiver, guard) = registry.register(Some(fingerprint), None); - - let mut upstream = upstream; - tokio::spawn(async move { - while let Some(event) = upstream.next().await { - if !handle.send(BroadcastItem::Event(event.into())) { - break; + // handle is always Some for the leader, but satisfy the type + if let Some(handle) = handle { + let mut upstream = upstream; + tokio::spawn(async move { + while let Some(event) = upstream.next().await { + if !handle.send(BroadcastItem::Event(event.into())) { + break; + } } - } - }); + }); + } broadcast_receiver_to_body_stream(receiver, guard) } @@ -608,7 +609,7 @@ pub(crate) fn register_subscription_leader( /// the listener guard is held for the lifetime of the stream to track listener count pub(crate) fn broadcast_receiver_to_body_stream( mut receiver: tokio::sync::broadcast::Receiver, - guard: ListenerGuard, + guard: FingerprintGuard, ) -> futures::stream::BoxStream<'static, Vec> { Box::pin(async_stream::stream! { let _guard = guard; diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 4cecd63af..3253527be 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -509,12 +509,13 @@ async fn handle_text_frame( // TODO: query dedupe - // subscription dedup: try to join an existing subscription - if let Some(fp) = fingerprint { - let registry = &schema_state.active_subscriptions; - if let Some((_sub_id, mut receiver, guard)) = - registry.try_join_by_fingerprint(fp) - { + // subscription dedup: join or create a fingerprinted subscription. + // joiners read from the existing broadcast channel. + // the leader gets a handle to feed upstream events into later + let sub_dedup = if let Some(fp) = fingerprint { + let (handle, receiver, guard) = + schema_state.active_subscriptions.dedupe_by_fingerprint(fp); + if handle.is_none() { let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); state.borrow_mut().subscriptions.insert(id.clone(), cancel_tx); @@ -527,6 +528,7 @@ async fn handle_text_frame( let id_for_loop = id.clone(); let mut cancelled = false; + let mut receiver = receiver; loop { tokio::select! { recv_result = receiver.recv() => { @@ -565,7 +567,10 @@ async fn handle_text_frame( Some(ServerMessage::complete(&id)) }; } - } + Some((handle, receiver, guard)) + } else { + None + }; match execute_pipeline( &client_request_details, @@ -610,11 +615,9 @@ async fn handle_text_frame( Some(ServerMessage::complete(&id)) } Ok(QueryPlanExecutionResult::Stream(response)) => { - let subscription_fingerprint = if is_subscription { fingerprint } else { None }; - let mut stream = crate::pipeline::register_subscription_leader( + let mut stream = crate::pipeline::start_subscription_leader( response.body, - subscription_fingerprint, - schema_state, + sub_dedup, ); // we use mpsc::channel(1) instead of oneshot because oneshot::Receiver diff --git a/lib/executor/src/executors/active_subscriptions.rs b/lib/executor/src/executors/active_subscriptions.rs index d525b3fb0..7955b705b 100644 --- a/lib/executor/src/executors/active_subscriptions.rs +++ b/lib/executor/src/executors/active_subscriptions.rs @@ -42,6 +42,9 @@ struct ActiveSubscriptionEntry { } pub struct ActiveSubscriptionsRegistry { + // two maps becuase http callbacks need uuids as ids which is the + // true map of all subscriptions, and then the fingerprint map is a + // secondary map only for deduplication subscriptions: DashMap, fingerprints: DashMap, // capacity of the broadcast channel per subscription, see router config `subscriptions.broadcast_capacity` @@ -57,38 +60,79 @@ impl ActiveSubscriptionsRegistry { } } - /// try to join an existing subscription by fingerprint. - /// returns the subscription id, a broadcast receiver, and a listener guard if a match exists - pub fn try_join_by_fingerprint( + /// Atomically joins or creates a subscription for the given fingerprint. + /// + /// The broadcast entry is created immediately so concurrent requests + /// subscribe to the same channel, even if the subscription has not + /// resolved yet. + /// + /// Returns a `SubscriptionHandle` only for the leader (first caller). + /// the leader uses it to feed upstream events into the broadcast channel. + /// + /// All callers (leader included) get a receiver and a `FingerprintGuard`. + pub fn dedupe_by_fingerprint( self: &Arc, fingerprint: Fingerprint, - ) -> Option<( - SubscriptionId, + ) -> ( + Option, tokio::sync::broadcast::Receiver, - ListenerGuard, - )> { - let sub_id = self.fingerprints.get(&fingerprint)?.value().clone(); - - let entry = self.subscriptions.get(&sub_id)?; - let receiver = entry.sender.subscribe(); - entry.listener_count.fetch_add(1, Ordering::AcqRel); - - let guard = ListenerGuard { - id: sub_id.clone(), - registry: Arc::clone(self), - listener_count: entry.listener_count.clone(), - fingerprint: entry.fingerprint, - }; - - trace!(subscription_id = %sub_id, fingerprint = fingerprint, "joined existing subscription via dedup"); - - Some((sub_id, receiver, guard)) + FingerprintGuard, + ) { + use dashmap::mapref::entry::Entry; + match self.fingerprints.entry(fingerprint) { + Entry::Occupied(entry) => { + let sub_id = entry.get().clone(); + // unwrap: fingerprints and subscriptions are always in sync + let sub = self.subscriptions.get(&sub_id).unwrap(); + let receiver = sub.sender.subscribe(); + sub.listener_count.fetch_add(1, Ordering::AcqRel); + + let guard = FingerprintGuard { + id: sub_id.clone(), + registry: Arc::clone(self), + listener_count: sub.listener_count.clone(), + fingerprint: Some(fingerprint), + }; + + trace!(subscription_id = %sub_id, fingerprint, "joined existing subscription via dedup"); + + (None, receiver, guard) + } + Entry::Vacant(fp_slot) => { + let id = Uuid::new_v4().to_string(); // TODO: doesnt have to be a UUID + let (sender, receiver) = tokio::sync::broadcast::channel(self.broadcast_capacity); + let listener_count = Arc::new(AtomicUsize::new(1)); + + self.subscriptions.insert( + id.clone(), + ActiveSubscriptionEntry { + sender, + listener_count: listener_count.clone(), + fingerprint: Some(fingerprint), + callback_state: None, + }, + ); + fp_slot.insert(id.clone()); + + let handle = SubscriptionHandle { + id: id.clone(), + registry: Arc::clone(self), + }; + let guard = FingerprintGuard { + id: id.clone(), + registry: Arc::clone(self), + listener_count, + fingerprint: Some(fingerprint), + }; + + trace!(subscription_id = %id, fingerprint, "registered new fingerprinted subscription"); + + (Some(handle), receiver, guard) + } + } } - /// register a brand new subscription. returns: - /// - a SubscriptionHandle for the upstream producer (dropping it removes the entry and closes the channel) - /// - a broadcast receiver for the first consumer - /// - a ListenerGuard that tracks this consumer's lifetime + /// register a subscription without dedup (e.g. http callbacks) pub fn register( self: &Arc, fingerprint: Option, @@ -96,9 +140,9 @@ impl ActiveSubscriptionsRegistry { ) -> ( SubscriptionHandle, tokio::sync::broadcast::Receiver, - ListenerGuard, + FingerprintGuard, ) { - let id = Uuid::new_v4().to_string(); + let id = Uuid::new_v4().to_string(); // TODO: doesnt have to be a UUID let (sender, receiver) = tokio::sync::broadcast::channel(self.broadcast_capacity); let listener_count = Arc::new(AtomicUsize::new(1)); @@ -120,7 +164,7 @@ impl ActiveSubscriptionsRegistry { registry: Arc::clone(self), }; - let guard = ListenerGuard { + let guard = FingerprintGuard { id: id.clone(), registry: Arc::clone(self), listener_count, @@ -219,26 +263,26 @@ impl SubscriptionHandle { impl Drop for SubscriptionHandle { fn drop(&mut self) { // removing the entry drops the broadcast sender inside it, closing the channel. - // all receivers will see Closed and their ListenerGuards will drop. - // we remove here (rather than in ListenerGuard) because the upstream is the + // all receivers will see Closed and their FingerprintGuards will drop. + // we remove here (rather than in FingerprintGuard) because the upstream is the // authoritative source - when it's gone, the subscription is done self.registry.remove(&self.id); trace!(subscription_id = %self.id, "subscription handle dropped, upstream closed"); } } -/// held by each consumer. on drop, decrements the listener count. -/// when the last listener drops and the subscription entry still exists +/// held by each consumer of a subscription. on drop, decrements the listener +/// count. when the last guard drops and the subscription entry still exists /// (upstream hasn't dropped yet), removes it - causing the upstream -/// task to see send() fail and exit -pub struct ListenerGuard { +/// producer's send() to return false and exit +pub struct FingerprintGuard { id: SubscriptionId, registry: Arc, listener_count: Arc, fingerprint: Option, } -impl Drop for ListenerGuard { +impl Drop for FingerprintGuard { fn drop(&mut self) { let prev = self.listener_count.fetch_sub(1, Ordering::AcqRel); if prev == 1 { diff --git a/lib/executor/src/executors/http_callback.rs b/lib/executor/src/executors/http_callback.rs index 2b33d0cab..ab3e71df9 100644 --- a/lib/executor/src/executors/http_callback.rs +++ b/lib/executor/src/executors/http_callback.rs @@ -117,7 +117,7 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { BoxStream<'static, Result, SubgraphExecutorError>>, SubgraphExecutorError, > { - let verifier = Uuid::new_v4().to_string(); + let verifier = Uuid::new_v4().to_string(); // TODO: doesnt have to be a UUID let callback_state = CallbackState { verifier: verifier.clone(), From 2622dc80f2805f8c8d17a215b7a020d36c7d46b4 Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Fri, 3 Apr 2026 17:18:54 +0200 Subject: [PATCH 11/42] ok --- lib/executor/src/executors/active_subscriptions.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/executor/src/executors/active_subscriptions.rs b/lib/executor/src/executors/active_subscriptions.rs index 7955b705b..439240b1f 100644 --- a/lib/executor/src/executors/active_subscriptions.rs +++ b/lib/executor/src/executors/active_subscriptions.rs @@ -202,6 +202,9 @@ impl ActiveSubscriptionsRegistry { /// send an event to a specific subscription's broadcast channel pub fn send_event(&self, id: &str, item: BroadcastItem) -> bool { if let Some(entry) = self.subscriptions.get(id) { + // if the channel is closed or full it means the consuming client is gone or too slow and + // unable to keep up. in both cases, we dont emit an error messages because it anyways cant + // go through entry.sender.send(item).is_ok() } else { false From efccf7c1f02c7bbc418d93c1070318280aaf188c Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Fri, 3 Apr 2026 17:46:52 +0200 Subject: [PATCH 12/42] remove unused --- bin/router/src/pipeline/http_callback.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/bin/router/src/pipeline/http_callback.rs b/bin/router/src/pipeline/http_callback.rs index c3c214348..9b304b202 100644 --- a/bin/router/src/pipeline/http_callback.rs +++ b/bin/router/src/pipeline/http_callback.rs @@ -14,7 +14,6 @@ use ntex::web::WebResponseError; use ntex::web::{self, types::Path, HttpRequest, HttpResponse}; use serde::Deserialize; use strum::EnumString; -use tokio::sync::mpsc; use tracing::{debug, error, trace, warn}; #[derive(Debug, Deserialize, EnumString)] From 7403f606eaa81836ae52499f5832284f79c6f35b Mon Sep 17 00:00:00 2001 From: Denis Badurina Date: Fri, 3 Apr 2026 18:53:46 +0200 Subject: [PATCH 13/42] begin super unification the great joiner --- bin/router/src/pipeline/header.rs | 2 +- bin/router/src/pipeline/http_callback.rs | 12 +- bin/router/src/pipeline/mod.rs | 207 +++--------------- bin/router/src/pipeline/websocket_server.rs | 78 +------ bin/router/src/schema_state.rs | 12 +- bin/router/src/shared_state.rs | 115 +++++++++- .../src/executors/active_subscriptions.rs | 198 +++++------------ lib/executor/src/executors/http.rs | 2 +- lib/executor/src/executors/http_callback.rs | 8 +- lib/executor/src/executors/map.rs | 10 +- lib/internal/src/inflight.rs | 49 +++-- 11 files changed, 262 insertions(+), 431 deletions(-) diff --git a/bin/router/src/pipeline/header.rs b/bin/router/src/pipeline/header.rs index 116a282dc..5ce52e8d6 100644 --- a/bin/router/src/pipeline/header.rs +++ b/bin/router/src/pipeline/header.rs @@ -81,7 +81,7 @@ impl SingleContentType { // IMPORTANT: make sure that the serialized string representations are valid because // there is an unwrap in the StreamContentType::media_types() method. /// Streamable content types for GraphQL responses. -#[derive(PartialEq, Default, Debug, IntoStaticStr, EnumString, AsRefStr, EnumIter)] +#[derive(PartialEq, Default, Debug, IntoStaticStr, EnumString, AsRefStr, EnumIter, Clone)] pub enum StreamContentType { // The order of the variants here matters for negotiation with `Accept: */*`. /// Incremental Delivery over HTTP (`multipart/mixed`) diff --git a/bin/router/src/pipeline/http_callback.rs b/bin/router/src/pipeline/http_callback.rs index 9b304b202..6575a4141 100644 --- a/bin/router/src/pipeline/http_callback.rs +++ b/bin/router/src/pipeline/http_callback.rs @@ -1,8 +1,6 @@ -use std::sync::Arc; - use bytes::Bytes as BytesLib; use hive_router_plan_executor::executors::active_subscriptions::{ - ActiveSubscriptionsRegistry, BroadcastItem, + ActiveSubscriptionsMap, BroadcastItem, }; use hive_router_plan_executor::executors::http_callback::{ CALLBACK_PROTOCOL_VERSION, SUBSCRIPTION_PROTOCOL_HEADER, @@ -144,7 +142,7 @@ fn validate_payload( Ok(()) } -fn handle_check(subscription_id: &str, registry: &ActiveSubscriptionsRegistry) { +fn handle_check(subscription_id: &str, registry: &ActiveSubscriptionsMap) { trace!(subscription_id = %subscription_id, "Received check message"); registry.record_heartbeat(subscription_id); } @@ -152,7 +150,7 @@ fn handle_check(subscription_id: &str, registry: &ActiveSubscriptionsRegistry) { fn handle_next( subscription_id: &str, payload: &CallbackPayload<'_>, - registry: &ActiveSubscriptionsRegistry, + registry: &ActiveSubscriptionsMap, ) -> Result<(), CallbackError> { trace!(subscription_id = %subscription_id, "Received next message"); @@ -181,7 +179,7 @@ fn handle_next( fn handle_complete( subscription_id: &str, payload: &CallbackPayload<'_>, - registry: &ActiveSubscriptionsRegistry, + registry: &ActiveSubscriptionsMap, ) { trace!(subscription_id = %subscription_id, "Received complete message"); if let Some(errors) = &payload.errors { @@ -196,7 +194,7 @@ pub async fn handler( req: HttpRequest, path: Path, body: Bytes, - active_subscriptions: web::types::State>, + active_subscriptions: web::types::State, ) -> Result { let subscription_id_from_path = path.into_inner(); diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index cfd46d12c..4f621a445 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,5 +1,6 @@ use futures::Stream; use futures::StreamExt; +use hive_router_internal::inflight::InFlightCleanupGuard; use std::{ collections::HashMap, hash::{Hash, Hasher}, @@ -10,9 +11,7 @@ use tracing::{error, trace, Instrument}; use xxhash_rust::xxh3::Xxh3; use hive_router_plan_executor::execution::plan::FailedExecutionResult; -use hive_router_plan_executor::executors::active_subscriptions::{ - BroadcastItem, FingerprintGuard, SubscriptionHandle, -}; +use hive_router_plan_executor::executors::active_subscriptions::BroadcastItem; use hive_router_plan_executor::headers::plan::ResponseHeaderAggregator; use hive_router_internal::telemetry::traces::spans::{ @@ -58,7 +57,10 @@ use crate::{ validation::validate_operation_with_cache, }, schema_state::SchemaState, - shared_state::{RouterRequestDedupeHeaderPolicy, RouterSharedState, SharedRouterResponse}, + shared_state::{ + RouterRequestDedupeHeaderPolicy, RouterSharedState, SharedRouterResponse, + SharedRouterSingleResponse, + }, LABORATORY_HTML, }; @@ -251,7 +253,7 @@ pub async fn graphql_request_handler( if is_subscription && (!shared_state.router_config.subscriptions.enabled || !response_mode.can_stream()) { - // check early, even though we check again after pipeline execution below + // check early, even though we check again planned execution below return Err(PipelineError::SubscriptionsNotSupported); } @@ -262,7 +264,7 @@ pub async fn graphql_request_handler( && matches!( normalize_payload.operation_for_plan.operation_kind, // same deduplication applies for queries and subscriptions - Some(OperationKind::Query) | Some(OperationKind::Subscription) | None + None | Some(OperationKind::Query) | Some(OperationKind::Subscription) ) { let variables_hash = hash_graphql_variables(&graphql_params.variables); let extensions_hash = graphql_params @@ -285,32 +287,12 @@ pub async fn graphql_request_handler( None }; - // subscription dedup: join or create a fingerprinted subscription. - // joiners get a receiver from the existing broadcast channel. - // the leader gets a handle to feed upstream events into later - let sub_dedup = if let Some(fp) = fingerprint.filter(|_| is_subscription) { - let (handle, receiver, guard) = - schema_state.active_subscriptions.dedupe_by_fingerprint(fp); - if handle.is_none() { - let body_stream = broadcast_receiver_to_body_stream(receiver, guard); - return build_streaming_response(body_stream, response_mode, None); - } - Some((handle, receiver, guard)) - } else { - None - }; - - let planned_response = if fingerprint.is_some() - && matches!( - normalize_payload.operation_for_plan.operation_kind, - Some(OperationKind::Query) | None - ) { - let fp = fingerprint.unwrap(); - let (shared_response, _role) = shared_state + let shared_response = if let Some(fp) = fingerprint { + let (planned_response, _role) = shared_state .in_flight_requests .claim(fp) - .get_or_try_init(|| async { - match execute_planned_request( + .get_or_try_init(|guard| async { + execute_planned_request( req, graphql_params, &normalize_payload, @@ -320,18 +302,12 @@ pub async fn graphql_request_handler( operation_span, plugin_req_state, response_mode, - None, + Some(guard), ) - .await? - { - PlannedResponse::Shared(r) => Ok::(r), - // subscriptions are excluded from the dedup branch above, so this is unreachable - PlannedResponse::Direct { .. } => unreachable!("stream responses never enter the dedup path"), - } + .await }) .await?; - - PlannedResponse::Shared(Arc::unwrap_or_clone(shared_response)) + Arc::unwrap_or_clone(planned_response) } else { execute_planned_request( req, @@ -343,19 +319,11 @@ pub async fn graphql_request_handler( operation_span, plugin_req_state, response_mode, - sub_dedup, + None, ) .await? }; - let (response, error_count) = match planned_response { - PlannedResponse::Shared(shared_response) => { - let error_count = shared_response.error_count; - (shared_response.into(), error_count) - } - PlannedResponse::Direct { response, .. } => (response, 0), - }; - if let Some(hive_usage_agent) = &shared_state.hive_usage_agent { usage_reporting::collect_usage_report( supergraph.supergraph_schema.clone(), @@ -377,21 +345,21 @@ pub async fn graphql_request_handler( // Thus, this expect should never panic. "Expected Usage Reporting options to be present when Hive Usage Agent is initialized", ), - error_count, + shared_response.error_count(), ) .await; } write_graphql_response_metric_status( req, - if error_count > 0 { + if shared_response.error_count() > 0 { GraphQLResponseStatus::Error } else { GraphQLResponseStatus::Ok }, ); - Ok(response) + Ok(shared_response.into()) } .instrument(span_clone) .await @@ -400,11 +368,6 @@ pub async fn graphql_request_handler( }) } -enum PlannedResponse { - Shared(SharedRouterResponse), - Direct { response: web::HttpResponse }, -} - #[allow(clippy::too_many_arguments)] async fn execute_planned_request<'exec>( req: &'exec HttpRequest, @@ -416,12 +379,8 @@ async fn execute_planned_request<'exec>( operation_span: GraphQLOperationSpan, plugin_req_state: Option>, response_mode: &'exec ResponseMode, - sub_dedup: Option<( - Option, - tokio::sync::broadcast::Receiver, - FingerprintGuard, - )>, -) -> Result { + in_flight_cleanup_guard: Option>, +) -> Result { let jwt_request_details = match &shared_state.jwt_auth_runtime { Some(jwt_auth_runtime) => match jwt_auth_runtime .validate_headers(req.headers(), &shared_state.jwt_claims_cache) @@ -472,15 +431,11 @@ async fn execute_planned_request<'exec>( .await? { QueryPlanExecutionResult::Stream(result) => { - let body_stream = start_subscription_leader(result.body, sub_dedup); + let stream_content_type = response_mode + .stream_content_type() + .ok_or(PipelineError::SubscriptionsTransportNotSupported)?; - let response = build_streaming_response( - body_stream, - response_mode, - result.response_headers_aggregator, - )?; - - Ok(PlannedResponse::Direct { response }) + todo!(); } QueryPlanExecutionResult::Single(result) => { let single_content_type = response_mode. @@ -488,6 +443,9 @@ async fn execute_planned_request<'exec>( // TODO: streaming single responses ok_or(PipelineError::UnsupportedContentType)?; + // drop the inflight planned request as soon as the response is ready + let _query_guard = in_flight_cleanup_guard; + let error_count = result.error_count; let mut response_builder = web::HttpResponse::Ok(); @@ -503,7 +461,7 @@ async fn execute_planned_request<'exec>( .status(result.status_code) .body(body.clone()); - Ok(PlannedResponse::Shared(SharedRouterResponse { + Ok(SharedRouterResponse::Single(SharedRouterSingleResponse { body, headers: Arc::new(response.headers().clone()), status: response.status(), @@ -575,113 +533,6 @@ pub async fn execute_pipeline<'exec>( execute_plan(supergraph, shared_state, planned_request, operation_span).await } -/// if sub_dedup is provided, spawns a task that feeds upstream events into -/// the broadcast channel and returns a stream backed by the broadcast receiver. -/// otherwise returns the upstream stream directly -pub(crate) fn start_subscription_leader( - upstream: futures::stream::BoxStream<'static, Vec>, - sub_dedup: Option<( - Option, - tokio::sync::broadcast::Receiver, - FingerprintGuard, - )>, -) -> futures::stream::BoxStream<'static, Vec> { - let Some((handle, receiver, guard)) = sub_dedup else { - return upstream; - }; - - // handle is always Some for the leader, but satisfy the type - if let Some(handle) = handle { - let mut upstream = upstream; - tokio::spawn(async move { - while let Some(event) = upstream.next().await { - if !handle.send(BroadcastItem::Event(event.into())) { - break; - } - } - }); - } - - broadcast_receiver_to_body_stream(receiver, guard) -} - -/// converts a broadcast receiver into a BoxStream of serialized event bodies. -/// the listener guard is held for the lifetime of the stream to track listener count -pub(crate) fn broadcast_receiver_to_body_stream( - mut receiver: tokio::sync::broadcast::Receiver, - guard: FingerprintGuard, -) -> futures::stream::BoxStream<'static, Vec> { - Box::pin(async_stream::stream! { - let _guard = guard; - loop { - match receiver.recv().await { - Ok(BroadcastItem::Event(data)) => { - yield data.to_vec(); - } - Ok(BroadcastItem::Error(errors)) => { - yield FailedExecutionResult { errors }.serialize(); - break; - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { - trace!(lagged = n, "broadcast receiver lagged, skipping missed messages"); - continue; - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => { - break; - } - } - } - }) -} - -fn build_streaming_response( - body_stream: futures::stream::BoxStream<'static, Vec>, - response_mode: &ResponseMode, - response_headers_aggregator: Option, -) -> Result { - let stream_content_type = response_mode - .stream_content_type() - .ok_or(PipelineError::SubscriptionsTransportNotSupported)?; - - let content_type_header = match stream_content_type { - StreamContentType::IncrementalDelivery => { - http::HeaderValue::from_static(INCREMENTAL_DELIVERY_CONTENT_TYPE) - } - StreamContentType::SSE => http::HeaderValue::from_static("text/event-stream"), - StreamContentType::ApolloMultipartHTTP => { - http::HeaderValue::from_static(APOLLO_MULTIPART_HTTP_CONTENT_TYPE) - } - }; - - let body: std::pin::Pin< - Box> + Send>, - > = match stream_content_type { - StreamContentType::IncrementalDelivery => Box::pin( - multipart_subscribe::create_incremental_delivery_stream(body_stream), - ), - StreamContentType::SSE => Box::pin(sse::create_stream( - body_stream, - std::time::Duration::from_secs(10), - )), - StreamContentType::ApolloMultipartHTTP => { - Box::pin(multipart_subscribe::create_apollo_multipart_http_stream( - body_stream, - std::time::Duration::from_secs(10), - )) - } - }; - - let mut response_builder = web::HttpResponse::Ok(); - - if let Some(response_headers_aggregator) = response_headers_aggregator { - response_headers_aggregator.modify_client_response_headers(&mut response_builder)?; - } - - Ok(response_builder - .header(http::header::CONTENT_TYPE, content_type_header) - .streaming(body)) -} - pub fn inbound_request_fingerprint( method: &http::Method, path: &str, diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 3253527be..88b910515 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -47,6 +47,7 @@ use crate::pipeline::{ use crate::schema_state::SchemaState; use crate::shared_state::RouterSharedState; +use crate::shared_state::SharedRouterResponse; use hive_router_plan_executor::execution::plan::FailedExecutionResult; use hive_router_plan_executor::executors::active_subscriptions::BroadcastItem; @@ -425,7 +426,12 @@ async fn handle_text_frame( let request_dedupe_enabled = shared_state.router_config.traffic_shaping.router.dedupe.enabled; - let fingerprint = if is_subscription && request_dedupe_enabled { + let fingerprint = if request_dedupe_enabled + && matches!( + normalize_payload.operation_for_plan.operation_kind, + // same deduplication applies for queries and subscriptions + None | Some(OperationKind::Query) | Some(OperationKind::Subscription) + ) { let variables_hash = hash_graphql_variables(&payload.variables); let extensions_hash = payload .extensions @@ -507,70 +513,7 @@ async fn handle_text_frame( jwt: jwt_request_details, }.into(); - // TODO: query dedupe - - // subscription dedup: join or create a fingerprinted subscription. - // joiners read from the existing broadcast channel. - // the leader gets a handle to feed upstream events into later - let sub_dedup = if let Some(fp) = fingerprint { - let (handle, receiver, guard) = - schema_state.active_subscriptions.dedupe_by_fingerprint(fp); - if handle.is_none() { - let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); - state.borrow_mut().subscriptions.insert(id.clone(), cancel_tx); - - let _ws_guard = SubscriptionGuard { - state: state.clone(), - id: id.clone(), - }; - - trace!(id = %id, "Subscription joined via dedup"); - - let id_for_loop = id.clone(); - let mut cancelled = false; - let mut receiver = receiver; - loop { - tokio::select! { - recv_result = receiver.recv() => { - match recv_result { - Ok(BroadcastItem::Event(data)) => { - let _ = sink.send(ServerMessage::next(&id_for_loop, &data)).await; - } - Ok(BroadcastItem::Error(errors)) => { - let body = FailedExecutionResult { errors }.serialize(); - let _ = sink.send(ServerMessage::next(&id_for_loop, &body)).await; - break; - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { - trace!(id = %id_for_loop, lagged = n, "broadcast receiver lagged, skipping missed messages"); - continue; - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => { - break; - } - } - } - _ = cancel_rx.recv() => { - cancelled = true; - break; - } - } - } - - drop(guard); - - return if cancelled { - trace!(id = %id, "Deduped subscription cancelled"); - None - } else { - trace!(id = %id, "Deduped subscription completed"); - Some(ServerMessage::complete(&id)) - }; - } - Some((handle, receiver, guard)) - } else { - None - }; + // TODO: dedupe match execute_pipeline( &client_request_details, @@ -615,10 +558,7 @@ async fn handle_text_frame( Some(ServerMessage::complete(&id)) } Ok(QueryPlanExecutionResult::Stream(response)) => { - let mut stream = crate::pipeline::start_subscription_leader( - response.body, - sub_dedup, - ); + let mut stream = response.body; // we use mpsc::channel(1) instead of oneshot because oneshot::Receiver // is consumed on first await, which doesn't work in tokio::select! loops that diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index a8789e6f5..bfcf8f286 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -10,7 +10,7 @@ use hive_router_internal::{ background_tasks::{BackgroundTask, BackgroundTasksManager}, }; use hive_router_plan_executor::{ - executors::active_subscriptions::{ActiveSubscriptionsRegistry, BroadcastItem}, + executors::active_subscriptions::{ActiveSubscriptionsMap, BroadcastItem}, executors::error::SubgraphExecutorError, hooks::on_supergraph_load::{ OnSupergraphLoadEndHookPayload, OnSupergraphLoadStartHookPayload, SupergraphData, @@ -47,7 +47,7 @@ pub struct SchemaState { pub validate_cache: Cache>>, pub normalize_cache: Cache>, pub telemetry_context: Arc, - pub active_subscriptions: Arc, + pub active_subscriptions: ActiveSubscriptionsMap, } #[derive(Debug, thiserror::Error)] @@ -97,9 +97,9 @@ impl SchemaState { let plan_cache = cache_state.plan_cache.clone(); let validate_cache = cache_state.validate_cache.clone(); let normalize_cache = cache_state.normalize_cache.clone(); - let active_subscriptions = Arc::new(ActiveSubscriptionsRegistry::new( + let active_subscriptions = ActiveSubscriptionsMap::new( router_config.subscriptions.broadcast_capacity, - )); + ); // This is cheap clone, as Cache is thread-safe and can be cloned without any performance penalty. let cache_state_for_invalidation = cache_state.clone(); @@ -245,7 +245,7 @@ impl SchemaState { router_config: Arc, telemetry_context: Arc, parsed_supergraph_sdl: Document, - active_subscriptions: Arc, + active_subscriptions: ActiveSubscriptionsMap, ) -> Result { let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; let metadata = Arc::new(planner.consumer_schema.schema_metadata()); @@ -346,7 +346,7 @@ impl BackgroundTask for SupergraphBackgroundLoaderTask { } struct HeartbeatEnforcerTask { - active_subscriptions: Arc, + active_subscriptions: ActiveSubscriptionsMap, heartbeat_interval: Duration, } diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 8a3aa3050..86f682977 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -1,3 +1,4 @@ +use futures::Stream; use graphql_tools::validation::validate::ValidationPlan; use hive_console_sdk::agent::usage_agent::{AgentError, UsageAgent}; use hive_router_config::traffic_shaping::{ @@ -8,6 +9,7 @@ use hive_router_internal::expressions::values::boolean::BooleanOrProgram; use hive_router_internal::expressions::ExpressionCompileError; use hive_router_internal::inflight::InFlightMap; use hive_router_internal::telemetry::TelemetryContext; +use hive_router_plan_executor::execution::plan::FailedExecutionResult; use hive_router_plan_executor::headers::{ compile::compile_headers_plan, errors::HeaderRuleCompileError, plan::HeaderRulesPlan, }; @@ -20,14 +22,21 @@ use ntex::{http::HeaderMap, util::Bytes}; use std::sync::atomic::AtomicUsize; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{collections::HashSet, sync::Arc}; +use tracing::trace; use crate::cache_state::CacheState; use crate::jwt::context::JwtTokenPayload; use crate::jwt::JwtAuthRuntime; use crate::pipeline::cors::{CORSConfigError, Cors}; +use crate::pipeline::header::StreamContentType; use crate::pipeline::introspection_policy::compile_introspection_policy; +use crate::pipeline::multipart_subscribe::{ + self, APOLLO_MULTIPART_HTTP_CONTENT_TYPE, INCREMENTAL_DELIVERY_CONTENT_TYPE, +}; use crate::pipeline::parser::ParseCacheEntry; use crate::pipeline::progressive_override::{OverrideLabelsCompileError, OverrideLabelsEvaluator}; +use crate::pipeline::sse; +use hive_router_plan_executor::executors::active_subscriptions::BroadcastItem; pub type JwtClaimsCache = Cache>; pub type RouterInflightRequestsMap = InFlightMap; @@ -76,15 +85,39 @@ impl From<&TrafficShapingRouterDedupeHeadersConfig> for RouterRequestDedupeHeade } #[derive(Clone)] -pub struct SharedRouterResponse { +pub enum SharedRouterResponse { + Single(SharedRouterSingleResponse), + Stream(SharedRouterSingleResponse), +} + +impl SharedRouterResponse { + pub fn error_count(&self) -> usize { + match self { + SharedRouterResponse::Single(resp) => resp.error_count, + SharedRouterResponse::Stream(resp) => resp.error_count, + } + } +} + +impl From for web::HttpResponse { + fn from(shared_response: SharedRouterResponse) -> Self { + match shared_response { + SharedRouterResponse::Single(single) => single.into(), + SharedRouterResponse::Stream(stream) => stream.into(), + } + } +} + +#[derive(Clone)] +pub struct SharedRouterSingleResponse { pub body: Bytes, pub headers: Arc, pub status: StatusCode, pub error_count: usize, } -impl From for web::HttpResponse { - fn from(shared_response: SharedRouterResponse) -> Self { +impl From for web::HttpResponse { + fn from(shared_response: SharedRouterSingleResponse) -> Self { let mut response = web::HttpResponse::Ok(); response.status(shared_response.status); @@ -96,6 +129,82 @@ impl From for web::HttpResponse { } } +#[derive(Clone)] +pub struct SharedRouterStreamResponse { + // status is always 200 for streaming responses, errors are sent through the stream + pub body: tokio::sync::broadcast::Sender, + pub headers: Arc, + pub stream_content_type: StreamContentType, + pub error_count: usize, +} + +impl From for web::HttpResponse { + fn from(shared_response: SharedRouterStreamResponse) -> Self { + let mut receiver = shared_response.body.subscribe(); + + let stream = Box::pin(async_stream::stream! { + loop { + match receiver.recv().await { + Ok(BroadcastItem::Event(data)) => { + yield data.to_vec(); + } + Ok(BroadcastItem::Error(errors)) => { + yield FailedExecutionResult { errors }.serialize(); + break; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + trace!(lagged = n, "broadcast receiver lagged, skipping missed messages"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + }); + + let stream_content_type = shared_response.stream_content_type; + + let content_type_header = match stream_content_type { + StreamContentType::IncrementalDelivery => { + http::HeaderValue::from_static(INCREMENTAL_DELIVERY_CONTENT_TYPE) + } + StreamContentType::SSE => http::HeaderValue::from_static("text/event-stream"), + StreamContentType::ApolloMultipartHTTP => { + http::HeaderValue::from_static(APOLLO_MULTIPART_HTTP_CONTENT_TYPE) + } + }; + + let body: std::pin::Pin< + Box> + Send>, + > = match stream_content_type { + StreamContentType::IncrementalDelivery => Box::pin( + multipart_subscribe::create_incremental_delivery_stream(stream), + ), + StreamContentType::SSE => Box::pin(sse::create_stream( + stream, + std::time::Duration::from_secs(10), + )), + StreamContentType::ApolloMultipartHTTP => { + Box::pin(multipart_subscribe::create_apollo_multipart_http_stream( + stream, + std::time::Duration::from_secs(10), + )) + } + }; + + let mut response = web::HttpResponse::Ok(); + + for (header_name, header_value) in shared_response.headers.iter() { + response.set_header(header_name, header_value); + } + + response + .header(http::header::CONTENT_TYPE, content_type_header) + .streaming(body) + } +} + /// Default TTL for JWT claims cache entries (5 seconds) const DEFAULT_JWT_CACHE_TTL_SECS: u64 = 5; diff --git a/lib/executor/src/executors/active_subscriptions.rs b/lib/executor/src/executors/active_subscriptions.rs index 439240b1f..ef3c282d8 100644 --- a/lib/executor/src/executors/active_subscriptions.rs +++ b/lib/executor/src/executors/active_subscriptions.rs @@ -10,7 +10,6 @@ use uuid::Uuid; use crate::response::graphql_error::GraphQLError; pub type SubscriptionId = String; -pub type Fingerprint = u64; #[derive(Clone, Debug)] pub enum BroadcastItem { @@ -36,139 +35,62 @@ impl CallbackState { struct ActiveSubscriptionEntry { sender: tokio::sync::broadcast::Sender, - listener_count: Arc, - fingerprint: Option, callback_state: Option, } -pub struct ActiveSubscriptionsRegistry { - // two maps becuase http callbacks need uuids as ids which is the - // true map of all subscriptions, and then the fingerprint map is a - // secondary map only for deduplication +struct ActiveSubscriptionsInner { subscriptions: DashMap, - fingerprints: DashMap, // capacity of the broadcast channel per subscription, see router config `subscriptions.broadcast_capacity` broadcast_capacity: usize, } -impl ActiveSubscriptionsRegistry { +/// cheap to clone - all clones share the same inner state +#[derive(Clone)] +pub struct ActiveSubscriptionsMap { + inner: Arc, +} + +impl ActiveSubscriptionsMap { pub fn new(broadcast_capacity: usize) -> Self { Self { - subscriptions: DashMap::new(), - fingerprints: DashMap::new(), - broadcast_capacity, + inner: Arc::new(ActiveSubscriptionsInner { + subscriptions: DashMap::new(), + broadcast_capacity, + }), } } - /// Atomically joins or creates a subscription for the given fingerprint. - /// - /// The broadcast entry is created immediately so concurrent requests - /// subscribe to the same channel, even if the subscription has not - /// resolved yet. - /// - /// Returns a `SubscriptionHandle` only for the leader (first caller). - /// the leader uses it to feed upstream events into the broadcast channel. - /// - /// All callers (leader included) get a receiver and a `FingerprintGuard`. - pub fn dedupe_by_fingerprint( - self: &Arc, - fingerprint: Fingerprint, - ) -> ( - Option, - tokio::sync::broadcast::Receiver, - FingerprintGuard, - ) { - use dashmap::mapref::entry::Entry; - match self.fingerprints.entry(fingerprint) { - Entry::Occupied(entry) => { - let sub_id = entry.get().clone(); - // unwrap: fingerprints and subscriptions are always in sync - let sub = self.subscriptions.get(&sub_id).unwrap(); - let receiver = sub.sender.subscribe(); - sub.listener_count.fetch_add(1, Ordering::AcqRel); - - let guard = FingerprintGuard { - id: sub_id.clone(), - registry: Arc::clone(self), - listener_count: sub.listener_count.clone(), - fingerprint: Some(fingerprint), - }; - - trace!(subscription_id = %sub_id, fingerprint, "joined existing subscription via dedup"); - - (None, receiver, guard) - } - Entry::Vacant(fp_slot) => { - let id = Uuid::new_v4().to_string(); // TODO: doesnt have to be a UUID - let (sender, receiver) = tokio::sync::broadcast::channel(self.broadcast_capacity); - let listener_count = Arc::new(AtomicUsize::new(1)); - - self.subscriptions.insert( - id.clone(), - ActiveSubscriptionEntry { - sender, - listener_count: listener_count.clone(), - fingerprint: Some(fingerprint), - callback_state: None, - }, - ); - fp_slot.insert(id.clone()); - - let handle = SubscriptionHandle { - id: id.clone(), - registry: Arc::clone(self), - }; - let guard = FingerprintGuard { - id: id.clone(), - registry: Arc::clone(self), - listener_count, - fingerprint: Some(fingerprint), - }; - - trace!(subscription_id = %id, fingerprint, "registered new fingerprinted subscription"); - - (Some(handle), receiver, guard) - } - } - } - - /// register a subscription without dedup (e.g. http callbacks) + /// Register a new subscription (e.g. http callbacks). + /// Always creates a new entry. Deduplication for fingerprinted subscriptions is handled + /// by the inflight map in the request pipeline, not here. pub fn register( - self: &Arc, - fingerprint: Option, + &self, callback_state: Option, ) -> ( SubscriptionHandle, tokio::sync::broadcast::Receiver, - FingerprintGuard, + ListenerGuard, ) { - let id = Uuid::new_v4().to_string(); // TODO: doesnt have to be a UUID - let (sender, receiver) = tokio::sync::broadcast::channel(self.broadcast_capacity); + let id = Uuid::new_v4().to_string(); + let (sender, receiver) = tokio::sync::broadcast::channel(self.inner.broadcast_capacity); let listener_count = Arc::new(AtomicUsize::new(1)); - let entry = ActiveSubscriptionEntry { - sender, - listener_count: listener_count.clone(), - fingerprint, - callback_state, - }; - - self.subscriptions.insert(id.clone(), entry); - - if let Some(fp) = fingerprint { - self.fingerprints.insert(fp, id.clone()); - } + self.inner.subscriptions.insert( + id.clone(), + ActiveSubscriptionEntry { + sender, + callback_state, + }, + ); let handle = SubscriptionHandle { id: id.clone(), - registry: Arc::clone(self), + map: self.clone(), }; - - let guard = FingerprintGuard { + let guard = ListenerGuard { id: id.clone(), - registry: Arc::clone(self), + map: self.clone(), listener_count, - fingerprint, }; trace!(subscription_id = %id, "registered new subscription"); @@ -178,19 +100,20 @@ impl ActiveSubscriptionsRegistry { /// check if a subscription exists pub fn contains(&self, id: &str) -> bool { - self.subscriptions.contains_key(id) + self.inner.subscriptions.contains_key(id) } /// get the verifier for a callback subscription pub fn get_callback_verifier(&self, id: &str) -> Option { - self.subscriptions + self.inner + .subscriptions .get(id) .and_then(|entry| entry.callback_state.as_ref().map(|cs| cs.verifier.clone())) } /// record a heartbeat for a callback subscription pub fn record_heartbeat(&self, id: &str) -> bool { - if let Some(entry) = self.subscriptions.get(id) { + if let Some(entry) = self.inner.subscriptions.get(id) { if let Some(ref cs) = entry.callback_state { cs.record_heartbeat(); return true; @@ -201,7 +124,7 @@ impl ActiveSubscriptionsRegistry { /// send an event to a specific subscription's broadcast channel pub fn send_event(&self, id: &str, item: BroadcastItem) -> bool { - if let Some(entry) = self.subscriptions.get(id) { + if let Some(entry) = self.inner.subscriptions.get(id) { // if the channel is closed or full it means the consuming client is gone or too slow and // unable to keep up. in both cases, we dont emit an error messages because it anyways cant // go through @@ -211,30 +134,25 @@ impl ActiveSubscriptionsRegistry { } } - /// remove a subscription entry and clean up its fingerprint mapping + /// remove a subscription entry pub fn remove(&self, id: &str) { - if let Some((_, entry)) = self.subscriptions.remove(id) { - if let Some(fp) = entry.fingerprint { - self.fingerprints.remove(&fp); - } - } + self.inner.subscriptions.remove(id); } /// close all active subscriptions with an error message pub fn close_all_with_error(&self, errors: Vec) { let item = BroadcastItem::Error(errors); - for entry in self.subscriptions.iter() { + for entry in self.inner.subscriptions.iter() { let _ = entry.sender.send(item.clone()); } - self.fingerprints.clear(); - self.subscriptions.clear(); + self.inner.subscriptions.clear(); } /// iterate over all subscription ids and their callback state for heartbeat enforcement pub fn iter_callback_subscriptions( &self, ) -> impl Iterator>)> + '_ { - self.subscriptions.iter().filter_map(|entry| { + self.inner.subscriptions.iter().filter_map(|entry| { entry .callback_state .as_ref() @@ -243,13 +161,13 @@ impl ActiveSubscriptionsRegistry { } } -/// held by the upstream producer (the task that reads from the subgraph). -/// dropping this removes the subscription entry from the registry, which drops -/// the broadcast sender and closes the channel. all receivers will see Closed -/// and their streams will end naturally +/// Held by the upstream producer (the task that reads from the subgraph). +/// Dropping this removes the subscription entry from the registry, which drops +/// the broadcast sender and closes the channel. All receivers will see `Closed` +/// and their streams will end naturally. pub struct SubscriptionHandle { id: SubscriptionId, - registry: Arc, + map: ActiveSubscriptionsMap, } impl SubscriptionHandle { @@ -257,44 +175,36 @@ impl SubscriptionHandle { &self.id } - /// send an event to all listeners of this subscription pub fn send(&self, item: BroadcastItem) -> bool { - self.registry.send_event(&self.id, item) + self.map.send_event(&self.id, item) } } impl Drop for SubscriptionHandle { fn drop(&mut self) { // removing the entry drops the broadcast sender inside it, closing the channel. - // all receivers will see Closed and their FingerprintGuards will drop. - // we remove here (rather than in FingerprintGuard) because the upstream is the - // authoritative source - when it's gone, the subscription is done - self.registry.remove(&self.id); + // all receivers will see Closed and their streams will end naturally + self.map.remove(&self.id); trace!(subscription_id = %self.id, "subscription handle dropped, upstream closed"); } } -/// held by each consumer of a subscription. on drop, decrements the listener -/// count. when the last guard drops and the subscription entry still exists -/// (upstream hasn't dropped yet), removes it - causing the upstream -/// producer's send() to return false and exit -pub struct FingerprintGuard { +/// Held by each consumer of a subscription. On drop, decrements the listener count. +/// When the last guard drops and the subscription entry still exists (upstream hasn't dropped +/// yet), removes it - causing the upstream producer's `send()` to return `false` and exit. +pub struct ListenerGuard { id: SubscriptionId, - registry: Arc, + map: ActiveSubscriptionsMap, listener_count: Arc, - fingerprint: Option, } -impl Drop for FingerprintGuard { +impl Drop for ListenerGuard { fn drop(&mut self) { let prev = self.listener_count.fetch_sub(1, Ordering::AcqRel); if prev == 1 { // last listener gone, clean up. this also drops the sender, // causing the upstream producer's send() to return false - self.registry.subscriptions.remove(&self.id); - if let Some(fp) = self.fingerprint { - self.registry.fingerprints.remove(&fp); - } + self.map.inner.subscriptions.remove(&self.id); trace!(subscription_id = %self.id, "last listener dropped, subscription removed"); } else { trace!(subscription_id = %self.id, remaining = prev - 1, "listener dropped"); diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index ce4b853d2..6c7596a41 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -392,7 +392,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { let claim = self.in_flight_requests.claim(fingerprint); let mut leader_http_request_capture = None; let (shared_response, role) = claim - .get_or_try_init(|| async { + .get_or_try_init(|_| async { let res = { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. diff --git a/lib/executor/src/executors/http_callback.rs b/lib/executor/src/executors/http_callback.rs index ab3e71df9..fa9ed8ea3 100644 --- a/lib/executor/src/executors/http_callback.rs +++ b/lib/executor/src/executors/http_callback.rs @@ -12,7 +12,7 @@ use tracing::{debug, error, trace}; use uuid::Uuid; use crate::executors::active_subscriptions::{ - ActiveSubscriptionsRegistry, BroadcastItem, CallbackState, + ActiveSubscriptionsMap, BroadcastItem, CallbackState, }; use crate::executors::common::{SubgraphExecutionRequest, SubgraphExecutor}; use crate::executors::error::SubgraphExecutorError; @@ -30,7 +30,7 @@ pub struct HttpCallbackSubgraphExecutor { pub header_map: HeaderMap, pub callback_base_url: String, pub heartbeat_interval_ms: u64, - pub active_subscriptions: Arc, + pub active_subscriptions: ActiveSubscriptionsMap, } impl HttpCallbackSubgraphExecutor { @@ -40,7 +40,7 @@ impl HttpCallbackSubgraphExecutor { http_client: Arc, callback_base_url: String, heartbeat_interval_ms: u64, - active_subscriptions: Arc, + active_subscriptions: ActiveSubscriptionsMap, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -133,7 +133,7 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { let (handle, mut receiver, guard) = self .active_subscriptions - .register(None, Some(callback_state)); + .register(Some(callback_state)); let subscription_id = handle.id().to_string(); diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index ab8594647..708b8752b 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -27,7 +27,7 @@ use tokio::sync::Semaphore; use crate::{ execution::client_request_details::ClientRequestDetails, executors::{ - active_subscriptions::ActiveSubscriptionsRegistry, + active_subscriptions::ActiveSubscriptionsMap, common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient, SubgraphHttpResponse}, @@ -76,7 +76,7 @@ pub struct SubgraphExecutorMap { in_flight_requests: InflightRequestsMap, telemetry_context: Arc, /// Shared registry of all active subscriptions (http streaming, websocket, http callback) - active_subscriptions: Arc, + active_subscriptions: ActiveSubscriptionsMap, } fn build_https_executor() -> Result, SubgraphExecutorError> { @@ -114,7 +114,7 @@ impl SubgraphExecutorMap { timeouts_by_subgraph: Default::default(), global_timeout, telemetry_context, - active_subscriptions: Arc::new(ActiveSubscriptionsRegistry::new(broadcast_capacity)), + active_subscriptions: ActiveSubscriptionsMap::new(broadcast_capacity), }) } @@ -122,7 +122,7 @@ impl SubgraphExecutorMap { subgraph_endpoint_map: &HashMap, config: Arc, telemetry_context: Arc, - active_subscriptions: Arc, + active_subscriptions: ActiveSubscriptionsMap, ) -> Result { let global_timeout = DurationOrProgram::compile( &config.traffic_shaping.all.request_timeout, @@ -159,7 +159,7 @@ impl SubgraphExecutorMap { } /// Returns the shared active subscriptions registry. - pub fn active_subscriptions(&self) -> Arc { + pub fn active_subscriptions(&self) -> ActiveSubscriptionsMap { self.active_subscriptions.clone() } diff --git a/lib/internal/src/inflight.rs b/lib/internal/src/inflight.rs index 7ac3c2237..4c1367bcf 100644 --- a/lib/internal/src/inflight.rs +++ b/lib/internal/src/inflight.rs @@ -92,10 +92,22 @@ where K: Eq + Hash + Clone, S: BuildHasher + Clone, { + /// Initialises the cell if empty (leader) or waits for the existing value (joiner). + /// + /// The leader's `init` closure receives an `InFlightCleanupGuard`. Dropping the guard removes + /// the entry from the map. For short-lived work (queries) drop it immediately. For long-lived + /// work (subscriptions) move it into the task that owns the upstream so the entry stays + /// visible to joiners for the full lifetime of the stream. + /// + /// On init failure the entry is cleaned up automatically regardless of what the caller does + /// with the guard, so no entry is left dangling. + /// + /// Joiners do not invoke `init` - they share the already-initialised value and have no cleanup + /// responsibility. #[inline] pub async fn get_or_try_init(self, init: F) -> Result<(Arc, InFlightRole), E> where - F: FnOnce() -> Fut, + F: FnOnce(InFlightCleanupGuard) -> Fut, Fut: Future>, { let mut did_initialize = false; @@ -106,25 +118,39 @@ where .cell .get_or_try_init(|| { did_initialize = true; + let guard = InFlightCleanupGuard { + key: self.key.clone(), + map: self.map.clone(), + }; async { - let _cleanup = InFlightCleanupGuard { key, map }; - init().await.map(Arc::new) + match init(guard).await { + Ok(v) => Ok(Arc::new(v)), + Err(e) => { + // clean up immediately on failure so a future request can retry + map.remove(&key); + Err(e) + } + } } }) .await? .clone(); - let role = if did_initialize { - InFlightRole::Leader + if did_initialize { + Ok((value, InFlightRole::Leader)) } else { - InFlightRole::Joiner - }; - - Ok((value, role)) + Ok((value, InFlightRole::Joiner)) + } } } -struct InFlightCleanupGuard +/// Removes the entry from the inflight map when dropped. +/// +/// For queries, drop this immediately after `get_or_try_init` returns so subsequent requests +/// are not deduplicated against a completed response. +/// For subscriptions, move this into the upstream pump task so the entry remains in the map +/// (and joiners can find it) for the full lifetime of the stream. +pub struct InFlightCleanupGuard where K: Eq + Hash, S: BuildHasher + Clone, @@ -139,9 +165,6 @@ where S: BuildHasher + Clone, { fn drop(&mut self) { - // It's important to remove the entry from the map before returning the result. - // This ensures that once the OnceCell is set, no future requests can join it. - // The cache is for the lifetime of the in-flight request only. self.map.remove(&self.key); } } From 4a447976a4975d687159293c4b7b9dc968999a2b Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Fri, 3 Apr 2026 21:57:25 +0200 Subject: [PATCH 14/42] switch to ulid fast --- Cargo.lock | 2 +- lib/executor/Cargo.toml | 2 +- lib/executor/src/executors/active_subscriptions.rs | 4 ++-- lib/executor/src/executors/http_callback.rs | 9 ++++----- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e92ca027a..00208ce72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2465,7 +2465,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tracing", - "uuid", + "ulid", "xxhash-rust", ] diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index c5cadda65..abd74c042 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -56,7 +56,7 @@ bumpalo = "3.19.0" sonic-simd = "0.1.2" async-stream = "0.3.6" futures-util = "0.3.31" -uuid = { version = "1", features = ["v4"] } +ulid = "1.2.1" [dev-dependencies] subgraphs = { path = "../../bench/subgraphs" } diff --git a/lib/executor/src/executors/active_subscriptions.rs b/lib/executor/src/executors/active_subscriptions.rs index ef3c282d8..d157bb54b 100644 --- a/lib/executor/src/executors/active_subscriptions.rs +++ b/lib/executor/src/executors/active_subscriptions.rs @@ -5,7 +5,7 @@ use std::time::Instant; use bytes::Bytes; use dashmap::DashMap; use tracing::trace; -use uuid::Uuid; +use ulid::Ulid; use crate::response::graphql_error::GraphQLError; @@ -71,7 +71,7 @@ impl ActiveSubscriptionsMap { tokio::sync::broadcast::Receiver, ListenerGuard, ) { - let id = Uuid::new_v4().to_string(); + let id = Ulid::new().to_string(); let (sender, receiver) = tokio::sync::broadcast::channel(self.inner.broadcast_capacity); let listener_count = Arc::new(AtomicUsize::new(1)); diff --git a/lib/executor/src/executors/http_callback.rs b/lib/executor/src/executors/http_callback.rs index fa9ed8ea3..979870682 100644 --- a/lib/executor/src/executors/http_callback.rs +++ b/lib/executor/src/executors/http_callback.rs @@ -9,7 +9,7 @@ use http_body_util::BodyExt; use http_body_util::Full; use hyper::Version; use tracing::{debug, error, trace}; -use uuid::Uuid; +use ulid::Ulid; use crate::executors::active_subscriptions::{ ActiveSubscriptionsMap, BroadcastItem, CallbackState, @@ -117,7 +117,7 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { BoxStream<'static, Result, SubgraphExecutorError>>, SubgraphExecutorError, > { - let verifier = Uuid::new_v4().to_string(); // TODO: doesnt have to be a UUID + let verifier = Ulid::new().to_string(); let callback_state = CallbackState { verifier: verifier.clone(), @@ -131,9 +131,8 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { )), }; - let (handle, mut receiver, guard) = self - .active_subscriptions - .register(Some(callback_state)); + let (handle, mut receiver, guard) = + self.active_subscriptions.register(Some(callback_state)); let subscription_id = handle.id().to_string(); From d2192dfd4a4790d571c6e5ea4847870b4fb710f7 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Fri, 3 Apr 2026 22:29:35 +0200 Subject: [PATCH 15/42] the great enforcer 2 --- bin/router/src/pipeline/http_callback.rs | 10 +- bin/router/src/pipeline/mod.rs | 15 ++- bin/router/src/schema_state.rs | 13 ++- bin/router/src/shared_state.rs | 4 +- .../src/executors/active_subscriptions.rs | 97 +++++++++---------- lib/executor/src/executors/http_callback.rs | 13 ++- lib/executor/src/executors/map.rs | 10 +- 7 files changed, 84 insertions(+), 78 deletions(-) diff --git a/bin/router/src/pipeline/http_callback.rs b/bin/router/src/pipeline/http_callback.rs index 6575a4141..a8778416a 100644 --- a/bin/router/src/pipeline/http_callback.rs +++ b/bin/router/src/pipeline/http_callback.rs @@ -1,6 +1,6 @@ use bytes::Bytes as BytesLib; use hive_router_plan_executor::executors::active_subscriptions::{ - ActiveSubscriptionsMap, BroadcastItem, + ActiveSubscriptions, BroadcastItem, }; use hive_router_plan_executor::executors::http_callback::{ CALLBACK_PROTOCOL_VERSION, SUBSCRIPTION_PROTOCOL_HEADER, @@ -142,7 +142,7 @@ fn validate_payload( Ok(()) } -fn handle_check(subscription_id: &str, registry: &ActiveSubscriptionsMap) { +fn handle_check(subscription_id: &str, registry: &ActiveSubscriptions) { trace!(subscription_id = %subscription_id, "Received check message"); registry.record_heartbeat(subscription_id); } @@ -150,7 +150,7 @@ fn handle_check(subscription_id: &str, registry: &ActiveSubscriptionsMap) { fn handle_next( subscription_id: &str, payload: &CallbackPayload<'_>, - registry: &ActiveSubscriptionsMap, + registry: &ActiveSubscriptions, ) -> Result<(), CallbackError> { trace!(subscription_id = %subscription_id, "Received next message"); @@ -179,7 +179,7 @@ fn handle_next( fn handle_complete( subscription_id: &str, payload: &CallbackPayload<'_>, - registry: &ActiveSubscriptionsMap, + registry: &ActiveSubscriptions, ) { trace!(subscription_id = %subscription_id, "Received complete message"); if let Some(errors) = &payload.errors { @@ -194,7 +194,7 @@ pub async fn handler( req: HttpRequest, path: Path, body: Bytes, - active_subscriptions: web::types::State, + active_subscriptions: web::types::State, ) -> Result { let subscription_id_from_path = path.into_inner(); diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 4f621a445..521e848cf 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,6 +1,6 @@ use futures::Stream; use futures::StreamExt; -use hive_router_internal::inflight::InFlightCleanupGuard; +use graphql_tools::parser::schema; use std::{ collections::HashMap, hash::{Hash, Hasher}, @@ -59,7 +59,7 @@ use crate::{ schema_state::SchemaState, shared_state::{ RouterRequestDedupeHeaderPolicy, RouterSharedState, SharedRouterResponse, - SharedRouterSingleResponse, + SharedRouterResponseGuard, SharedRouterSingleResponse, }, LABORATORY_HTML, }; @@ -379,7 +379,7 @@ async fn execute_planned_request<'exec>( operation_span: GraphQLOperationSpan, plugin_req_state: Option>, response_mode: &'exec ResponseMode, - in_flight_cleanup_guard: Option>, + guard: Option, ) -> Result { let jwt_request_details = match &shared_state.jwt_auth_runtime { Some(jwt_auth_runtime) => match jwt_auth_runtime @@ -435,6 +435,13 @@ async fn execute_planned_request<'exec>( .stream_content_type() .ok_or(PipelineError::SubscriptionsTransportNotSupported)?; + // TODO: ugly AF, to use actual type - we must remove active_subscriptions + // from the executors and move it elsehwerhwe + schema_state.active_subscriptions.register( + guard.map(|g| Box::new(g) as Box), + None, + ); + todo!(); } QueryPlanExecutionResult::Single(result) => { @@ -444,7 +451,7 @@ async fn execute_planned_request<'exec>( ok_or(PipelineError::UnsupportedContentType)?; // drop the inflight planned request as soon as the response is ready - let _query_guard = in_flight_cleanup_guard; + let _query_guard = guard; let error_count = result.error_count; let mut response_builder = web::HttpResponse::Ok(); diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index bfcf8f286..1ed802dc3 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -10,7 +10,7 @@ use hive_router_internal::{ background_tasks::{BackgroundTask, BackgroundTasksManager}, }; use hive_router_plan_executor::{ - executors::active_subscriptions::{ActiveSubscriptionsMap, BroadcastItem}, + executors::active_subscriptions::{ActiveSubscriptions, BroadcastItem}, executors::error::SubgraphExecutorError, hooks::on_supergraph_load::{ OnSupergraphLoadEndHookPayload, OnSupergraphLoadStartHookPayload, SupergraphData, @@ -47,7 +47,7 @@ pub struct SchemaState { pub validate_cache: Cache>>, pub normalize_cache: Cache>, pub telemetry_context: Arc, - pub active_subscriptions: ActiveSubscriptionsMap, + pub active_subscriptions: ActiveSubscriptions, } #[derive(Debug, thiserror::Error)] @@ -97,9 +97,8 @@ impl SchemaState { let plan_cache = cache_state.plan_cache.clone(); let validate_cache = cache_state.validate_cache.clone(); let normalize_cache = cache_state.normalize_cache.clone(); - let active_subscriptions = ActiveSubscriptionsMap::new( - router_config.subscriptions.broadcast_capacity, - ); + let active_subscriptions = + ActiveSubscriptions::new(router_config.subscriptions.broadcast_capacity); // This is cheap clone, as Cache is thread-safe and can be cloned without any performance penalty. let cache_state_for_invalidation = cache_state.clone(); @@ -245,7 +244,7 @@ impl SchemaState { router_config: Arc, telemetry_context: Arc, parsed_supergraph_sdl: Document, - active_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: ActiveSubscriptions, ) -> Result { let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; let metadata = Arc::new(planner.consumer_schema.schema_metadata()); @@ -346,7 +345,7 @@ impl BackgroundTask for SupergraphBackgroundLoaderTask { } struct HeartbeatEnforcerTask { - active_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: ActiveSubscriptions, heartbeat_interval: Duration, } diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 86f682977..b47ea70f6 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -7,7 +7,7 @@ use hive_router_config::traffic_shaping::{ use hive_router_config::HiveRouterConfig; use hive_router_internal::expressions::values::boolean::BooleanOrProgram; use hive_router_internal::expressions::ExpressionCompileError; -use hive_router_internal::inflight::InFlightMap; +use hive_router_internal::inflight::{InFlightCleanupGuard, InFlightMap}; use hive_router_internal::telemetry::TelemetryContext; use hive_router_plan_executor::execution::plan::FailedExecutionResult; use hive_router_plan_executor::headers::{ @@ -84,6 +84,8 @@ impl From<&TrafficShapingRouterDedupeHeadersConfig> for RouterRequestDedupeHeade } } +pub type SharedRouterResponseGuard = InFlightCleanupGuard; + #[derive(Clone)] pub enum SharedRouterResponse { Single(SharedRouterSingleResponse), diff --git a/lib/executor/src/executors/active_subscriptions.rs b/lib/executor/src/executors/active_subscriptions.rs index d157bb54b..8b3e3a116 100644 --- a/lib/executor/src/executors/active_subscriptions.rs +++ b/lib/executor/src/executors/active_subscriptions.rs @@ -13,15 +13,14 @@ pub type SubscriptionId = String; #[derive(Clone, Debug)] pub enum BroadcastItem { - /// a normal subscription event from the upstream, already serialized. - /// uses Bytes for zero-copy cloning across broadcast receivers + /// A normal subscription event from the upstream, already serialized. + /// Uses Bytes for zero-copy cloning across broadcast receivers. Event(Bytes), - /// a terminal error pushed externally (e.g. supergraph reload, shutdown). - /// consumers should yield this as the final event and then stop + /// An error pushed externally (e.g. supergraph reload, shutdown). + /// Consumers should yield this as the final event and then stop. Error(Vec), } -/// state specific to http callback subscriptions pub struct CallbackState { pub verifier: String, pub last_heartbeat: Arc>, @@ -33,61 +32,60 @@ impl CallbackState { } } -struct ActiveSubscriptionEntry { +struct Subscription { sender: tokio::sync::broadcast::Sender, + /// The optional callback state that is only present for http callback subscriptions. callback_state: Option, } -struct ActiveSubscriptionsInner { - subscriptions: DashMap, - // capacity of the broadcast channel per subscription, see router config `subscriptions.broadcast_capacity` - broadcast_capacity: usize, -} - -/// cheap to clone - all clones share the same inner state #[derive(Clone)] -pub struct ActiveSubscriptionsMap { - inner: Arc, +pub struct ActiveSubscriptions { + // map of subscription ids to their sender + map: Arc>, + // capacity of the broadcast channel per subscription, + // see router config `subscriptions.broadcast_capacity` + broadcast_capacity: usize, } -impl ActiveSubscriptionsMap { +impl ActiveSubscriptions { pub fn new(broadcast_capacity: usize) -> Self { Self { - inner: Arc::new(ActiveSubscriptionsInner { - subscriptions: DashMap::new(), - broadcast_capacity, - }), + map: Arc::new(DashMap::new()), + broadcast_capacity, } } - /// Register a new subscription (e.g. http callbacks). - /// Always creates a new entry. Deduplication for fingerprinted subscriptions is handled - /// by the inflight map in the request pipeline, not here. + /// Register a new subscription to be used for broadcasting events to consuming clients. + /// Returns a handle for the producer to send events, a receiver for consumers to subscribe + /// to, and a guard that each consumer must hold onto while consuming. pub fn register( &self, + guard: Option>, callback_state: Option, ) -> ( - SubscriptionHandle, + ProducerHandle, tokio::sync::broadcast::Receiver, - ListenerGuard, + ConsumerGuard, ) { - let id = Ulid::new().to_string(); - let (sender, receiver) = tokio::sync::broadcast::channel(self.inner.broadcast_capacity); let listener_count = Arc::new(AtomicUsize::new(1)); + let (sender, receiver) = tokio::sync::broadcast::channel(self.broadcast_capacity); + + let id = Ulid::new().to_string(); - self.inner.subscriptions.insert( + self.map.insert( id.clone(), - ActiveSubscriptionEntry { + Subscription { sender, callback_state, }, ); - let handle = SubscriptionHandle { + let handle = ProducerHandle { id: id.clone(), map: self.clone(), + _guard: guard, }; - let guard = ListenerGuard { + let guard = ConsumerGuard { id: id.clone(), map: self.clone(), listener_count, @@ -100,20 +98,19 @@ impl ActiveSubscriptionsMap { /// check if a subscription exists pub fn contains(&self, id: &str) -> bool { - self.inner.subscriptions.contains_key(id) + self.map.contains_key(id) } /// get the verifier for a callback subscription pub fn get_callback_verifier(&self, id: &str) -> Option { - self.inner - .subscriptions + self.map .get(id) .and_then(|entry| entry.callback_state.as_ref().map(|cs| cs.verifier.clone())) } /// record a heartbeat for a callback subscription pub fn record_heartbeat(&self, id: &str) -> bool { - if let Some(entry) = self.inner.subscriptions.get(id) { + if let Some(entry) = self.map.get(id) { if let Some(ref cs) = entry.callback_state { cs.record_heartbeat(); return true; @@ -124,7 +121,7 @@ impl ActiveSubscriptionsMap { /// send an event to a specific subscription's broadcast channel pub fn send_event(&self, id: &str, item: BroadcastItem) -> bool { - if let Some(entry) = self.inner.subscriptions.get(id) { + if let Some(entry) = self.map.get(id) { // if the channel is closed or full it means the consuming client is gone or too slow and // unable to keep up. in both cases, we dont emit an error messages because it anyways cant // go through @@ -136,23 +133,23 @@ impl ActiveSubscriptionsMap { /// remove a subscription entry pub fn remove(&self, id: &str) { - self.inner.subscriptions.remove(id); + self.map.remove(id); } /// close all active subscriptions with an error message pub fn close_all_with_error(&self, errors: Vec) { let item = BroadcastItem::Error(errors); - for entry in self.inner.subscriptions.iter() { + for entry in self.map.iter() { let _ = entry.sender.send(item.clone()); } - self.inner.subscriptions.clear(); + self.map.clear(); } /// iterate over all subscription ids and their callback state for heartbeat enforcement pub fn iter_callback_subscriptions( &self, ) -> impl Iterator>)> + '_ { - self.inner.subscriptions.iter().filter_map(|entry| { + self.map.iter().filter_map(|entry| { entry .callback_state .as_ref() @@ -162,15 +159,17 @@ impl ActiveSubscriptionsMap { } /// Held by the upstream producer (the task that reads from the subgraph). +/// It is the actual subscription handle that can be used to send events to consumers. /// Dropping this removes the subscription entry from the registry, which drops /// the broadcast sender and closes the channel. All receivers will see `Closed` /// and their streams will end naturally. -pub struct SubscriptionHandle { +pub struct ProducerHandle { id: SubscriptionId, - map: ActiveSubscriptionsMap, + map: ActiveSubscriptions, + _guard: Option>, } -impl SubscriptionHandle { +impl ProducerHandle { pub fn id(&self) -> &str { &self.id } @@ -180,7 +179,7 @@ impl SubscriptionHandle { } } -impl Drop for SubscriptionHandle { +impl Drop for ProducerHandle { fn drop(&mut self) { // removing the entry drops the broadcast sender inside it, closing the channel. // all receivers will see Closed and their streams will end naturally @@ -189,22 +188,22 @@ impl Drop for SubscriptionHandle { } } -/// Held by each consumer of a subscription. On drop, decrements the listener count. +/// Held by each consumer of a subscription (producer). On drop, decrements the listener count. /// When the last guard drops and the subscription entry still exists (upstream hasn't dropped /// yet), removes it - causing the upstream producer's `send()` to return `false` and exit. -pub struct ListenerGuard { +pub struct ConsumerGuard { id: SubscriptionId, - map: ActiveSubscriptionsMap, + map: ActiveSubscriptions, listener_count: Arc, } -impl Drop for ListenerGuard { +impl Drop for ConsumerGuard { fn drop(&mut self) { let prev = self.listener_count.fetch_sub(1, Ordering::AcqRel); if prev == 1 { // last listener gone, clean up. this also drops the sender, // causing the upstream producer's send() to return false - self.map.inner.subscriptions.remove(&self.id); + self.map.map.remove(&self.id); trace!(subscription_id = %self.id, "last listener dropped, subscription removed"); } else { trace!(subscription_id = %self.id, remaining = prev - 1, "listener dropped"); diff --git a/lib/executor/src/executors/http_callback.rs b/lib/executor/src/executors/http_callback.rs index 979870682..09d756e5a 100644 --- a/lib/executor/src/executors/http_callback.rs +++ b/lib/executor/src/executors/http_callback.rs @@ -11,9 +11,7 @@ use hyper::Version; use tracing::{debug, error, trace}; use ulid::Ulid; -use crate::executors::active_subscriptions::{ - ActiveSubscriptionsMap, BroadcastItem, CallbackState, -}; +use crate::executors::active_subscriptions::{ActiveSubscriptions, BroadcastItem, CallbackState}; use crate::executors::common::{SubgraphExecutionRequest, SubgraphExecutor}; use crate::executors::error::SubgraphExecutorError; use crate::executors::http::{build_request_body, HttpClient}; @@ -30,7 +28,7 @@ pub struct HttpCallbackSubgraphExecutor { pub header_map: HeaderMap, pub callback_base_url: String, pub heartbeat_interval_ms: u64, - pub active_subscriptions: ActiveSubscriptionsMap, + pub active_subscriptions: ActiveSubscriptions, } impl HttpCallbackSubgraphExecutor { @@ -40,7 +38,7 @@ impl HttpCallbackSubgraphExecutor { http_client: Arc, callback_base_url: String, heartbeat_interval_ms: u64, - active_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: ActiveSubscriptions, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -131,8 +129,9 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { )), }; - let (handle, mut receiver, guard) = - self.active_subscriptions.register(Some(callback_state)); + let (handle, mut receiver, guard) = self + .active_subscriptions + .register(None, Some(callback_state)); let subscription_id = handle.id().to_string(); diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 708b8752b..6c834f4f4 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -27,7 +27,7 @@ use tokio::sync::Semaphore; use crate::{ execution::client_request_details::ClientRequestDetails, executors::{ - active_subscriptions::ActiveSubscriptionsMap, + active_subscriptions::ActiveSubscriptions, common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient, SubgraphHttpResponse}, @@ -76,7 +76,7 @@ pub struct SubgraphExecutorMap { in_flight_requests: InflightRequestsMap, telemetry_context: Arc, /// Shared registry of all active subscriptions (http streaming, websocket, http callback) - active_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: ActiveSubscriptions, } fn build_https_executor() -> Result, SubgraphExecutorError> { @@ -114,7 +114,7 @@ impl SubgraphExecutorMap { timeouts_by_subgraph: Default::default(), global_timeout, telemetry_context, - active_subscriptions: ActiveSubscriptionsMap::new(broadcast_capacity), + active_subscriptions: ActiveSubscriptions::new(broadcast_capacity), }) } @@ -122,7 +122,7 @@ impl SubgraphExecutorMap { subgraph_endpoint_map: &HashMap, config: Arc, telemetry_context: Arc, - active_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: ActiveSubscriptions, ) -> Result { let global_timeout = DurationOrProgram::compile( &config.traffic_shaping.all.request_timeout, @@ -159,7 +159,7 @@ impl SubgraphExecutorMap { } /// Returns the shared active subscriptions registry. - pub fn active_subscriptions(&self) -> ActiveSubscriptionsMap { + pub fn active_subscriptions(&self) -> ActiveSubscriptions { self.active_subscriptions.clone() } From 6923ad46307e54638402b6ccbf53e40f344553af Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Fri, 3 Apr 2026 22:49:12 +0200 Subject: [PATCH 16/42] excuse me please work its still great --- bin/router/src/pipeline/mod.rs | 86 ++++++++++++------- bin/router/src/shared_state.rs | 5 +- .../src/executors/active_subscriptions.rs | 9 +- lib/executor/src/executors/http_callback.rs | 2 +- 4 files changed, 65 insertions(+), 37 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 521e848cf..c456d8839 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -29,7 +29,10 @@ use hive_router_query_planner::{ state::supergraph_state::OperationKind, utils::cancellation::CancellationToken, }; use http::{header::CONTENT_TYPE, Method}; -use ntex::web::{self, HttpRequest}; +use ntex::{ + rt, + web::{self, HttpRequest}, +}; use sonic_rs::{JsonContainerTrait, JsonType, JsonValueTrait, Value}; use crate::{ @@ -59,7 +62,7 @@ use crate::{ schema_state::SchemaState, shared_state::{ RouterRequestDedupeHeaderPolicy, RouterSharedState, SharedRouterResponse, - SharedRouterResponseGuard, SharedRouterSingleResponse, + SharedRouterResponseGuard, SharedRouterSingleResponse, SharedRouterStreamResponse, }, LABORATORY_HTML, }; @@ -433,46 +436,67 @@ async fn execute_planned_request<'exec>( QueryPlanExecutionResult::Stream(result) => { let stream_content_type = response_mode .stream_content_type() - .ok_or(PipelineError::SubscriptionsTransportNotSupported)?; - - // TODO: ugly AF, to use actual type - we must remove active_subscriptions - // from the executors and move it elsehwerhwe - schema_state.active_subscriptions.register( - guard.map(|g| Box::new(g) as Box), - None, - ); + .ok_or(PipelineError::SubscriptionsTransportNotSupported)? + .clone(); + + let (producer_handle, sender, _receiver, _consumer_guard) = + schema_state.active_subscriptions.register( + // TODO: ugly AF (and unnecessary), to use actual type - we must + // remove active_subscriptions from the executors and move it elsehwerhwe + guard.map(|g| Box::new(g) as Box), + None, + ); + + let mut body_stream = result.body; + rt::spawn(async move { + while let Some(chunk) = body_stream.next().await { + if !producer_handle.send(BroadcastItem::Event(bytes::Bytes::from(chunk))) { + // all receivers gone, stop draining + break; + } + } + // dropping producer_handle closes the broadcast channel + }); - todo!(); + let headers = if let Some(aggregator) = result.response_headers_aggregator { + let mut builder = web::HttpResponse::Ok(); + aggregator.modify_client_response_headers(&mut builder)?; + Arc::new(builder.finish().headers().clone()) + } else { + Arc::new(ntex::http::HeaderMap::new()) + }; + + Ok(SharedRouterResponse::Stream(SharedRouterStreamResponse { + body: sender, + headers, + stream_content_type, + error_count: result.error_count, + })) } QueryPlanExecutionResult::Single(result) => { let single_content_type = response_mode. single_content_type(). // TODO: streaming single responses - ok_or(PipelineError::UnsupportedContentType)?; + ok_or(PipelineError::UnsupportedContentType)?. + clone(); - // drop the inflight planned request as soon as the response is ready + // drop the router shared request as soon as the response is ready let _query_guard = guard; - let error_count = result.error_count; - let mut response_builder = web::HttpResponse::Ok(); - - if let Some(response_headers_aggregator) = result.response_headers_aggregator { - response_headers_aggregator - .modify_client_response_headers(&mut response_builder)?; - } - - let body = ntex::util::Bytes::from(result.body); - - let response = response_builder - .content_type(single_content_type.as_ref()) - .status(result.status_code) - .body(body.clone()); + let headers = if let Some(aggregator) = result.response_headers_aggregator { + let mut builder = web::HttpResponse::Ok(); + aggregator.modify_client_response_headers(&mut builder)?; + Arc::new(builder.finish().headers().clone()) + } else { + Arc::new(ntex::http::HeaderMap::new()) + }; Ok(SharedRouterResponse::Single(SharedRouterSingleResponse { - body, - headers: Arc::new(response.headers().clone()), - status: response.status(), - error_count, + body: ntex::util::Bytes::from(result.body), + headers, + single_content_type, + status: result.status_code, + error_count: result.error_count, })) } } diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index b47ea70f6..d691c5912 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -28,7 +28,7 @@ use crate::cache_state::CacheState; use crate::jwt::context::JwtTokenPayload; use crate::jwt::JwtAuthRuntime; use crate::pipeline::cors::{CORSConfigError, Cors}; -use crate::pipeline::header::StreamContentType; +use crate::pipeline::header::{SingleContentType, StreamContentType}; use crate::pipeline::introspection_policy::compile_introspection_policy; use crate::pipeline::multipart_subscribe::{ self, APOLLO_MULTIPART_HTTP_CONTENT_TYPE, INCREMENTAL_DELIVERY_CONTENT_TYPE, @@ -89,7 +89,7 @@ pub type SharedRouterResponseGuard = InFlightCleanupGuard, pub status: StatusCode, + pub single_content_type: SingleContentType, pub error_count: usize, } diff --git a/lib/executor/src/executors/active_subscriptions.rs b/lib/executor/src/executors/active_subscriptions.rs index 8b3e3a116..16e143b72 100644 --- a/lib/executor/src/executors/active_subscriptions.rs +++ b/lib/executor/src/executors/active_subscriptions.rs @@ -56,19 +56,22 @@ impl ActiveSubscriptions { } /// Register a new subscription to be used for broadcasting events to consuming clients. - /// Returns a handle for the producer to send events, a receiver for consumers to subscribe - /// to, and a guard that each consumer must hold onto while consuming. + /// Returns a handle for the producer to send events, a sender that can be cloned to subscribe + /// new receivers, a receiver for consumers to subscribe to, and a guard that each consumer + /// must hold onto while consuming. pub fn register( &self, guard: Option>, callback_state: Option, ) -> ( ProducerHandle, + tokio::sync::broadcast::Sender, tokio::sync::broadcast::Receiver, ConsumerGuard, ) { let listener_count = Arc::new(AtomicUsize::new(1)); let (sender, receiver) = tokio::sync::broadcast::channel(self.broadcast_capacity); + let sender_clone = sender.clone(); let id = Ulid::new().to_string(); @@ -93,7 +96,7 @@ impl ActiveSubscriptions { trace!(subscription_id = %id, "registered new subscription"); - (handle, receiver, guard) + (handle, sender_clone, receiver, guard) } /// check if a subscription exists diff --git a/lib/executor/src/executors/http_callback.rs b/lib/executor/src/executors/http_callback.rs index 09d756e5a..d6f2dcefd 100644 --- a/lib/executor/src/executors/http_callback.rs +++ b/lib/executor/src/executors/http_callback.rs @@ -129,7 +129,7 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { )), }; - let (handle, mut receiver, guard) = self + let (handle, _sender, mut receiver, guard) = self .active_subscriptions .register(None, Some(callback_state)); From 0ff715692b092917ec2e61de9ea1db2ccfa42cba Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Fri, 3 Apr 2026 23:05:47 +0200 Subject: [PATCH 17/42] close the gap --- bin/router/src/pipeline/mod.rs | 7 ++++++- bin/router/src/shared_state.rs | 32 +++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index c456d8839..2d5fc2abd 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -439,7 +439,7 @@ async fn execute_planned_request<'exec>( .ok_or(PipelineError::SubscriptionsTransportNotSupported)? .clone(); - let (producer_handle, sender, _receiver, _consumer_guard) = + let (producer_handle, sender, bootstrap_receiver, consumer_guard) = schema_state.active_subscriptions.register( // TODO: ugly AF (and unnecessary), to use actual type - we must // remove active_subscriptions from the executors and move it elsehwerhwe @@ -449,6 +449,10 @@ async fn execute_planned_request<'exec>( let mut body_stream = result.body; rt::spawn(async move { + // keep consumer_guard alive for the duration of the pump so it doesn't + // decrement the listener count to 0 and remove the subscription entry + // from active_subscriptions before the producer has a chance to send + let _consumer_guard = consumer_guard; while let Some(chunk) = body_stream.next().await { if !producer_handle.send(BroadcastItem::Event(bytes::Bytes::from(chunk))) { // all receivers gone, stop draining @@ -471,6 +475,7 @@ async fn execute_planned_request<'exec>( headers, stream_content_type, error_count: result.error_count, + bootstrap_receiver: Some(bootstrap_receiver), })) } QueryPlanExecutionResult::Single(result) => { diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index d691c5912..51ebac334 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -132,19 +132,49 @@ impl From for web::HttpResponse { } } -#[derive(Clone)] pub struct SharedRouterStreamResponse { // status is always 200 for streaming responses, errors are sent through the stream pub body: tokio::sync::broadcast::Sender, pub headers: Arc, pub stream_content_type: StreamContentType, pub error_count: usize, + // only set for the leader (the request that actually opens the upstream). the pump task + // is spawned before this response is returned to the caller, so there is a window between + // the spawn and the leader eventually calling body.subscribe() where the pump could send + // events to a channel with no receivers. broadcast does buffer sent events in its internal + // ring buffer, but a receiver created via subscribe() only sees events sent after it was + // created - it cannot read anything already in the buffer. so even though the events are + // technically buffered, the real consumer would miss them. worse, if there are zero + // receivers at send time, broadcast returns an error and the event is gone entirely, not + // even buffered. this receiver is created before the pump spawns, keeping the receiver + // count above zero during that window. it is dropped inside From + // only after the real consumer receiver has been created via body.subscribe(), which + // guarantees no event is sent to an empty channel and no event is missed. None on clones + // because joiners are already late subscribers and create their own receiver when they + // call body.subscribe(). + pub bootstrap_receiver: Option>, +} + +impl Clone for SharedRouterStreamResponse { + fn clone(&self) -> Self { + Self { + body: self.body.clone(), + headers: self.headers.clone(), + stream_content_type: self.stream_content_type.clone(), + error_count: self.error_count, + bootstrap_receiver: None, + } + } } impl From for web::HttpResponse { fn from(shared_response: SharedRouterStreamResponse) -> Self { let mut receiver = shared_response.body.subscribe(); + // drop the bootstrap receiver only after the real consumer receiver is created above, + // closing the gap where events could be lost between pump spawn and subscribe() + drop(shared_response.bootstrap_receiver); + let stream = Box::pin(async_stream::stream! { loop { match receiver.recv().await { From c898f7d93889567d8a22a789f6df1c645df46a13 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Fri, 3 Apr 2026 23:16:00 +0200 Subject: [PATCH 18/42] WIP --- lib/executor/src/executors/http_callback.rs | 121 +++++++++++++------- lib/executor/src/executors/map.rs | 22 ++-- 2 files changed, 89 insertions(+), 54 deletions(-) diff --git a/lib/executor/src/executors/http_callback.rs b/lib/executor/src/executors/http_callback.rs index d6f2dcefd..fa6d75d9f 100644 --- a/lib/executor/src/executors/http_callback.rs +++ b/lib/executor/src/executors/http_callback.rs @@ -3,24 +3,63 @@ use std::time::{Duration, Instant}; use async_trait::async_trait; use bytes::Bytes; +use dashmap::DashMap; use futures::stream::BoxStream; use http::{HeaderMap, HeaderValue}; use http_body_util::BodyExt; use http_body_util::Full; use hyper::Version; +use tokio::sync::mpsc; use tracing::{debug, error, trace}; use ulid::Ulid; -use crate::executors::active_subscriptions::{ActiveSubscriptions, BroadcastItem, CallbackState}; -use crate::executors::common::{SubgraphExecutionRequest, SubgraphExecutor}; +use crate::executors::common::{ + SubgraphExecutionRequest, SubgraphExecutor, SUBSCRIPTION_EVENT_BUFFER_CAPACITY, +}; use crate::executors::error::SubgraphExecutorError; use crate::executors::http::{build_request_body, HttpClient}; use crate::plugin_context::PluginRequestState; +use crate::response::graphql_error::GraphQLError; use crate::response::subgraph_response::SubgraphResponse; pub const CALLBACK_PROTOCOL_VERSION: &str = "callback/1.0"; pub const SUBSCRIPTION_PROTOCOL_HEADER: &str = "subscription-protocol"; +type SubscriptionId = String; + +#[derive(Clone)] +pub struct ActiveSubscription { + pub verifier: String, + pub sender: mpsc::Sender, + pub last_heartbeat: Arc>, +} + +impl ActiveSubscription { + pub fn record_heartbeat(&self) { + *self.last_heartbeat.lock().unwrap() = Instant::now(); + } +} + +#[derive(Debug)] +pub enum CallbackMessage { + Next { payload: Bytes }, + Complete { errors: Option> }, +} + +pub type ActiveSubscriptionsMap = Arc>; + +struct SubscriptionGuard { + subscription_id: SubscriptionId, + active_subscriptions: ActiveSubscriptionsMap, +} + +impl Drop for SubscriptionGuard { + fn drop(&mut self) { + self.active_subscriptions.remove(&self.subscription_id); + trace!(subscription_id = %self.subscription_id, "HTTP callback subscription entry removed from active subscriptions"); + } +} + pub struct HttpCallbackSubgraphExecutor { pub subgraph_name: String, pub endpoint: http::Uri, @@ -28,7 +67,7 @@ pub struct HttpCallbackSubgraphExecutor { pub header_map: HeaderMap, pub callback_base_url: String, pub heartbeat_interval_ms: u64, - pub active_subscriptions: ActiveSubscriptions, + pub active_subscriptions: ActiveSubscriptionsMap, } impl HttpCallbackSubgraphExecutor { @@ -38,7 +77,7 @@ impl HttpCallbackSubgraphExecutor { http_client: Arc, callback_base_url: String, heartbeat_interval_ms: u64, - active_subscriptions: ActiveSubscriptions, + active_subscriptions: ActiveSubscriptionsMap, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -115,27 +154,33 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { BoxStream<'static, Result, SubgraphExecutorError>>, SubgraphExecutorError, > { + let subscription_id = Ulid::new().to_string(); let verifier = Ulid::new().to_string(); - let callback_state = CallbackState { - verifier: verifier.clone(), - // initialize last_heartbeat to now + heartbeat_interval so the enforcer - // won't evict the subscription before the subgraph's initial check arrives. - // the initial check from the subgraph can take up to heartbeat_interval to - // arrive (due to network latency), and without this head start the enforcer - // would evict the subscription before the first heartbeat is recorded. - last_heartbeat: Arc::new(Mutex::new( - Instant::now() + Duration::from_millis(self.heartbeat_interval_ms), - )), - }; - - let (handle, _sender, mut receiver, guard) = self - .active_subscriptions - .register(None, Some(callback_state)); + let body = self.build_request_body(&mut execution_request, &subscription_id, &verifier)?; - let subscription_id = handle.id().to_string(); + let (tx, mut rx) = mpsc::channel::(SUBSCRIPTION_EVENT_BUFFER_CAPACITY); + self.active_subscriptions.insert( + subscription_id.clone(), + ActiveSubscription { + verifier, + sender: tx, + // initialize last_heartbeat to now + heartbeat_interval so the enforcer + // won't evict the subscription before the subgraph's initial check arrives. + // the initial check from the subgraph can take up to heartbeat_interval to + // arrive (due to network latency), and without this head start the enforcer + // would evict the subscription before the first heartbeat is recorded. + last_heartbeat: Arc::new(Mutex::new( + Instant::now() + Duration::from_millis(self.heartbeat_interval_ms), + )), + }, + ); - let body = self.build_request_body(&mut execution_request, &subscription_id, &verifier)?; + // guard removes the entry from `active_subscriptions` when dropped + let guard = SubscriptionGuard { + subscription_id: subscription_id.clone(), + active_subscriptions: self.active_subscriptions.clone(), + }; let mut req = hyper::Request::builder() .method(http::Method::POST) @@ -199,15 +244,14 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { } Ok(Box::pin(async_stream::stream! { - // hold the handle and guard so the subscription entry is removed when the stream ends - let _handle = handle; + // `guard` is held here; dropping the stream drops `guard`, removing the map entry. let _guard = guard; trace!(subscription_id = %subscription_id, "HTTP callback subscription stream started"); - loop { - match receiver.recv().await { - Ok(BroadcastItem::Event(payload)) => { + while let Some(msg) = rx.recv().await { + match msg { + CallbackMessage::Next { payload } => { trace!(subscription_id = %subscription_id, "received next payload"); match SubgraphResponse::deserialize_from_bytes(payload) { Ok(response) => yield Ok(response), @@ -222,25 +266,18 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { } } } - Ok(BroadcastItem::Error(errors)) => { - trace!(subscription_id = %subscription_id, "received close with error"); - if !errors.is_empty() { - yield Ok(SubgraphResponse { - errors: Some(errors), - ..Default::default() - }); + CallbackMessage::Complete { errors } => { + trace!(subscription_id = %subscription_id, "received complete"); + if let Some(errors) = errors { + if !errors.is_empty() { + yield Ok(SubgraphResponse { + errors: Some(errors), + ..Default::default() + }); + } } break; } - Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { - // slow consumer, skip missed messages and continue - trace!(subscription_id = %subscription_id, lagged = n, "broadcast receiver lagged, skipping missed messages"); - continue; - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => { - trace!(subscription_id = %subscription_id, "broadcast channel closed"); - break; - } } } diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 6c834f4f4..6e04d575c 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -27,11 +27,10 @@ use tokio::sync::Semaphore; use crate::{ execution::client_request_details::ClientRequestDetails, executors::{ - active_subscriptions::ActiveSubscriptions, common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient, SubgraphHttpResponse}, - http_callback::HttpCallbackSubgraphExecutor, + http_callback::{ActiveSubscriptionsMap, HttpCallbackSubgraphExecutor}, websocket::WsSubgraphExecutor, }, hooks::on_subgraph_execute::{ @@ -75,8 +74,8 @@ pub struct SubgraphExecutorMap { max_connections_per_host: usize, in_flight_requests: InflightRequestsMap, telemetry_context: Arc, - /// Shared registry of all active subscriptions (http streaming, websocket, http callback) - active_subscriptions: ActiveSubscriptions, + /// Shared map of active HTTP callback subscriptions + active_callback_subscriptions: ActiveSubscriptionsMap, } fn build_https_executor() -> Result, SubgraphExecutorError> { @@ -99,7 +98,6 @@ impl SubgraphExecutorMap { .build(build_https_executor()?); let max_connections_per_host = config.traffic_shaping.max_connections_per_host; - let broadcast_capacity = config.subscriptions.broadcast_capacity; Ok(SubgraphExecutorMap { http_executors_by_subgraph: Default::default(), @@ -114,7 +112,7 @@ impl SubgraphExecutorMap { timeouts_by_subgraph: Default::default(), global_timeout, telemetry_context, - active_subscriptions: ActiveSubscriptions::new(broadcast_capacity), + active_callback_subscriptions: Arc::new(DashMap::new()), }) } @@ -122,7 +120,7 @@ impl SubgraphExecutorMap { subgraph_endpoint_map: &HashMap, config: Arc, telemetry_context: Arc, - active_subscriptions: ActiveSubscriptions, + active_callback_subscriptions: ActiveSubscriptionsMap, ) -> Result { let global_timeout = DurationOrProgram::compile( &config.traffic_shaping.all.request_timeout, @@ -133,7 +131,7 @@ impl SubgraphExecutorMap { })?; let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), global_timeout, telemetry_context)?; - subgraph_executor_map.active_subscriptions = active_subscriptions; + subgraph_executor_map.active_callback_subscriptions = active_callback_subscriptions; for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.iter() { let endpoint_config = config @@ -158,9 +156,9 @@ impl SubgraphExecutorMap { Ok(subgraph_executor_map) } - /// Returns the shared active subscriptions registry. - pub fn active_subscriptions(&self) -> ActiveSubscriptions { - self.active_subscriptions.clone() + /// Returns the shared active callback subscriptions map for use by callback handlers. + pub fn active_callback_subscriptions(&self) -> ActiveSubscriptionsMap { + self.active_callback_subscriptions.clone() } pub async fn execute<'exec>( @@ -507,7 +505,7 @@ impl SubgraphExecutorMap { self.client.clone(), callback_config.public_url.to_string(), heartbeat_interval_ms, - self.active_subscriptions.clone(), + self.active_callback_subscriptions.clone(), ) .to_boxed_arc(); From 7cbbc15790e8d11644f1a2316a229d543eb3e411 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Fri, 3 Apr 2026 23:51:39 +0200 Subject: [PATCH 19/42] now we're talking, the great divide --- bin/router/src/lib.rs | 13 ++- .../src/pipeline}/active_subscriptions.rs | 62 ++---------- bin/router/src/pipeline/http_callback.rs | 95 +++++++++++-------- bin/router/src/pipeline/mod.rs | 37 +++----- bin/router/src/pipeline/websocket_server.rs | 4 - bin/router/src/schema_state.rs | 63 ++++++------ bin/router/src/shared_state.rs | 8 +- e2e/src/testkit/mod.rs | 12 +-- lib/executor/src/executors/common.rs | 13 --- lib/executor/src/executors/http_callback.rs | 39 ++++---- lib/executor/src/executors/map.rs | 16 ++-- lib/executor/src/executors/mod.rs | 1 - lib/executor/src/executors/websocket.rs | 13 +-- 13 files changed, 164 insertions(+), 212 deletions(-) rename {lib/executor/src/executors => bin/router/src/pipeline}/active_subscriptions.rs (74%) diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 44baeab28..06bf96266 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -22,6 +22,7 @@ use crate::{ }, jwt::JwtAuthRuntime, pipeline::{ + active_subscriptions::ActiveSubscriptions, error::handle_pipeline_error, graphql_request_handler, header::ResponseMode, @@ -213,7 +214,7 @@ pub async fn router_entrypoint(plugin_registry: PluginRegistry) -> Result<(), Ro .await?; let shared_state_clone = shared_state.clone(); - let active_subs = schema_state.active_subscriptions.clone(); + let callback_subscriptions_for_handler = schema_state.callback_subscriptions.clone(); // when `listen` is set, the callback route lives on a dedicated server bound to that address // otherwise, the callback route is mounted on the main server on the `callback_path` @@ -225,12 +226,12 @@ pub async fn router_entrypoint(plugin_registry: PluginRegistry) -> Result<(), Ro }) => { let cb_path = path.to_string(); let cb_addr = listen.to_string(); - let cb_active_subs = active_subs.clone(); + let cb_subs = callback_subscriptions_for_handler.clone(); let cb_server = web::HttpServer::new(async move || { - let cb_active_subs = cb_active_subs.clone(); + let cb_subs = cb_subs.clone(); let cb_path = cb_path.clone(); web::App::new() - .state(cb_active_subs) + .state(cb_subs) .configure(move |m| add_callback_handler(m, &cb_path)) }) .bind(&cb_addr) @@ -316,6 +317,8 @@ pub async fn configure_app_from_config( }; let plugins_arc = plugin_registry.initialize_plugins(&router_config, bg_tasks_manager)?; + let active_subscriptions = + ActiveSubscriptions::new(router_config.subscriptions.broadcast_capacity); let router_config_arc = Arc::new(router_config); let telemetry_context_arc = Arc::new(telemetry_context); let cache_state = Arc::new(CacheState::new()); @@ -330,6 +333,7 @@ pub async fn configure_app_from_config( router_config_arc.clone(), plugins_arc.clone(), cache_state.clone(), + active_subscriptions.clone(), ) .await?; let schema_state_arc = Arc::new(schema_state); @@ -357,6 +361,7 @@ pub async fn configure_app_from_config( telemetry_context_arc, plugins_arc, cache_state, + active_subscriptions.clone(), )?); Ok((shared_state, schema_state_arc)) diff --git a/lib/executor/src/executors/active_subscriptions.rs b/bin/router/src/pipeline/active_subscriptions.rs similarity index 74% rename from lib/executor/src/executors/active_subscriptions.rs rename to bin/router/src/pipeline/active_subscriptions.rs index 16e143b72..212f2176b 100644 --- a/lib/executor/src/executors/active_subscriptions.rs +++ b/bin/router/src/pipeline/active_subscriptions.rs @@ -1,13 +1,13 @@ use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Mutex}; -use std::time::Instant; +use std::sync::Arc; use bytes::Bytes; use dashmap::DashMap; +use hive_router_plan_executor::response::graphql_error::GraphQLError; use tracing::trace; use ulid::Ulid; -use crate::response::graphql_error::GraphQLError; +use crate::shared_state::SharedRouterResponseGuard; pub type SubscriptionId = String; @@ -21,21 +21,8 @@ pub enum BroadcastItem { Error(Vec), } -pub struct CallbackState { - pub verifier: String, - pub last_heartbeat: Arc>, -} - -impl CallbackState { - pub fn record_heartbeat(&self) { - *self.last_heartbeat.lock().unwrap() = Instant::now(); - } -} - struct Subscription { sender: tokio::sync::broadcast::Sender, - /// The optional callback state that is only present for http callback subscriptions. - callback_state: Option, } #[derive(Clone)] @@ -61,8 +48,7 @@ impl ActiveSubscriptions { /// must hold onto while consuming. pub fn register( &self, - guard: Option>, - callback_state: Option, + guard: Option, ) -> ( ProducerHandle, tokio::sync::broadcast::Sender, @@ -75,13 +61,7 @@ impl ActiveSubscriptions { let id = Ulid::new().to_string(); - self.map.insert( - id.clone(), - Subscription { - sender, - callback_state, - }, - ); + self.map.insert(id.clone(), Subscription { sender }); let handle = ProducerHandle { id: id.clone(), @@ -104,24 +84,6 @@ impl ActiveSubscriptions { self.map.contains_key(id) } - /// get the verifier for a callback subscription - pub fn get_callback_verifier(&self, id: &str) -> Option { - self.map - .get(id) - .and_then(|entry| entry.callback_state.as_ref().map(|cs| cs.verifier.clone())) - } - - /// record a heartbeat for a callback subscription - pub fn record_heartbeat(&self, id: &str) -> bool { - if let Some(entry) = self.map.get(id) { - if let Some(ref cs) = entry.callback_state { - cs.record_heartbeat(); - return true; - } - } - false - } - /// send an event to a specific subscription's broadcast channel pub fn send_event(&self, id: &str, item: BroadcastItem) -> bool { if let Some(entry) = self.map.get(id) { @@ -147,18 +109,6 @@ impl ActiveSubscriptions { } self.map.clear(); } - - /// iterate over all subscription ids and their callback state for heartbeat enforcement - pub fn iter_callback_subscriptions( - &self, - ) -> impl Iterator>)> + '_ { - self.map.iter().filter_map(|entry| { - entry - .callback_state - .as_ref() - .map(|cs| (entry.key().clone(), cs.last_heartbeat.clone())) - }) - } } /// Held by the upstream producer (the task that reads from the subgraph). @@ -169,7 +119,7 @@ impl ActiveSubscriptions { pub struct ProducerHandle { id: SubscriptionId, map: ActiveSubscriptions, - _guard: Option>, + _guard: Option, } impl ProducerHandle { diff --git a/bin/router/src/pipeline/http_callback.rs b/bin/router/src/pipeline/http_callback.rs index a8778416a..66c13de74 100644 --- a/bin/router/src/pipeline/http_callback.rs +++ b/bin/router/src/pipeline/http_callback.rs @@ -1,9 +1,8 @@ use bytes::Bytes as BytesLib; -use hive_router_plan_executor::executors::active_subscriptions::{ - ActiveSubscriptions, BroadcastItem, -}; +use dashmap::mapref::one::Ref; use hive_router_plan_executor::executors::http_callback::{ - CALLBACK_PROTOCOL_VERSION, SUBSCRIPTION_PROTOCOL_HEADER, + CallbackMessage, CallbackSubscription, CallbackSubscriptionsMap, CALLBACK_PROTOCOL_VERSION, + SUBSCRIPTION_PROTOCOL_HEADER, }; use hive_router_plan_executor::response::graphql_error::GraphQLError; use http::StatusCode; @@ -12,6 +11,7 @@ use ntex::web::WebResponseError; use ntex::web::{self, types::Path, HttpRequest, HttpResponse}; use serde::Deserialize; use strum::EnumString; +use tokio::sync::mpsc; use tracing::{debug, error, trace, warn}; #[derive(Debug, Deserialize, EnumString)] @@ -142,15 +142,16 @@ fn validate_payload( Ok(()) } -fn handle_check(subscription_id: &str, registry: &ActiveSubscriptions) { +fn handle_check(subscription_id: &str, subscription: &Ref<'_, String, CallbackSubscription>) { trace!(subscription_id = %subscription_id, "Received check message"); - registry.record_heartbeat(subscription_id); + subscription.record_heartbeat(); } fn handle_next( subscription_id: &str, payload: &CallbackPayload<'_>, - registry: &ActiveSubscriptions, + subscription: Ref<'_, String, CallbackSubscription>, + callback_subscriptions: &CallbackSubscriptionsMap, ) -> Result<(), CallbackError> { trace!(subscription_id = %subscription_id, "Received next message"); @@ -163,38 +164,53 @@ fn handle_next( } }; - if !registry.send_event(subscription_id, BroadcastItem::Event(data)) { - debug!(subscription_id = %subscription_id, "Subscription receiver dropped"); - registry.remove(subscription_id); - return Err(CallbackError::SubscriptionDropped { - subscription_id: subscription_id.to_string(), - }); + match subscription + .sender + .try_send(CallbackMessage::Next { payload: data }) + { + Ok(()) => Ok(()), + Err(mpsc::error::TrySendError::Full(_)) => { + // if the channel is full it means the consuming client is too slow and unable to keep + // up. we terminate the subscription without an error message because it anyways cant go through + warn!(subscription_id = %subscription_id, "Subscription client is too slow"); + drop(subscription); + callback_subscriptions.remove(subscription_id); + Err(CallbackError::ClientTooSlow { + subscription_id: subscription_id.to_string(), + }) + } + Err(mpsc::error::TrySendError::Closed(_)) => { + debug!(subscription_id = %subscription_id, "Subscription receiver dropped"); + drop(subscription); + callback_subscriptions.remove(subscription_id); + Err(CallbackError::SubscriptionDropped { + subscription_id: subscription_id.to_string(), + }) + } } - - // TODO: ClientTooSlow - - Ok(()) } fn handle_complete( subscription_id: &str, payload: &CallbackPayload<'_>, - registry: &ActiveSubscriptions, + subscription: Ref<'_, String, CallbackSubscription>, + callback_subscriptions: &CallbackSubscriptionsMap, ) { trace!(subscription_id = %subscription_id, "Received complete message"); - if let Some(errors) = &payload.errors { - if !errors.is_empty() { - registry.send_event(subscription_id, BroadcastItem::Error(errors.clone())); - } - } - registry.remove(subscription_id); + // if the buffer is full or closed we ignore and remove the subscription, we dont send + // the final error message because the client is already unable to consume + let _ = subscription.sender.try_send(CallbackMessage::Complete { + errors: payload.errors.clone(), + }); + drop(subscription); + callback_subscriptions.remove(subscription_id); } pub async fn handler( req: HttpRequest, path: Path, body: Bytes, - active_subscriptions: web::types::State, + callback_subscriptions: web::types::State, ) -> Result { let subscription_id_from_path = path.into_inner(); @@ -204,30 +220,29 @@ pub async fn handler( validate_payload(&payload, &subscription_id_from_path)?; - if !active_subscriptions.contains(&payload.id) { - return Err(CallbackError::SubscriptionNotFound { - subscription_id: payload.id.clone(), - }); - } - - let verifier = active_subscriptions - .get_callback_verifier(&payload.id) - .ok_or_else(|| CallbackError::SubscriptionNotFound { - subscription_id: payload.id.clone(), - })?; + let subscription = match callback_subscriptions.get(&payload.id) { + Some(sub) => sub, + None => { + return Err(CallbackError::SubscriptionNotFound { + subscription_id: payload.id.clone(), + }); + } + }; - if verifier != payload.verifier { + if subscription.verifier != payload.verifier { return Err(CallbackError::InvalidVerifier { subscription_id: payload.id.clone(), }); } match payload.action { - CallbackAction::Check => handle_check(&payload.id, &active_subscriptions), + CallbackAction::Check => handle_check(&payload.id, &subscription), CallbackAction::Next => { - handle_next(&payload.id, &payload, &active_subscriptions)?; + handle_next(&payload.id, &payload, subscription, &callback_subscriptions)?; + } + CallbackAction::Complete => { + handle_complete(&payload.id, &payload, subscription, &callback_subscriptions) } - CallbackAction::Complete => handle_complete(&payload.id, &payload, &active_subscriptions), }; Ok(HttpResponse::NoContent() diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 2d5fc2abd..e807e0c52 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,19 +1,4 @@ -use futures::Stream; use futures::StreamExt; -use graphql_tools::parser::schema; -use std::{ - collections::HashMap, - hash::{Hash, Hasher}, - sync::Arc, - time::Instant, -}; -use tracing::{error, trace, Instrument}; -use xxhash_rust::xxh3::Xxh3; - -use hive_router_plan_executor::execution::plan::FailedExecutionResult; -use hive_router_plan_executor::executors::active_subscriptions::BroadcastItem; -use hive_router_plan_executor::headers::plan::ResponseHeaderAggregator; - use hive_router_internal::telemetry::traces::spans::{ graphql::GraphQLOperationSpan, http_request::HttpServerRequestSpan, }; @@ -34,9 +19,18 @@ use ntex::{ web::{self, HttpRequest}, }; use sonic_rs::{JsonContainerTrait, JsonType, JsonValueTrait, Value}; +use std::{ + collections::HashMap, + hash::{Hash, Hasher}, + sync::Arc, + time::Instant, +}; +use tracing::{error, Instrument}; +use xxhash_rust::xxh3::Xxh3; use crate::{ pipeline::{ + active_subscriptions::BroadcastItem, authorization::enforce_operation_authorization, body_read::read_body_stream, coerce_variables::{coerce_request_variables, CoerceVariablesPayload}, @@ -44,11 +38,8 @@ use crate::{ error::PipelineError, execution::{execute_plan, PlannedRequest}, execution_request::{deserialize_graphql_params, DeserializationResult, GetQueryStr}, - header::{RequestAccepts, ResponseMode, StreamContentType, TEXT_HTML_MIME}, + header::{RequestAccepts, ResponseMode, TEXT_HTML_MIME}, introspection_policy::handle_introspection_policy, - multipart_subscribe::{ - APOLLO_MULTIPART_HTTP_CONTENT_TYPE, INCREMENTAL_DELIVERY_CONTENT_TYPE, - }, normalize::{normalize_request_with_cache, GraphQLNormalizationPayload}, parser::{parse_operation_with_cache, ParseResult}, progressive_override::request_override_context, @@ -69,6 +60,7 @@ use crate::{ use hive_router_internal::telemetry::metrics::catalog::values::GraphQLResponseStatus; +pub mod active_subscriptions; pub mod authorization; pub mod body_read; pub mod coerce_variables; @@ -440,12 +432,7 @@ async fn execute_planned_request<'exec>( .clone(); let (producer_handle, sender, bootstrap_receiver, consumer_guard) = - schema_state.active_subscriptions.register( - // TODO: ugly AF (and unnecessary), to use actual type - we must - // remove active_subscriptions from the executors and move it elsehwerhwe - guard.map(|g| Box::new(g) as Box), - None, - ); + shared_state.active_subscriptions.register(guard); let mut body_stream = result.body; rt::spawn(async move { diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 88b910515..1daceceb3 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -47,10 +47,6 @@ use crate::pipeline::{ use crate::schema_state::SchemaState; use crate::shared_state::RouterSharedState; -use crate::shared_state::SharedRouterResponse; -use hive_router_plan_executor::execution::plan::FailedExecutionResult; -use hive_router_plan_executor::executors::active_subscriptions::BroadcastItem; - type WsStateRef = Rc>>>; pub async fn ws_index( diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index 1ed802dc3..7ee36084c 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -1,6 +1,8 @@ +use crate::pipeline::active_subscriptions::ActiveSubscriptions; use crate::pipeline::authorization::metadata::AuthorizationMetadataExt; use arc_swap::{ArcSwap, Guard}; use async_trait::async_trait; +use dashmap::DashMap; use graphql_tools::static_graphql::schema::Document; use graphql_tools::validation::utils::ValidationError; use hive_router_config::{supergraph::SupergraphSource, HiveRouterConfig}; @@ -9,8 +11,11 @@ use hive_router_internal::{ authorization::metadata::AuthorizationMetadata, background_tasks::{BackgroundTask, BackgroundTasksManager}, }; +use hive_router_plan_executor::executors::http_callback::{ + CallbackMessage, CallbackSubscriptionsMap, +}; +use hive_router_plan_executor::response::graphql_error::GraphQLErrorExtensions; use hive_router_plan_executor::{ - executors::active_subscriptions::{ActiveSubscriptions, BroadcastItem}, executors::error::SubgraphExecutorError, hooks::on_supergraph_load::{ OnSupergraphLoadEndHookPayload, OnSupergraphLoadStartHookPayload, SupergraphData, @@ -47,7 +52,7 @@ pub struct SchemaState { pub validate_cache: Cache>>, pub normalize_cache: Cache>, pub telemetry_context: Arc, - pub active_subscriptions: ActiveSubscriptions, + pub callback_subscriptions: CallbackSubscriptionsMap, } #[derive(Debug, thiserror::Error)] @@ -83,6 +88,7 @@ impl SchemaState { router_config: Arc, plugins: Option>>, cache_state: Arc, + active_subscriptions: ActiveSubscriptions, ) -> Result { let (tx, mut rx) = mpsc::channel::(1); let background_loader = SupergraphBackgroundLoader::new( @@ -97,20 +103,19 @@ impl SchemaState { let plan_cache = cache_state.plan_cache.clone(); let validate_cache = cache_state.validate_cache.clone(); let normalize_cache = cache_state.normalize_cache.clone(); - let active_subscriptions = - ActiveSubscriptions::new(router_config.subscriptions.broadcast_capacity); + let callback_subscriptions: CallbackSubscriptionsMap = Arc::new(DashMap::new()); // This is cheap clone, as Cache is thread-safe and can be cloned without any performance penalty. let cache_state_for_invalidation = cache_state.clone(); - let active_subscriptions_for_build_data = active_subscriptions.clone(); + let callback_subscriptions_for_build_data = callback_subscriptions.clone(); // kick off subscriptions/subgraphs that are idling/timed out due to missed heartbeats if let Some(ref callback_config) = router_config.subscriptions.callback { if !callback_config.heartbeat_interval.is_zero() { - let enforcer_subs = active_subscriptions.clone(); + let enforcer_subs = callback_subscriptions.clone(); let heartbeat_interval = callback_config.heartbeat_interval; - bg_tasks_manager.register_task(HeartbeatEnforcerTask { - active_subscriptions: enforcer_subs, + bg_tasks_manager.register_task(CallbackHeartbeatEnforcerTask { + callback_subscriptions: enforcer_subs, heartbeat_interval, }); } @@ -165,7 +170,7 @@ impl SchemaState { router_config.clone(), task_telemetry.clone(), new_ast, - active_subscriptions_for_build_data.clone(), + callback_subscriptions_for_build_data.clone(), ) }) { Ok(mut new_supergraph_data) => { @@ -208,7 +213,7 @@ impl SchemaState { // close all active subscriptions before swapping supergraph data active_subscriptions_for_reload.close_all_with_error(vec![ // this is litearaly the same message apollo sends - reasoning is - // drop-in-replacement - is that oke? should we have our own? + // drop in replacement - is that oke? should we have our own? GraphQLError::from_message_and_code( "subscription has been closed due to a schema reload", "SUBSCRIPTION_SCHEMA_RELOAD", @@ -236,7 +241,7 @@ impl SchemaState { validate_cache, normalize_cache, telemetry_context: telemetry_context.clone(), - active_subscriptions, + callback_subscriptions, }) } @@ -244,7 +249,7 @@ impl SchemaState { router_config: Arc, telemetry_context: Arc, parsed_supergraph_sdl: Document, - active_subscriptions: ActiveSubscriptions, + callback_subscriptions: CallbackSubscriptionsMap, ) -> Result { let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; let metadata = Arc::new(planner.consumer_schema.schema_metadata()); @@ -253,7 +258,7 @@ impl SchemaState { &planner.supergraph.subgraph_endpoint_map, router_config, telemetry_context, - active_subscriptions, + callback_subscriptions, )?); Ok(SupergraphData { @@ -344,13 +349,13 @@ impl BackgroundTask for SupergraphBackgroundLoaderTask { } } -struct HeartbeatEnforcerTask { - active_subscriptions: ActiveSubscriptions, +struct CallbackHeartbeatEnforcerTask { + callback_subscriptions: CallbackSubscriptionsMap, heartbeat_interval: Duration, } #[async_trait] -impl BackgroundTask for HeartbeatEnforcerTask { +impl BackgroundTask for CallbackHeartbeatEnforcerTask { fn id(&self) -> &str { "http-callback-heartbeat-enforcer" } @@ -368,14 +373,14 @@ impl BackgroundTask for HeartbeatEnforcerTask { } let mut timed_out = Vec::new(); - for (id, last_heartbeat) in self.active_subscriptions.iter_callback_subscriptions() { - let last = *last_heartbeat.lock().unwrap(); + for entry in self.callback_subscriptions.iter() { + let last = *entry.value().last_heartbeat.lock().unwrap(); if Instant::now().duration_since(last) > self.heartbeat_interval + // add a grace period if latency increases due to usage std::time::Duration::from_millis(500) { - timed_out.push(id); + timed_out.push(entry.key().clone()); } } @@ -383,16 +388,18 @@ impl BackgroundTask for HeartbeatEnforcerTask { for id in timed_out { debug!( subscription_id = %id, - "terminating subscription due to missed heartbeat" + "terminating subscription due to http callback subgraph missed heartbeat" ); - self.active_subscriptions.send_event( - &id, - BroadcastItem::Error(vec![GraphQLError::from_message_and_code( - "Subgraph gone due to heartbeat timeout".to_string(), - "SUBGRAPH_GONE", - )]), - ); - self.active_subscriptions.remove(&id); + if let Some((_, sub)) = self.callback_subscriptions.remove(&id) { + // we dont care about the result of this send, if it fails it means the client + // is already gone or too slow, either way we just terminate the subscription + let _ = sub.sender.try_send(CallbackMessage::Complete { + errors: Some(vec![GraphQLError::from_message_and_extensions( + "Subgraph gone due to heartbeat timeout".to_string(), + GraphQLErrorExtensions::new_from_code("SUBGRAPH_GONE"), + )]), + }); + } } } } diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 51ebac334..1ac586d2b 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -27,6 +27,7 @@ use tracing::trace; use crate::cache_state::CacheState; use crate::jwt::context::JwtTokenPayload; use crate::jwt::JwtAuthRuntime; +use crate::pipeline::active_subscriptions::{ActiveSubscriptions, BroadcastItem}; use crate::pipeline::cors::{CORSConfigError, Cors}; use crate::pipeline::header::{SingleContentType, StreamContentType}; use crate::pipeline::introspection_policy::compile_introspection_policy; @@ -36,7 +37,6 @@ use crate::pipeline::multipart_subscribe::{ use crate::pipeline::parser::ParseCacheEntry; use crate::pipeline::progressive_override::{OverrideLabelsCompileError, OverrideLabelsEvaluator}; use crate::pipeline::sse; -use hive_router_plan_executor::executors::active_subscriptions::BroadcastItem; pub type JwtClaimsCache = Cache>; pub type RouterInflightRequestsMap = InFlightMap; @@ -296,8 +296,10 @@ pub struct RouterSharedState { pub plugins: Option>>, pub in_flight_requests: RouterInflightRequestsMap, pub in_flight_requests_header_policy: RouterRequestDedupeHeaderPolicy, - /// tracks the number of active long-lived clients (websockets + http streams) + /// Tracks the number of active long-lived clients (websockets + http streams) pub long_lived_client_count: Arc, + /// Tracks all active subscriptions from clients to the router. + pub active_subscriptions: ActiveSubscriptions, } impl RouterSharedState { @@ -309,6 +311,7 @@ impl RouterSharedState { telemetry_context: Arc, plugins: Option>>, cache_state: Arc, + active_subscriptions: ActiveSubscriptions, ) -> Result { let parse_cache = cache_state.parse_cache.clone(); Ok(Self { @@ -341,6 +344,7 @@ impl RouterSharedState { .headers) .into(), long_lived_client_count: Arc::new(AtomicUsize::new(0)), + active_subscriptions, }) } } diff --git a/e2e/src/testkit/mod.rs b/e2e/src/testkit/mod.rs index ad5632610..37461fc2d 100644 --- a/e2e/src/testkit/mod.rs +++ b/e2e/src/testkit/mod.rs @@ -707,7 +707,7 @@ impl TestRouter { let serv_shared_state = shared_state.clone(); let serv_schema_state = schema_state.clone(); - let serv_active_subs = schema_state.active_subscriptions.clone(); + let serv_callback_subs = schema_state.callback_subscriptions.clone(); let serv_graphql_path = self.graphql_path.clone(); let serv_websocket_path = self.websocket_path.clone(); @@ -721,13 +721,13 @@ impl TestRouter { }) => { let cb_path = path.to_string(); let cb_addr = listen.to_string(); - let cb_active_subs = schema_state.active_subscriptions.clone(); + let cb_subs = schema_state.callback_subscriptions.clone(); let server = web::HttpServer::new(async move || { - let active_subs = cb_active_subs.clone(); + let cb_subs = cb_subs.clone(); let cb_path = cb_path.clone(); web::App::new() - .state(active_subs) + .state(cb_subs) .configure(move |m| add_callback_handler(m, &cb_path)) }) .bind(&cb_addr) @@ -761,7 +761,7 @@ impl TestRouter { let paths = serv_paths.clone(); let prometheus = serv_prometheus.clone(); let serv_callback_path = serv_callback_path.clone(); - let active_subs = serv_active_subs.clone(); + let callback_subs = serv_callback_subs.clone(); // set the tracing dispatch on the server thread. the guard is // intentionally leaked: dropping it would restore the no-op default @@ -778,7 +778,7 @@ impl TestRouter { .middleware(PluginService) .state(shared_state) .state(schema_state) - .state(active_subs) + .state(callback_subs) .configure(|m| configure_ntex_app(m, &paths, prometheus)) .configure(|m| { if let Some(ref callback) = serv_callback_path { diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index 3dcf87624..667515074 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -62,16 +62,3 @@ impl SubgraphExecutionRequest<'_> { .insert(key, value); } } - -// the channel capacity for buffering subscription events between websockets or the callback -// handler and the stream consumer. back-pressure flows like this: -// -// ntex h1 dispatcher (send to client) -// poll_flush() blocks when TCP send buffer is full (slow client) -// poll_next_chunk() is not called until flush completes -// rx.recv() in the async_stream is not polled -// channel fills up -// -// so this bound only triggers when the client reading from the router is too slow, hence -// backpressure comes from the client itself, not the router -pub const SUBSCRIPTION_EVENT_BUFFER_CAPACITY: usize = 256; diff --git a/lib/executor/src/executors/http_callback.rs b/lib/executor/src/executors/http_callback.rs index fa6d75d9f..66afe761e 100644 --- a/lib/executor/src/executors/http_callback.rs +++ b/lib/executor/src/executors/http_callback.rs @@ -13,9 +13,7 @@ use tokio::sync::mpsc; use tracing::{debug, error, trace}; use ulid::Ulid; -use crate::executors::common::{ - SubgraphExecutionRequest, SubgraphExecutor, SUBSCRIPTION_EVENT_BUFFER_CAPACITY, -}; +use crate::executors::common::{SubgraphExecutionRequest, SubgraphExecutor}; use crate::executors::error::SubgraphExecutorError; use crate::executors::http::{build_request_body, HttpClient}; use crate::plugin_context::PluginRequestState; @@ -25,16 +23,14 @@ use crate::response::subgraph_response::SubgraphResponse; pub const CALLBACK_PROTOCOL_VERSION: &str = "callback/1.0"; pub const SUBSCRIPTION_PROTOCOL_HEADER: &str = "subscription-protocol"; -type SubscriptionId = String; - #[derive(Clone)] -pub struct ActiveSubscription { +pub struct CallbackSubscription { pub verifier: String, pub sender: mpsc::Sender, pub last_heartbeat: Arc>, } -impl ActiveSubscription { +impl CallbackSubscription { pub fn record_heartbeat(&self) { *self.last_heartbeat.lock().unwrap() = Instant::now(); } @@ -46,16 +42,16 @@ pub enum CallbackMessage { Complete { errors: Option> }, } -pub type ActiveSubscriptionsMap = Arc>; +pub type CallbackSubscriptionsMap = Arc>; -struct SubscriptionGuard { - subscription_id: SubscriptionId, - active_subscriptions: ActiveSubscriptionsMap, +struct CallbackSubscriptionGuard { + subscription_id: String, + callback_subscriptions: CallbackSubscriptionsMap, } -impl Drop for SubscriptionGuard { +impl Drop for CallbackSubscriptionGuard { fn drop(&mut self) { - self.active_subscriptions.remove(&self.subscription_id); + self.callback_subscriptions.remove(&self.subscription_id); trace!(subscription_id = %self.subscription_id, "HTTP callback subscription entry removed from active subscriptions"); } } @@ -67,7 +63,7 @@ pub struct HttpCallbackSubgraphExecutor { pub header_map: HeaderMap, pub callback_base_url: String, pub heartbeat_interval_ms: u64, - pub active_subscriptions: ActiveSubscriptionsMap, + pub active_subscriptions: CallbackSubscriptionsMap, } impl HttpCallbackSubgraphExecutor { @@ -77,7 +73,7 @@ impl HttpCallbackSubgraphExecutor { http_client: Arc, callback_base_url: String, heartbeat_interval_ms: u64, - active_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: CallbackSubscriptionsMap, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -159,10 +155,15 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { let body = self.build_request_body(&mut execution_request, &subscription_id, &verifier)?; - let (tx, mut rx) = mpsc::channel::(SUBSCRIPTION_EVENT_BUFFER_CAPACITY); + // all subscriptions emit events into the shared active subscriptions broadcaster + // which itself handles back-pressure by dropping old events when the buffer is full, + // so we can use a small buffer here + // TODO: do we thererefore need to buffer at all? + let (tx, mut rx) = mpsc::channel::(16); + self.active_subscriptions.insert( subscription_id.clone(), - ActiveSubscription { + CallbackSubscription { verifier, sender: tx, // initialize last_heartbeat to now + heartbeat_interval so the enforcer @@ -177,9 +178,9 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { ); // guard removes the entry from `active_subscriptions` when dropped - let guard = SubscriptionGuard { + let guard = CallbackSubscriptionGuard { subscription_id: subscription_id.clone(), - active_subscriptions: self.active_subscriptions.clone(), + callback_subscriptions: self.active_subscriptions.clone(), }; let mut req = hyper::Request::builder() diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 6e04d575c..56eef932f 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -30,7 +30,7 @@ use crate::{ common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient, SubgraphHttpResponse}, - http_callback::{ActiveSubscriptionsMap, HttpCallbackSubgraphExecutor}, + http_callback::{CallbackSubscriptionsMap, HttpCallbackSubgraphExecutor}, websocket::WsSubgraphExecutor, }, hooks::on_subgraph_execute::{ @@ -75,7 +75,7 @@ pub struct SubgraphExecutorMap { in_flight_requests: InflightRequestsMap, telemetry_context: Arc, /// Shared map of active HTTP callback subscriptions - active_callback_subscriptions: ActiveSubscriptionsMap, + callback_subscriptions: CallbackSubscriptionsMap, } fn build_https_executor() -> Result, SubgraphExecutorError> { @@ -112,7 +112,7 @@ impl SubgraphExecutorMap { timeouts_by_subgraph: Default::default(), global_timeout, telemetry_context, - active_callback_subscriptions: Arc::new(DashMap::new()), + callback_subscriptions: Arc::new(DashMap::new()), }) } @@ -120,7 +120,7 @@ impl SubgraphExecutorMap { subgraph_endpoint_map: &HashMap, config: Arc, telemetry_context: Arc, - active_callback_subscriptions: ActiveSubscriptionsMap, + active_callback_subscriptions: CallbackSubscriptionsMap, ) -> Result { let global_timeout = DurationOrProgram::compile( &config.traffic_shaping.all.request_timeout, @@ -131,7 +131,7 @@ impl SubgraphExecutorMap { })?; let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), global_timeout, telemetry_context)?; - subgraph_executor_map.active_callback_subscriptions = active_callback_subscriptions; + subgraph_executor_map.callback_subscriptions = active_callback_subscriptions; for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.iter() { let endpoint_config = config @@ -157,8 +157,8 @@ impl SubgraphExecutorMap { } /// Returns the shared active callback subscriptions map for use by callback handlers. - pub fn active_callback_subscriptions(&self) -> ActiveSubscriptionsMap { - self.active_callback_subscriptions.clone() + pub fn callback_subscriptions(&self) -> CallbackSubscriptionsMap { + self.callback_subscriptions.clone() } pub async fn execute<'exec>( @@ -505,7 +505,7 @@ impl SubgraphExecutorMap { self.client.clone(), callback_config.public_url.to_string(), heartbeat_interval_ms, - self.active_callback_subscriptions.clone(), + self.callback_subscriptions.clone(), ) .to_boxed_arc(); diff --git a/lib/executor/src/executors/mod.rs b/lib/executor/src/executors/mod.rs index efc39c23d..283e331e3 100644 --- a/lib/executor/src/executors/mod.rs +++ b/lib/executor/src/executors/mod.rs @@ -1,4 +1,3 @@ -pub mod active_subscriptions; pub mod common; pub mod dedupe; pub mod error; diff --git a/lib/executor/src/executors/websocket.rs b/lib/executor/src/executors/websocket.rs index 4e1f53c07..4509e0eb9 100644 --- a/lib/executor/src/executors/websocket.rs +++ b/lib/executor/src/executors/websocket.rs @@ -7,9 +7,7 @@ use ntex::rt::Arbiter; use tokio::sync::mpsc; use tracing::{debug, warn}; -use crate::executors::common::{ - SubgraphExecutionRequest, SubgraphExecutor, SUBSCRIPTION_EVENT_BUFFER_CAPACITY, -}; +use crate::executors::common::{SubgraphExecutionRequest, SubgraphExecutor}; use crate::executors::error::SubgraphExecutorError; use crate::executors::graphql_transport_ws::build_subscribe_payload; use crate::executors::websocket_client::{connect, WsClient}; @@ -114,9 +112,12 @@ impl SubgraphExecutor for WsSubgraphExecutor { BoxStream<'static, Result, SubgraphExecutorError>>, SubgraphExecutorError, > { - let (tx, mut rx) = mpsc::channel::, SubgraphExecutorError>>( - SUBSCRIPTION_EVENT_BUFFER_CAPACITY, - ); + // all subscriptions emit events into the shared active subscriptions broadcaster + // which itself handles back-pressure by dropping old events when the buffer is full, + // so we can use a small buffer here + // TODO: do we thererefore need to buffer at all? + let (tx, mut rx) = + mpsc::channel::, SubgraphExecutorError>>(16); let endpoint = self.endpoint.clone(); let subgraph_name = self.subgraph_name.clone(); From 4fbaa848ae0b4fdd62abe367593f0e314ff7c16a Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Fri, 3 Apr 2026 23:55:24 +0200 Subject: [PATCH 20/42] lets talk naming --- .../src/pipeline/active_subscriptions.rs | 29 ++++++++----------- bin/router/src/pipeline/mod.rs | 4 +-- bin/router/src/shared_state.rs | 10 +++---- 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/bin/router/src/pipeline/active_subscriptions.rs b/bin/router/src/pipeline/active_subscriptions.rs index 212f2176b..69acb3865 100644 --- a/bin/router/src/pipeline/active_subscriptions.rs +++ b/bin/router/src/pipeline/active_subscriptions.rs @@ -12,17 +12,17 @@ use crate::shared_state::SharedRouterResponseGuard; pub type SubscriptionId = String; #[derive(Clone, Debug)] -pub enum BroadcastItem { +pub enum SubscriptionEvent { /// A normal subscription event from the upstream, already serialized. /// Uses Bytes for zero-copy cloning across broadcast receivers. - Event(Bytes), + Raw(Bytes), /// An error pushed externally (e.g. supergraph reload, shutdown). /// Consumers should yield this as the final event and then stop. Error(Vec), } struct Subscription { - sender: tokio::sync::broadcast::Sender, + sender: tokio::sync::broadcast::Sender, } #[derive(Clone)] @@ -51,8 +51,8 @@ impl ActiveSubscriptions { guard: Option, ) -> ( ProducerHandle, - tokio::sync::broadcast::Sender, - tokio::sync::broadcast::Receiver, + tokio::sync::broadcast::Sender, + tokio::sync::broadcast::Receiver, ConsumerGuard, ) { let listener_count = Arc::new(AtomicUsize::new(1)); @@ -65,7 +65,7 @@ impl ActiveSubscriptions { let handle = ProducerHandle { id: id.clone(), - map: self.clone(), + subscriptions: self.clone(), _guard: guard, }; let guard = ConsumerGuard { @@ -79,13 +79,8 @@ impl ActiveSubscriptions { (handle, sender_clone, receiver, guard) } - /// check if a subscription exists - pub fn contains(&self, id: &str) -> bool { - self.map.contains_key(id) - } - /// send an event to a specific subscription's broadcast channel - pub fn send_event(&self, id: &str, item: BroadcastItem) -> bool { + pub fn send(&self, id: &str, item: SubscriptionEvent) -> bool { if let Some(entry) = self.map.get(id) { // if the channel is closed or full it means the consuming client is gone or too slow and // unable to keep up. in both cases, we dont emit an error messages because it anyways cant @@ -103,7 +98,7 @@ impl ActiveSubscriptions { /// close all active subscriptions with an error message pub fn close_all_with_error(&self, errors: Vec) { - let item = BroadcastItem::Error(errors); + let item = SubscriptionEvent::Error(errors); for entry in self.map.iter() { let _ = entry.sender.send(item.clone()); } @@ -118,7 +113,7 @@ impl ActiveSubscriptions { /// and their streams will end naturally. pub struct ProducerHandle { id: SubscriptionId, - map: ActiveSubscriptions, + subscriptions: ActiveSubscriptions, _guard: Option, } @@ -127,8 +122,8 @@ impl ProducerHandle { &self.id } - pub fn send(&self, item: BroadcastItem) -> bool { - self.map.send_event(&self.id, item) + pub fn send(&self, item: SubscriptionEvent) -> bool { + self.subscriptions.send(&self.id, item) } } @@ -136,7 +131,7 @@ impl Drop for ProducerHandle { fn drop(&mut self) { // removing the entry drops the broadcast sender inside it, closing the channel. // all receivers will see Closed and their streams will end naturally - self.map.remove(&self.id); + self.subscriptions.remove(&self.id); trace!(subscription_id = %self.id, "subscription handle dropped, upstream closed"); } } diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index e807e0c52..7d4faf9b8 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -30,7 +30,7 @@ use xxhash_rust::xxh3::Xxh3; use crate::{ pipeline::{ - active_subscriptions::BroadcastItem, + active_subscriptions::SubscriptionEvent, authorization::enforce_operation_authorization, body_read::read_body_stream, coerce_variables::{coerce_request_variables, CoerceVariablesPayload}, @@ -441,7 +441,7 @@ async fn execute_planned_request<'exec>( // from active_subscriptions before the producer has a chance to send let _consumer_guard = consumer_guard; while let Some(chunk) = body_stream.next().await { - if !producer_handle.send(BroadcastItem::Event(bytes::Bytes::from(chunk))) { + if !producer_handle.send(SubscriptionEvent::Raw(bytes::Bytes::from(chunk))) { // all receivers gone, stop draining break; } diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 1ac586d2b..9b2e64f83 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -27,7 +27,7 @@ use tracing::trace; use crate::cache_state::CacheState; use crate::jwt::context::JwtTokenPayload; use crate::jwt::JwtAuthRuntime; -use crate::pipeline::active_subscriptions::{ActiveSubscriptions, BroadcastItem}; +use crate::pipeline::active_subscriptions::{ActiveSubscriptions, SubscriptionEvent}; use crate::pipeline::cors::{CORSConfigError, Cors}; use crate::pipeline::header::{SingleContentType, StreamContentType}; use crate::pipeline::introspection_policy::compile_introspection_policy; @@ -134,7 +134,7 @@ impl From for web::HttpResponse { pub struct SharedRouterStreamResponse { // status is always 200 for streaming responses, errors are sent through the stream - pub body: tokio::sync::broadcast::Sender, + pub body: tokio::sync::broadcast::Sender, pub headers: Arc, pub stream_content_type: StreamContentType, pub error_count: usize, @@ -152,7 +152,7 @@ pub struct SharedRouterStreamResponse { // guarantees no event is sent to an empty channel and no event is missed. None on clones // because joiners are already late subscribers and create their own receiver when they // call body.subscribe(). - pub bootstrap_receiver: Option>, + pub bootstrap_receiver: Option>, } impl Clone for SharedRouterStreamResponse { @@ -178,10 +178,10 @@ impl From for web::HttpResponse { let stream = Box::pin(async_stream::stream! { loop { match receiver.recv().await { - Ok(BroadcastItem::Event(data)) => { + Ok(SubscriptionEvent::Raw(data)) => { yield data.to_vec(); } - Ok(BroadcastItem::Error(errors)) => { + Ok(SubscriptionEvent::Error(errors)) => { yield FailedExecutionResult { errors }.serialize(); break; } From b428344be73a14cbc1afd7f6627cb1ff5fafd035 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 00:20:54 +0200 Subject: [PATCH 21/42] thats right keep that thang simple --- .../src/pipeline/active_subscriptions.rs | 112 ++++-------------- bin/router/src/pipeline/mod.rs | 14 +-- bin/router/src/shared_state.rs | 31 ++--- 3 files changed, 42 insertions(+), 115 deletions(-) diff --git a/bin/router/src/pipeline/active_subscriptions.rs b/bin/router/src/pipeline/active_subscriptions.rs index 69acb3865..f51d18c64 100644 --- a/bin/router/src/pipeline/active_subscriptions.rs +++ b/bin/router/src/pipeline/active_subscriptions.rs @@ -1,9 +1,9 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use bytes::Bytes; use dashmap::DashMap; use hive_router_plan_executor::response::graphql_error::GraphQLError; +use tokio::sync::broadcast; use tracing::trace; use ulid::Ulid; @@ -21,16 +21,9 @@ pub enum SubscriptionEvent { Error(Vec), } -struct Subscription { - sender: tokio::sync::broadcast::Sender, -} - #[derive(Clone)] pub struct ActiveSubscriptions { - // map of subscription ids to their sender - map: Arc>, - // capacity of the broadcast channel per subscription, - // see router config `subscriptions.broadcast_capacity` + map: Arc>>, broadcast_capacity: usize, } @@ -42,119 +35,64 @@ impl ActiveSubscriptions { } } - /// Register a new subscription to be used for broadcasting events to consuming clients. - /// Returns a handle for the producer to send events, a sender that can be cloned to subscribe - /// new receivers, a receiver for consumers to subscribe to, and a guard that each consumer - /// must hold onto while consuming. + /// Register a new active subscription. Returns a producer handle for the upstream pump + /// and a pre-subscribed receiver for the leader consumer. The pump task owns the handle + /// for the full lifetime of the upstream stream - when the handle drops (pump done or all + /// receivers gone) the broadcast channel closes and all consumer receivers terminate. pub fn register( &self, guard: Option, - ) -> ( - ProducerHandle, - tokio::sync::broadcast::Sender, - tokio::sync::broadcast::Receiver, - ConsumerGuard, - ) { - let listener_count = Arc::new(AtomicUsize::new(1)); - let (sender, receiver) = tokio::sync::broadcast::channel(self.broadcast_capacity); - let sender_clone = sender.clone(); - + ) -> (ProducerHandle, broadcast::Receiver) { + let (sender, receiver) = broadcast::channel(self.broadcast_capacity); let id = Ulid::new().to_string(); - - self.map.insert(id.clone(), Subscription { sender }); + self.map.insert(id.clone(), sender.clone()); let handle = ProducerHandle { id: id.clone(), - subscriptions: self.clone(), + map: self.map.clone(), + sender, _guard: guard, }; - let guard = ConsumerGuard { - id: id.clone(), - map: self.clone(), - listener_count, - }; trace!(subscription_id = %id, "registered new subscription"); - (handle, sender_clone, receiver, guard) + (handle, receiver) } - /// send an event to a specific subscription's broadcast channel - pub fn send(&self, id: &str, item: SubscriptionEvent) -> bool { - if let Some(entry) = self.map.get(id) { - // if the channel is closed or full it means the consuming client is gone or too slow and - // unable to keep up. in both cases, we dont emit an error messages because it anyways cant - // go through - entry.sender.send(item).is_ok() - } else { - false - } - } - - /// remove a subscription entry - pub fn remove(&self, id: &str) { - self.map.remove(id); - } - - /// close all active subscriptions with an error message + /// Close all active subscriptions with an error and clear the registry. pub fn close_all_with_error(&self, errors: Vec) { let item = SubscriptionEvent::Error(errors); for entry in self.map.iter() { - let _ = entry.sender.send(item.clone()); + let _ = entry.send(item.clone()); } self.map.clear(); } } -/// Held by the upstream producer (the task that reads from the subgraph). -/// It is the actual subscription handle that can be used to send events to consumers. -/// Dropping this removes the subscription entry from the registry, which drops -/// the broadcast sender and closes the channel. All receivers will see `Closed` -/// and their streams will end naturally. +/// Held by the upstream pump task for the full lifetime of the stream. Dropping it removes +/// the subscription from the registry, closes the broadcast channel, and drops the inflight +/// cleanup guard - which removes the dedupe entry so new requests start a fresh upstream. pub struct ProducerHandle { id: SubscriptionId, - subscriptions: ActiveSubscriptions, + map: Arc>>, + sender: broadcast::Sender, _guard: Option, } impl ProducerHandle { - pub fn id(&self) -> &str { - &self.id + pub fn sender(&self) -> &broadcast::Sender { + &self.sender } + /// Returns false when all consumers have gone and the event cannot be delivered. pub fn send(&self, item: SubscriptionEvent) -> bool { - self.subscriptions.send(&self.id, item) + self.sender.send(item).is_ok() } } impl Drop for ProducerHandle { fn drop(&mut self) { - // removing the entry drops the broadcast sender inside it, closing the channel. - // all receivers will see Closed and their streams will end naturally - self.subscriptions.remove(&self.id); - trace!(subscription_id = %self.id, "subscription handle dropped, upstream closed"); - } -} - -/// Held by each consumer of a subscription (producer). On drop, decrements the listener count. -/// When the last guard drops and the subscription entry still exists (upstream hasn't dropped -/// yet), removes it - causing the upstream producer's `send()` to return `false` and exit. -pub struct ConsumerGuard { - id: SubscriptionId, - map: ActiveSubscriptions, - listener_count: Arc, -} - -impl Drop for ConsumerGuard { - fn drop(&mut self) { - let prev = self.listener_count.fetch_sub(1, Ordering::AcqRel); - if prev == 1 { - // last listener gone, clean up. this also drops the sender, - // causing the upstream producer's send() to return false - self.map.map.remove(&self.id); - trace!(subscription_id = %self.id, "last listener dropped, subscription removed"); - } else { - trace!(subscription_id = %self.id, remaining = prev - 1, "listener dropped"); - } + self.map.remove(&self.id); + trace!(subscription_id = %self.id, "producer dropped, upstream closed"); } } diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 7d4faf9b8..fa35cb237 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -431,15 +431,15 @@ async fn execute_planned_request<'exec>( .ok_or(PipelineError::SubscriptionsTransportNotSupported)? .clone(); - let (producer_handle, sender, bootstrap_receiver, consumer_guard) = - shared_state.active_subscriptions.register(guard); + let (producer_handle, receiver) = shared_state.active_subscriptions.register(guard); + + // subscribe the sender before spawning the pump so the channel always has + // at least one receiver - prevents events from being lost in the window + // between spawn and the consumer calling subscribe() + let sender = producer_handle.sender().clone(); let mut body_stream = result.body; rt::spawn(async move { - // keep consumer_guard alive for the duration of the pump so it doesn't - // decrement the listener count to 0 and remove the subscription entry - // from active_subscriptions before the producer has a chance to send - let _consumer_guard = consumer_guard; while let Some(chunk) = body_stream.next().await { if !producer_handle.send(SubscriptionEvent::Raw(bytes::Bytes::from(chunk))) { // all receivers gone, stop draining @@ -462,7 +462,7 @@ async fn execute_planned_request<'exec>( headers, stream_content_type, error_count: result.error_count, - bootstrap_receiver: Some(bootstrap_receiver), + receiver: Some(receiver), })) } QueryPlanExecutionResult::Single(result) => { diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 9b2e64f83..91f7f6ce6 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -138,21 +138,10 @@ pub struct SharedRouterStreamResponse { pub headers: Arc, pub stream_content_type: StreamContentType, pub error_count: usize, - // only set for the leader (the request that actually opens the upstream). the pump task - // is spawned before this response is returned to the caller, so there is a window between - // the spawn and the leader eventually calling body.subscribe() where the pump could send - // events to a channel with no receivers. broadcast does buffer sent events in its internal - // ring buffer, but a receiver created via subscribe() only sees events sent after it was - // created - it cannot read anything already in the buffer. so even though the events are - // technically buffered, the real consumer would miss them. worse, if there are zero - // receivers at send time, broadcast returns an error and the event is gone entirely, not - // even buffered. this receiver is created before the pump spawns, keeping the receiver - // count above zero during that window. it is dropped inside From - // only after the real consumer receiver has been created via body.subscribe(), which - // guarantees no event is sent to an empty channel and no event is missed. None on clones - // because joiners are already late subscribers and create their own receiver when they - // call body.subscribe(). - pub bootstrap_receiver: Option>, + // the leader gets the receiver that was subscribed before the pump was spawned, so + // there is no window where the channel has zero receivers and events can be lost. + // joiners get None and subscribe via body.subscribe() when consumed. + pub receiver: Option>, } impl Clone for SharedRouterStreamResponse { @@ -162,18 +151,18 @@ impl Clone for SharedRouterStreamResponse { headers: self.headers.clone(), stream_content_type: self.stream_content_type.clone(), error_count: self.error_count, - bootstrap_receiver: None, + receiver: None, } } } impl From for web::HttpResponse { fn from(shared_response: SharedRouterStreamResponse) -> Self { - let mut receiver = shared_response.body.subscribe(); - - // drop the bootstrap receiver only after the real consumer receiver is created above, - // closing the gap where events could be lost between pump spawn and subscribe() - drop(shared_response.bootstrap_receiver); + // leader already has a pre-subscribed receiver to avoid missing + // any potential events emitted. joiners, on the other hand, subscribe + let mut receiver = shared_response + .receiver + .unwrap_or_else(|| shared_response.body.subscribe()); let stream = Box::pin(async_stream::stream! { loop { From 37f4a28f77eb52d6cd1b8e17306cffbb73e82116 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 00:26:42 +0200 Subject: [PATCH 22/42] promotion --- e2e/src/subscriptions.rs | 92 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/e2e/src/subscriptions.rs b/e2e/src/subscriptions.rs index bb0e213df..f66bdea3b 100644 --- a/e2e/src/subscriptions.rs +++ b/e2e/src/subscriptions.rs @@ -1452,4 +1452,96 @@ mod subscriptions_e2e_tests { "Expected requests to reviews subgraph to be deduplicated" ); } + + #[ntex::test] + async fn active_subscriptions_deduplication_promotion() { + use futures::StreamExt; + + let subgraphs = TestSubgraphs::builder().build().start().await; + let router = TestRouter::builder() + .with_subgraphs(&subgraphs) + .inline_config( + r#" + supergraph: + source: file + path: supergraph.graphql + subscriptions: + enabled: true + traffic_shaping: + router: + dedupe: + enabled: true + "#, + ) + .build() + .start() + .await; + + let query = r#" + subscription { + reviewAdded(intervalInMs: 100) { + id + } + } + "#; + let headers = some_header_map! { + http::header::ACCEPT => "text/event-stream" + }; + + let mut sub1 = router + .send_graphql_request(query, None, headers.clone()) + .await; + + assert!(sub1.status().is_success(), "Expected 200 OK"); + + // consume 2 events from sub1 to let the source stream advance + let chunk = sub1.next().await.unwrap().unwrap(); + assert!( + std::str::from_utf8(&chunk).unwrap().contains(r#""id":"1""#), + "Expected first event to be id=1" + ); + let chunk = sub1.next().await.unwrap().unwrap(); + assert!( + std::str::from_utf8(&chunk).unwrap().contains(r#""id":"2""#), + "Expected second event to be id=2" + ); + let chunk = sub1.next().await.unwrap().unwrap(); + assert!( + std::str::from_utf8(&chunk).unwrap().contains(r#""id":"3""#), + "Expected third event to be id=3" + ); + + // subscribe again with the same query - dedup promotes sub2 onto the live source + let sub2 = router + .send_graphql_request(query, None, headers.clone()) + .await; + + assert!(sub2.status().is_success(), "Expected 200 OK"); + + // drop sub1 now that sub2 is connected; sub2 must become the active subscriber + drop(sub1); + + // sub2 should receive the remainder of the stream from where the source left off + let body = sub2.string_body().await; + assert!( + body.contains("event: next") && body.contains("event: complete"), + "Expected sub2 to receive remaining events and complete, got: {body}" + ); + + // sub2 must not have received the first 3 events that were already consumed by sub1 + assert!( + !body.contains(r#""id":"1""#) + && !body.contains(r#""id":"2""#) + && !body.contains(r#""id":"3""#), + "Expected sub2 to not replay events already consumed by sub1, got: {body}" + ); + + // only one subgraph request should have been made + let reviews_requests = subgraphs.get_requests_log("reviews").unwrap_or_default(); + assert_eq!( + reviews_requests.len(), + 1, + "Expected requests to reviews subgraph to be deduplicated" + ); + } } From 6c2d8471f364f436b38b50852503cc1697b3b3bc Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 00:46:43 +0200 Subject: [PATCH 23/42] WS server dedupe start hehe --- bin/router/src/pipeline/mod.rs | 29 ++- bin/router/src/pipeline/websocket_server.rs | 254 ++++++-------------- 2 files changed, 88 insertions(+), 195 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index fa35cb237..8b648c3a1 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -15,6 +15,7 @@ use hive_router_query_planner::{ }; use http::{header::CONTENT_TYPE, Method}; use ntex::{ + http::HeaderMap, rt, web::{self, HttpRequest}, }; @@ -288,7 +289,9 @@ pub async fn graphql_request_handler( .claim(fp) .get_or_try_init(|guard| async { execute_planned_request( - req, + req.method(), + req.uri(), + req.headers(), graphql_params, &normalize_payload, supergraph, @@ -305,7 +308,9 @@ pub async fn graphql_request_handler( Arc::unwrap_or_clone(planned_response) } else { execute_planned_request( - req, + req.method(), + req.uri(), + req.headers(), graphql_params, &normalize_payload, supergraph, @@ -364,8 +369,10 @@ pub async fn graphql_request_handler( } #[allow(clippy::too_many_arguments)] -async fn execute_planned_request<'exec>( - req: &'exec HttpRequest, +pub async fn execute_planned_request<'exec>( + method: &'exec Method, + url: &'exec http::Uri, + headers: &'exec HeaderMap, mut graphql_params: GraphQLParams, normalize_payload: &Arc, supergraph: &'exec SupergraphData, @@ -378,7 +385,7 @@ async fn execute_planned_request<'exec>( ) -> Result { let jwt_request_details = match &shared_state.jwt_auth_runtime { Some(jwt_auth_runtime) => match jwt_auth_runtime - .validate_headers(req.headers(), &shared_state.jwt_claims_cache) + .validate_headers(headers, &shared_state.jwt_claims_cache) .await? { Some(jwt_context) => JwtRequestDetails::Authenticated { @@ -396,9 +403,9 @@ async fn execute_planned_request<'exec>( coerce_request_variables(supergraph, &mut graphql_params.variables, normalize_payload)?; let client_request_details = ClientRequestDetails { - method: req.method(), - url: req.uri(), - headers: req.headers(), + method, + url, + headers, operation: OperationDetails { name: normalize_payload.operation_for_plan.name.as_deref(), kind: match normalize_payload.operation_for_plan.operation_kind { @@ -454,7 +461,7 @@ async fn execute_planned_request<'exec>( aggregator.modify_client_response_headers(&mut builder)?; Arc::new(builder.finish().headers().clone()) } else { - Arc::new(ntex::http::HeaderMap::new()) + Arc::new(HeaderMap::new()) }; Ok(SharedRouterResponse::Stream(SharedRouterStreamResponse { @@ -480,7 +487,7 @@ async fn execute_planned_request<'exec>( aggregator.modify_client_response_headers(&mut builder)?; Arc::new(builder.finish().headers().clone()) } else { - Arc::new(ntex::http::HeaderMap::new()) + Arc::new(HeaderMap::new()) }; Ok(SharedRouterResponse::Single(SharedRouterSingleResponse { @@ -559,7 +566,7 @@ pub async fn execute_pipeline<'exec>( pub fn inbound_request_fingerprint( method: &http::Method, path: &str, - request_headers: &ntex::http::HeaderMap, + request_headers: &HeaderMap, dedupe_header_policy: &RouterRequestDedupeHeaderPolicy, schema_checksum: u64, normalized_operation_hash: u64, diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 1daceceb3..191359de2 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -1,16 +1,21 @@ +use http::Method; +use ntex::channel::oneshot; +use ntex::http::{header::HeaderName, header::HeaderValue, HeaderMap}; +use ntex::router::Path; +use ntex::service::{fn_factory_with_config, fn_service, fn_shutdown, Service}; +use ntex::web::{self, ws, Error, HttpRequest, HttpResponse}; +use ntex::{chain, rt}; +use sonic_rs::{JsonContainerTrait, JsonValueTrait, Value}; use std::cell::RefCell; use std::collections::HashMap; use std::io; use std::rc::Rc; use std::sync::Arc; use std::time::Instant; +use tokio::sync::mpsc; +use tracing::{debug, error, trace, warn, Instrument}; -use futures::StreamExt; use hive_router_internal::telemetry::traces::spans::graphql::GraphQLOperationSpan; -use hive_router_plan_executor::execution::client_request_details::{ - ClientRequestDetails, JwtRequestDetails, OperationDetails, -}; -use hive_router_plan_executor::execution::plan::QueryPlanExecutionResult; use hive_router_plan_executor::executors::graphql_transport_ws::{ ClientMessage, CloseCode, ConnectionInitPayload, ServerMessage, WS_SUBPROTOCOL, }; @@ -23,22 +28,11 @@ use hive_router_plan_executor::plugin_context::{ }; use hive_router_plan_executor::response::graphql_error::{GraphQLError, GraphQLErrorExtensions}; use hive_router_query_planner::state::supergraph_state::OperationKind; -use http::Method; -use ntex::channel::oneshot; -use ntex::http::{header::HeaderName, header::HeaderValue, HeaderMap}; -use ntex::router::Path; -use ntex::service::{fn_factory_with_config, fn_service, fn_shutdown, Service}; -use ntex::web::{self, ws, Error, HttpRequest, HttpResponse}; -use ntex::{chain, rt}; -use sonic_rs::{JsonContainerTrait, JsonValueTrait, Value}; -use tokio::sync::mpsc; -use tracing::{debug, error, trace, warn, Instrument}; use crate::jwt::errors::JwtError; -use crate::pipeline::coerce_variables::coerce_request_variables; use crate::pipeline::error::PipelineError; -use crate::pipeline::execute_pipeline; -use crate::pipeline::execution_request::GetQueryStr; +use crate::pipeline::execute_planned_request; +use crate::pipeline::header::{ResponseMode, SingleContentType, StreamContentType}; use crate::pipeline::{ hash_graphql_extensions, hash_graphql_variables, inbound_request_fingerprint, normalize::normalize_request_with_cache, parser::parse_operation_with_cache, usage_reporting, @@ -299,7 +293,7 @@ async fn handle_text_frame( } } - let mut payload = GraphQLParams { + let payload = GraphQLParams { query: Some(payload.query), operation_name: payload.operation_name, variables: payload.variables.unwrap_or_default(), @@ -423,11 +417,11 @@ async fn handle_text_frame( shared_state.router_config.traffic_shaping.router.dedupe.enabled; let fingerprint = if request_dedupe_enabled - && matches!( - normalize_payload.operation_for_plan.operation_kind, - // same deduplication applies for queries and subscriptions - None | Some(OperationKind::Query) | Some(OperationKind::Subscription) - ) { + && matches!( + normalize_payload.operation_for_plan.operation_kind, + // same deduplication applies for queries and subscriptions + None | Some(OperationKind::Query) | Some(OperationKind::Subscription) + ) { let variables_hash = hash_graphql_variables(&payload.variables); let extensions_hash = payload .extensions @@ -447,169 +441,61 @@ async fn handle_text_frame( None }; - let jwt_request_details = match &shared_state.jwt_auth_runtime { - Some(jwt_auth_runtime) => match jwt_auth_runtime - .validate_headers(&headers, &shared_state.jwt_claims_cache) - .await - { - Ok(Some(jwt_context)) => JwtRequestDetails::Authenticated { - scopes: jwt_context.extract_scopes(), - claims: match jwt_context - .get_claims_value() - .map_err(PipelineError::JwtForwardingError) - { - Ok(claims) => claims, - Err(e) => return Some(e.into_server_message(&id)), - }, - token: jwt_context.token_raw, - prefix: jwt_context.token_prefix, - }, - Ok(None) => JwtRequestDetails::Unauthenticated, - // jwt_auth_runtime.validate_headers() will error out only if - // authentication is required and has failed. we therefore use - // the JwtError conversion here to respond with proper error message - // close with Forbidden. - Err(e) => { - let _ = sink.send(e.clone().into_server_message(&id)).await; - // we report error as graphql error, but we also close the - // connection since we're dealing with auth so let's be safe - return Some(e.into_close_message()); - } - }, - None => JwtRequestDetails::Unauthenticated, - }; - - let variable_payload = match coerce_request_variables( - supergraph, - &mut payload.variables, - &normalize_payload, - ) { - Ok(payload) => payload, - Err(err) => return Some(err.into_server_message(&id)), - }; - - // synthetic client request details for plan executor - let client_request_details = ClientRequestDetails { - method: &Method::POST, - url: &WS_URI_PATH, - headers: &headers, - operation: OperationDetails { - name: normalize_payload.operation_for_plan.name.as_deref(), - kind: match normalize_payload.operation_for_plan.operation_kind { - Some(OperationKind::Query) => "query", - Some(OperationKind::Mutation) => "mutation", - Some(OperationKind::Subscription) => "subscription", - None => "query", - }, - query: match payload.get_query() { - Ok(q) => q, - Err(e) => return Some(e.into_server_message(&id)), - }, - }, - jwt: jwt_request_details, - }.into(); - - // TODO: dedupe - - match execute_pipeline( - &client_request_details, - &normalize_payload, - variable_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - ) - .await - { - Ok(QueryPlanExecutionResult::Single(response)) => { - if let Some(hive_usage_agent) = &shared_state.hive_usage_agent { - usage_reporting::collect_usage_report( - supergraph.supergraph_schema.clone(), - started_at.elapsed(), - client_name, - client_version, - normalize_payload.operation_for_plan.name.as_deref(), - &parser_payload.minified_document, - hive_usage_agent, - shared_state - .router_config - .telemetry - .hive - .as_ref() - .map(|c| &c.usage_reporting) - .expect( - // SAFETY: According to `configure_app_from_config` in `bin/router/src/lib.rs`, - // the UsageAgent is only created when usage reporting is enabled. - // Thus, this expect should never panic. - "Expected Usage Reporting options to be present when Hive Usage Agent is initialized", - ), - response.error_count, + // synthetic request details for plan executor + let shared_response = if let Some(fp) = fingerprint { + let (shared_response, _role) = match shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init(|guard| async { + execute_planned_request( + &Method::POST, + ws_uri, + &headers, + payload, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + &ResponseMode::Dual( + SingleContentType::default(), + StreamContentType::default(), + ), + Some(guard), ) - .await; - } - - let _ = sink.send(ServerMessage::next(&id, &response.body)).await; - Some(ServerMessage::complete(&id)) - } - Ok(QueryPlanExecutionResult::Stream(response)) => { - let mut stream = response.body; - - // we use mpsc::channel(1) instead of oneshot because oneshot::Receiver - // is consumed on first await, which doesn't work in tokio::select! loops that - // need to poll the receiver multiple times across iterations - let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); - - state - .borrow_mut() - .subscriptions - .insert(id.clone(), cancel_tx); - - // automatically remove the subscription from subscriptions when dropped - let _guard = SubscriptionGuard { - state: state.clone(), - id: id.clone(), + .await + }) + .await { + Ok(result) => result, + Err(err) => return Some(err.into_server_message(&id)), }; - - let mut cancelled = false; - - trace!(id = %id, "Subscription started"); - - let id_for_loop = id.clone(); - loop { - tokio::select! { - maybe_item = stream.next() => { - match maybe_item { - Some(body) => { - let _ = sink.send(ServerMessage::next(&id_for_loop, &body)).await; - } - None => { - break; // completed - } - } - } - _ = cancel_rx.recv() => { - cancelled = true; - break; // cancelled - } - } - } - - if cancelled { - trace!(id = %id, "Subscription cancelled"); - // we dont emit complete on cancelled subscriptions. - // they're either deliberately cancelled by the client - // or dropped due to connection close, either way - // we dont/cant inform the client with a complete message - None - } else { - trace!(id = %id, "Subscription completed"); - Some(ServerMessage::complete(&id)) - } + Arc::unwrap_or_clone(shared_response) + } else { + match execute_planned_request( + &Method::POST, + ws_uri, + &headers, + payload, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + &ResponseMode::Dual( + SingleContentType::default(), + StreamContentType::default(), + ), + None, + ) + .await { + Ok(result) => result, + Err(err) => return Some(err.into_server_message(&id)), } - Err(err) => Some(err.into_server_message(&id)), - } + }; + + todo!(); } .instrument(span_clone) .await; From 7e684a59689df677e35149b72c7e2435a710b04b Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 01:00:34 +0200 Subject: [PATCH 24/42] ok this looks good right --- bin/router/src/pipeline/websocket_server.rs | 93 ++++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 191359de2..eed7c2817 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -30,6 +30,7 @@ use hive_router_plan_executor::response::graphql_error::{GraphQLError, GraphQLEr use hive_router_query_planner::state::supergraph_state::OperationKind; use crate::jwt::errors::JwtError; +use crate::pipeline::active_subscriptions::SubscriptionEvent; use crate::pipeline::error::PipelineError; use crate::pipeline::execute_planned_request; use crate::pipeline::header::{ResponseMode, SingleContentType, StreamContentType}; @@ -39,7 +40,7 @@ use crate::pipeline::{ validation::validate_operation_with_cache, }; use crate::schema_state::SchemaState; -use crate::shared_state::RouterSharedState; +use crate::shared_state::{RouterSharedState, SharedRouterResponse}; type WsStateRef = Rc>>>; @@ -468,6 +469,12 @@ async fn handle_text_frame( }) .await { Ok(result) => result, + Err(PipelineError::JwtError(err)) => { + let _ = sink.send(err.clone().into_server_message(&id)).await; + // we report error as graphql error, but we also close the + // connection since we're dealing with auth so let's be safe + return Some(err.into_close_message()); + }, Err(err) => return Some(err.into_server_message(&id)), }; Arc::unwrap_or_clone(shared_response) @@ -495,7 +502,89 @@ async fn handle_text_frame( } }; - todo!(); + if let Some(hive_usage_agent) = &shared_state.hive_usage_agent { + usage_reporting::collect_usage_report( + supergraph.supergraph_schema.clone(), + started_at.elapsed(), + client_name, + client_version, + normalize_payload.operation_for_plan.name.as_deref(), + &parser_payload.minified_document, + hive_usage_agent, + shared_state + .router_config + .telemetry + .hive + .as_ref() + .map(|c| &c.usage_reporting) + .expect("Expected Usage Reporting options to be present when Hive Usage Agent is initialized"), + shared_response.error_count(), + ) + .await; + } + + match shared_response { + SharedRouterResponse::Single(response) => { + let _ = sink.send(ServerMessage::next(&id, &response.body)).await; + Some(ServerMessage::complete(&id)) + } + SharedRouterResponse::Stream(response) => { + let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); + + state + .borrow_mut() + .subscriptions + .insert(id.clone(), cancel_tx); + + let _guard = SubscriptionGuard { + state: state.clone(), + id: id.clone(), + }; + + let mut receiver = response + .receiver + .unwrap_or_else(|| response.body.subscribe()); + let mut cancelled = false; + + trace!(id = %id, "Subscription started"); + + let id_for_loop = id.clone(); + loop { + tokio::select! { + maybe_item = receiver.recv() => { + match maybe_item { + Ok(SubscriptionEvent::Raw(data)) => { + let _ = sink.send(ServerMessage::next(&id_for_loop, &data)).await; + } + Ok(SubscriptionEvent::Error(errors)) => { + let _ = sink.send(ServerMessage::error(&id_for_loop, &errors)).await; + break; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + trace!(id = %id_for_loop, lagged = n, "broadcast receiver lagged, skipping missed messages"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + _ = cancel_rx.recv() => { + cancelled = true; + break; + } + } + } + + if cancelled { + trace!(id = %id, "Subscription cancelled"); + None + } else { + trace!(id = %id, "Subscription completed"); + Some(ServerMessage::complete(&id)) + } + } + } } .instrument(span_clone) .await; From c5764b389982c2f5329fe0dcdd264a8b5ef8b573 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 01:05:38 +0200 Subject: [PATCH 25/42] lol of course --- bin/router/src/shared_state.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 91f7f6ce6..4e4bf8171 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -128,6 +128,8 @@ impl From for web::HttpResponse { response.set_header(header_name, header_value); } + response.content_type(shared_response.single_content_type.as_ref()); + response.body(shared_response.body) } } @@ -221,9 +223,9 @@ impl From for web::HttpResponse { response.set_header(header_name, header_value); } - response - .header(http::header::CONTENT_TYPE, content_type_header) - .streaming(body) + response.content_type(content_type_header); + + response.streaming(body) } } From 436a276478135eb6389d4828ce4684df3425a166 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 01:12:10 +0200 Subject: [PATCH 26/42] right, needs a spawn --- bin/router/src/pipeline/websocket_server.rs | 71 ++++++++++++--------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index eed7c2817..4b59c63bb 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -536,7 +536,7 @@ async fn handle_text_frame( .subscriptions .insert(id.clone(), cancel_tx); - let _guard = SubscriptionGuard { + let guard = SubscriptionGuard { state: state.clone(), id: id.clone(), }; @@ -544,45 +544,54 @@ async fn handle_text_frame( let mut receiver = response .receiver .unwrap_or_else(|| response.body.subscribe()); - let mut cancelled = false; trace!(id = %id, "Subscription started"); + let sink = sink.clone(); let id_for_loop = id.clone(); - loop { - tokio::select! { - maybe_item = receiver.recv() => { - match maybe_item { - Ok(SubscriptionEvent::Raw(data)) => { - let _ = sink.send(ServerMessage::next(&id_for_loop, &data)).await; - } - Ok(SubscriptionEvent::Error(errors)) => { - let _ = sink.send(ServerMessage::error(&id_for_loop, &errors)).await; - break; - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { - trace!(id = %id_for_loop, lagged = n, "broadcast receiver lagged, skipping missed messages"); - continue; - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => { - break; + // must be spawned - blocking the frame handler would prevent + // ClientMessage::Complete from being received and processed, + // making cancellation impossible + rt::spawn(async move { + let _guard = guard; + let mut cancelled = false; + + loop { + tokio::select! { + maybe_item = receiver.recv() => { + match maybe_item { + Ok(SubscriptionEvent::Raw(data)) => { + let _ = sink.send(ServerMessage::next(&id_for_loop, &data)).await; + } + Ok(SubscriptionEvent::Error(errors)) => { + let _ = sink.send(ServerMessage::error(&id_for_loop, &errors)).await; + break; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + trace!(id = %id_for_loop, lagged = n, "broadcast receiver lagged, skipping missed messages"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } } } - } - _ = cancel_rx.recv() => { - cancelled = true; - break; + _ = cancel_rx.recv() => { + cancelled = true; + break; + } } } - } - if cancelled { - trace!(id = %id, "Subscription cancelled"); - None - } else { - trace!(id = %id, "Subscription completed"); - Some(ServerMessage::complete(&id)) - } + if cancelled { + trace!(id = %id_for_loop, "Subscription cancelled"); + } else { + trace!(id = %id_for_loop, "Subscription completed"); + let _ = sink.send(ServerMessage::complete(&id_for_loop)).await; + } + }); + + None } } } From 50734f0f1e74a498deb67b431d5153287006a9e2 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 01:22:02 +0200 Subject: [PATCH 27/42] WS dedupe hehe --- e2e/src/subscriptions.rs | 199 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) diff --git a/e2e/src/subscriptions.rs b/e2e/src/subscriptions.rs index f66bdea3b..a10482250 100644 --- a/e2e/src/subscriptions.rs +++ b/e2e/src/subscriptions.rs @@ -1544,4 +1544,203 @@ mod subscriptions_e2e_tests { "Expected requests to reviews subgraph to be deduplicated" ); } + + #[ntex::test] + async fn active_across_transports_subscriptions_deduplication() { + use futures::StreamExt; + use hive_router_plan_executor::executors::{ + graphql_transport_ws::SubscribePayload, websocket_client::WsClient, + }; + + let subgraphs = TestSubgraphs::builder().build().start().await; + let router = TestRouter::builder() + .with_subgraphs(&subgraphs) + .inline_config( + r#" + supergraph: + source: file + path: supergraph.graphql + subscriptions: + enabled: true + websocket: + enabled: true + traffic_shaping: + router: + dedupe: + enabled: true + "#, + ) + .build() + .start() + .await; + + let query = r#" + subscription { + reviewAdded(intervalInMs: 100) { + id + product { + name + } + } + } + "#; + + let sse_headers = some_header_map! { + http::header::ACCEPT => "text/event-stream" + }; + let multipart_headers = some_header_map! { + http::header::ACCEPT => "multipart/mixed;subscriptionSpec=1.0" + }; + + let wsconn = router.ws().await; + let mut ws_client = WsClient::init(wsconn, None) + .await + .expect("Failed to init WsClient"); + let ws_payload = SubscribePayload { + query: query.into(), + ..Default::default() + }; + let mut ws_stream = ws_client.subscribe(ws_payload).await; + + let (sub_sse, sub_multipart) = tokio::join!( + router.send_graphql_request(query, None, sse_headers), + router.send_graphql_request(query, None, multipart_headers), + ); + + let sse_body = sub_sse.string_body().await; + assert!( + sse_body.contains("event: next") && sse_body.contains("event: complete"), + "Expected SSE subscription to receive events and complete" + ); + + let multipart_body = sub_multipart.string_body().await; + assert!( + multipart_body.contains("--graphql") && multipart_body.contains("--graphql--"), + "Expected multipart subscription to receive events and complete" + ); + + let mut ws_received = 0; + while let Some(response) = ws_stream.next().await { + assert!( + response.errors.is_none(), + "Expected no errors from WS subscription" + ); + assert!( + !response.data.is_null(), + "Expected data from WS subscription" + ); + ws_received += 1; + } + assert!( + ws_received > 0, + "Expected WS subscription to receive at least one event" + ); + + let reviews_requests = subgraphs.get_requests_log("reviews").unwrap_or_default(); + assert_eq!( + reviews_requests.len(), + 1, + "Expected requests to reviews subgraph to be deduplicated across transports" + ); + } + + #[ntex::test] + async fn active_across_transports_subscriptions_deduplication_promotion() { + use futures::StreamExt; + use hive_router_plan_executor::executors::{ + graphql_transport_ws::SubscribePayload, websocket_client::WsClient, + }; + + let subgraphs = TestSubgraphs::builder().build().start().await; + let router = TestRouter::builder() + .with_subgraphs(&subgraphs) + .inline_config( + r#" + supergraph: + source: file + path: supergraph.graphql + subscriptions: + enabled: true + websocket: + enabled: true + traffic_shaping: + router: + dedupe: + enabled: true + "#, + ) + .build() + .start() + .await; + + let query = r#" + subscription { + reviewAdded(intervalInMs: 100) { + id + } + } + "#; + + let wsconn = router.ws().await; + let mut ws_client = WsClient::init(wsconn, None) + .await + .expect("Failed to init WsClient"); + let ws_payload = SubscribePayload { + query: query.into(), + ..Default::default() + }; + let mut ws_stream = ws_client.subscribe(ws_payload).await; + + // consume 3 events from sub1 to let the source stream advance + let response = ws_stream.next().await.unwrap(); + assert!( + response.data.to_string().contains(r#""id":"1""#), + "Expected first event to be id=1" + ); + let response = ws_stream.next().await.unwrap(); + assert!( + response.data.to_string().contains(r#""id":"2""#), + "Expected second event to be id=2" + ); + let response = ws_stream.next().await.unwrap(); + assert!( + response.data.to_string().contains(r#""id":"3""#), + "Expected third event to be id=3" + ); + + // subscribe again with SSE - dedup promotes sub2 onto the live source + let sse_headers = some_header_map! { + http::header::ACCEPT => "text/event-stream" + }; + let sub2 = router.send_graphql_request(query, None, sse_headers).await; + + assert!(sub2.status().is_success(), "Expected 200 OK"); + + // drop the WS sub now that sub2 is connected; sub2 must become the active subscriber + drop(ws_stream); + drop(ws_client); + + // sub2 should receive the remainder of the stream from where the source left off + let body = sub2.string_body().await; + assert!( + body.contains("event: next") && body.contains("event: complete"), + "Expected sub2 to receive remaining events and complete, got: {body}" + ); + + // sub2 must not have received the first 3 events already consumed by the WS sub + assert!( + !body.contains(r#""id":"1""#) + && !body.contains(r#""id":"2""#) + && !body.contains(r#""id":"3""#), + "Expected sub2 to not replay events already consumed by the WS sub, got: {body}" + ); + + // only one subgraph request should have been made + let reviews_requests = subgraphs.get_requests_log("reviews").unwrap_or_default(); + assert_eq!( + reviews_requests.len(), + 1, + "Expected requests to reviews subgraph to be deduplicated across transports" + ); + } } From ec9f98a8e43489a4407657f9769a57abaf4c7e16 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 01:35:42 +0200 Subject: [PATCH 28/42] big man thing across transports --- e2e/src/subscriptions.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/e2e/src/subscriptions.rs b/e2e/src/subscriptions.rs index a10482250..c9ce728c0 100644 --- a/e2e/src/subscriptions.rs +++ b/e2e/src/subscriptions.rs @@ -1568,6 +1568,7 @@ mod subscriptions_e2e_tests { router: dedupe: enabled: true + headers: none "#, ) .build() @@ -1667,6 +1668,7 @@ mod subscriptions_e2e_tests { router: dedupe: enabled: true + headers: none "#, ) .build() @@ -1694,17 +1696,17 @@ mod subscriptions_e2e_tests { // consume 3 events from sub1 to let the source stream advance let response = ws_stream.next().await.unwrap(); assert!( - response.data.to_string().contains(r#""id":"1""#), + response.data.to_string().contains(r#""id": "1""#), "Expected first event to be id=1" ); let response = ws_stream.next().await.unwrap(); assert!( - response.data.to_string().contains(r#""id":"2""#), + response.data.to_string().contains(r#""id": "2""#), "Expected second event to be id=2" ); let response = ws_stream.next().await.unwrap(); assert!( - response.data.to_string().contains(r#""id":"3""#), + response.data.to_string().contains(r#""id": "3""#), "Expected third event to be id=3" ); From 91ff2dd2896dcb3d4f6575e3e705502461093e69 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 02:02:50 +0200 Subject: [PATCH 29/42] dedupe across boundaries fixed --- bin/router/src/pipeline/mod.rs | 26 +++++++++---------- bin/router/src/shared_state.rs | 46 ++++++++++++++++------------------ 2 files changed, 35 insertions(+), 37 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 8b648c3a1..8c943966c 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -359,7 +359,7 @@ pub async fn graphql_request_handler( }, ); - Ok(shared_response.into()) + shared_response.into_response(response_mode) } .instrument(span_clone) .await @@ -433,6 +433,10 @@ pub async fn execute_planned_request<'exec>( .await? { QueryPlanExecutionResult::Stream(result) => { + // we dont use the stream content type because subscriptions + // can be deduplicated across transports - but we do store + // the header value in the shared response because the user + // might choose to not deduplicate across transport boundaries let stream_content_type = response_mode .stream_content_type() .ok_or(PipelineError::SubscriptionsTransportNotSupported)? @@ -456,18 +460,16 @@ pub async fn execute_planned_request<'exec>( // dropping producer_handle closes the broadcast channel }); - let headers = if let Some(aggregator) = result.response_headers_aggregator { - let mut builder = web::HttpResponse::Ok(); + let mut builder = web::HttpResponse::Ok(); + if let Some(aggregator) = result.response_headers_aggregator { aggregator.modify_client_response_headers(&mut builder)?; - Arc::new(builder.finish().headers().clone()) - } else { - Arc::new(HeaderMap::new()) }; + builder.content_type(stream_content_type.as_ref()); + let headers = Arc::new(builder.finish().headers().clone()); Ok(SharedRouterResponse::Stream(SharedRouterStreamResponse { body: sender, headers, - stream_content_type, error_count: result.error_count, receiver: Some(receiver), })) @@ -482,18 +484,16 @@ pub async fn execute_planned_request<'exec>( // drop the router shared request as soon as the response is ready let _query_guard = guard; - let headers = if let Some(aggregator) = result.response_headers_aggregator { - let mut builder = web::HttpResponse::Ok(); + let mut builder = web::HttpResponse::Ok(); + if let Some(aggregator) = result.response_headers_aggregator { aggregator.modify_client_response_headers(&mut builder)?; - Arc::new(builder.finish().headers().clone()) - } else { - Arc::new(HeaderMap::new()) }; + builder.content_type(single_content_type.as_ref()); + let headers = Arc::new(builder.finish().headers().clone()); Ok(SharedRouterResponse::Single(SharedRouterSingleResponse { body: ntex::util::Bytes::from(result.body), headers, - single_content_type, status: result.status_code, error_count: result.error_count, })) diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 4e4bf8171..d00f26aec 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -29,7 +29,8 @@ use crate::jwt::context::JwtTokenPayload; use crate::jwt::JwtAuthRuntime; use crate::pipeline::active_subscriptions::{ActiveSubscriptions, SubscriptionEvent}; use crate::pipeline::cors::{CORSConfigError, Cors}; -use crate::pipeline::header::{SingleContentType, StreamContentType}; +use crate::pipeline::error::PipelineError; +use crate::pipeline::header::{ResponseMode, SingleContentType, StreamContentType}; use crate::pipeline::introspection_policy::compile_introspection_policy; use crate::pipeline::multipart_subscribe::{ self, APOLLO_MULTIPART_HTTP_CONTENT_TYPE, INCREMENTAL_DELIVERY_CONTENT_TYPE, @@ -99,13 +100,18 @@ impl SharedRouterResponse { SharedRouterResponse::Stream(resp) => resp.error_count, } } -} - -impl From for web::HttpResponse { - fn from(shared_response: SharedRouterResponse) -> Self { - match shared_response { - SharedRouterResponse::Single(single) => single.into(), - SharedRouterResponse::Stream(stream) => stream.into(), + pub fn into_response( + self, + response_mode: &ResponseMode, + ) -> Result { + match self { + SharedRouterResponse::Single(single) => Ok(single.into()), + SharedRouterResponse::Stream(stream) => { + let stream_content_type = response_mode + .stream_content_type() + .ok_or(PipelineError::SubscriptionsTransportNotSupported)?; + Ok(stream.into_response(stream_content_type)) + } } } } @@ -115,7 +121,6 @@ pub struct SharedRouterSingleResponse { pub body: Bytes, pub headers: Arc, pub status: StatusCode, - pub single_content_type: SingleContentType, pub error_count: usize, } @@ -123,22 +128,19 @@ impl From for web::HttpResponse { fn from(shared_response: SharedRouterSingleResponse) -> Self { let mut response = web::HttpResponse::Ok(); response.status(shared_response.status); - for (header_name, header_value) in shared_response.headers.iter() { response.set_header(header_name, header_value); } - - response.content_type(shared_response.single_content_type.as_ref()); - response.body(shared_response.body) } } +// status is always 200 for streaming responses, errors are sent through the stream. +// stream content type is not included because we can deduplicate subscriptions across +// different content types, the response format is decided when converting to response pub struct SharedRouterStreamResponse { - // status is always 200 for streaming responses, errors are sent through the stream pub body: tokio::sync::broadcast::Sender, pub headers: Arc, - pub stream_content_type: StreamContentType, pub error_count: usize, // the leader gets the receiver that was subscribed before the pump was spawned, so // there is no window where the channel has zero receivers and events can be lost. @@ -151,20 +153,17 @@ impl Clone for SharedRouterStreamResponse { Self { body: self.body.clone(), headers: self.headers.clone(), - stream_content_type: self.stream_content_type.clone(), error_count: self.error_count, receiver: None, } } } -impl From for web::HttpResponse { - fn from(shared_response: SharedRouterStreamResponse) -> Self { +impl SharedRouterStreamResponse { + pub fn into_response(self, stream_content_type: &StreamContentType) -> web::HttpResponse { // leader already has a pre-subscribed receiver to avoid missing // any potential events emitted. joiners, on the other hand, subscribe - let mut receiver = shared_response - .receiver - .unwrap_or_else(|| shared_response.body.subscribe()); + let mut receiver = self.receiver.unwrap_or_else(|| self.body.subscribe()); let stream = Box::pin(async_stream::stream! { loop { @@ -187,8 +186,6 @@ impl From for web::HttpResponse { } }); - let stream_content_type = shared_response.stream_content_type; - let content_type_header = match stream_content_type { StreamContentType::IncrementalDelivery => { http::HeaderValue::from_static(INCREMENTAL_DELIVERY_CONTENT_TYPE) @@ -219,10 +216,11 @@ impl From for web::HttpResponse { let mut response = web::HttpResponse::Ok(); - for (header_name, header_value) in shared_response.headers.iter() { + for (header_name, header_value) in self.headers.iter() { response.set_header(header_name, header_value); } + // we set content type after so that we can override the shared header response.content_type(content_type_header); response.streaming(body) From 12b97a59e3479b2d4863afbd8dc31f21a0d74d10 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 02:20:11 +0200 Subject: [PATCH 30/42] leave a comment --- lib/router-config/src/traffic_shaping.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/router-config/src/traffic_shaping.rs b/lib/router-config/src/traffic_shaping.rs index 25d1cfc17..7dc22492d 100644 --- a/lib/router-config/src/traffic_shaping.rs +++ b/lib/router-config/src/traffic_shaping.rs @@ -219,6 +219,14 @@ pub struct TrafficShapingRouterDedupeConfig { /// /// The deduplication is transport agnostic. A query over WebSocket would get deduplicated with an /// identical query over HTTP if they arrive at the same time and have the same fingerprint. + /// + /// Note: `content-type` is part of the fingerprint when `headers` includes it (e.g. `all`). + /// Since HTTP streaming clients send different `accept` headers than WebSocket clients, + /// cross-transport deduplication for subscriptions only applies when `content-type` (and + /// transport-specific headers) are excluded from the key. Configure `headers: none` or + /// `headers: { include: [] }` (or exclude the relevant headers) to enable true cross-transport + /// deduplication, where a WebSocket subscription and an SSE subscription with the same operation + /// share a single upstream connection and the events are fanned out to both. #[serde(default = "default_router_dedupe_enabled")] pub enabled: bool, From 843bbd7ece662427b7721f3f4714a90729abe7b1 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 02:24:03 +0200 Subject: [PATCH 31/42] sure clippy --- bin/router/src/shared_state.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index d00f26aec..acd293ba2 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -30,7 +30,7 @@ use crate::jwt::JwtAuthRuntime; use crate::pipeline::active_subscriptions::{ActiveSubscriptions, SubscriptionEvent}; use crate::pipeline::cors::{CORSConfigError, Cors}; use crate::pipeline::error::PipelineError; -use crate::pipeline::header::{ResponseMode, SingleContentType, StreamContentType}; +use crate::pipeline::header::{ResponseMode, StreamContentType}; use crate::pipeline::introspection_policy::compile_introspection_policy; use crate::pipeline::multipart_subscribe::{ self, APOLLO_MULTIPART_HTTP_CONTENT_TYPE, INCREMENTAL_DELIVERY_CONTENT_TYPE, From 3278e059e418b64865c132a6ad94c617b143c9e4 Mon Sep 17 00:00:00 2001 From: theguild-bot Date: Sat, 4 Apr 2026 00:24:10 +0000 Subject: [PATCH 32/42] docs: update documentation --- docs/README.md | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/docs/README.md b/docs/README.md index 740da4f4e..e4bf7fa64 100644 --- a/docs/README.md +++ b/docs/README.md @@ -18,10 +18,10 @@ |[**override\_subgraph\_urls**](#override_subgraph_urls)|`object`|Configuration for overriding subgraph URLs.
Default: `{}`
|| |[**plugins**](#plugins)|`object`|Configuration for custom plugins
|| |[**query\_planner**](#query_planner)|`object`|Query planning configuration.
Default: `{"allow_expose":false,"timeout":"10s"}`
|| -|[**subscriptions**](#subscriptions)|`object`|Configuration for subscriptions.
Default: `{"enabled":false}`
|| +|[**subscriptions**](#subscriptions)|`object`|Configuration for subscriptions.
Default: `{"broadcast_capacity":0,"enabled":false}`
|| |[**supergraph**](#supergraph)|`object`|Configuration for the Federation supergraph source. By default, the router will use a local file-based supergraph source (`./supergraph.graphql`).
|| |[**telemetry**](#telemetry)|`object`|Default: `{"client_identification":{"name_header":"graphql-client-name","version_header":"graphql-client-version"},"hive":null,"metrics":{"exporters":[],"instrumentation":{"common":{"histogram":{"aggregation":"explicit","bytes":{"buckets":[128,512,1024,2048,4096,8192,16384,32768,65536,131072,262144,524288,1048576,2097152,3145728,4194304,5242880],"record_min_max":false},"seconds":{"buckets":[0.005,0.01,0.025,0.05,0.075,0.1,0.25,0.5,0.75,1,2.5,5,7.5,10],"record_min_max":false}}},"instruments":{}}},"resource":{"attributes":{}},"tracing":{"collect":{"max_attributes_per_event":16,"max_attributes_per_link":32,"max_attributes_per_span":128,"max_events_per_span":128,"parent_based_sampler":false,"sampling":1},"exporters":[],"instrumentation":{"spans":{"mode":"spec_compliant"}},"propagation":{"b3":false,"baggage":false,"jaeger":false,"trace_context":true}}}`
|| -|[**traffic\_shaping**](#traffic_shaping)|`object`|Configuration for the traffic-shaping of the executor. Use these configurations to control how requests are being executed to subgraphs.
Default: `{"all":{"dedupe_enabled":true,"pool_idle_timeout":"50s","request_timeout":"30s"},"max_connections_per_host":100,"router":{"dedupe":{"enabled":false,"headers":"all"},"request_timeout":"1m"}}`
|| +|[**traffic\_shaping**](#traffic_shaping)|`object`|Configuration for the traffic-shaping of the executor. Use these configurations to control how requests are being executed to subgraphs.
Default: `{"all":{"dedupe_enabled":true,"pool_idle_timeout":"50s","request_timeout":"30s"},"max_connections_per_host":100,"router":{"dedupe":{"enabled":false,"headers":"all"},"max_long_lived_clients":128,"request_timeout":"1m"}}`
|| |[**websocket**](#websocket)|`object`|Configuration of router's WebSocket server.
Default: `{"enabled":false,"headers":{"persist":false,"source":"connection"},"path":null}`
|| **Additional Properties:** not allowed @@ -122,6 +122,7 @@ query_planner: allow_expose: false timeout: 10s subscriptions: + broadcast_capacity: 0 enabled: false supergraph: {} telemetry: @@ -202,6 +203,7 @@ traffic_shaping: dedupe: enabled: false headers: all + max_long_lived_clients: 128 request_timeout: 1m websocket: enabled: false @@ -1949,6 +1951,7 @@ Configuration for subscriptions. |Name|Type|Description|Required| |----|----|-----------|--------| +|**broadcast\_capacity**|`integer`|The capacity of the broadcast channel used to fan out subscription events to all active listeners.

Each active subscription has its own broadcast channel. This value controls how many events
can be buffered in that channel before slow consumers start lagging. If a consumer falls too
far behind and the buffer is full, it will skip the missed messages and continue from the
latest available event.

Subscription events are typically low-frequency, so the default of 32 is sufficient for most
use cases. Increase this value if you expect bursts of events or have slow consumers that
need more headroom to catch up.

Defaults to 32.
Default: `32`
Format: `"uint"`
Minimum: `0`
|| |[**callback**](#subscriptionscallback)|`object`, `null`|Configuration for subgraphs using the HTTP Callback protocol.
|yes| |**enabled**|`boolean`|Enables/disables subscriptions. By default, the subscriptions are disabled.

You can override this setting by setting the `SUBSCRIPTIONS_ENABLED` environment variable to `true` or `false`.
Default: `false`
|| |[**websocket**](#subscriptionswebsocket)|`object`, `null`|Configuration for subgraphs using WebSocket protocol.
|| @@ -1957,6 +1960,7 @@ Configuration for subscriptions. **Example** ```yaml +broadcast_capacity: 0 enabled: false ``` @@ -3038,7 +3042,7 @@ Configuration for the traffic-shaping of the executor. Use these configurations |----|----|-----------|--------| |[**all**](#traffic_shapingall)|`object`|The default configuration that will be applied to all subgraphs, unless overridden by a specific subgraph configuration.
Default: `{"dedupe_enabled":true,"pool_idle_timeout":"50s","request_timeout":"30s"}`
|| |**max\_connections\_per\_host**|`integer`|Limits the concurrent amount of requests/connections per host/subgraph.
Default: `100`
Format: `"uint"`
Minimum: `0`
|| -|[**router**](#traffic_shapingrouter)|`object`|Configuration for the router itself, e.g., for handling incoming requests, or other router-level traffic shaping configurations.
Default: `{"dedupe":{"enabled":false,"headers":"all"},"request_timeout":"1m"}`
|| +|[**router**](#traffic_shapingrouter)|`object`|Configuration for the router itself, e.g., for handling incoming requests, or other router-level traffic shaping configurations.
Default: `{"dedupe":{"enabled":false,"headers":"all"},"max_long_lived_clients":128,"request_timeout":"1m"}`
|| |[**subgraphs**](#traffic_shapingsubgraphs)|`object`|Optional per-subgraph configurations that will override the default configuration for specific subgraphs.
|| **Additional Properties:** not allowed @@ -3054,6 +3058,7 @@ router: dedupe: enabled: false headers: all + max_long_lived_clients: 128 request_timeout: 1m ``` @@ -3093,6 +3098,7 @@ Configuration for the router itself, e.g., for handling incoming requests, or ot |Name|Type|Description|Required| |----|----|-----------|--------| |[**dedupe**](#traffic_shapingrouterdedupe)|`object`|Default: `{"enabled":false,"headers":"all"}`
|| +|**max\_long\_lived\_clients**|`integer`|Maximum number of concurrent long-lived clients (WebSocket connections and HTTP streaming responses).
Regular non-streaming requests are not counted toward this limit.
When the limit is reached, new WebSocket and streaming HTTP requests are rejected with 503.
If both WebSockets and Subscriptions are disabled, this setting has no effect.
Default: `128`
Format: `"uint"`
Minimum: `0`
|| |**request\_timeout**|`string`|Optional timeout configuration for incoming requests to the router.
It starts from the moment the request is received by the router,
and includes the entire processing of the request (validation, execution, etc.) until a response is sent back to the client.
If a request takes longer than the specified duration, it will be aborted and a timeout error will be returned to the client.
Default: `"1m"`
|| **Additional Properties:** not allowed @@ -3102,6 +3108,7 @@ Configuration for the router itself, e.g., for handling incoming requests, or ot dedupe: enabled: false headers: all +max_long_lived_clients: 128 request_timeout: 1m ``` @@ -3113,7 +3120,7 @@ request_timeout: 1m |Name|Type|Description|Required| |----|----|-----------|--------| -|**enabled**|`boolean`|Enables/disables in-flight request deduplication at the router endpoint level.

When enabled, identical incoming GraphQL query requests that are processed at the same time
share the same in-flight execution result.
Default: `false`
|| +|**enabled**|`boolean`|Enables/disables in-flight request and active subscriptions deduplication at the router level.

When enabled, the router deduplicates both queries and subscriptions using the same
fingerprint key (method, path, selected headers, schema checksum, normalized operation
hash, variables, and extensions). The `headers` configuration below controls which
headers participate in that key for all operation types.

For queries, concurrent HTTP requests that produce the same fingerprint share a single
in-flight execution - only the first one runs, and the rest wait for and receive the
same result.

For subscriptions, the mechanism is broadcast-based rather than request-sharing. The
first client with a given fingerprint becomes the leader: it runs the upstream subscription
and its events are fanned out through a broadcast channel backed by an active subscriptions
registry. Any subsequent client that arrives with an identical fingerprint while that subscription
is still active joins as a listener on the same broadcast channel instead of starting a new upstream
connection. When all listeners have dropped and the leader finishes, the entry is removed from the
registry.

WebSocket connections participate in the same deduplication space as HTTP. Each
subscribe message is processed with a synthetic request assembled from the WebSocket
path and the headers derived from the `websocket.headers` config. The fingerprint is computed
from those synthetic headers using the same header policy, so a subscription started over HTTP
and an identical one started over WebSocket will deduplicate against each other.

The deduplication is transport agnostic. A query over WebSocket would get deduplicated with an
identical query over HTTP if they arrive at the same time and have the same fingerprint.

Note: `content-type` is part of the fingerprint when `headers` includes it (e.g. `all`).
Since HTTP streaming clients send different `accept` headers than WebSocket clients,
cross-transport deduplication for subscriptions only applies when `content-type` (and
transport-specific headers) are excluded from the key. Configure `headers: none` or
`headers: { include: [] }` (or exclude the relevant headers) to enable true cross-transport
deduplication, where a WebSocket subscription and an SSE subscription with the same operation
share a single upstream connection and the events are fanned out to both.
Default: `false`
|| |**headers**||Header configuration participating in the dedupe key.

Accepted forms:
- `all`
- `none`
- `{ include: ["authorization", "cookie"] }`

Header names are case-insensitive and validated as standard HTTP header names.
Default: `"all"`
|| **Additional Properties:** not allowed From 6b0cea087d6ad8e0bc68838725f2d4ffa54a6677 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 12:08:40 +0200 Subject: [PATCH 33/42] chill out clippy --- bin/router/src/pipeline/mod.rs | 1 + bin/router/src/shared_state.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 8c943966c..ea0ae52da 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -563,6 +563,7 @@ pub async fn execute_pipeline<'exec>( execute_plan(supergraph, shared_state, planned_request, operation_span).await } +#[allow(clippy::too_many_arguments)] pub fn inbound_request_fingerprint( method: &http::Method, path: &str, diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index acd293ba2..de96706d3 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -292,6 +292,7 @@ pub struct RouterSharedState { } impl RouterSharedState { + #[allow(clippy::too_many_arguments)] pub fn new( router_config: Arc, jwt_auth_runtime: Option, From 30260fd007621dcf58d3ec197bc701a37bef43da Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 12:41:53 +0200 Subject: [PATCH 34/42] long lived limit wrapping --- .../src/pipeline/long_lived_client_limit.rs | 66 ++++++++++----- e2e/src/subscriptions.rs | 81 +++++++++++++++++++ e2e/src/testkit/mod.rs | 4 + 3 files changed, 131 insertions(+), 20 deletions(-) diff --git a/bin/router/src/pipeline/long_lived_client_limit.rs b/bin/router/src/pipeline/long_lived_client_limit.rs index 8c2a1d221..fc65768b8 100644 --- a/bin/router/src/pipeline/long_lived_client_limit.rs +++ b/bin/router/src/pipeline/long_lived_client_limit.rs @@ -1,21 +1,26 @@ -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, +use std::{ + rc::Rc, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll}, }; use http::{header, StatusCode}; use ntex::{ + http::body::{Body, BodySize, MessageBody}, service::{Service, ServiceCtx}, + util::Bytes, web::{self, DefaultError}, Middleware, SharedCfg, }; use crate::RouterSharedState; -/// pre-resolved at app construction time so the per-request path is branch-free #[derive(Clone)] pub struct LongLivedClientLimitService { - /// false means the middleware is entirely bypassed on every request + // false means the middleware is entirely bypassed on every request enabled: bool, } @@ -79,7 +84,7 @@ where .max_long_lived_clients; let counter = shared_state.long_lived_client_count.clone(); - // try to reserve a slot; back off if we're at the limit + // try to reserve a slot, bail if at the limit let prev = counter.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| { if current < limit { Some(current + 1) @@ -87,7 +92,6 @@ where None } }); - if prev.is_err() { let error_response = web::HttpResponse::build(StatusCode::SERVICE_UNAVAILABLE) .header(header::RETRY_AFTER, "5") @@ -97,13 +101,21 @@ where let guard = LongLivedClientGuard(counter); let response = ctx.call(&self.service, req).await?; - drop(guard); + + // wrap the body so the guard lives until the stream is fully consumed + let response = response.map_body(|_head, body| { + let wrapped = GuardedBody { + inner: body.into_body().into(), + _guard: guard, + }; + Body::from_message(wrapped).into() + }); Ok(response) } } -/// decrements the counter when dropped +// decrements the counter when dropped struct LongLivedClientGuard(Arc); impl Drop for LongLivedClientGuard { @@ -112,11 +124,30 @@ impl Drop for LongLivedClientGuard { } } -/// returns true if the request is a websocket upgrade or an http streaming request. -/// -/// deliberately ordered cheapest check first: -/// 1. upgrade: websocket - two header lookups, no parsing -/// 2. accept streaming - one header lookup + fast substring pre-filter, full parse only if needed +// wraps the body and keeps the guard alive until it's fully consumed and dropped. +// one extra vtable call per chunk on top of the Box dispatch streaming bodies +// already go through - negligible next to actual I/O cost per chunk. +struct GuardedBody { + inner: Body, + _guard: LongLivedClientGuard, +} + +impl MessageBody for GuardedBody { + fn size(&self) -> BodySize { + self.inner.size() + } + + fn poll_next_chunk( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + self.inner.poll_next_chunk(cx) + } +} + +// cheapest check first: +// 1. upgrade: websocket - two header lookups, no parsing +// 2. accept: streaming - one header lookup + fast substring pre-filter, full parse only if needed #[inline] fn is_long_lived_request(headers: &ntex::http::HeaderMap) -> bool { // websocket: Connection: Upgrade + Upgrade: websocket @@ -133,9 +164,7 @@ fn is_long_lived_request(headers: &ntex::http::HeaderMap) -> bool { return true; } - // http streaming: Accept header contains a known streaming content type. - // we do a fast substring scan before handing off to the full Accept parser - // to avoid the parse cost on the hot path for regular requests. + // fast substring scan before full Accept parse to avoid cost on regular requests let accept = match headers.get(header::ACCEPT).and_then(|v| v.to_str().ok()) { Some(v) if !v.is_empty() => v, _ => return false, @@ -155,9 +184,6 @@ fn is_long_lived_request(headers: &ntex::http::HeaderMap) -> bool { .is_some() } -/// fast pre-filter: returns true if the raw Accept string contains any substring -/// that could match a known streaming content type, avoiding the full parse on -/// the vast majority of regular (application/json) requests. #[inline] fn looks_like_streaming_accept(accept: &str) -> bool { // covers: multipart/mixed, text/event-stream diff --git a/e2e/src/subscriptions.rs b/e2e/src/subscriptions.rs index c9ce728c0..bf6196be9 100644 --- a/e2e/src/subscriptions.rs +++ b/e2e/src/subscriptions.rs @@ -1745,4 +1745,85 @@ mod subscriptions_e2e_tests { "Expected requests to reviews subgraph to be deduplicated across transports" ); } + + #[ntex::test] + async fn max_long_lived_clients_rejects_over_limit() { + use futures::StreamExt; + + let subgraphs = TestSubgraphs::builder().build().start().await; + let router = TestRouter::builder() + .with_subgraphs(&subgraphs) + .inline_config( + r#" + supergraph: + source: file + path: supergraph.graphql + subscriptions: + enabled: true + traffic_shaping: + router: + max_long_lived_clients: 2 + "#, + ) + .build() + .start() + .await; + + let query = r#" + subscription { + reviewAdded(intervalInMs: 200) { + id + } + } + "#; + let headers = some_header_map! { + http::header::ACCEPT => "text/event-stream" + }; + + // open two subscriptions and keep them alive by reading the first event + let mut sub1 = router + .send_graphql_request(query, None, headers.clone()) + .await; + assert!(sub1.status().is_success(), "sub1 should be accepted"); + let _ = sub1.next().await; + + let mut sub2 = router + .send_graphql_request(query, None, headers.clone()) + .await; + assert!(sub2.status().is_success(), "sub2 should be accepted"); + let _ = sub2.next().await; + + // the third subscriber exceeds the limit and must be rejected + let sub3 = router + .send_graphql_request(query, None, headers.clone()) + .await; + assert_eq!( + sub3.status(), + reqwest::StatusCode::SERVICE_UNAVAILABLE, + "sub3 should be rejected with 503 when the limit is reached" + ); + let retry_after = sub3.header("retry-after"); + assert!( + retry_after.is_some(), + "rejected response should include a Retry-After header" + ); + let body = sub3.string_body().await; + assert_eq!(body, "Too many long-lived clients"); + + // release the two held subscriptions + drop(sub1); + drop(sub2); + + // wait briefly for the slots to be freed + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // a new subscriber should now be accepted again + let sub4 = router + .send_graphql_request(query, None, headers.clone()) + .await; + assert!( + sub4.status().is_success(), + "sub4 should be accepted after the previous slots were freed" + ); + } } diff --git a/e2e/src/testkit/mod.rs b/e2e/src/testkit/mod.rs index 37461fc2d..72804b642 100644 --- a/e2e/src/testkit/mod.rs +++ b/e2e/src/testkit/mod.rs @@ -27,6 +27,7 @@ use tracing::{info, warn}; use hive_router::{ add_callback_handler, background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app, init_rustls_crypto_provider, invoke_shutdown_hooks, + pipeline::long_lived_client_limit::LongLivedClientLimitService, plugins::plugins_service::PluginService, telemetry::Telemetry, PluginRegistry, RouterPaths, RouterSharedState, SchemaState, }; @@ -755,6 +756,7 @@ impl TestRouter { let serv_paths = paths.clone(); let serv_prometheus = prometheus.clone(); + let long_lived_limit = LongLivedClientLimitService::new(&shared_state.router_config); let serv = test::server_with(test::config().port(self.port), move || { let shared_state = serv_shared_state.clone(); let schema_state = serv_schema_state.clone(); @@ -762,6 +764,7 @@ impl TestRouter { let prometheus = serv_prometheus.clone(); let serv_callback_path = serv_callback_path.clone(); let callback_subs = serv_callback_subs.clone(); + let long_lived_limit = long_lived_limit.clone(); // set the tracing dispatch on the server thread. the guard is // intentionally leaked: dropping it would restore the no-op default @@ -775,6 +778,7 @@ impl TestRouter { async move { web::App::new() + .middleware(long_lived_limit) .middleware(PluginService) .state(shared_state) .state(schema_state) From 36043b260609fabed81e97ed0d02bcb5359e0370 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 13:19:15 +0200 Subject: [PATCH 35/42] revert inflight map, testing --- bin/router/src/pipeline/mod.rs | 4 +- bin/router/src/pipeline/websocket_server.rs | 4 +- lib/executor/src/executors/http.rs | 2 +- lib/internal/src/inflight.rs | 47 ++++++--------------- 4 files changed, 17 insertions(+), 40 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index ea0ae52da..b474f1a3d 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -287,7 +287,7 @@ pub async fn graphql_request_handler( let (planned_response, _role) = shared_state .in_flight_requests .claim(fp) - .get_or_try_init(|guard| async { + .get_or_try_init(|| async { execute_planned_request( req.method(), req.uri(), @@ -300,7 +300,7 @@ pub async fn graphql_request_handler( operation_span, plugin_req_state, response_mode, - Some(guard), + None, ) .await }) diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 4b59c63bb..c35d89806 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -447,7 +447,7 @@ async fn handle_text_frame( let (shared_response, _role) = match shared_state .in_flight_requests .claim(fp) - .get_or_try_init(|guard| async { + .get_or_try_init(|| async { execute_planned_request( &Method::POST, ws_uri, @@ -463,7 +463,7 @@ async fn handle_text_frame( SingleContentType::default(), StreamContentType::default(), ), - Some(guard), + None, ) .await }) diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 6c7596a41..ce4b853d2 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -392,7 +392,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { let claim = self.in_flight_requests.claim(fingerprint); let mut leader_http_request_capture = None; let (shared_response, role) = claim - .get_or_try_init(|_| async { + .get_or_try_init(|| async { let res = { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. diff --git a/lib/internal/src/inflight.rs b/lib/internal/src/inflight.rs index 4c1367bcf..4684c8db5 100644 --- a/lib/internal/src/inflight.rs +++ b/lib/internal/src/inflight.rs @@ -92,22 +92,10 @@ where K: Eq + Hash + Clone, S: BuildHasher + Clone, { - /// Initialises the cell if empty (leader) or waits for the existing value (joiner). - /// - /// The leader's `init` closure receives an `InFlightCleanupGuard`. Dropping the guard removes - /// the entry from the map. For short-lived work (queries) drop it immediately. For long-lived - /// work (subscriptions) move it into the task that owns the upstream so the entry stays - /// visible to joiners for the full lifetime of the stream. - /// - /// On init failure the entry is cleaned up automatically regardless of what the caller does - /// with the guard, so no entry is left dangling. - /// - /// Joiners do not invoke `init` - they share the already-initialised value and have no cleanup - /// responsibility. #[inline] pub async fn get_or_try_init(self, init: F) -> Result<(Arc, InFlightRole), E> where - F: FnOnce(InFlightCleanupGuard) -> Fut, + F: FnOnce() -> Fut, Fut: Future>, { let mut did_initialize = false; @@ -118,38 +106,24 @@ where .cell .get_or_try_init(|| { did_initialize = true; - let guard = InFlightCleanupGuard { - key: self.key.clone(), - map: self.map.clone(), - }; async { - match init(guard).await { - Ok(v) => Ok(Arc::new(v)), - Err(e) => { - // clean up immediately on failure so a future request can retry - map.remove(&key); - Err(e) - } - } + let _cleanup = InFlightCleanupGuard { key, map }; + init().await.map(Arc::new) } }) .await? .clone(); - if did_initialize { - Ok((value, InFlightRole::Leader)) + let role = if did_initialize { + InFlightRole::Leader } else { - Ok((value, InFlightRole::Joiner)) - } + InFlightRole::Joiner + }; + + Ok((value, role)) } } -/// Removes the entry from the inflight map when dropped. -/// -/// For queries, drop this immediately after `get_or_try_init` returns so subsequent requests -/// are not deduplicated against a completed response. -/// For subscriptions, move this into the upstream pump task so the entry remains in the map -/// (and joiners can find it) for the full lifetime of the stream. pub struct InFlightCleanupGuard where K: Eq + Hash, @@ -165,6 +139,9 @@ where S: BuildHasher + Clone, { fn drop(&mut self) { + // It's important to remove the entry from the map before returning the result. + // This ensures that once the OnceCell is set, no future requests can join it. + // The cache is for the lifetime of the in-flight request only. self.map.remove(&self.key); } } From e5509db2c0d58832741459db8ed2968e6b8f783e Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 18:41:31 +0200 Subject: [PATCH 36/42] no clone guard --- bin/router/src/pipeline/mod.rs | 4 ++-- bin/router/src/pipeline/websocket_server.rs | 4 ++-- lib/executor/src/executors/http.rs | 2 +- lib/internal/src/inflight.rs | 14 +++++--------- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index b474f1a3d..ea0ae52da 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -287,7 +287,7 @@ pub async fn graphql_request_handler( let (planned_response, _role) = shared_state .in_flight_requests .claim(fp) - .get_or_try_init(|| async { + .get_or_try_init(|guard| async { execute_planned_request( req.method(), req.uri(), @@ -300,7 +300,7 @@ pub async fn graphql_request_handler( operation_span, plugin_req_state, response_mode, - None, + Some(guard), ) .await }) diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index c35d89806..4b59c63bb 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -447,7 +447,7 @@ async fn handle_text_frame( let (shared_response, _role) = match shared_state .in_flight_requests .claim(fp) - .get_or_try_init(|| async { + .get_or_try_init(|guard| async { execute_planned_request( &Method::POST, ws_uri, @@ -463,7 +463,7 @@ async fn handle_text_frame( SingleContentType::default(), StreamContentType::default(), ), - None, + Some(guard), ) .await }) diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index ce4b853d2..6c7596a41 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -392,7 +392,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { let claim = self.in_flight_requests.claim(fingerprint); let mut leader_http_request_capture = None; let (shared_response, role) = claim - .get_or_try_init(|| async { + .get_or_try_init(|_| async { let res = { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. diff --git a/lib/internal/src/inflight.rs b/lib/internal/src/inflight.rs index 4684c8db5..3cbf32a6c 100644 --- a/lib/internal/src/inflight.rs +++ b/lib/internal/src/inflight.rs @@ -95,21 +95,17 @@ where #[inline] pub async fn get_or_try_init(self, init: F) -> Result<(Arc, InFlightRole), E> where - F: FnOnce() -> Fut, + F: FnOnce(InFlightCleanupGuard) -> Fut, Fut: Future>, { let mut did_initialize = false; - let key = self.key.clone(); - let map = self.map.clone(); + let InFlightClaim { key, map, cell } = self; - let value = self - .cell + let value = cell .get_or_try_init(|| { did_initialize = true; - async { - let _cleanup = InFlightCleanupGuard { key, map }; - init().await.map(Arc::new) - } + let guard = InFlightCleanupGuard { key, map }; + async { init(guard).await.map(Arc::new) } }) .await? .clone(); From 383b356d9edf00bae0f3a29c35bd975e8d751128 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 19:10:14 +0200 Subject: [PATCH 37/42] with or without guard --- bin/router/src/pipeline/mod.rs | 68 +++++++++----- bin/router/src/pipeline/websocket_server.rs | 98 ++++++++++++++------- lib/executor/src/executors/http.rs | 2 +- lib/internal/src/inflight.rs | 42 ++++++++- 4 files changed, 153 insertions(+), 57 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index ea0ae52da..1255b68f2 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -284,28 +284,52 @@ pub async fn graphql_request_handler( }; let shared_response = if let Some(fp) = fingerprint { - let (planned_response, _role) = shared_state - .in_flight_requests - .claim(fp) - .get_or_try_init(|guard| async { - execute_planned_request( - req.method(), - req.uri(), - req.headers(), - graphql_params, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - response_mode, - Some(guard), - ) - .await - }) - .await?; - Arc::unwrap_or_clone(planned_response) + let (shared_response, _role) = if is_subscription { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init_with_guard(|guard| async { + execute_planned_request( + req.method(), + req.uri(), + req.headers(), + graphql_params, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + response_mode, + Some(guard), + ) + .await + }) + .await? + } else { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init(|| async { + execute_planned_request( + req.method(), + req.uri(), + req.headers(), + graphql_params, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + response_mode, + None, + ) + .await + }) + .await? + }; + Arc::unwrap_or_clone(shared_response) } else { execute_planned_request( req.method(), diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 4b59c63bb..ad6ec0175 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -444,39 +444,67 @@ async fn handle_text_frame( // synthetic request details for plan executor let shared_response = if let Some(fp) = fingerprint { - let (shared_response, _role) = match shared_state - .in_flight_requests - .claim(fp) - .get_or_try_init(|guard| async { - execute_planned_request( - &Method::POST, - ws_uri, - &headers, - payload, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - &ResponseMode::Dual( - SingleContentType::default(), - StreamContentType::default(), - ), - Some(guard), - ) + let result = if is_subscription { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init_with_guard(|guard| async { + execute_planned_request( + &Method::POST, + ws_uri, + &headers, + payload, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + &ResponseMode::Dual( + SingleContentType::default(), + StreamContentType::default(), + ), + Some(guard), + ) + .await + }) .await - }) - .await { - Ok(result) => result, - Err(PipelineError::JwtError(err)) => { - let _ = sink.send(err.clone().into_server_message(&id)).await; - // we report error as graphql error, but we also close the - // connection since we're dealing with auth so let's be safe - return Some(err.into_close_message()); - }, - Err(err) => return Some(err.into_server_message(&id)), - }; + } else { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init(|| async { + execute_planned_request( + &Method::POST, + ws_uri, + &headers, + payload, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + &ResponseMode::Dual( + SingleContentType::default(), + StreamContentType::default(), + ), + None, + ) + .await + }) + .await + }; + let (shared_response, _role) = match result { + Ok(result) => result, + Err(PipelineError::JwtError(err)) => { + let _ = sink.send(err.clone().into_server_message(&id)).await; + // we report error as graphql error, but we also close the + // connection since we're dealing with auth so let's be safe + return Some(err.into_close_message()); + }, + Err(err) => return Some(err.into_server_message(&id)), + }; Arc::unwrap_or_clone(shared_response) } else { match execute_planned_request( @@ -498,6 +526,12 @@ async fn handle_text_frame( ) .await { Ok(result) => result, + Err(PipelineError::JwtError(err)) => { + let _ = sink.send(err.clone().into_server_message(&id)).await; + // we report error as graphql error, but we also close the + // connection since we're dealing with auth so let's be safe + return Some(err.into_close_message()); + }, Err(err) => return Some(err.into_server_message(&id)), } }; diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 6c7596a41..ce4b853d2 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -392,7 +392,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { let claim = self.in_flight_requests.claim(fingerprint); let mut leader_http_request_capture = None; let (shared_response, role) = claim - .get_or_try_init(|_| async { + .get_or_try_init(|| async { let res = { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. diff --git a/lib/internal/src/inflight.rs b/lib/internal/src/inflight.rs index 3cbf32a6c..d9fb5ce3e 100644 --- a/lib/internal/src/inflight.rs +++ b/lib/internal/src/inflight.rs @@ -94,6 +94,42 @@ where { #[inline] pub async fn get_or_try_init(self, init: F) -> Result<(Arc, InFlightRole), E> + where + F: FnOnce() -> Fut, + Fut: Future>, + { + let mut did_initialize = false; + let InFlightClaim { key, map, cell } = self; + + let value = cell + .get_or_try_init(|| { + did_initialize = true; + async { + let _cleanup = InFlightCleanupGuard { key, map }; + init().await.map(Arc::new) + } + }) + .await? + .clone(); + + let role = if did_initialize { + InFlightRole::Leader + } else { + InFlightRole::Joiner + }; + + Ok((value, role)) + } + + /// Like [`get_or_try_init`](Self::get_or_try_init), but passes the cleanup guard to the + /// caller. The caller is responsible for dropping the guard when deduplication should end. + /// This is useful for long-lived operations like subscriptions where the deduplication + /// window extends beyond the init future. + #[inline] + pub async fn get_or_try_init_with_guard( + self, + init: F, + ) -> Result<(Arc, InFlightRole), E> where F: FnOnce(InFlightCleanupGuard) -> Fut, Fut: Future>, @@ -104,8 +140,10 @@ where let value = cell .get_or_try_init(|| { did_initialize = true; - let guard = InFlightCleanupGuard { key, map }; - async { init(guard).await.map(Arc::new) } + async { + let guard = InFlightCleanupGuard { key, map }; + init(guard).await.map(Arc::new) + } }) .await? .clone(); From 0fa4d1e65888511320dba7c52fd58b481fdb904a Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 20:35:46 +0200 Subject: [PATCH 38/42] Revert "with or without guard" This reverts commit 383b356d9edf00bae0f3a29c35bd975e8d751128. --- bin/router/src/pipeline/mod.rs | 68 +++++--------- bin/router/src/pipeline/websocket_server.rs | 98 +++++++-------------- lib/executor/src/executors/http.rs | 2 +- lib/internal/src/inflight.rs | 42 +-------- 4 files changed, 57 insertions(+), 153 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 1255b68f2..ea0ae52da 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -284,52 +284,28 @@ pub async fn graphql_request_handler( }; let shared_response = if let Some(fp) = fingerprint { - let (shared_response, _role) = if is_subscription { - shared_state - .in_flight_requests - .claim(fp) - .get_or_try_init_with_guard(|guard| async { - execute_planned_request( - req.method(), - req.uri(), - req.headers(), - graphql_params, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - response_mode, - Some(guard), - ) - .await - }) - .await? - } else { - shared_state - .in_flight_requests - .claim(fp) - .get_or_try_init(|| async { - execute_planned_request( - req.method(), - req.uri(), - req.headers(), - graphql_params, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - response_mode, - None, - ) - .await - }) - .await? - }; - Arc::unwrap_or_clone(shared_response) + let (planned_response, _role) = shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init(|guard| async { + execute_planned_request( + req.method(), + req.uri(), + req.headers(), + graphql_params, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + response_mode, + Some(guard), + ) + .await + }) + .await?; + Arc::unwrap_or_clone(planned_response) } else { execute_planned_request( req.method(), diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index ad6ec0175..4b59c63bb 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -444,67 +444,39 @@ async fn handle_text_frame( // synthetic request details for plan executor let shared_response = if let Some(fp) = fingerprint { - let result = if is_subscription { - shared_state - .in_flight_requests - .claim(fp) - .get_or_try_init_with_guard(|guard| async { - execute_planned_request( - &Method::POST, - ws_uri, - &headers, - payload, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - &ResponseMode::Dual( - SingleContentType::default(), - StreamContentType::default(), - ), - Some(guard), - ) - .await - }) - .await - } else { - shared_state - .in_flight_requests - .claim(fp) - .get_or_try_init(|| async { - execute_planned_request( - &Method::POST, - ws_uri, - &headers, - payload, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - &ResponseMode::Dual( - SingleContentType::default(), - StreamContentType::default(), - ), - None, - ) - .await - }) + let (shared_response, _role) = match shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init(|guard| async { + execute_planned_request( + &Method::POST, + ws_uri, + &headers, + payload, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + &ResponseMode::Dual( + SingleContentType::default(), + StreamContentType::default(), + ), + Some(guard), + ) .await - }; - let (shared_response, _role) = match result { - Ok(result) => result, - Err(PipelineError::JwtError(err)) => { - let _ = sink.send(err.clone().into_server_message(&id)).await; - // we report error as graphql error, but we also close the - // connection since we're dealing with auth so let's be safe - return Some(err.into_close_message()); - }, - Err(err) => return Some(err.into_server_message(&id)), - }; + }) + .await { + Ok(result) => result, + Err(PipelineError::JwtError(err)) => { + let _ = sink.send(err.clone().into_server_message(&id)).await; + // we report error as graphql error, but we also close the + // connection since we're dealing with auth so let's be safe + return Some(err.into_close_message()); + }, + Err(err) => return Some(err.into_server_message(&id)), + }; Arc::unwrap_or_clone(shared_response) } else { match execute_planned_request( @@ -526,12 +498,6 @@ async fn handle_text_frame( ) .await { Ok(result) => result, - Err(PipelineError::JwtError(err)) => { - let _ = sink.send(err.clone().into_server_message(&id)).await; - // we report error as graphql error, but we also close the - // connection since we're dealing with auth so let's be safe - return Some(err.into_close_message()); - }, Err(err) => return Some(err.into_server_message(&id)), } }; diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index ce4b853d2..6c7596a41 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -392,7 +392,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { let claim = self.in_flight_requests.claim(fingerprint); let mut leader_http_request_capture = None; let (shared_response, role) = claim - .get_or_try_init(|| async { + .get_or_try_init(|_| async { let res = { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. diff --git a/lib/internal/src/inflight.rs b/lib/internal/src/inflight.rs index d9fb5ce3e..3cbf32a6c 100644 --- a/lib/internal/src/inflight.rs +++ b/lib/internal/src/inflight.rs @@ -94,42 +94,6 @@ where { #[inline] pub async fn get_or_try_init(self, init: F) -> Result<(Arc, InFlightRole), E> - where - F: FnOnce() -> Fut, - Fut: Future>, - { - let mut did_initialize = false; - let InFlightClaim { key, map, cell } = self; - - let value = cell - .get_or_try_init(|| { - did_initialize = true; - async { - let _cleanup = InFlightCleanupGuard { key, map }; - init().await.map(Arc::new) - } - }) - .await? - .clone(); - - let role = if did_initialize { - InFlightRole::Leader - } else { - InFlightRole::Joiner - }; - - Ok((value, role)) - } - - /// Like [`get_or_try_init`](Self::get_or_try_init), but passes the cleanup guard to the - /// caller. The caller is responsible for dropping the guard when deduplication should end. - /// This is useful for long-lived operations like subscriptions where the deduplication - /// window extends beyond the init future. - #[inline] - pub async fn get_or_try_init_with_guard( - self, - init: F, - ) -> Result<(Arc, InFlightRole), E> where F: FnOnce(InFlightCleanupGuard) -> Fut, Fut: Future>, @@ -140,10 +104,8 @@ where let value = cell .get_or_try_init(|| { did_initialize = true; - async { - let guard = InFlightCleanupGuard { key, map }; - init(guard).await.map(Arc::new) - } + let guard = InFlightCleanupGuard { key, map }; + async { init(guard).await.map(Arc::new) } }) .await? .clone(); From 0329884fadc3391d764d73e1a3d9a4f8b4aa67c3 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 20:37:21 +0200 Subject: [PATCH 39/42] `_` drop immed `_guard` bind and drop at end --- bin/router/src/pipeline/mod.rs | 3 +-- lib/executor/src/executors/http.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index ea0ae52da..608f540f4 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -481,8 +481,7 @@ pub async fn execute_planned_request<'exec>( ok_or(PipelineError::UnsupportedContentType)?. clone(); - // drop the router shared request as soon as the response is ready - let _query_guard = guard; + // drop the `guard` as soon as the response is ready let mut builder = web::HttpResponse::Ok(); if let Some(aggregator) = result.response_headers_aggregator { diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 6c7596a41..a8ba4cca8 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -392,7 +392,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { let claim = self.in_flight_requests.claim(fingerprint); let mut leader_http_request_capture = None; let (shared_response, role) = claim - .get_or_try_init(|_| async { + .get_or_try_init(|_guard| async { let res = { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. From 91e793892302a367b86007864aca632dee1a31de Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 22:07:36 +0200 Subject: [PATCH 40/42] Reapply "with or without guard" This reverts commit 0fa4d1e65888511320dba7c52fd58b481fdb904a. # Conflicts: # lib/executor/src/executors/http.rs --- bin/router/src/pipeline/mod.rs | 68 +++++++++----- bin/router/src/pipeline/websocket_server.rs | 98 ++++++++++++++------- lib/executor/src/executors/http.rs | 2 +- lib/internal/src/inflight.rs | 42 ++++++++- 4 files changed, 153 insertions(+), 57 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 608f540f4..ffbd548d0 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -284,28 +284,52 @@ pub async fn graphql_request_handler( }; let shared_response = if let Some(fp) = fingerprint { - let (planned_response, _role) = shared_state - .in_flight_requests - .claim(fp) - .get_or_try_init(|guard| async { - execute_planned_request( - req.method(), - req.uri(), - req.headers(), - graphql_params, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - response_mode, - Some(guard), - ) - .await - }) - .await?; - Arc::unwrap_or_clone(planned_response) + let (shared_response, _role) = if is_subscription { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init_with_guard(|guard| async { + execute_planned_request( + req.method(), + req.uri(), + req.headers(), + graphql_params, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + response_mode, + Some(guard), + ) + .await + }) + .await? + } else { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init(|| async { + execute_planned_request( + req.method(), + req.uri(), + req.headers(), + graphql_params, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + response_mode, + None, + ) + .await + }) + .await? + }; + Arc::unwrap_or_clone(shared_response) } else { execute_planned_request( req.method(), diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 4b59c63bb..ad6ec0175 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -444,39 +444,67 @@ async fn handle_text_frame( // synthetic request details for plan executor let shared_response = if let Some(fp) = fingerprint { - let (shared_response, _role) = match shared_state - .in_flight_requests - .claim(fp) - .get_or_try_init(|guard| async { - execute_planned_request( - &Method::POST, - ws_uri, - &headers, - payload, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - &ResponseMode::Dual( - SingleContentType::default(), - StreamContentType::default(), - ), - Some(guard), - ) + let result = if is_subscription { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init_with_guard(|guard| async { + execute_planned_request( + &Method::POST, + ws_uri, + &headers, + payload, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + &ResponseMode::Dual( + SingleContentType::default(), + StreamContentType::default(), + ), + Some(guard), + ) + .await + }) .await - }) - .await { - Ok(result) => result, - Err(PipelineError::JwtError(err)) => { - let _ = sink.send(err.clone().into_server_message(&id)).await; - // we report error as graphql error, but we also close the - // connection since we're dealing with auth so let's be safe - return Some(err.into_close_message()); - }, - Err(err) => return Some(err.into_server_message(&id)), - }; + } else { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init(|| async { + execute_planned_request( + &Method::POST, + ws_uri, + &headers, + payload, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + &ResponseMode::Dual( + SingleContentType::default(), + StreamContentType::default(), + ), + None, + ) + .await + }) + .await + }; + let (shared_response, _role) = match result { + Ok(result) => result, + Err(PipelineError::JwtError(err)) => { + let _ = sink.send(err.clone().into_server_message(&id)).await; + // we report error as graphql error, but we also close the + // connection since we're dealing with auth so let's be safe + return Some(err.into_close_message()); + }, + Err(err) => return Some(err.into_server_message(&id)), + }; Arc::unwrap_or_clone(shared_response) } else { match execute_planned_request( @@ -498,6 +526,12 @@ async fn handle_text_frame( ) .await { Ok(result) => result, + Err(PipelineError::JwtError(err)) => { + let _ = sink.send(err.clone().into_server_message(&id)).await; + // we report error as graphql error, but we also close the + // connection since we're dealing with auth so let's be safe + return Some(err.into_close_message()); + }, Err(err) => return Some(err.into_server_message(&id)), } }; diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index a8ba4cca8..ce4b853d2 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -392,7 +392,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { let claim = self.in_flight_requests.claim(fingerprint); let mut leader_http_request_capture = None; let (shared_response, role) = claim - .get_or_try_init(|_guard| async { + .get_or_try_init(|| async { let res = { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. diff --git a/lib/internal/src/inflight.rs b/lib/internal/src/inflight.rs index 3cbf32a6c..d9fb5ce3e 100644 --- a/lib/internal/src/inflight.rs +++ b/lib/internal/src/inflight.rs @@ -94,6 +94,42 @@ where { #[inline] pub async fn get_or_try_init(self, init: F) -> Result<(Arc, InFlightRole), E> + where + F: FnOnce() -> Fut, + Fut: Future>, + { + let mut did_initialize = false; + let InFlightClaim { key, map, cell } = self; + + let value = cell + .get_or_try_init(|| { + did_initialize = true; + async { + let _cleanup = InFlightCleanupGuard { key, map }; + init().await.map(Arc::new) + } + }) + .await? + .clone(); + + let role = if did_initialize { + InFlightRole::Leader + } else { + InFlightRole::Joiner + }; + + Ok((value, role)) + } + + /// Like [`get_or_try_init`](Self::get_or_try_init), but passes the cleanup guard to the + /// caller. The caller is responsible for dropping the guard when deduplication should end. + /// This is useful for long-lived operations like subscriptions where the deduplication + /// window extends beyond the init future. + #[inline] + pub async fn get_or_try_init_with_guard( + self, + init: F, + ) -> Result<(Arc, InFlightRole), E> where F: FnOnce(InFlightCleanupGuard) -> Fut, Fut: Future>, @@ -104,8 +140,10 @@ where let value = cell .get_or_try_init(|| { did_initialize = true; - let guard = InFlightCleanupGuard { key, map }; - async { init(guard).await.map(Arc::new) } + async { + let guard = InFlightCleanupGuard { key, map }; + init(guard).await.map(Arc::new) + } }) .await? .clone(); From 7e8c6aea41043d5fcde73ba46b8115e5e2cdd39e Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 22:44:33 +0200 Subject: [PATCH 41/42] =?UTF-8?q?=F0=9F=91=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From 8314410f8a17429289e0146ce4277dadba14507f Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Sat, 4 Apr 2026 22:57:38 +0200 Subject: [PATCH 42/42] chill out with redundancy --- bin/router/src/pipeline/mod.rs | 87 +++++++-------------- bin/router/src/pipeline/websocket_server.rs | 80 ++++++------------- 2 files changed, 50 insertions(+), 117 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index ffbd548d0..f8ce54571 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -283,69 +283,38 @@ pub async fn graphql_request_handler( None }; + let exec = |guard| execute_planned_request( + req.method(), + req.uri(), + req.headers(), + graphql_params, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + response_mode, + guard, + ); + let shared_response = if let Some(fp) = fingerprint { let (shared_response, _role) = if is_subscription { - shared_state - .in_flight_requests - .claim(fp) - .get_or_try_init_with_guard(|guard| async { - execute_planned_request( - req.method(), - req.uri(), - req.headers(), - graphql_params, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - response_mode, - Some(guard), - ) - .await - }) - .await? - } else { - shared_state - .in_flight_requests - .claim(fp) - .get_or_try_init(|| async { - execute_planned_request( - req.method(), - req.uri(), - req.headers(), - graphql_params, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - response_mode, - None, - ) - .await - }) - .await? - }; + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init_with_guard(|guard| exec(Some(guard))) + .await? + } else { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init(|| exec(None)) + .await? + }; Arc::unwrap_or_clone(shared_response) } else { - execute_planned_request( - req.method(), - req.uri(), - req.headers(), - graphql_params, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - response_mode, - None, - ) - .await? + exec(None).await? }; if let Some(hive_usage_agent) = &shared_state.hive_usage_agent { diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index ad6ec0175..03c3d4394 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -443,56 +443,37 @@ async fn handle_text_frame( }; // synthetic request details for plan executor + let response_mode = ResponseMode::Dual( + SingleContentType::default(), + StreamContentType::default(), + ); + let exec = |guard| execute_planned_request( + &Method::POST, + ws_uri, + &headers, + payload, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + &response_mode, + guard, + ); + let shared_response = if let Some(fp) = fingerprint { let result = if is_subscription { shared_state .in_flight_requests .claim(fp) - .get_or_try_init_with_guard(|guard| async { - execute_planned_request( - &Method::POST, - ws_uri, - &headers, - payload, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - &ResponseMode::Dual( - SingleContentType::default(), - StreamContentType::default(), - ), - Some(guard), - ) - .await - }) + .get_or_try_init_with_guard(|guard| exec(Some(guard))) .await } else { shared_state .in_flight_requests .claim(fp) - .get_or_try_init(|| async { - execute_planned_request( - &Method::POST, - ws_uri, - &headers, - payload, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - &ResponseMode::Dual( - SingleContentType::default(), - StreamContentType::default(), - ), - None, - ) - .await - }) + .get_or_try_init(|| exec(None)) .await }; let (shared_response, _role) = match result { @@ -507,24 +488,7 @@ async fn handle_text_frame( }; Arc::unwrap_or_clone(shared_response) } else { - match execute_planned_request( - &Method::POST, - ws_uri, - &headers, - payload, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - &ResponseMode::Dual( - SingleContentType::default(), - StreamContentType::default(), - ), - None, - ) - .await { + match exec(None).await { Ok(result) => result, Err(PipelineError::JwtError(err)) => { let _ = sink.send(err.clone().into_server_message(&id)).await;