diff --git a/Cargo.lock b/Cargo.lock index e92ca027a..00208ce72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2465,7 +2465,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tracing", - "uuid", + "ulid", "xxhash-rust", ] diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index b0ab292b9..06bf96266 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -22,10 +22,12 @@ use crate::{ }, jwt::JwtAuthRuntime, pipeline::{ + active_subscriptions::ActiveSubscriptions, error::handle_pipeline_error, graphql_request_handler, header::ResponseMode, http_callback::handler, + long_lived_client_limit::LongLivedClientLimitService, request_extensions::{ read_graphql_operation_metric_identity, read_graphql_response_metric_status, read_request_body_size, write_graphql_response_metric_status, @@ -212,7 +214,7 @@ pub async fn router_entrypoint(plugin_registry: PluginRegistry) -> Result<(), Ro .await?; let shared_state_clone = shared_state.clone(); - let active_subs = schema_state.active_callback_subscriptions.clone(); + let callback_subscriptions_for_handler = schema_state.callback_subscriptions.clone(); // when `listen` is set, the callback route lives on a dedicated server bound to that address // otherwise, the callback route is mounted on the main server on the `callback_path` @@ -224,12 +226,12 @@ pub async fn router_entrypoint(plugin_registry: PluginRegistry) -> Result<(), Ro }) => { let cb_path = path.to_string(); let cb_addr = listen.to_string(); - let cb_active_subs = active_subs.clone(); + let cb_subs = callback_subscriptions_for_handler.clone(); let cb_server = web::HttpServer::new(async move || { - let cb_active_subs = cb_active_subs.clone(); + let cb_subs = cb_subs.clone(); let cb_path = cb_path.clone(); web::App::new() - .state(cb_active_subs) + .state(cb_subs) .configure(move |m| add_callback_handler(m, &cb_path)) }) .bind(&cb_addr) @@ -248,10 +250,15 @@ pub async fn router_entrypoint(plugin_registry: PluginRegistry) -> Result<(), Ro let paths = RouterPaths::new(graphql_path.clone(), websocket_path, callback_path); paths.detect_conflicts(&prometheus)?; + let long_lived_client_limit_service = + LongLivedClientLimitService::new(&shared_state.router_config); + let maybe_error = web::HttpServer::new(async move || { let landing_page_path = graphql_path.clone(); let prometheus = prometheus.clone(); + let long_lived_client_limit_service = long_lived_client_limit_service.clone(); web::App::new() + .middleware(long_lived_client_limit_service) .middleware(PluginService) .state(shared_state.clone()) .state(schema_state.clone()) @@ -310,6 +317,8 @@ pub async fn configure_app_from_config( }; let plugins_arc = plugin_registry.initialize_plugins(&router_config, bg_tasks_manager)?; + let active_subscriptions = + ActiveSubscriptions::new(router_config.subscriptions.broadcast_capacity); let router_config_arc = Arc::new(router_config); let telemetry_context_arc = Arc::new(telemetry_context); let cache_state = Arc::new(CacheState::new()); @@ -324,6 +333,7 @@ pub async fn configure_app_from_config( router_config_arc.clone(), plugins_arc.clone(), cache_state.clone(), + active_subscriptions.clone(), ) .await?; let schema_state_arc = Arc::new(schema_state); @@ -351,6 +361,7 @@ pub async fn configure_app_from_config( telemetry_context_arc, plugins_arc, cache_state, + active_subscriptions.clone(), )?); Ok((shared_state, schema_state_arc)) diff --git a/bin/router/src/pipeline/active_subscriptions.rs b/bin/router/src/pipeline/active_subscriptions.rs new file mode 100644 index 000000000..f51d18c64 --- /dev/null +++ b/bin/router/src/pipeline/active_subscriptions.rs @@ -0,0 +1,98 @@ +use std::sync::Arc; + +use bytes::Bytes; +use dashmap::DashMap; +use hive_router_plan_executor::response::graphql_error::GraphQLError; +use tokio::sync::broadcast; +use tracing::trace; +use ulid::Ulid; + +use crate::shared_state::SharedRouterResponseGuard; + +pub type SubscriptionId = String; + +#[derive(Clone, Debug)] +pub enum SubscriptionEvent { + /// A normal subscription event from the upstream, already serialized. + /// Uses Bytes for zero-copy cloning across broadcast receivers. + Raw(Bytes), + /// An error pushed externally (e.g. supergraph reload, shutdown). + /// Consumers should yield this as the final event and then stop. + Error(Vec), +} + +#[derive(Clone)] +pub struct ActiveSubscriptions { + map: Arc>>, + broadcast_capacity: usize, +} + +impl ActiveSubscriptions { + pub fn new(broadcast_capacity: usize) -> Self { + Self { + map: Arc::new(DashMap::new()), + broadcast_capacity, + } + } + + /// Register a new active subscription. Returns a producer handle for the upstream pump + /// and a pre-subscribed receiver for the leader consumer. The pump task owns the handle + /// for the full lifetime of the upstream stream - when the handle drops (pump done or all + /// receivers gone) the broadcast channel closes and all consumer receivers terminate. + pub fn register( + &self, + guard: Option, + ) -> (ProducerHandle, broadcast::Receiver) { + let (sender, receiver) = broadcast::channel(self.broadcast_capacity); + let id = Ulid::new().to_string(); + self.map.insert(id.clone(), sender.clone()); + + let handle = ProducerHandle { + id: id.clone(), + map: self.map.clone(), + sender, + _guard: guard, + }; + + trace!(subscription_id = %id, "registered new subscription"); + + (handle, receiver) + } + + /// Close all active subscriptions with an error and clear the registry. + pub fn close_all_with_error(&self, errors: Vec) { + let item = SubscriptionEvent::Error(errors); + for entry in self.map.iter() { + let _ = entry.send(item.clone()); + } + self.map.clear(); + } +} + +/// Held by the upstream pump task for the full lifetime of the stream. Dropping it removes +/// the subscription from the registry, closes the broadcast channel, and drops the inflight +/// cleanup guard - which removes the dedupe entry so new requests start a fresh upstream. +pub struct ProducerHandle { + id: SubscriptionId, + map: Arc>>, + sender: broadcast::Sender, + _guard: Option, +} + +impl ProducerHandle { + pub fn sender(&self) -> &broadcast::Sender { + &self.sender + } + + /// Returns false when all consumers have gone and the event cannot be delivered. + pub fn send(&self, item: SubscriptionEvent) -> bool { + self.sender.send(item).is_ok() + } +} + +impl Drop for ProducerHandle { + fn drop(&mut self) { + self.map.remove(&self.id); + trace!(subscription_id = %self.id, "producer dropped, upstream closed"); + } +} diff --git a/bin/router/src/pipeline/header.rs b/bin/router/src/pipeline/header.rs index 116a282dc..5ce52e8d6 100644 --- a/bin/router/src/pipeline/header.rs +++ b/bin/router/src/pipeline/header.rs @@ -81,7 +81,7 @@ impl SingleContentType { // IMPORTANT: make sure that the serialized string representations are valid because // there is an unwrap in the StreamContentType::media_types() method. /// Streamable content types for GraphQL responses. -#[derive(PartialEq, Default, Debug, IntoStaticStr, EnumString, AsRefStr, EnumIter)] +#[derive(PartialEq, Default, Debug, IntoStaticStr, EnumString, AsRefStr, EnumIter, Clone)] pub enum StreamContentType { // The order of the variants here matters for negotiation with `Accept: */*`. /// Incremental Delivery over HTTP (`multipart/mixed`) diff --git a/bin/router/src/pipeline/http_callback.rs b/bin/router/src/pipeline/http_callback.rs index c9e2de216..66c13de74 100644 --- a/bin/router/src/pipeline/http_callback.rs +++ b/bin/router/src/pipeline/http_callback.rs @@ -1,7 +1,7 @@ use bytes::Bytes as BytesLib; use dashmap::mapref::one::Ref; use hive_router_plan_executor::executors::http_callback::{ - ActiveSubscription, ActiveSubscriptionsMap, CallbackMessage, CALLBACK_PROTOCOL_VERSION, + CallbackMessage, CallbackSubscription, CallbackSubscriptionsMap, CALLBACK_PROTOCOL_VERSION, SUBSCRIPTION_PROTOCOL_HEADER, }; use hive_router_plan_executor::response::graphql_error::GraphQLError; @@ -142,7 +142,7 @@ fn validate_payload( Ok(()) } -fn handle_check(subscription_id: &str, subscription: &Ref<'_, String, ActiveSubscription>) { +fn handle_check(subscription_id: &str, subscription: &Ref<'_, String, CallbackSubscription>) { trace!(subscription_id = %subscription_id, "Received check message"); subscription.record_heartbeat(); } @@ -150,8 +150,8 @@ fn handle_check(subscription_id: &str, subscription: &Ref<'_, String, ActiveSubs fn handle_next( subscription_id: &str, payload: &CallbackPayload<'_>, - subscription: Ref<'_, String, ActiveSubscription>, - active_subscriptions: &ActiveSubscriptionsMap, + subscription: Ref<'_, String, CallbackSubscription>, + callback_subscriptions: &CallbackSubscriptionsMap, ) -> Result<(), CallbackError> { trace!(subscription_id = %subscription_id, "Received next message"); @@ -174,7 +174,7 @@ fn handle_next( // up. we terminate the subscription without an error message because it anyways cant go through warn!(subscription_id = %subscription_id, "Subscription client is too slow"); drop(subscription); - active_subscriptions.remove(subscription_id); + callback_subscriptions.remove(subscription_id); Err(CallbackError::ClientTooSlow { subscription_id: subscription_id.to_string(), }) @@ -182,7 +182,7 @@ fn handle_next( Err(mpsc::error::TrySendError::Closed(_)) => { debug!(subscription_id = %subscription_id, "Subscription receiver dropped"); drop(subscription); - active_subscriptions.remove(subscription_id); + callback_subscriptions.remove(subscription_id); Err(CallbackError::SubscriptionDropped { subscription_id: subscription_id.to_string(), }) @@ -193,8 +193,8 @@ fn handle_next( fn handle_complete( subscription_id: &str, payload: &CallbackPayload<'_>, - subscription: Ref<'_, String, ActiveSubscription>, - active_subscriptions: &ActiveSubscriptionsMap, + subscription: Ref<'_, String, CallbackSubscription>, + callback_subscriptions: &CallbackSubscriptionsMap, ) { trace!(subscription_id = %subscription_id, "Received complete message"); // if the buffer is full or closed we ignore and remove the subscription, we dont send @@ -203,14 +203,14 @@ fn handle_complete( errors: payload.errors.clone(), }); drop(subscription); - active_subscriptions.remove(subscription_id); + callback_subscriptions.remove(subscription_id); } pub async fn handler( req: HttpRequest, path: Path, body: Bytes, - active_subscriptions: web::types::State, + callback_subscriptions: web::types::State, ) -> Result { let subscription_id_from_path = path.into_inner(); @@ -220,7 +220,7 @@ pub async fn handler( validate_payload(&payload, &subscription_id_from_path)?; - let subscription = match active_subscriptions.get(&payload.id) { + let subscription = match callback_subscriptions.get(&payload.id) { Some(sub) => sub, None => { return Err(CallbackError::SubscriptionNotFound { @@ -238,10 +238,10 @@ pub async fn handler( match payload.action { CallbackAction::Check => handle_check(&payload.id, &subscription), CallbackAction::Next => { - handle_next(&payload.id, &payload, subscription, &active_subscriptions)?; + handle_next(&payload.id, &payload, subscription, &callback_subscriptions)?; } CallbackAction::Complete => { - handle_complete(&payload.id, &payload, subscription, &active_subscriptions) + handle_complete(&payload.id, &payload, subscription, &callback_subscriptions) } }; diff --git a/bin/router/src/pipeline/long_lived_client_limit.rs b/bin/router/src/pipeline/long_lived_client_limit.rs new file mode 100644 index 000000000..fc65768b8 --- /dev/null +++ b/bin/router/src/pipeline/long_lived_client_limit.rs @@ -0,0 +1,191 @@ +use std::{ + rc::Rc, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use http::{header, StatusCode}; +use ntex::{ + http::body::{Body, BodySize, MessageBody}, + service::{Service, ServiceCtx}, + util::Bytes, + web::{self, DefaultError}, + Middleware, SharedCfg, +}; + +use crate::RouterSharedState; + +#[derive(Clone)] +pub struct LongLivedClientLimitService { + // false means the middleware is entirely bypassed on every request + enabled: bool, +} + +impl LongLivedClientLimitService { + pub fn new(router_config: &hive_router_config::HiveRouterConfig) -> Self { + let limit = router_config.traffic_shaping.router.max_long_lived_clients; + let has_long_lived = router_config.subscriptions.enabled || router_config.websocket.enabled; + Self { + enabled: limit > 0 && has_long_lived, + } + } +} + +impl Middleware for LongLivedClientLimitService { + type Service = LongLivedClientLimitMiddleware; + + fn create(&self, service: S, _cfg: SharedCfg) -> Self::Service { + LongLivedClientLimitMiddleware { + service, + enabled: self.enabled, + } + } +} + +pub struct LongLivedClientLimitMiddleware { + service: S, + enabled: bool, +} + +impl Service> for LongLivedClientLimitMiddleware +where + S: Service, Response = web::WebResponse, Error = web::Error>, +{ + type Response = web::WebResponse; + type Error = S::Error; + + ntex::forward_ready!(service); + + async fn call( + &self, + req: web::WebRequest, + ctx: ServiceCtx<'_, Self>, + ) -> Result { + if !self.enabled { + return ctx.call(&self.service, req).await; + } + + if !is_long_lived_request(req.headers()) { + return ctx.call(&self.service, req).await; + } + + let shared_state = match req.app_state::>() { + Some(s) => s, + None => return ctx.call(&self.service, req).await, + }; + + let limit = shared_state + .router_config + .traffic_shaping + .router + .max_long_lived_clients; + let counter = shared_state.long_lived_client_count.clone(); + + // try to reserve a slot, bail if at the limit + let prev = counter.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| { + if current < limit { + Some(current + 1) + } else { + None + } + }); + if prev.is_err() { + let error_response = web::HttpResponse::build(StatusCode::SERVICE_UNAVAILABLE) + .header(header::RETRY_AFTER, "5") + .body("Too many long-lived clients"); + return Ok(req.into_response(error_response)); + } + + let guard = LongLivedClientGuard(counter); + let response = ctx.call(&self.service, req).await?; + + // wrap the body so the guard lives until the stream is fully consumed + let response = response.map_body(|_head, body| { + let wrapped = GuardedBody { + inner: body.into_body().into(), + _guard: guard, + }; + Body::from_message(wrapped).into() + }); + + Ok(response) + } +} + +// decrements the counter when dropped +struct LongLivedClientGuard(Arc); + +impl Drop for LongLivedClientGuard { + fn drop(&mut self) { + self.0.fetch_sub(1, Ordering::AcqRel); + } +} + +// wraps the body and keeps the guard alive until it's fully consumed and dropped. +// one extra vtable call per chunk on top of the Box dispatch streaming bodies +// already go through - negligible next to actual I/O cost per chunk. +struct GuardedBody { + inner: Body, + _guard: LongLivedClientGuard, +} + +impl MessageBody for GuardedBody { + fn size(&self) -> BodySize { + self.inner.size() + } + + fn poll_next_chunk( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + self.inner.poll_next_chunk(cx) + } +} + +// cheapest check first: +// 1. upgrade: websocket - two header lookups, no parsing +// 2. accept: streaming - one header lookup + fast substring pre-filter, full parse only if needed +#[inline] +fn is_long_lived_request(headers: &ntex::http::HeaderMap) -> bool { + // websocket: Connection: Upgrade + Upgrade: websocket + // both headers must be present and contain the expected values (case-insensitive) + if headers + .get(header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .is_some_and(|v| v.eq_ignore_ascii_case("websocket")) + && headers + .get(header::CONNECTION) + .and_then(|v| v.to_str().ok()) + .is_some_and(|v| v.to_ascii_lowercase().contains("upgrade")) + { + return true; + } + + // fast substring scan before full Accept parse to avoid cost on regular requests + let accept = match headers.get(header::ACCEPT).and_then(|v| v.to_str().ok()) { + Some(v) if !v.is_empty() => v, + _ => return false, + }; + + if !looks_like_streaming_accept(accept) { + return false; + } + + use crate::pipeline::header::StreamContentType; + use headers_accept::Accept; + use std::str::FromStr; + + Accept::from_str(accept) + .ok() + .and_then(|a| a.negotiate(StreamContentType::media_types().iter())) + .is_some() +} + +#[inline] +fn looks_like_streaming_accept(accept: &str) -> bool { + // covers: multipart/mixed, text/event-stream + accept.contains("multipart") || accept.contains("event-stream") +} diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 04d515a4d..f8ce54571 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,13 +1,4 @@ -use futures::Stream; -use std::{ - collections::HashMap, - hash::{Hash, Hasher}, - sync::Arc, - time::Instant, -}; -use tracing::{error, Instrument}; -use xxhash_rust::xxh3::Xxh3; - +use futures::StreamExt; use hive_router_internal::telemetry::traces::spans::{ graphql::GraphQLOperationSpan, http_request::HttpServerRequestSpan, }; @@ -23,11 +14,24 @@ use hive_router_query_planner::{ state::supergraph_state::OperationKind, utils::cancellation::CancellationToken, }; use http::{header::CONTENT_TYPE, Method}; -use ntex::web::{self, HttpRequest}; +use ntex::{ + http::HeaderMap, + rt, + web::{self, HttpRequest}, +}; use sonic_rs::{JsonContainerTrait, JsonType, JsonValueTrait, Value}; +use std::{ + collections::HashMap, + hash::{Hash, Hasher}, + sync::Arc, + time::Instant, +}; +use tracing::{error, Instrument}; +use xxhash_rust::xxh3::Xxh3; use crate::{ pipeline::{ + active_subscriptions::SubscriptionEvent, authorization::enforce_operation_authorization, body_read::read_body_stream, coerce_variables::{coerce_request_variables, CoerceVariablesPayload}, @@ -35,11 +39,8 @@ use crate::{ error::PipelineError, execution::{execute_plan, PlannedRequest}, execution_request::{deserialize_graphql_params, DeserializationResult, GetQueryStr}, - header::{RequestAccepts, ResponseMode, StreamContentType, TEXT_HTML_MIME}, + header::{RequestAccepts, ResponseMode, TEXT_HTML_MIME}, introspection_policy::handle_introspection_policy, - multipart_subscribe::{ - APOLLO_MULTIPART_HTTP_CONTENT_TYPE, INCREMENTAL_DELIVERY_CONTENT_TYPE, - }, normalize::{normalize_request_with_cache, GraphQLNormalizationPayload}, parser::{parse_operation_with_cache, ParseResult}, progressive_override::request_override_context, @@ -51,12 +52,16 @@ use crate::{ validation::validate_operation_with_cache, }, schema_state::SchemaState, - shared_state::{RouterRequestDedupeHeaderPolicy, RouterSharedState, SharedRouterResponse}, + shared_state::{ + RouterRequestDedupeHeaderPolicy, RouterSharedState, SharedRouterResponse, + SharedRouterResponseGuard, SharedRouterSingleResponse, SharedRouterStreamResponse, + }, LABORATORY_HTML, }; use hive_router_internal::telemetry::metrics::catalog::values::GraphQLResponseStatus; +pub mod active_subscriptions; pub mod authorization; pub mod body_read; pub mod coerce_variables; @@ -68,6 +73,7 @@ pub mod execution_request; pub mod header; pub mod http_callback; pub mod introspection_policy; +pub mod long_lived_client_limit; pub mod multipart_subscribe; pub mod normalize; pub mod parser; @@ -243,17 +249,18 @@ pub async fn graphql_request_handler( if is_subscription && (!shared_state.router_config.subscriptions.enabled || !response_mode.can_stream()) { - // check early, even though we check again after pipeline execution below + // check early, even though we check again planned execution below return Err(PipelineError::SubscriptionsNotSupported); } let request_dedupe_enabled = shared_state.router_config.traffic_shaping.router.dedupe.enabled; - let planned_response = if request_dedupe_enabled + let fingerprint = if request_dedupe_enabled && matches!( normalize_payload.operation_for_plan.operation_kind, - Some(OperationKind::Query) | None + // same deduplication applies for queries and subscriptions + None | Some(OperationKind::Query) | Some(OperationKind::Subscription) ) { let variables_hash = hash_graphql_variables(&graphql_params.variables); let extensions_hash = graphql_params @@ -262,60 +269,52 @@ pub async fn graphql_request_handler( .map_or(0, hash_graphql_extensions); let schema_checksum = supergraph.schema_checksum(); - let fingerprint = inbound_request_fingerprint( - req, + Some(inbound_request_fingerprint( + req.method(), + req.path(), + req.headers(), &shared_state.in_flight_requests_header_policy, schema_checksum, normalize_payload.normalized_operation_hash, variables_hash, extensions_hash, - ); - let (shared_response, _role) = shared_state - .in_flight_requests - .claim(fingerprint) - .get_or_try_init(|| async { - match execute_planned_request( - req, - graphql_params, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - response_mode, - ) - .await? - { - PlannedResponse::Shared(r) => Ok::(r), - // subscriptions are excluded from the dedup branch above, so this is unreachable - PlannedResponse::Direct { .. } => unreachable!("stream responses never enter the dedup path"), - } - }) - .await?; - - PlannedResponse::Shared(Arc::unwrap_or_clone(shared_response)) + )) } else { - execute_planned_request( - req, - graphql_params, - &normalize_payload, - supergraph, - shared_state, - schema_state, - operation_span, - plugin_req_state, - response_mode, - ) - .await? + None }; - let (response, error_count) = match planned_response { - PlannedResponse::Shared(shared_response) => { - let error_count = shared_response.error_count; - (shared_response.into(), error_count) - } - PlannedResponse::Direct { response, .. } => (response, 0), + let exec = |guard| execute_planned_request( + req.method(), + req.uri(), + req.headers(), + graphql_params, + &normalize_payload, + supergraph, + shared_state, + schema_state, + operation_span, + plugin_req_state, + response_mode, + guard, + ); + + let shared_response = if let Some(fp) = fingerprint { + let (shared_response, _role) = if is_subscription { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init_with_guard(|guard| exec(Some(guard))) + .await? + } else { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init(|| exec(None)) + .await? + }; + Arc::unwrap_or_clone(shared_response) + } else { + exec(None).await? }; if let Some(hive_usage_agent) = &shared_state.hive_usage_agent { @@ -339,21 +338,21 @@ pub async fn graphql_request_handler( // Thus, this expect should never panic. "Expected Usage Reporting options to be present when Hive Usage Agent is initialized", ), - error_count, + shared_response.error_count(), ) .await; } write_graphql_response_metric_status( req, - if error_count > 0 { + if shared_response.error_count() > 0 { GraphQLResponseStatus::Error } else { GraphQLResponseStatus::Ok }, ); - Ok(response) + shared_response.into_response(response_mode) } .instrument(span_clone) .await @@ -362,14 +361,11 @@ pub async fn graphql_request_handler( }) } -enum PlannedResponse { - Shared(SharedRouterResponse), - Direct { response: web::HttpResponse }, -} - #[allow(clippy::too_many_arguments)] -async fn execute_planned_request<'exec>( - req: &'exec HttpRequest, +pub async fn execute_planned_request<'exec>( + method: &'exec Method, + url: &'exec http::Uri, + headers: &'exec HeaderMap, mut graphql_params: GraphQLParams, normalize_payload: &Arc, supergraph: &'exec SupergraphData, @@ -378,10 +374,11 @@ async fn execute_planned_request<'exec>( operation_span: GraphQLOperationSpan, plugin_req_state: Option>, response_mode: &'exec ResponseMode, -) -> Result { + guard: Option, +) -> Result { let jwt_request_details = match &shared_state.jwt_auth_runtime { Some(jwt_auth_runtime) => match jwt_auth_runtime - .validate_headers(req.headers(), &shared_state.jwt_claims_cache) + .validate_headers(headers, &shared_state.jwt_claims_cache) .await? { Some(jwt_context) => JwtRequestDetails::Authenticated { @@ -399,9 +396,9 @@ async fn execute_planned_request<'exec>( coerce_request_variables(supergraph, &mut graphql_params.variables, normalize_payload)?; let client_request_details = ClientRequestDetails { - method: req.method(), - url: req.uri(), - headers: req.headers(), + method, + url, + headers, operation: OperationDetails { name: normalize_payload.operation_for_plan.name.as_deref(), kind: match normalize_payload.operation_for_plan.operation_kind { @@ -429,79 +426,68 @@ async fn execute_planned_request<'exec>( .await? { QueryPlanExecutionResult::Stream(result) => { + // we dont use the stream content type because subscriptions + // can be deduplicated across transports - but we do store + // the header value in the shared response because the user + // might choose to not deduplicate across transport boundaries let stream_content_type = response_mode .stream_content_type() - .ok_or(PipelineError::SubscriptionsTransportNotSupported)?; - - let content_type_header = match stream_content_type { - StreamContentType::IncrementalDelivery => { - http::HeaderValue::from_static(INCREMENTAL_DELIVERY_CONTENT_TYPE) - } - StreamContentType::SSE => http::HeaderValue::from_static("text/event-stream"), - StreamContentType::ApolloMultipartHTTP => { - http::HeaderValue::from_static(APOLLO_MULTIPART_HTTP_CONTENT_TYPE) + .ok_or(PipelineError::SubscriptionsTransportNotSupported)? + .clone(); + + let (producer_handle, receiver) = shared_state.active_subscriptions.register(guard); + + // subscribe the sender before spawning the pump so the channel always has + // at least one receiver - prevents events from being lost in the window + // between spawn and the consumer calling subscribe() + let sender = producer_handle.sender().clone(); + + let mut body_stream = result.body; + rt::spawn(async move { + while let Some(chunk) = body_stream.next().await { + if !producer_handle.send(SubscriptionEvent::Raw(bytes::Bytes::from(chunk))) { + // all receivers gone, stop draining + break; + } } - }; + // dropping producer_handle closes the broadcast channel + }); - // TODO: why exactly do we need a type cast here? - let body: std::pin::Pin< - Box> + Send>, - > = match stream_content_type { - StreamContentType::IncrementalDelivery => Box::pin( - multipart_subscribe::create_incremental_delivery_stream(result.body), - ), - StreamContentType::SSE => Box::pin(sse::create_stream( - result.body, - std::time::Duration::from_secs(10), - )), - StreamContentType::ApolloMultipartHTTP => { - Box::pin(multipart_subscribe::create_apollo_multipart_http_stream( - result.body, - std::time::Duration::from_secs(10), - )) - } + let mut builder = web::HttpResponse::Ok(); + if let Some(aggregator) = result.response_headers_aggregator { + aggregator.modify_client_response_headers(&mut builder)?; }; - - let mut response_builder = web::HttpResponse::Ok(); - - if let Some(response_headers_aggregator) = result.response_headers_aggregator { - response_headers_aggregator - .modify_client_response_headers(&mut response_builder)?; - } - - let response = response_builder - // .status(result.status) status codes in streaming responses should always be ok - .header(http::header::CONTENT_TYPE, content_type_header) - .streaming(body); - - Ok(PlannedResponse::Direct { response }) + builder.content_type(stream_content_type.as_ref()); + let headers = Arc::new(builder.finish().headers().clone()); + + Ok(SharedRouterResponse::Stream(SharedRouterStreamResponse { + body: sender, + headers, + error_count: result.error_count, + receiver: Some(receiver), + })) } QueryPlanExecutionResult::Single(result) => { let single_content_type = response_mode. single_content_type(). // TODO: streaming single responses - ok_or(PipelineError::UnsupportedContentType)?; + ok_or(PipelineError::UnsupportedContentType)?. + clone(); - let error_count = result.error_count; - let mut response_builder = web::HttpResponse::Ok(); + // drop the `guard` as soon as the response is ready - if let Some(response_headers_aggregator) = result.response_headers_aggregator { - response_headers_aggregator - .modify_client_response_headers(&mut response_builder)?; - } - - let body = ntex::util::Bytes::from(result.body); - - let response = response_builder - .content_type(single_content_type.as_ref()) - .status(result.status_code) - .body(body.clone()); - - Ok(PlannedResponse::Shared(SharedRouterResponse { - body, - headers: Arc::new(response.headers().clone()), - status: response.status(), - error_count, + let mut builder = web::HttpResponse::Ok(); + if let Some(aggregator) = result.response_headers_aggregator { + aggregator.modify_client_response_headers(&mut builder)?; + }; + builder.content_type(single_content_type.as_ref()); + let headers = Arc::new(builder.finish().headers().clone()); + + Ok(SharedRouterResponse::Single(SharedRouterSingleResponse { + body: ntex::util::Bytes::from(result.body), + headers, + status: result.status_code, + error_count: result.error_count, })) } } @@ -569,8 +555,11 @@ pub async fn execute_pipeline<'exec>( execute_plan(supergraph, shared_state, planned_request, operation_span).await } -fn inbound_request_fingerprint( - req: &HttpRequest, +#[allow(clippy::too_many_arguments)] +pub fn inbound_request_fingerprint( + method: &http::Method, + path: &str, + request_headers: &HeaderMap, dedupe_header_policy: &RouterRequestDedupeHeaderPolicy, schema_checksum: u64, normalized_operation_hash: u64, @@ -579,8 +568,7 @@ fn inbound_request_fingerprint( ) -> u64 { let mut hasher = Xxh3::new(); - let mut headers: Vec<(&str, &str)> = req - .headers() + let mut headers: Vec<(&str, &str)> = request_headers .iter() .filter(|(name, _)| dedupe_header_policy.should_include(name.as_str())) .filter_map(|(name, value)| value.to_str().ok().map(|v_str| (name.as_str(), v_str))) @@ -591,8 +579,8 @@ fn inbound_request_fingerprint( .then_with(|| left_value.cmp(right_value)) }); - req.method().hash(&mut hasher); - req.path().hash(&mut hasher); + method.hash(&mut hasher); + path.hash(&mut hasher); headers.hash(&mut hasher); schema_checksum.hash(&mut hasher); normalized_operation_hash.hash(&mut hasher); @@ -602,7 +590,7 @@ fn inbound_request_fingerprint( hasher.finish() } -fn hash_graphql_variables(variables: &HashMap) -> u64 { +pub fn hash_graphql_variables(variables: &HashMap) -> u64 { let mut hasher = Xxh3::new(); let mut keys: Vec<&str> = variables.keys().map(String::as_str).collect(); @@ -619,7 +607,7 @@ fn hash_graphql_variables(variables: &HashMap) -> u64 { hasher.finish() } -fn hash_graphql_extensions(extensions: &HashMap) -> u64 { +pub fn hash_graphql_extensions(extensions: &HashMap) -> u64 { // reused as hash_graphql_variables has the same function signature hash_graphql_variables(extensions) } diff --git a/bin/router/src/pipeline/websocket_server.rs b/bin/router/src/pipeline/websocket_server.rs index 8b9cdd911..03c3d4394 100644 --- a/bin/router/src/pipeline/websocket_server.rs +++ b/bin/router/src/pipeline/websocket_server.rs @@ -1,16 +1,21 @@ +use http::Method; +use ntex::channel::oneshot; +use ntex::http::{header::HeaderName, header::HeaderValue, HeaderMap}; +use ntex::router::Path; +use ntex::service::{fn_factory_with_config, fn_service, fn_shutdown, Service}; +use ntex::web::{self, ws, Error, HttpRequest, HttpResponse}; +use ntex::{chain, rt}; +use sonic_rs::{JsonContainerTrait, JsonValueTrait, Value}; use std::cell::RefCell; use std::collections::HashMap; use std::io; use std::rc::Rc; use std::sync::Arc; use std::time::Instant; +use tokio::sync::mpsc; +use tracing::{debug, error, trace, warn, Instrument}; -use futures::StreamExt; use hive_router_internal::telemetry::traces::spans::graphql::GraphQLOperationSpan; -use hive_router_plan_executor::execution::client_request_details::{ - ClientRequestDetails, JwtRequestDetails, OperationDetails, -}; -use hive_router_plan_executor::execution::plan::QueryPlanExecutionResult; use hive_router_plan_executor::executors::graphql_transport_ws::{ ClientMessage, CloseCode, ConnectionInitPayload, ServerMessage, WS_SUBPROTOCOL, }; @@ -23,28 +28,19 @@ use hive_router_plan_executor::plugin_context::{ }; use hive_router_plan_executor::response::graphql_error::{GraphQLError, GraphQLErrorExtensions}; use hive_router_query_planner::state::supergraph_state::OperationKind; -use http::Method; -use ntex::channel::oneshot; -use ntex::http::{header::HeaderName, header::HeaderValue, HeaderMap}; -use ntex::router::Path; -use ntex::service::{fn_factory_with_config, fn_service, fn_shutdown, Service}; -use ntex::web::{self, ws, Error, HttpRequest, HttpResponse}; -use ntex::{chain, rt}; -use sonic_rs::{JsonContainerTrait, JsonValueTrait, Value}; -use tokio::sync::mpsc; -use tracing::{debug, error, trace, warn, Instrument}; use crate::jwt::errors::JwtError; -use crate::pipeline::coerce_variables::coerce_request_variables; +use crate::pipeline::active_subscriptions::SubscriptionEvent; use crate::pipeline::error::PipelineError; -use crate::pipeline::execute_pipeline; -use crate::pipeline::execution_request::GetQueryStr; -use crate::pipeline::normalize::normalize_request_with_cache; -use crate::pipeline::parser::parse_operation_with_cache; -use crate::pipeline::usage_reporting; -use crate::pipeline::validation::validate_operation_with_cache; +use crate::pipeline::execute_planned_request; +use crate::pipeline::header::{ResponseMode, SingleContentType, StreamContentType}; +use crate::pipeline::{ + hash_graphql_extensions, hash_graphql_variables, inbound_request_fingerprint, + normalize::normalize_request_with_cache, parser::parse_operation_with_cache, usage_reporting, + validation::validate_operation_with_cache, +}; use crate::schema_state::SchemaState; -use crate::shared_state::RouterSharedState; +use crate::shared_state::{RouterSharedState, SharedRouterResponse}; type WsStateRef = Rc>>>; @@ -298,7 +294,7 @@ async fn handle_text_frame( } } - let mut payload = GraphQLParams { + let payload = GraphQLParams { query: Some(payload.query), operation_name: payload.operation_name, variables: payload.variables.unwrap_or_default(), @@ -418,114 +414,119 @@ async fn handle_text_frame( return Some(PipelineError::SubscriptionsNotSupported.into_server_message(&id)); } - let jwt_request_details = match &shared_state.jwt_auth_runtime { - Some(jwt_auth_runtime) => match jwt_auth_runtime - .validate_headers(&headers, &shared_state.jwt_claims_cache) - .await - { - Ok(Some(jwt_context)) => JwtRequestDetails::Authenticated { - scopes: jwt_context.extract_scopes(), - claims: match jwt_context - .get_claims_value() - .map_err(PipelineError::JwtForwardingError) - { - Ok(claims) => claims, - Err(e) => return Some(e.into_server_message(&id)), - }, - token: jwt_context.token_raw, - prefix: jwt_context.token_prefix, - }, - Ok(None) => JwtRequestDetails::Unauthenticated, - // jwt_auth_runtime.validate_headers() will error out only if - // authentication is required and has failed. we therefore use - // the JwtError conversion here to respond with proper error message - // close with Forbidden. - Err(e) => { - let _ = sink.send(e.clone().into_server_message(&id)).await; - // we report error as graphql error, but we also close the - // connection since we're dealing with auth so let's be safe - return Some(e.into_close_message()); - } - }, - None => JwtRequestDetails::Unauthenticated, - }; - - let variable_payload = match coerce_request_variables( - supergraph, - &mut payload.variables, - &normalize_payload, - ) { - Ok(payload) => payload, - Err(err) => return Some(err.into_server_message(&id)), + let request_dedupe_enabled = + shared_state.router_config.traffic_shaping.router.dedupe.enabled; + + let fingerprint = if request_dedupe_enabled + && matches!( + normalize_payload.operation_for_plan.operation_kind, + // same deduplication applies for queries and subscriptions + None | Some(OperationKind::Query) | Some(OperationKind::Subscription) + ) { + let variables_hash = hash_graphql_variables(&payload.variables); + let extensions_hash = payload + .extensions + .as_ref() + .map_or(0, hash_graphql_extensions); + Some(inbound_request_fingerprint( + &Method::POST, + ws_uri.path(), + &headers, + &shared_state.in_flight_requests_header_policy, + supergraph.schema_checksum(), + normalize_payload.normalized_operation_hash, + variables_hash, + extensions_hash, + )) + } else { + None }; - // synthetic client request details for plan executor - let client_request_details = ClientRequestDetails { - method: &Method::POST, - url: &WS_URI_PATH, - headers: &headers, - operation: OperationDetails { - name: normalize_payload.operation_for_plan.name.as_deref(), - kind: match normalize_payload.operation_for_plan.operation_kind { - Some(OperationKind::Query) => "query", - Some(OperationKind::Mutation) => "mutation", - Some(OperationKind::Subscription) => "subscription", - None => "query", - }, - query: match payload.get_query() { - Ok(q) => q, - Err(e) => return Some(e.into_server_message(&id)), - }, - }, - jwt: jwt_request_details, - }.into(); - - match execute_pipeline( - &client_request_details, + // synthetic request details for plan executor + let response_mode = ResponseMode::Dual( + SingleContentType::default(), + StreamContentType::default(), + ); + let exec = |guard| execute_planned_request( + &Method::POST, + ws_uri, + &headers, + payload, &normalize_payload, - variable_payload, supergraph, shared_state, schema_state, operation_span, plugin_req_state, - ) - .await - { - Ok(QueryPlanExecutionResult::Single(response)) => { - if let Some(hive_usage_agent) = &shared_state.hive_usage_agent { - usage_reporting::collect_usage_report( - supergraph.supergraph_schema.clone(), - started_at.elapsed(), - client_name, - client_version, - normalize_payload.operation_for_plan.name.as_deref(), - &parser_payload.minified_document, - hive_usage_agent, - shared_state - .router_config - .telemetry - .hive - .as_ref() - .map(|c| &c.usage_reporting) - .expect( - // SAFETY: According to `configure_app_from_config` in `bin/router/src/lib.rs`, - // the UsageAgent is only created when usage reporting is enabled. - // Thus, this expect should never panic. - "Expected Usage Reporting options to be present when Hive Usage Agent is initialized", - ), - response.error_count, - ) - .await; - } + &response_mode, + guard, + ); + + let shared_response = if let Some(fp) = fingerprint { + let result = if is_subscription { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init_with_guard(|guard| exec(Some(guard))) + .await + } else { + shared_state + .in_flight_requests + .claim(fp) + .get_or_try_init(|| exec(None)) + .await + }; + let (shared_response, _role) = match result { + Ok(result) => result, + Err(PipelineError::JwtError(err)) => { + let _ = sink.send(err.clone().into_server_message(&id)).await; + // we report error as graphql error, but we also close the + // connection since we're dealing with auth so let's be safe + return Some(err.into_close_message()); + }, + Err(err) => return Some(err.into_server_message(&id)), + }; + Arc::unwrap_or_clone(shared_response) + } else { + match exec(None).await { + Ok(result) => result, + Err(PipelineError::JwtError(err)) => { + let _ = sink.send(err.clone().into_server_message(&id)).await; + // we report error as graphql error, but we also close the + // connection since we're dealing with auth so let's be safe + return Some(err.into_close_message()); + }, + Err(err) => return Some(err.into_server_message(&id)), + } + }; + + if let Some(hive_usage_agent) = &shared_state.hive_usage_agent { + usage_reporting::collect_usage_report( + supergraph.supergraph_schema.clone(), + started_at.elapsed(), + client_name, + client_version, + normalize_payload.operation_for_plan.name.as_deref(), + &parser_payload.minified_document, + hive_usage_agent, + shared_state + .router_config + .telemetry + .hive + .as_ref() + .map(|c| &c.usage_reporting) + .expect("Expected Usage Reporting options to be present when Hive Usage Agent is initialized"), + shared_response.error_count(), + ) + .await; + } + match shared_response { + SharedRouterResponse::Single(response) => { let _ = sink.send(ServerMessage::next(&id, &response.body)).await; Some(ServerMessage::complete(&id)) } - Ok(QueryPlanExecutionResult::Stream(response)) => { - // we use mpsc::channel(1) instead of oneshot because oneshot::Receiver - // is consumed on first await, which doesn't work in tokio::select! loops that - // need to poll the receiver multiple times across iterations + SharedRouterResponse::Stream(response) => { let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); state @@ -533,50 +534,63 @@ async fn handle_text_frame( .subscriptions .insert(id.clone(), cancel_tx); - // automatically remove the subscription from subscriptions when dropped - let _guard = SubscriptionGuard { + let guard = SubscriptionGuard { state: state.clone(), id: id.clone(), }; - let mut stream = response.body; - let mut cancelled = false; + let mut receiver = response + .receiver + .unwrap_or_else(|| response.body.subscribe()); trace!(id = %id, "Subscription started"); + let sink = sink.clone(); let id_for_loop = id.clone(); - loop { - tokio::select! { - maybe_item = stream.next() => { - match maybe_item { - Some(body) => { - let _ = sink.send(ServerMessage::next(&id_for_loop, &body)).await; - } - None => { - break; // completed + // must be spawned - blocking the frame handler would prevent + // ClientMessage::Complete from being received and processed, + // making cancellation impossible + rt::spawn(async move { + let _guard = guard; + let mut cancelled = false; + + loop { + tokio::select! { + maybe_item = receiver.recv() => { + match maybe_item { + Ok(SubscriptionEvent::Raw(data)) => { + let _ = sink.send(ServerMessage::next(&id_for_loop, &data)).await; + } + Ok(SubscriptionEvent::Error(errors)) => { + let _ = sink.send(ServerMessage::error(&id_for_loop, &errors)).await; + break; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + trace!(id = %id_for_loop, lagged = n, "broadcast receiver lagged, skipping missed messages"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } } } - } - _ = cancel_rx.recv() => { - cancelled = true; - break; // cancelled + _ = cancel_rx.recv() => { + cancelled = true; + break; + } } } - } - if cancelled { - trace!(id = %id, "Subscription cancelled"); - // we dont emit complete on cancelled subscriptions. - // they're either deliberately cancelled by the client - // or dropped due to connection close, either way - // we dont/cant inform the client with a complete message - None - } else { - trace!(id = %id, "Subscription completed"); - Some(ServerMessage::complete(&id)) - } + if cancelled { + trace!(id = %id_for_loop, "Subscription cancelled"); + } else { + trace!(id = %id_for_loop, "Subscription completed"); + let _ = sink.send(ServerMessage::complete(&id_for_loop)).await; + } + }); + + None } - Err(err) => Some(err.into_server_message(&id)), } } .instrument(span_clone) diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index 7071fa42b..7ee36084c 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -1,3 +1,4 @@ +use crate::pipeline::active_subscriptions::ActiveSubscriptions; use crate::pipeline::authorization::metadata::AuthorizationMetadataExt; use arc_swap::{ArcSwap, Guard}; use async_trait::async_trait; @@ -10,10 +11,12 @@ use hive_router_internal::{ authorization::metadata::AuthorizationMetadata, background_tasks::{BackgroundTask, BackgroundTasksManager}, }; +use hive_router_plan_executor::executors::http_callback::{ + CallbackMessage, CallbackSubscriptionsMap, +}; use hive_router_plan_executor::response::graphql_error::GraphQLErrorExtensions; use hive_router_plan_executor::{ executors::error::SubgraphExecutorError, - executors::http_callback::{ActiveSubscriptionsMap, CallbackMessage}, hooks::on_supergraph_load::{ OnSupergraphLoadEndHookPayload, OnSupergraphLoadStartHookPayload, SupergraphData, }, @@ -49,7 +52,7 @@ pub struct SchemaState { pub validate_cache: Cache>>, pub normalize_cache: Cache>, pub telemetry_context: Arc, - pub active_callback_subscriptions: ActiveSubscriptionsMap, + pub callback_subscriptions: CallbackSubscriptionsMap, } #[derive(Debug, thiserror::Error)] @@ -85,6 +88,7 @@ impl SchemaState { router_config: Arc, plugins: Option>>, cache_state: Arc, + active_subscriptions: ActiveSubscriptions, ) -> Result { let (tx, mut rx) = mpsc::channel::(1); let background_loader = SupergraphBackgroundLoader::new( @@ -99,19 +103,19 @@ impl SchemaState { let plan_cache = cache_state.plan_cache.clone(); let validate_cache = cache_state.validate_cache.clone(); let normalize_cache = cache_state.normalize_cache.clone(); - let active_callback_subscriptions: ActiveSubscriptionsMap = Arc::new(DashMap::new()); + let callback_subscriptions: CallbackSubscriptionsMap = Arc::new(DashMap::new()); // This is cheap clone, as Cache is thread-safe and can be cloned without any performance penalty. let cache_state_for_invalidation = cache_state.clone(); - let active_callback_subscriptions_for_build_data = active_callback_subscriptions.clone(); + let callback_subscriptions_for_build_data = callback_subscriptions.clone(); // kick off subscriptions/subgraphs that are idling/timed out due to missed heartbeats if let Some(ref callback_config) = router_config.subscriptions.callback { if !callback_config.heartbeat_interval.is_zero() { - let enforcer_subs = active_callback_subscriptions.clone(); + let enforcer_subs = callback_subscriptions.clone(); let heartbeat_interval = callback_config.heartbeat_interval; - bg_tasks_manager.register_task(HeartbeatEnforcerTask { - active_subscriptions: enforcer_subs, + bg_tasks_manager.register_task(CallbackHeartbeatEnforcerTask { + callback_subscriptions: enforcer_subs, heartbeat_interval, }); } @@ -119,6 +123,7 @@ impl SchemaState { let metrics = telemetry_context.metrics.clone(); let task_telemetry = telemetry_context.clone(); + let active_subscriptions_for_reload = active_subscriptions.clone(); bg_tasks_manager.register_handle(async move { let supergraph_metrics = &metrics.supergraph; while let Some(new_sdl) = rx.recv().await { @@ -165,7 +170,7 @@ impl SchemaState { router_config.clone(), task_telemetry.clone(), new_ast, - active_callback_subscriptions_for_build_data.clone(), + callback_subscriptions_for_build_data.clone(), ) }) { Ok(mut new_supergraph_data) => { @@ -205,6 +210,16 @@ impl SchemaState { new_supergraph_data = end_payload.new_supergraph_data; } + // close all active subscriptions before swapping supergraph data + active_subscriptions_for_reload.close_all_with_error(vec![ + // this is litearaly the same message apollo sends - reasoning is + // drop in replacement - is that oke? should we have our own? + GraphQLError::from_message_and_code( + "subscription has been closed due to a schema reload", + "SUBSCRIPTION_SCHEMA_RELOAD", + ), + ]); + swappable_data_spawn_clone.store(Arc::new(Some(new_supergraph_data))); debug!("Supergraph updated successfully"); @@ -226,7 +241,7 @@ impl SchemaState { validate_cache, normalize_cache, telemetry_context: telemetry_context.clone(), - active_callback_subscriptions, + callback_subscriptions, }) } @@ -234,7 +249,7 @@ impl SchemaState { router_config: Arc, telemetry_context: Arc, parsed_supergraph_sdl: Document, - active_callback_subscriptions: ActiveSubscriptionsMap, + callback_subscriptions: CallbackSubscriptionsMap, ) -> Result { let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; let metadata = Arc::new(planner.consumer_schema.schema_metadata()); @@ -243,7 +258,7 @@ impl SchemaState { &planner.supergraph.subgraph_endpoint_map, router_config, telemetry_context, - active_callback_subscriptions, + callback_subscriptions, )?); Ok(SupergraphData { @@ -334,13 +349,13 @@ impl BackgroundTask for SupergraphBackgroundLoaderTask { } } -struct HeartbeatEnforcerTask { - active_subscriptions: ActiveSubscriptionsMap, +struct CallbackHeartbeatEnforcerTask { + callback_subscriptions: CallbackSubscriptionsMap, heartbeat_interval: Duration, } #[async_trait] -impl BackgroundTask for HeartbeatEnforcerTask { +impl BackgroundTask for CallbackHeartbeatEnforcerTask { fn id(&self) -> &str { "http-callback-heartbeat-enforcer" } @@ -358,7 +373,7 @@ impl BackgroundTask for HeartbeatEnforcerTask { } let mut timed_out = Vec::new(); - for entry in self.active_subscriptions.iter() { + for entry in self.callback_subscriptions.iter() { let last = *entry.value().last_heartbeat.lock().unwrap(); if Instant::now().duration_since(last) > self.heartbeat_interval + @@ -373,9 +388,9 @@ impl BackgroundTask for HeartbeatEnforcerTask { for id in timed_out { debug!( subscription_id = %id, - "terminating subscription due to missed heartbeat" + "terminating subscription due to http callback subgraph missed heartbeat" ); - if let Some((_, sub)) = self.active_subscriptions.remove(&id) { + if let Some((_, sub)) = self.callback_subscriptions.remove(&id) { // we dont care about the result of this send, if it fails it means the client // is already gone or too slow, either way we just terminate the subscription let _ = sub.sender.try_send(CallbackMessage::Complete { diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 76aef6673..de96706d3 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -1,3 +1,4 @@ +use futures::Stream; use graphql_tools::validation::validate::ValidationPlan; use hive_console_sdk::agent::usage_agent::{AgentError, UsageAgent}; use hive_router_config::traffic_shaping::{ @@ -6,8 +7,9 @@ use hive_router_config::traffic_shaping::{ use hive_router_config::HiveRouterConfig; use hive_router_internal::expressions::values::boolean::BooleanOrProgram; use hive_router_internal::expressions::ExpressionCompileError; -use hive_router_internal::inflight::InFlightMap; +use hive_router_internal::inflight::{InFlightCleanupGuard, InFlightMap}; use hive_router_internal::telemetry::TelemetryContext; +use hive_router_plan_executor::execution::plan::FailedExecutionResult; use hive_router_plan_executor::headers::{ compile::compile_headers_plan, errors::HeaderRuleCompileError, plan::HeaderRulesPlan, }; @@ -17,16 +19,25 @@ use moka::future::Cache; use moka::Expiry; use ntex::web; use ntex::{http::HeaderMap, util::Bytes}; +use std::sync::atomic::AtomicUsize; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{collections::HashSet, sync::Arc}; +use tracing::trace; use crate::cache_state::CacheState; use crate::jwt::context::JwtTokenPayload; use crate::jwt::JwtAuthRuntime; +use crate::pipeline::active_subscriptions::{ActiveSubscriptions, SubscriptionEvent}; use crate::pipeline::cors::{CORSConfigError, Cors}; +use crate::pipeline::error::PipelineError; +use crate::pipeline::header::{ResponseMode, StreamContentType}; use crate::pipeline::introspection_policy::compile_introspection_policy; +use crate::pipeline::multipart_subscribe::{ + self, APOLLO_MULTIPART_HTTP_CONTENT_TYPE, INCREMENTAL_DELIVERY_CONTENT_TYPE, +}; use crate::pipeline::parser::ParseCacheEntry; use crate::pipeline::progressive_override::{OverrideLabelsCompileError, OverrideLabelsEvaluator}; +use crate::pipeline::sse; pub type JwtClaimsCache = Cache>; pub type RouterInflightRequestsMap = InFlightMap; @@ -74,27 +85,148 @@ impl From<&TrafficShapingRouterDedupeHeadersConfig> for RouterRequestDedupeHeade } } +pub type SharedRouterResponseGuard = InFlightCleanupGuard; + +#[derive(Clone)] +pub enum SharedRouterResponse { + Single(SharedRouterSingleResponse), + Stream(SharedRouterStreamResponse), +} + +impl SharedRouterResponse { + pub fn error_count(&self) -> usize { + match self { + SharedRouterResponse::Single(resp) => resp.error_count, + SharedRouterResponse::Stream(resp) => resp.error_count, + } + } + pub fn into_response( + self, + response_mode: &ResponseMode, + ) -> Result { + match self { + SharedRouterResponse::Single(single) => Ok(single.into()), + SharedRouterResponse::Stream(stream) => { + let stream_content_type = response_mode + .stream_content_type() + .ok_or(PipelineError::SubscriptionsTransportNotSupported)?; + Ok(stream.into_response(stream_content_type)) + } + } + } +} + #[derive(Clone)] -pub struct SharedRouterResponse { +pub struct SharedRouterSingleResponse { pub body: Bytes, pub headers: Arc, pub status: StatusCode, pub error_count: usize, } -impl From for web::HttpResponse { - fn from(shared_response: SharedRouterResponse) -> Self { +impl From for web::HttpResponse { + fn from(shared_response: SharedRouterSingleResponse) -> Self { let mut response = web::HttpResponse::Ok(); response.status(shared_response.status); - for (header_name, header_value) in shared_response.headers.iter() { response.set_header(header_name, header_value); } - response.body(shared_response.body) } } +// status is always 200 for streaming responses, errors are sent through the stream. +// stream content type is not included because we can deduplicate subscriptions across +// different content types, the response format is decided when converting to response +pub struct SharedRouterStreamResponse { + pub body: tokio::sync::broadcast::Sender, + pub headers: Arc, + pub error_count: usize, + // the leader gets the receiver that was subscribed before the pump was spawned, so + // there is no window where the channel has zero receivers and events can be lost. + // joiners get None and subscribe via body.subscribe() when consumed. + pub receiver: Option>, +} + +impl Clone for SharedRouterStreamResponse { + fn clone(&self) -> Self { + Self { + body: self.body.clone(), + headers: self.headers.clone(), + error_count: self.error_count, + receiver: None, + } + } +} + +impl SharedRouterStreamResponse { + pub fn into_response(self, stream_content_type: &StreamContentType) -> web::HttpResponse { + // leader already has a pre-subscribed receiver to avoid missing + // any potential events emitted. joiners, on the other hand, subscribe + let mut receiver = self.receiver.unwrap_or_else(|| self.body.subscribe()); + + let stream = Box::pin(async_stream::stream! { + loop { + match receiver.recv().await { + Ok(SubscriptionEvent::Raw(data)) => { + yield data.to_vec(); + } + Ok(SubscriptionEvent::Error(errors)) => { + yield FailedExecutionResult { errors }.serialize(); + break; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + trace!(lagged = n, "broadcast receiver lagged, skipping missed messages"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + }); + + let content_type_header = match stream_content_type { + StreamContentType::IncrementalDelivery => { + http::HeaderValue::from_static(INCREMENTAL_DELIVERY_CONTENT_TYPE) + } + StreamContentType::SSE => http::HeaderValue::from_static("text/event-stream"), + StreamContentType::ApolloMultipartHTTP => { + http::HeaderValue::from_static(APOLLO_MULTIPART_HTTP_CONTENT_TYPE) + } + }; + + let body: std::pin::Pin< + Box> + Send>, + > = match stream_content_type { + StreamContentType::IncrementalDelivery => Box::pin( + multipart_subscribe::create_incremental_delivery_stream(stream), + ), + StreamContentType::SSE => Box::pin(sse::create_stream( + stream, + std::time::Duration::from_secs(10), + )), + StreamContentType::ApolloMultipartHTTP => { + Box::pin(multipart_subscribe::create_apollo_multipart_http_stream( + stream, + std::time::Duration::from_secs(10), + )) + } + }; + + let mut response = web::HttpResponse::Ok(); + + for (header_name, header_value) in self.headers.iter() { + response.set_header(header_name, header_value); + } + + // we set content type after so that we can override the shared header + response.content_type(content_type_header); + + response.streaming(body) + } +} + /// Default TTL for JWT claims cache entries (5 seconds) const DEFAULT_JWT_CACHE_TTL_SECS: u64 = 5; @@ -153,9 +285,14 @@ pub struct RouterSharedState { pub plugins: Option>>, pub in_flight_requests: RouterInflightRequestsMap, pub in_flight_requests_header_policy: RouterRequestDedupeHeaderPolicy, + /// Tracks the number of active long-lived clients (websockets + http streams) + pub long_lived_client_count: Arc, + /// Tracks all active subscriptions from clients to the router. + pub active_subscriptions: ActiveSubscriptions, } impl RouterSharedState { + #[allow(clippy::too_many_arguments)] pub fn new( router_config: Arc, jwt_auth_runtime: Option, @@ -164,6 +301,7 @@ impl RouterSharedState { telemetry_context: Arc, plugins: Option>>, cache_state: Arc, + active_subscriptions: ActiveSubscriptions, ) -> Result { let parse_cache = cache_state.parse_cache.clone(); Ok(Self { @@ -195,6 +333,8 @@ impl RouterSharedState { .dedupe .headers) .into(), + long_lived_client_count: Arc::new(AtomicUsize::new(0)), + active_subscriptions, }) } } diff --git a/docs/README.md b/docs/README.md index 740da4f4e..e4bf7fa64 100644 --- a/docs/README.md +++ b/docs/README.md @@ -18,10 +18,10 @@ |[**override\_subgraph\_urls**](#override_subgraph_urls)|`object`|Configuration for overriding subgraph URLs.
Default: `{}`
|| |[**plugins**](#plugins)|`object`|Configuration for custom plugins
|| |[**query\_planner**](#query_planner)|`object`|Query planning configuration.
Default: `{"allow_expose":false,"timeout":"10s"}`
|| -|[**subscriptions**](#subscriptions)|`object`|Configuration for subscriptions.
Default: `{"enabled":false}`
|| +|[**subscriptions**](#subscriptions)|`object`|Configuration for subscriptions.
Default: `{"broadcast_capacity":0,"enabled":false}`
|| |[**supergraph**](#supergraph)|`object`|Configuration for the Federation supergraph source. By default, the router will use a local file-based supergraph source (`./supergraph.graphql`).
|| |[**telemetry**](#telemetry)|`object`|Default: `{"client_identification":{"name_header":"graphql-client-name","version_header":"graphql-client-version"},"hive":null,"metrics":{"exporters":[],"instrumentation":{"common":{"histogram":{"aggregation":"explicit","bytes":{"buckets":[128,512,1024,2048,4096,8192,16384,32768,65536,131072,262144,524288,1048576,2097152,3145728,4194304,5242880],"record_min_max":false},"seconds":{"buckets":[0.005,0.01,0.025,0.05,0.075,0.1,0.25,0.5,0.75,1,2.5,5,7.5,10],"record_min_max":false}}},"instruments":{}}},"resource":{"attributes":{}},"tracing":{"collect":{"max_attributes_per_event":16,"max_attributes_per_link":32,"max_attributes_per_span":128,"max_events_per_span":128,"parent_based_sampler":false,"sampling":1},"exporters":[],"instrumentation":{"spans":{"mode":"spec_compliant"}},"propagation":{"b3":false,"baggage":false,"jaeger":false,"trace_context":true}}}`
|| -|[**traffic\_shaping**](#traffic_shaping)|`object`|Configuration for the traffic-shaping of the executor. Use these configurations to control how requests are being executed to subgraphs.
Default: `{"all":{"dedupe_enabled":true,"pool_idle_timeout":"50s","request_timeout":"30s"},"max_connections_per_host":100,"router":{"dedupe":{"enabled":false,"headers":"all"},"request_timeout":"1m"}}`
|| +|[**traffic\_shaping**](#traffic_shaping)|`object`|Configuration for the traffic-shaping of the executor. Use these configurations to control how requests are being executed to subgraphs.
Default: `{"all":{"dedupe_enabled":true,"pool_idle_timeout":"50s","request_timeout":"30s"},"max_connections_per_host":100,"router":{"dedupe":{"enabled":false,"headers":"all"},"max_long_lived_clients":128,"request_timeout":"1m"}}`
|| |[**websocket**](#websocket)|`object`|Configuration of router's WebSocket server.
Default: `{"enabled":false,"headers":{"persist":false,"source":"connection"},"path":null}`
|| **Additional Properties:** not allowed @@ -122,6 +122,7 @@ query_planner: allow_expose: false timeout: 10s subscriptions: + broadcast_capacity: 0 enabled: false supergraph: {} telemetry: @@ -202,6 +203,7 @@ traffic_shaping: dedupe: enabled: false headers: all + max_long_lived_clients: 128 request_timeout: 1m websocket: enabled: false @@ -1949,6 +1951,7 @@ Configuration for subscriptions. |Name|Type|Description|Required| |----|----|-----------|--------| +|**broadcast\_capacity**|`integer`|The capacity of the broadcast channel used to fan out subscription events to all active listeners.

Each active subscription has its own broadcast channel. This value controls how many events
can be buffered in that channel before slow consumers start lagging. If a consumer falls too
far behind and the buffer is full, it will skip the missed messages and continue from the
latest available event.

Subscription events are typically low-frequency, so the default of 32 is sufficient for most
use cases. Increase this value if you expect bursts of events or have slow consumers that
need more headroom to catch up.

Defaults to 32.
Default: `32`
Format: `"uint"`
Minimum: `0`
|| |[**callback**](#subscriptionscallback)|`object`, `null`|Configuration for subgraphs using the HTTP Callback protocol.
|yes| |**enabled**|`boolean`|Enables/disables subscriptions. By default, the subscriptions are disabled.

You can override this setting by setting the `SUBSCRIPTIONS_ENABLED` environment variable to `true` or `false`.
Default: `false`
|| |[**websocket**](#subscriptionswebsocket)|`object`, `null`|Configuration for subgraphs using WebSocket protocol.
|| @@ -1957,6 +1960,7 @@ Configuration for subscriptions. **Example** ```yaml +broadcast_capacity: 0 enabled: false ``` @@ -3038,7 +3042,7 @@ Configuration for the traffic-shaping of the executor. Use these configurations |----|----|-----------|--------| |[**all**](#traffic_shapingall)|`object`|The default configuration that will be applied to all subgraphs, unless overridden by a specific subgraph configuration.
Default: `{"dedupe_enabled":true,"pool_idle_timeout":"50s","request_timeout":"30s"}`
|| |**max\_connections\_per\_host**|`integer`|Limits the concurrent amount of requests/connections per host/subgraph.
Default: `100`
Format: `"uint"`
Minimum: `0`
|| -|[**router**](#traffic_shapingrouter)|`object`|Configuration for the router itself, e.g., for handling incoming requests, or other router-level traffic shaping configurations.
Default: `{"dedupe":{"enabled":false,"headers":"all"},"request_timeout":"1m"}`
|| +|[**router**](#traffic_shapingrouter)|`object`|Configuration for the router itself, e.g., for handling incoming requests, or other router-level traffic shaping configurations.
Default: `{"dedupe":{"enabled":false,"headers":"all"},"max_long_lived_clients":128,"request_timeout":"1m"}`
|| |[**subgraphs**](#traffic_shapingsubgraphs)|`object`|Optional per-subgraph configurations that will override the default configuration for specific subgraphs.
|| **Additional Properties:** not allowed @@ -3054,6 +3058,7 @@ router: dedupe: enabled: false headers: all + max_long_lived_clients: 128 request_timeout: 1m ``` @@ -3093,6 +3098,7 @@ Configuration for the router itself, e.g., for handling incoming requests, or ot |Name|Type|Description|Required| |----|----|-----------|--------| |[**dedupe**](#traffic_shapingrouterdedupe)|`object`|Default: `{"enabled":false,"headers":"all"}`
|| +|**max\_long\_lived\_clients**|`integer`|Maximum number of concurrent long-lived clients (WebSocket connections and HTTP streaming responses).
Regular non-streaming requests are not counted toward this limit.
When the limit is reached, new WebSocket and streaming HTTP requests are rejected with 503.
If both WebSockets and Subscriptions are disabled, this setting has no effect.
Default: `128`
Format: `"uint"`
Minimum: `0`
|| |**request\_timeout**|`string`|Optional timeout configuration for incoming requests to the router.
It starts from the moment the request is received by the router,
and includes the entire processing of the request (validation, execution, etc.) until a response is sent back to the client.
If a request takes longer than the specified duration, it will be aborted and a timeout error will be returned to the client.
Default: `"1m"`
|| **Additional Properties:** not allowed @@ -3102,6 +3108,7 @@ Configuration for the router itself, e.g., for handling incoming requests, or ot dedupe: enabled: false headers: all +max_long_lived_clients: 128 request_timeout: 1m ``` @@ -3113,7 +3120,7 @@ request_timeout: 1m |Name|Type|Description|Required| |----|----|-----------|--------| -|**enabled**|`boolean`|Enables/disables in-flight request deduplication at the router endpoint level.

When enabled, identical incoming GraphQL query requests that are processed at the same time
share the same in-flight execution result.
Default: `false`
|| +|**enabled**|`boolean`|Enables/disables in-flight request and active subscriptions deduplication at the router level.

When enabled, the router deduplicates both queries and subscriptions using the same
fingerprint key (method, path, selected headers, schema checksum, normalized operation
hash, variables, and extensions). The `headers` configuration below controls which
headers participate in that key for all operation types.

For queries, concurrent HTTP requests that produce the same fingerprint share a single
in-flight execution - only the first one runs, and the rest wait for and receive the
same result.

For subscriptions, the mechanism is broadcast-based rather than request-sharing. The
first client with a given fingerprint becomes the leader: it runs the upstream subscription
and its events are fanned out through a broadcast channel backed by an active subscriptions
registry. Any subsequent client that arrives with an identical fingerprint while that subscription
is still active joins as a listener on the same broadcast channel instead of starting a new upstream
connection. When all listeners have dropped and the leader finishes, the entry is removed from the
registry.

WebSocket connections participate in the same deduplication space as HTTP. Each
subscribe message is processed with a synthetic request assembled from the WebSocket
path and the headers derived from the `websocket.headers` config. The fingerprint is computed
from those synthetic headers using the same header policy, so a subscription started over HTTP
and an identical one started over WebSocket will deduplicate against each other.

The deduplication is transport agnostic. A query over WebSocket would get deduplicated with an
identical query over HTTP if they arrive at the same time and have the same fingerprint.

Note: `content-type` is part of the fingerprint when `headers` includes it (e.g. `all`).
Since HTTP streaming clients send different `accept` headers than WebSocket clients,
cross-transport deduplication for subscriptions only applies when `content-type` (and
transport-specific headers) are excluded from the key. Configure `headers: none` or
`headers: { include: [] }` (or exclude the relevant headers) to enable true cross-transport
deduplication, where a WebSocket subscription and an SSE subscription with the same operation
share a single upstream connection and the events are fanned out to both.
Default: `false`
|| |**headers**||Header configuration participating in the dedupe key.

Accepted forms:
- `all`
- `none`
- `{ include: ["authorization", "cookie"] }`

Header names are case-insensitive and validated as standard HTTP header names.
Default: `"all"`
|| **Additional Properties:** not allowed diff --git a/e2e/src/subscriptions.rs b/e2e/src/subscriptions.rs index 7951dcf67..bf6196be9 100644 --- a/e2e/src/subscriptions.rs +++ b/e2e/src/subscriptions.rs @@ -1394,4 +1394,436 @@ mod subscriptions_e2e_tests { event: complete "#); } + + #[ntex::test] + async fn active_subscriptions_deduplication() { + let subgraphs = TestSubgraphs::builder().build().start().await; + let router = TestRouter::builder() + .with_subgraphs(&subgraphs) + .inline_config( + r#" + supergraph: + source: file + path: supergraph.graphql + subscriptions: + enabled: true + traffic_shaping: + router: + dedupe: + enabled: true + "#, + ) + .build() + .start() + .await; + + let query = r#" + subscription { + reviewAdded(intervalInMs: 100) { + id + product { + name + } + } + } + "#; + let headers = some_header_map! { + http::header::ACCEPT => "text/event-stream" + }; + + let (sub1, sub2, sub3) = tokio::join!( + router.send_graphql_request(query, None, headers.clone()), + router.send_graphql_request(query, None, headers.clone()), + router.send_graphql_request(query, None, headers.clone()), + ); + + for sub in [&sub1, &sub2, &sub3] { + let body = sub.string_body().await; + assert!( + body.contains("event: next") && body.contains("event: complete"), + "Expected subscription to receive events and complete" + ); + } + + let reviews_requests = subgraphs.get_requests_log("reviews").unwrap_or_default(); + assert_eq!( + reviews_requests.len(), + 1, + "Expected requests to reviews subgraph to be deduplicated" + ); + } + + #[ntex::test] + async fn active_subscriptions_deduplication_promotion() { + use futures::StreamExt; + + let subgraphs = TestSubgraphs::builder().build().start().await; + let router = TestRouter::builder() + .with_subgraphs(&subgraphs) + .inline_config( + r#" + supergraph: + source: file + path: supergraph.graphql + subscriptions: + enabled: true + traffic_shaping: + router: + dedupe: + enabled: true + "#, + ) + .build() + .start() + .await; + + let query = r#" + subscription { + reviewAdded(intervalInMs: 100) { + id + } + } + "#; + let headers = some_header_map! { + http::header::ACCEPT => "text/event-stream" + }; + + let mut sub1 = router + .send_graphql_request(query, None, headers.clone()) + .await; + + assert!(sub1.status().is_success(), "Expected 200 OK"); + + // consume 2 events from sub1 to let the source stream advance + let chunk = sub1.next().await.unwrap().unwrap(); + assert!( + std::str::from_utf8(&chunk).unwrap().contains(r#""id":"1""#), + "Expected first event to be id=1" + ); + let chunk = sub1.next().await.unwrap().unwrap(); + assert!( + std::str::from_utf8(&chunk).unwrap().contains(r#""id":"2""#), + "Expected second event to be id=2" + ); + let chunk = sub1.next().await.unwrap().unwrap(); + assert!( + std::str::from_utf8(&chunk).unwrap().contains(r#""id":"3""#), + "Expected third event to be id=3" + ); + + // subscribe again with the same query - dedup promotes sub2 onto the live source + let sub2 = router + .send_graphql_request(query, None, headers.clone()) + .await; + + assert!(sub2.status().is_success(), "Expected 200 OK"); + + // drop sub1 now that sub2 is connected; sub2 must become the active subscriber + drop(sub1); + + // sub2 should receive the remainder of the stream from where the source left off + let body = sub2.string_body().await; + assert!( + body.contains("event: next") && body.contains("event: complete"), + "Expected sub2 to receive remaining events and complete, got: {body}" + ); + + // sub2 must not have received the first 3 events that were already consumed by sub1 + assert!( + !body.contains(r#""id":"1""#) + && !body.contains(r#""id":"2""#) + && !body.contains(r#""id":"3""#), + "Expected sub2 to not replay events already consumed by sub1, got: {body}" + ); + + // only one subgraph request should have been made + let reviews_requests = subgraphs.get_requests_log("reviews").unwrap_or_default(); + assert_eq!( + reviews_requests.len(), + 1, + "Expected requests to reviews subgraph to be deduplicated" + ); + } + + #[ntex::test] + async fn active_across_transports_subscriptions_deduplication() { + use futures::StreamExt; + use hive_router_plan_executor::executors::{ + graphql_transport_ws::SubscribePayload, websocket_client::WsClient, + }; + + let subgraphs = TestSubgraphs::builder().build().start().await; + let router = TestRouter::builder() + .with_subgraphs(&subgraphs) + .inline_config( + r#" + supergraph: + source: file + path: supergraph.graphql + subscriptions: + enabled: true + websocket: + enabled: true + traffic_shaping: + router: + dedupe: + enabled: true + headers: none + "#, + ) + .build() + .start() + .await; + + let query = r#" + subscription { + reviewAdded(intervalInMs: 100) { + id + product { + name + } + } + } + "#; + + let sse_headers = some_header_map! { + http::header::ACCEPT => "text/event-stream" + }; + let multipart_headers = some_header_map! { + http::header::ACCEPT => "multipart/mixed;subscriptionSpec=1.0" + }; + + let wsconn = router.ws().await; + let mut ws_client = WsClient::init(wsconn, None) + .await + .expect("Failed to init WsClient"); + let ws_payload = SubscribePayload { + query: query.into(), + ..Default::default() + }; + let mut ws_stream = ws_client.subscribe(ws_payload).await; + + let (sub_sse, sub_multipart) = tokio::join!( + router.send_graphql_request(query, None, sse_headers), + router.send_graphql_request(query, None, multipart_headers), + ); + + let sse_body = sub_sse.string_body().await; + assert!( + sse_body.contains("event: next") && sse_body.contains("event: complete"), + "Expected SSE subscription to receive events and complete" + ); + + let multipart_body = sub_multipart.string_body().await; + assert!( + multipart_body.contains("--graphql") && multipart_body.contains("--graphql--"), + "Expected multipart subscription to receive events and complete" + ); + + let mut ws_received = 0; + while let Some(response) = ws_stream.next().await { + assert!( + response.errors.is_none(), + "Expected no errors from WS subscription" + ); + assert!( + !response.data.is_null(), + "Expected data from WS subscription" + ); + ws_received += 1; + } + assert!( + ws_received > 0, + "Expected WS subscription to receive at least one event" + ); + + let reviews_requests = subgraphs.get_requests_log("reviews").unwrap_or_default(); + assert_eq!( + reviews_requests.len(), + 1, + "Expected requests to reviews subgraph to be deduplicated across transports" + ); + } + + #[ntex::test] + async fn active_across_transports_subscriptions_deduplication_promotion() { + use futures::StreamExt; + use hive_router_plan_executor::executors::{ + graphql_transport_ws::SubscribePayload, websocket_client::WsClient, + }; + + let subgraphs = TestSubgraphs::builder().build().start().await; + let router = TestRouter::builder() + .with_subgraphs(&subgraphs) + .inline_config( + r#" + supergraph: + source: file + path: supergraph.graphql + subscriptions: + enabled: true + websocket: + enabled: true + traffic_shaping: + router: + dedupe: + enabled: true + headers: none + "#, + ) + .build() + .start() + .await; + + let query = r#" + subscription { + reviewAdded(intervalInMs: 100) { + id + } + } + "#; + + let wsconn = router.ws().await; + let mut ws_client = WsClient::init(wsconn, None) + .await + .expect("Failed to init WsClient"); + let ws_payload = SubscribePayload { + query: query.into(), + ..Default::default() + }; + let mut ws_stream = ws_client.subscribe(ws_payload).await; + + // consume 3 events from sub1 to let the source stream advance + let response = ws_stream.next().await.unwrap(); + assert!( + response.data.to_string().contains(r#""id": "1""#), + "Expected first event to be id=1" + ); + let response = ws_stream.next().await.unwrap(); + assert!( + response.data.to_string().contains(r#""id": "2""#), + "Expected second event to be id=2" + ); + let response = ws_stream.next().await.unwrap(); + assert!( + response.data.to_string().contains(r#""id": "3""#), + "Expected third event to be id=3" + ); + + // subscribe again with SSE - dedup promotes sub2 onto the live source + let sse_headers = some_header_map! { + http::header::ACCEPT => "text/event-stream" + }; + let sub2 = router.send_graphql_request(query, None, sse_headers).await; + + assert!(sub2.status().is_success(), "Expected 200 OK"); + + // drop the WS sub now that sub2 is connected; sub2 must become the active subscriber + drop(ws_stream); + drop(ws_client); + + // sub2 should receive the remainder of the stream from where the source left off + let body = sub2.string_body().await; + assert!( + body.contains("event: next") && body.contains("event: complete"), + "Expected sub2 to receive remaining events and complete, got: {body}" + ); + + // sub2 must not have received the first 3 events already consumed by the WS sub + assert!( + !body.contains(r#""id":"1""#) + && !body.contains(r#""id":"2""#) + && !body.contains(r#""id":"3""#), + "Expected sub2 to not replay events already consumed by the WS sub, got: {body}" + ); + + // only one subgraph request should have been made + let reviews_requests = subgraphs.get_requests_log("reviews").unwrap_or_default(); + assert_eq!( + reviews_requests.len(), + 1, + "Expected requests to reviews subgraph to be deduplicated across transports" + ); + } + + #[ntex::test] + async fn max_long_lived_clients_rejects_over_limit() { + use futures::StreamExt; + + let subgraphs = TestSubgraphs::builder().build().start().await; + let router = TestRouter::builder() + .with_subgraphs(&subgraphs) + .inline_config( + r#" + supergraph: + source: file + path: supergraph.graphql + subscriptions: + enabled: true + traffic_shaping: + router: + max_long_lived_clients: 2 + "#, + ) + .build() + .start() + .await; + + let query = r#" + subscription { + reviewAdded(intervalInMs: 200) { + id + } + } + "#; + let headers = some_header_map! { + http::header::ACCEPT => "text/event-stream" + }; + + // open two subscriptions and keep them alive by reading the first event + let mut sub1 = router + .send_graphql_request(query, None, headers.clone()) + .await; + assert!(sub1.status().is_success(), "sub1 should be accepted"); + let _ = sub1.next().await; + + let mut sub2 = router + .send_graphql_request(query, None, headers.clone()) + .await; + assert!(sub2.status().is_success(), "sub2 should be accepted"); + let _ = sub2.next().await; + + // the third subscriber exceeds the limit and must be rejected + let sub3 = router + .send_graphql_request(query, None, headers.clone()) + .await; + assert_eq!( + sub3.status(), + reqwest::StatusCode::SERVICE_UNAVAILABLE, + "sub3 should be rejected with 503 when the limit is reached" + ); + let retry_after = sub3.header("retry-after"); + assert!( + retry_after.is_some(), + "rejected response should include a Retry-After header" + ); + let body = sub3.string_body().await; + assert_eq!(body, "Too many long-lived clients"); + + // release the two held subscriptions + drop(sub1); + drop(sub2); + + // wait briefly for the slots to be freed + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // a new subscriber should now be accepted again + let sub4 = router + .send_graphql_request(query, None, headers.clone()) + .await; + assert!( + sub4.status().is_success(), + "sub4 should be accepted after the previous slots were freed" + ); + } } diff --git a/e2e/src/testkit/mod.rs b/e2e/src/testkit/mod.rs index d38cab9ca..72804b642 100644 --- a/e2e/src/testkit/mod.rs +++ b/e2e/src/testkit/mod.rs @@ -27,6 +27,7 @@ use tracing::{info, warn}; use hive_router::{ add_callback_handler, background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app, init_rustls_crypto_provider, invoke_shutdown_hooks, + pipeline::long_lived_client_limit::LongLivedClientLimitService, plugins::plugins_service::PluginService, telemetry::Telemetry, PluginRegistry, RouterPaths, RouterSharedState, SchemaState, }; @@ -707,7 +708,7 @@ impl TestRouter { let serv_shared_state = shared_state.clone(); let serv_schema_state = schema_state.clone(); - let serv_active_subs = schema_state.active_callback_subscriptions.clone(); + let serv_callback_subs = schema_state.callback_subscriptions.clone(); let serv_graphql_path = self.graphql_path.clone(); let serv_websocket_path = self.websocket_path.clone(); @@ -721,13 +722,13 @@ impl TestRouter { }) => { let cb_path = path.to_string(); let cb_addr = listen.to_string(); - let cb_active_subs = schema_state.active_callback_subscriptions.clone(); + let cb_subs = schema_state.callback_subscriptions.clone(); let server = web::HttpServer::new(async move || { - let active_subs = cb_active_subs.clone(); + let cb_subs = cb_subs.clone(); let cb_path = cb_path.clone(); web::App::new() - .state(active_subs) + .state(cb_subs) .configure(move |m| add_callback_handler(m, &cb_path)) }) .bind(&cb_addr) @@ -755,13 +756,15 @@ impl TestRouter { let serv_paths = paths.clone(); let serv_prometheus = prometheus.clone(); + let long_lived_limit = LongLivedClientLimitService::new(&shared_state.router_config); let serv = test::server_with(test::config().port(self.port), move || { let shared_state = serv_shared_state.clone(); let schema_state = serv_schema_state.clone(); let paths = serv_paths.clone(); let prometheus = serv_prometheus.clone(); let serv_callback_path = serv_callback_path.clone(); - let active_subs = serv_active_subs.clone(); + let callback_subs = serv_callback_subs.clone(); + let long_lived_limit = long_lived_limit.clone(); // set the tracing dispatch on the server thread. the guard is // intentionally leaked: dropping it would restore the no-op default @@ -775,10 +778,11 @@ impl TestRouter { async move { web::App::new() + .middleware(long_lived_limit) .middleware(PluginService) .state(shared_state) .state(schema_state) - .state(active_subs) + .state(callback_subs) .configure(|m| configure_ntex_app(m, &paths, prometheus)) .configure(|m| { if let Some(ref callback) = serv_callback_path { diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index c5cadda65..abd74c042 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -56,7 +56,7 @@ bumpalo = "3.19.0" sonic-simd = "0.1.2" async-stream = "0.3.6" futures-util = "0.3.31" -uuid = { version = "1", features = ["v4"] } +ulid = "1.2.1" [dev-dependencies] subgraphs = { path = "../../bench/subgraphs" } diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index 3dcf87624..667515074 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -62,16 +62,3 @@ impl SubgraphExecutionRequest<'_> { .insert(key, value); } } - -// the channel capacity for buffering subscription events between websockets or the callback -// handler and the stream consumer. back-pressure flows like this: -// -// ntex h1 dispatcher (send to client) -// poll_flush() blocks when TCP send buffer is full (slow client) -// poll_next_chunk() is not called until flush completes -// rx.recv() in the async_stream is not polled -// channel fills up -// -// so this bound only triggers when the client reading from the router is too slow, hence -// backpressure comes from the client itself, not the router -pub const SUBSCRIPTION_EVENT_BUFFER_CAPACITY: usize = 256; diff --git a/lib/executor/src/executors/http_callback.rs b/lib/executor/src/executors/http_callback.rs index 14a6fcb40..66afe761e 100644 --- a/lib/executor/src/executors/http_callback.rs +++ b/lib/executor/src/executors/http_callback.rs @@ -11,11 +11,9 @@ use http_body_util::Full; use hyper::Version; use tokio::sync::mpsc; use tracing::{debug, error, trace}; -use uuid::Uuid; +use ulid::Ulid; -use crate::executors::common::{ - SubgraphExecutionRequest, SubgraphExecutor, SUBSCRIPTION_EVENT_BUFFER_CAPACITY, -}; +use crate::executors::common::{SubgraphExecutionRequest, SubgraphExecutor}; use crate::executors::error::SubgraphExecutorError; use crate::executors::http::{build_request_body, HttpClient}; use crate::plugin_context::PluginRequestState; @@ -25,16 +23,14 @@ use crate::response::subgraph_response::SubgraphResponse; pub const CALLBACK_PROTOCOL_VERSION: &str = "callback/1.0"; pub const SUBSCRIPTION_PROTOCOL_HEADER: &str = "subscription-protocol"; -type SubscriptionId = String; - #[derive(Clone)] -pub struct ActiveSubscription { +pub struct CallbackSubscription { pub verifier: String, pub sender: mpsc::Sender, pub last_heartbeat: Arc>, } -impl ActiveSubscription { +impl CallbackSubscription { pub fn record_heartbeat(&self) { *self.last_heartbeat.lock().unwrap() = Instant::now(); } @@ -46,16 +42,16 @@ pub enum CallbackMessage { Complete { errors: Option> }, } -pub type ActiveSubscriptionsMap = Arc>; +pub type CallbackSubscriptionsMap = Arc>; -struct SubscriptionGuard { - subscription_id: SubscriptionId, - active_subscriptions: ActiveSubscriptionsMap, +struct CallbackSubscriptionGuard { + subscription_id: String, + callback_subscriptions: CallbackSubscriptionsMap, } -impl Drop for SubscriptionGuard { +impl Drop for CallbackSubscriptionGuard { fn drop(&mut self) { - self.active_subscriptions.remove(&self.subscription_id); + self.callback_subscriptions.remove(&self.subscription_id); trace!(subscription_id = %self.subscription_id, "HTTP callback subscription entry removed from active subscriptions"); } } @@ -67,7 +63,7 @@ pub struct HttpCallbackSubgraphExecutor { pub header_map: HeaderMap, pub callback_base_url: String, pub heartbeat_interval_ms: u64, - pub active_subscriptions: ActiveSubscriptionsMap, + pub active_subscriptions: CallbackSubscriptionsMap, } impl HttpCallbackSubgraphExecutor { @@ -77,7 +73,7 @@ impl HttpCallbackSubgraphExecutor { http_client: Arc, callback_base_url: String, heartbeat_interval_ms: u64, - active_subscriptions: ActiveSubscriptionsMap, + active_subscriptions: CallbackSubscriptionsMap, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -154,15 +150,20 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { BoxStream<'static, Result, SubgraphExecutorError>>, SubgraphExecutorError, > { - let subscription_id = Uuid::new_v4().to_string(); - let verifier = Uuid::new_v4().to_string(); + let subscription_id = Ulid::new().to_string(); + let verifier = Ulid::new().to_string(); let body = self.build_request_body(&mut execution_request, &subscription_id, &verifier)?; - let (tx, mut rx) = mpsc::channel::(SUBSCRIPTION_EVENT_BUFFER_CAPACITY); + // all subscriptions emit events into the shared active subscriptions broadcaster + // which itself handles back-pressure by dropping old events when the buffer is full, + // so we can use a small buffer here + // TODO: do we thererefore need to buffer at all? + let (tx, mut rx) = mpsc::channel::(16); + self.active_subscriptions.insert( subscription_id.clone(), - ActiveSubscription { + CallbackSubscription { verifier, sender: tx, // initialize last_heartbeat to now + heartbeat_interval so the enforcer @@ -177,9 +178,9 @@ impl SubgraphExecutor for HttpCallbackSubgraphExecutor { ); // guard removes the entry from `active_subscriptions` when dropped - let guard = SubscriptionGuard { + let guard = CallbackSubscriptionGuard { subscription_id: subscription_id.clone(), - active_subscriptions: self.active_subscriptions.clone(), + callback_subscriptions: self.active_subscriptions.clone(), }; let mut req = hyper::Request::builder() diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 6e04d575c..56eef932f 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -30,7 +30,7 @@ use crate::{ common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient, SubgraphHttpResponse}, - http_callback::{ActiveSubscriptionsMap, HttpCallbackSubgraphExecutor}, + http_callback::{CallbackSubscriptionsMap, HttpCallbackSubgraphExecutor}, websocket::WsSubgraphExecutor, }, hooks::on_subgraph_execute::{ @@ -75,7 +75,7 @@ pub struct SubgraphExecutorMap { in_flight_requests: InflightRequestsMap, telemetry_context: Arc, /// Shared map of active HTTP callback subscriptions - active_callback_subscriptions: ActiveSubscriptionsMap, + callback_subscriptions: CallbackSubscriptionsMap, } fn build_https_executor() -> Result, SubgraphExecutorError> { @@ -112,7 +112,7 @@ impl SubgraphExecutorMap { timeouts_by_subgraph: Default::default(), global_timeout, telemetry_context, - active_callback_subscriptions: Arc::new(DashMap::new()), + callback_subscriptions: Arc::new(DashMap::new()), }) } @@ -120,7 +120,7 @@ impl SubgraphExecutorMap { subgraph_endpoint_map: &HashMap, config: Arc, telemetry_context: Arc, - active_callback_subscriptions: ActiveSubscriptionsMap, + active_callback_subscriptions: CallbackSubscriptionsMap, ) -> Result { let global_timeout = DurationOrProgram::compile( &config.traffic_shaping.all.request_timeout, @@ -131,7 +131,7 @@ impl SubgraphExecutorMap { })?; let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), global_timeout, telemetry_context)?; - subgraph_executor_map.active_callback_subscriptions = active_callback_subscriptions; + subgraph_executor_map.callback_subscriptions = active_callback_subscriptions; for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.iter() { let endpoint_config = config @@ -157,8 +157,8 @@ impl SubgraphExecutorMap { } /// Returns the shared active callback subscriptions map for use by callback handlers. - pub fn active_callback_subscriptions(&self) -> ActiveSubscriptionsMap { - self.active_callback_subscriptions.clone() + pub fn callback_subscriptions(&self) -> CallbackSubscriptionsMap { + self.callback_subscriptions.clone() } pub async fn execute<'exec>( @@ -505,7 +505,7 @@ impl SubgraphExecutorMap { self.client.clone(), callback_config.public_url.to_string(), heartbeat_interval_ms, - self.active_callback_subscriptions.clone(), + self.callback_subscriptions.clone(), ) .to_boxed_arc(); diff --git a/lib/executor/src/executors/websocket.rs b/lib/executor/src/executors/websocket.rs index 4e1f53c07..4509e0eb9 100644 --- a/lib/executor/src/executors/websocket.rs +++ b/lib/executor/src/executors/websocket.rs @@ -7,9 +7,7 @@ use ntex::rt::Arbiter; use tokio::sync::mpsc; use tracing::{debug, warn}; -use crate::executors::common::{ - SubgraphExecutionRequest, SubgraphExecutor, SUBSCRIPTION_EVENT_BUFFER_CAPACITY, -}; +use crate::executors::common::{SubgraphExecutionRequest, SubgraphExecutor}; use crate::executors::error::SubgraphExecutorError; use crate::executors::graphql_transport_ws::build_subscribe_payload; use crate::executors::websocket_client::{connect, WsClient}; @@ -114,9 +112,12 @@ impl SubgraphExecutor for WsSubgraphExecutor { BoxStream<'static, Result, SubgraphExecutorError>>, SubgraphExecutorError, > { - let (tx, mut rx) = mpsc::channel::, SubgraphExecutorError>>( - SUBSCRIPTION_EVENT_BUFFER_CAPACITY, - ); + // all subscriptions emit events into the shared active subscriptions broadcaster + // which itself handles back-pressure by dropping old events when the buffer is full, + // so we can use a small buffer here + // TODO: do we thererefore need to buffer at all? + let (tx, mut rx) = + mpsc::channel::, SubgraphExecutorError>>(16); let endpoint = self.endpoint.clone(); let subgraph_name = self.subgraph_name.clone(); diff --git a/lib/internal/src/inflight.rs b/lib/internal/src/inflight.rs index 7ac3c2237..d9fb5ce3e 100644 --- a/lib/internal/src/inflight.rs +++ b/lib/internal/src/inflight.rs @@ -99,11 +99,9 @@ where Fut: Future>, { let mut did_initialize = false; - let key = self.key.clone(); - let map = self.map.clone(); + let InFlightClaim { key, map, cell } = self; - let value = self - .cell + let value = cell .get_or_try_init(|| { did_initialize = true; async { @@ -122,9 +120,45 @@ where Ok((value, role)) } + + /// Like [`get_or_try_init`](Self::get_or_try_init), but passes the cleanup guard to the + /// caller. The caller is responsible for dropping the guard when deduplication should end. + /// This is useful for long-lived operations like subscriptions where the deduplication + /// window extends beyond the init future. + #[inline] + pub async fn get_or_try_init_with_guard( + self, + init: F, + ) -> Result<(Arc, InFlightRole), E> + where + F: FnOnce(InFlightCleanupGuard) -> Fut, + Fut: Future>, + { + let mut did_initialize = false; + let InFlightClaim { key, map, cell } = self; + + let value = cell + .get_or_try_init(|| { + did_initialize = true; + async { + let guard = InFlightCleanupGuard { key, map }; + init(guard).await.map(Arc::new) + } + }) + .await? + .clone(); + + let role = if did_initialize { + InFlightRole::Leader + } else { + InFlightRole::Joiner + }; + + Ok((value, role)) + } } -struct InFlightCleanupGuard +pub struct InFlightCleanupGuard where K: Eq + Hash, S: BuildHasher + Clone, diff --git a/lib/router-config/src/subscriptions.rs b/lib/router-config/src/subscriptions.rs index f3c727362..c922780fd 100644 --- a/lib/router-config/src/subscriptions.rs +++ b/lib/router-config/src/subscriptions.rs @@ -15,6 +15,20 @@ pub struct SubscriptionsConfig { /// You can override this setting by setting the `SUBSCRIPTIONS_ENABLED` environment variable to `true` or `false`. #[serde(default)] pub enabled: bool, + /// The capacity of the broadcast channel used to fan out subscription events to all active listeners. + /// + /// Each active subscription has its own broadcast channel. This value controls how many events + /// can be buffered in that channel before slow consumers start lagging. If a consumer falls too + /// far behind and the buffer is full, it will skip the missed messages and continue from the + /// latest available event. + /// + /// Subscription events are typically low-frequency, so the default of 32 is sufficient for most + /// use cases. Increase this value if you expect bursts of events or have slow consumers that + /// need more headroom to catch up. + /// + /// Defaults to 32. + #[serde(default = "default_broadcast_capacity")] + pub broadcast_capacity: usize, /// Configuration for subgraphs using the HTTP Callback protocol. #[serde(default, skip_serializing_if = "Option::is_none")] pub callback: Option, @@ -61,6 +75,10 @@ pub struct CallbackConfig { pub subgraphs: HashSet, } +fn default_broadcast_capacity() -> usize { + 32 +} + fn default_callback_path() -> AbsolutePath { AbsolutePath::try_from("/callback").expect("default callback path is valid") } diff --git a/lib/router-config/src/traffic_shaping.rs b/lib/router-config/src/traffic_shaping.rs index 8b362b6c6..7dc22492d 100644 --- a/lib/router-config/src/traffic_shaping.rs +++ b/lib/router-config/src/traffic_shaping.rs @@ -180,15 +180,53 @@ pub struct TrafficShapingRouterConfig { )] #[schemars(with = "String")] pub request_timeout: Duration, + + /// Maximum number of concurrent long-lived clients (WebSocket connections and HTTP streaming responses). + /// Regular non-streaming requests are not counted toward this limit. + /// When the limit is reached, new WebSocket and streaming HTTP requests are rejected with 503. + /// If both WebSockets and Subscriptions are disabled, this setting has no effect. + #[serde(default = "default_max_long_lived_clients")] + pub max_long_lived_clients: usize, } #[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] #[serde(deny_unknown_fields)] pub struct TrafficShapingRouterDedupeConfig { - /// Enables/disables in-flight request deduplication at the router endpoint level. + /// Enables/disables in-flight request and active subscriptions deduplication at the router level. + /// + /// When enabled, the router deduplicates both queries and subscriptions using the same + /// fingerprint key (method, path, selected headers, schema checksum, normalized operation + /// hash, variables, and extensions). The `headers` configuration below controls which + /// headers participate in that key for all operation types. + /// + /// For queries, concurrent HTTP requests that produce the same fingerprint share a single + /// in-flight execution - only the first one runs, and the rest wait for and receive the + /// same result. + /// + /// For subscriptions, the mechanism is broadcast-based rather than request-sharing. The + /// first client with a given fingerprint becomes the leader: it runs the upstream subscription + /// and its events are fanned out through a broadcast channel backed by an active subscriptions + /// registry. Any subsequent client that arrives with an identical fingerprint while that subscription + /// is still active joins as a listener on the same broadcast channel instead of starting a new upstream + /// connection. When all listeners have dropped and the leader finishes, the entry is removed from the + /// registry. /// - /// When enabled, identical incoming GraphQL query requests that are processed at the same time - /// share the same in-flight execution result. + /// WebSocket connections participate in the same deduplication space as HTTP. Each + /// subscribe message is processed with a synthetic request assembled from the WebSocket + /// path and the headers derived from the `websocket.headers` config. The fingerprint is computed + /// from those synthetic headers using the same header policy, so a subscription started over HTTP + /// and an identical one started over WebSocket will deduplicate against each other. + /// + /// The deduplication is transport agnostic. A query over WebSocket would get deduplicated with an + /// identical query over HTTP if they arrive at the same time and have the same fingerprint. + /// + /// Note: `content-type` is part of the fingerprint when `headers` includes it (e.g. `all`). + /// Since HTTP streaming clients send different `accept` headers than WebSocket clients, + /// cross-transport deduplication for subscriptions only applies when `content-type` (and + /// transport-specific headers) are excluded from the key. Configure `headers: none` or + /// `headers: { include: [] }` (or exclude the relevant headers) to enable true cross-transport + /// deduplication, where a WebSocket subscription and an SSE subscription with the same operation + /// share a single upstream connection and the events are fanned out to both. #[serde(default = "default_router_dedupe_enabled")] pub enabled: bool, @@ -238,11 +276,16 @@ fn default_router_request_timeout() -> Duration { Duration::from_secs(60) } +fn default_max_long_lived_clients() -> usize { + 128 +} + impl Default for TrafficShapingRouterConfig { fn default() -> Self { Self { dedupe: Default::default(), request_timeout: default_router_request_timeout(), + max_long_lived_clients: default_max_long_lived_clients(), } } }