From 937d5768519002f39d3699135007f1dc4ab04cd5 Mon Sep 17 00:00:00 2001 From: pgherveou Date: Wed, 1 Jul 2026 11:18:51 +0200 Subject: [PATCH] feat(truapi-server): add wire and chain infrastructure --- .../crates/truapi-server/src/chain_runtime.rs | 1592 +++++++++++++++++ rust/crates/truapi-server/src/dispatcher.rs | 272 +++ rust/crates/truapi-server/src/frame.rs | 433 +++++ .../truapi-server/src/host_rpc_client.rs | 587 ++++++ rust/crates/truapi-server/src/lib.rs | 14 +- rust/crates/truapi-server/src/subscription.rs | 495 +++++ rust/crates/truapi-server/src/transport.rs | 12 + .../truapi-server/tests/golden_frame.rs | 53 + .../tests/snapshots/golden-account-get.bin | Bin 0 -> 14 bytes .../tests/wire_table_ts_parity.rs | 228 +++ 10 files changed, 3683 insertions(+), 3 deletions(-) create mode 100644 rust/crates/truapi-server/src/chain_runtime.rs create mode 100644 rust/crates/truapi-server/src/dispatcher.rs create mode 100644 rust/crates/truapi-server/src/frame.rs create mode 100644 rust/crates/truapi-server/src/host_rpc_client.rs create mode 100644 rust/crates/truapi-server/src/subscription.rs create mode 100644 rust/crates/truapi-server/src/transport.rs create mode 100644 rust/crates/truapi-server/tests/golden_frame.rs create mode 100644 rust/crates/truapi-server/tests/snapshots/golden-account-get.bin create mode 100644 rust/crates/truapi-server/tests/wire_table_ts_parity.rs diff --git a/rust/crates/truapi-server/src/chain_runtime.rs b/rust/crates/truapi-server/src/chain_runtime.rs new file mode 100644 index 00000000..b8fcacca --- /dev/null +++ b/rust/crates/truapi-server/src/chain_runtime.rs @@ -0,0 +1,1592 @@ +//! ChainHead v1 state machine used by `PlatformRuntimeHost`. +//! +//! [`ChainRuntime`] keeps one [`ChainConnection`] per chain (keyed by genesis +//! hash) on top of the platform-provided [`JsonRpcConnection`]. The generic +//! JSON-RPC mechanics are delegated to [`crate::host_rpc_client`], while +//! `subxt-rpcs` owns the raw `chainHead_v1` method shapes and event parsing. +//! This module keeps the TrUAPI-facing local follow ids and maps subxt DTOs to +//! public v01 [`RemoteChainHeadFollowItem`] values. +//! +//! The chain-side traits return [`RuntimeFailure`], a local classification +//! that the [`crate::runtime`] layer maps to [`truapi::CallError`] variants +//! (`Unsupported`, `HostFailure`, ...). This avoids leaking json-rpc plumbing +//! into the public API. + +#![allow(dead_code)] + +use core::pin::Pin; +use core::task::{Context, Poll}; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex; + +use futures::FutureExt; +use futures::channel::mpsc; +use futures::future::{AbortHandle, Abortable}; +use futures::future::{BoxFuture, Shared}; +use futures::stream::BoxStream; +use futures::{Stream, StreamExt}; +use parity_scale_codec::{Decode, Error as ScaleError, Input}; +use primitive_types::H256; +use serde::de::{Deserializer, Error as DeError}; +use serde_json::Value; +use subxt_rpcs::client::RpcClient; +use subxt_rpcs::methods::chain_head as subxt_chain; +use subxt_rpcs::{ChainHeadRpcMethods, Error as SubxtRpcError, RpcConfig}; +use tracing::instrument; +use truapi::v01::{ + OperationStartedResult, RemoteChainHeadBodyRequest, RemoteChainHeadBodyResponse, + RemoteChainHeadCallRequest, RemoteChainHeadCallResponse, RemoteChainHeadContinueRequest, + RemoteChainHeadFollowItem, RemoteChainHeadFollowRequest, RemoteChainHeadHeaderRequest, + RemoteChainHeadHeaderResponse, RemoteChainHeadStopOperationRequest, + RemoteChainHeadStorageRequest, RemoteChainHeadStorageResponse, RemoteChainHeadUnpinRequest, + RemoteChainSpecChainNameResponse, RemoteChainSpecGenesisHashResponse, + RemoteChainSpecPropertiesResponse, RemoteChainTransactionBroadcastRequest, + RemoteChainTransactionBroadcastResponse, RemoteChainTransactionStopRequest, RuntimeApi, + RuntimeSpec, RuntimeType, StorageQueryItem, StorageQueryType, StorageResultItem, +}; +use truapi_platform::JsonRpcConnection; + +use crate::host_rpc_client::HostRpcClient; +use crate::subscription::Spawner; + +const FOLLOW_METHOD: &str = "remote_chain_head_follow"; + +struct TruapiRpcConfig; + +impl RpcConfig for TruapiRpcConfig { + type Header = RawHeader; + type Hash = H256; + type AccountId = (); +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct RawHeader(Vec); + +impl Decode for RawHeader { + fn decode(input: &mut I) -> Result { + let Some(len) = input.remaining_len()? else { + return Err("raw header input length is unknown".into()); + }; + let mut bytes = vec![0u8; len]; + input.read(&mut bytes)?; + Ok(Self(bytes)) + } +} + +impl<'de> serde::Deserialize<'de> for RawHeader { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let bytes = subxt_chain::Bytes::deserialize(deserializer).map_err(D::Error::custom)?; + Ok(Self(bytes.0)) + } +} + +/// Shared, single-flight `chainHead_v1_follow` setup keyed by local follow id. +/// Concurrent callers for the same id await one in-flight request rather than +/// each opening (and leaking) a separate remote subscription. +type FollowSetup = Shared>>; + +/// Shared, single-flight provider connect keyed by genesis hash. Concurrent +/// first connections for the same chain await one in-flight `connect` rather +/// than each opening a connection and orphaning all but the last insert. +type ConnectionSetup = Shared, RuntimeFailure>>>; + +/// Classification of framework-level chain failures separate from JSON-RPC +/// domain errors. Maps cleanly to [`truapi::CallError`] variants at the +/// `PlatformRuntimeHost` boundary. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RuntimeFailureKind { + /// Backend is not wired or refused the request for plumbing reasons. + Unavailable, + /// Backend responded but the payload was malformed or the call failed. + HostFailure, +} + +/// Framework-level chain failure with a diagnostic reason. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RuntimeFailure { + kind: RuntimeFailureKind, + method: &'static str, + reason: Option, +} + +impl RuntimeFailure { + /// Backend refused the call for unavailability reasons (no provider, the + /// connection died, etc.). + pub fn unavailable(method: &'static str) -> Self { + Self { + kind: RuntimeFailureKind::Unavailable, + method, + reason: None, + } + } + + /// Backend produced a structural error (malformed json-rpc, unexpected + /// shape, ...). + pub fn host_failure(method: &'static str, reason: impl Into) -> Self { + Self { + kind: RuntimeFailureKind::HostFailure, + method, + reason: Some(reason.into()), + } + } + + /// Failure classification. + pub fn kind(&self) -> RuntimeFailureKind { + self.kind + } + + /// Method tag the failure originated from. + #[allow(dead_code)] + pub fn method(&self) -> &'static str { + self.method + } + + /// Diagnostic reason. Always non-empty for `HostFailure`. + pub fn reason(&self) -> String { + match &self.reason { + Some(reason) => format!("{}: {}", self.method, reason), + None => self.method.to_string(), + } + } + + /// Re-tag this failure under `method`, preserving its kind and reason. + fn reclassify(&self, method: &'static str) -> RuntimeFailure { + match self.kind() { + RuntimeFailureKind::Unavailable => RuntimeFailure::unavailable(method), + RuntimeFailureKind::HostFailure => RuntimeFailure::host_failure(method, self.reason()), + } + } +} + +/// Provider of `JsonRpcConnection` instances keyed by chain genesis hash. +/// The default [`UnavailableChainProvider`] makes every call fail; real +/// hosts plug in the platform-side `ChainProvider`. +#[async_trait::async_trait] +pub trait RuntimeChainProvider: Send + Sync { + /// Open or reuse a JSON-RPC connection for the chain identified by + /// `genesis_hash`. + async fn connect( + &self, + genesis_hash: Vec, + ) -> Result, RuntimeFailure>; +} + +/// Default provider: every `connect` call fails with `Unavailable`, so each +/// chain RPC surfaces a typed "unavailable" error to the product. +#[allow(dead_code)] +#[derive(Default)] +pub struct UnavailableChainProvider; + +#[async_trait::async_trait] +impl RuntimeChainProvider for UnavailableChainProvider { + async fn connect( + &self, + _genesis_hash: Vec, + ) -> Result, RuntimeFailure> { + Err(RuntimeFailure::unavailable("remote_chain_connect")) + } +} + +/// chainHead-v1 state machine on top of a [`RuntimeChainProvider`]. +/// +/// Each method maps a typed v01 chain request to one or more json-rpc calls, +/// shares one `chainHead_v1_follow` subscription per (genesis_hash, local +/// follow id) pair, and parses follow events back into typed +/// [`RemoteChainHeadFollowItem`] values. +#[derive(Clone)] +pub struct ChainRuntime { + provider: Arc, + spawner: Spawner, + connections: Arc>>>, + connection_setups: Arc>>, +} + +impl ChainRuntime { + /// Build a `ChainRuntime` driven by `provider`. Background tasks (response + /// pumps, follow setup) are spawned on `spawner`. + pub fn new(provider: Arc, spawner: Spawner) -> Self { + Self { + provider, + spawner, + connections: Arc::new(Mutex::new(HashMap::new())), + connection_setups: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Start (or attach to an existing) `chainHead_v1_follow` subscription. + /// Returns a stream of typed follow items that closes when the remote + /// sends `stop` or the connection drops. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.follow"))] + pub fn remote_chain_head_follow( + &self, + follow_subscription_id: String, + request: RemoteChainHeadFollowRequest, + ) -> BoxStream<'static, RemoteChainHeadFollowItem> { + let (tx, rx) = mpsc::unbounded(); + let runtime = self.clone(); + let cleanup_runtime = self.clone(); + let cleanup_genesis_hash = request.genesis_hash.clone(); + let cleanup_follow_id = follow_subscription_id.clone(); + + let fut = async move { + if runtime + .start_follow(follow_subscription_id, request, Some(tx.clone())) + .await + .is_err() + { + let _ = tx.unbounded_send(FollowSignal::Interrupt); + } + }; + (self.spawner)(fut.boxed()); + + ManagedSubscription::new( + rx.boxed(), + Some(Box::new(move || { + cleanup_runtime.cleanup_follow(&cleanup_genesis_hash, &cleanup_follow_id); + })), + ) + .filter_map(|signal| async move { + match signal { + FollowSignal::Item(item) => Some(item), + FollowSignal::Interrupt => None, + } + }) + .boxed() + } + + /// Fetch a block header. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.header"))] + pub async fn remote_chain_head_header( + &self, + request: RemoteChainHeadHeaderRequest, + ) -> Result { + let method = "remote_chain_head_header"; + let connection = self.connection_for(method, &request.genesis_hash).await?; + let remote_follow_id = self + .ensure_follow_context(method, &connection, request.follow_subscription_id, false) + .await?; + + let hash = hash_from_bytes(method, &request.hash)?; + let header = connection + .methods + .chainhead_v1_header(&remote_follow_id, hash) + .await + .map_err(|err| rpc_failure(method, err))? + .map(|header| header.0); + Ok(RemoteChainHeadHeaderResponse { header }) + } + + /// Start a chainHead_v1_body operation. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.body"))] + pub async fn remote_chain_head_body( + &self, + request: RemoteChainHeadBodyRequest, + ) -> Result { + let method = "remote_chain_head_body"; + let connection = self.connection_for(method, &request.genesis_hash).await?; + let remote_follow_id = self + .ensure_follow_context(method, &connection, request.follow_subscription_id, false) + .await?; + + let operation = connection + .methods + .chainhead_v1_body(&remote_follow_id, hash_from_bytes(method, &request.hash)?) + .await + .map_err(|err| rpc_failure(method, err)) + .and_then(operation_started_result)?; + Ok(RemoteChainHeadBodyResponse { operation }) + } + + /// Start a chainHead_v1_storage operation. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.storage"))] + pub async fn remote_chain_head_storage( + &self, + request: RemoteChainHeadStorageRequest, + ) -> Result { + let method = "remote_chain_head_storage"; + let connection = self.connection_for(method, &request.genesis_hash).await?; + let remote_follow_id = self + .ensure_follow_context(method, &connection, request.follow_subscription_id, false) + .await?; + + let items = request + .items + .iter() + .map(map_storage_query_item) + .collect::>(); + + let operation = connection + .methods + .chainhead_v1_storage( + &remote_follow_id, + hash_from_bytes(method, &request.hash)?, + items, + request.child_trie.as_deref(), + ) + .await + .map_err(|err| rpc_failure(method, err)) + .and_then(operation_started_result)?; + Ok(RemoteChainHeadStorageResponse { operation }) + } + + /// Start a chainHead_v1_call operation. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.call"))] + pub async fn remote_chain_head_call( + &self, + request: RemoteChainHeadCallRequest, + ) -> Result { + let method = "remote_chain_head_call"; + let connection = self.connection_for(method, &request.genesis_hash).await?; + let remote_follow_id = self + .ensure_follow_context(method, &connection, request.follow_subscription_id, true) + .await?; + + let operation = connection + .methods + .chainhead_v1_call( + &remote_follow_id, + hash_from_bytes(method, &request.hash)?, + &request.function, + &request.call_parameters, + ) + .await + .map_err(|err| rpc_failure(method, err)) + .and_then(operation_started_result)?; + Ok(RemoteChainHeadCallResponse { operation }) + } + + /// Release pinned blocks. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.unpin"))] + pub async fn remote_chain_head_unpin( + &self, + request: RemoteChainHeadUnpinRequest, + ) -> Result<(), RuntimeFailure> { + let method = "remote_chain_head_unpin"; + let connection = self.connection_for(method, &request.genesis_hash).await?; + let remote_follow_id = self + .ensure_follow_context(method, &connection, request.follow_subscription_id, false) + .await?; + for hash in request.hashes { + connection + .methods + .chainhead_v1_unpin(&remote_follow_id, hash_from_bytes(method, &hash)?) + .await + .map_err(|err| rpc_failure(method, err))?; + } + Ok(()) + } + + /// Continue a paused operation. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.continue"))] + pub async fn remote_chain_head_continue( + &self, + request: RemoteChainHeadContinueRequest, + ) -> Result<(), RuntimeFailure> { + let method = "remote_chain_head_continue"; + let connection = self.connection_for(method, &request.genesis_hash).await?; + let remote_follow_id = self + .ensure_follow_context(method, &connection, request.follow_subscription_id, false) + .await?; + connection + .methods + .chainhead_v1_continue(&remote_follow_id, &request.operation_id) + .await + .map_err(|err| rpc_failure(method, err)) + } + + /// Stop a chain-head operation. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.stop_operation"))] + pub async fn remote_chain_head_stop_operation( + &self, + request: RemoteChainHeadStopOperationRequest, + ) -> Result<(), RuntimeFailure> { + let method = "remote_chain_head_stop_operation"; + let connection = self.connection_for(method, &request.genesis_hash).await?; + let remote_follow_id = self + .ensure_follow_context(method, &connection, request.follow_subscription_id, false) + .await?; + connection + .methods + .chainhead_v1_stop_operation(&remote_follow_id, &request.operation_id) + .await + .map_err(|err| rpc_failure(method, err)) + } + + /// Echo back the chain genesis hash via chainSpec_v1_genesisHash. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.spec_genesis_hash"))] + pub async fn remote_chain_spec_genesis_hash( + &self, + genesis_hash: Vec, + ) -> Result { + let method = "remote_chain_spec_genesis_hash"; + let connection = self.connection_for(method, &genesis_hash).await?; + let genesis_hash = connection + .methods + .chainspec_v1_genesis_hash() + .await + .map_err(|err| rpc_failure(method, err)) + .map(hash_to_bytes)?; + Ok(RemoteChainSpecGenesisHashResponse { genesis_hash }) + } + + /// Fetch the chain display name via chainSpec_v1_chainName. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.spec_chain_name"))] + pub async fn remote_chain_spec_chain_name( + &self, + genesis_hash: Vec, + ) -> Result { + let method = "remote_chain_spec_chain_name"; + let connection = self.connection_for(method, &genesis_hash).await?; + let chain_name = connection + .methods + .chainspec_v1_chain_name() + .await + .map_err(|err| rpc_failure(method, err))?; + Ok(RemoteChainSpecChainNameResponse { chain_name }) + } + + /// Fetch the chain JSON properties via chainSpec_v1_properties. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.spec_properties"))] + pub async fn remote_chain_spec_properties( + &self, + genesis_hash: Vec, + ) -> Result { + let method = "remote_chain_spec_properties"; + let connection = self.connection_for(method, &genesis_hash).await?; + let value = connection + .methods + .chainspec_v1_properties::() + .await + .map_err(|err| rpc_failure(method, err))?; + let properties = serde_json::to_string(&value) + .map_err(|err| RuntimeFailure::host_failure(method, err.to_string()))?; + Ok(RemoteChainSpecPropertiesResponse { properties }) + } + + /// Broadcast a signed transaction via transaction_v1_broadcast. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.transaction_broadcast"))] + pub async fn remote_chain_transaction_broadcast( + &self, + request: RemoteChainTransactionBroadcastRequest, + ) -> Result { + let method = "remote_chain_transaction_broadcast"; + let connection = self.connection_for(method, &request.genesis_hash).await?; + let operation_id = connection + .methods + .transaction_v1_broadcast(&request.transaction) + .await + .map_err(|err| rpc_failure(method, err))?; + Ok(RemoteChainTransactionBroadcastResponse { operation_id }) + } + + /// Stop a transaction broadcast via transaction_v1_stop. + #[instrument(skip_all, fields(runtime.method = "chain_runtime.transaction_stop"))] + pub async fn remote_chain_transaction_stop( + &self, + request: RemoteChainTransactionStopRequest, + ) -> Result<(), RuntimeFailure> { + let method = "remote_chain_transaction_stop"; + let connection = self.connection_for(method, &request.genesis_hash).await?; + connection + .methods + .transaction_v1_stop(&request.operation_id) + .await + .map_err(|err| rpc_failure(method, err)) + } + + #[instrument(skip_all, fields(runtime.method = "chain_runtime.connection_for", method = method))] + async fn connection_for( + &self, + method: &'static str, + genesis_hash: &[u8], + ) -> Result, RuntimeFailure> { + let key = encode_hex(genesis_hash); + let setup = { + let mut connections = self.connections.lock().unwrap(); + match connections.get(&key) { + Some(connection) if !connection.is_closed() => return Ok(connection.clone()), + Some(_) => { + connections.remove(&key); + } + None => {} + } + // Single-flight the provider connect (same shape as + // `follow_setups`): concurrent first connections for the same + // chain share one in-flight `connect` instead of racing the + // insert and orphaning the loser's connection. + let mut setups = self.connection_setups.lock().unwrap(); + if let Some(existing) = setups.get(&key) { + existing.clone() + } else { + let provider = self.provider.clone(); + let spawner = self.spawner.clone(); + let connections = self.connections.clone(); + let setups_map = self.connection_setups.clone(); + let setup_key = key.clone(); + let genesis_hash = genesis_hash.to_owned(); + let setup: ConnectionSetup = async move { + let result = provider.connect(genesis_hash).await.map(|rpc| { + let connection = ChainConnection::new(rpc, spawner); + connections + .lock() + .unwrap() + .insert(setup_key.clone(), connection.clone()); + connection + }); + setups_map.lock().unwrap().remove(&setup_key); + result + } + .boxed() + .shared(); + setups.insert(key, setup.clone()); + setup + } + }; + + setup.await.map_err(|failure| failure.reclassify(method)) + } + + #[instrument(skip_all, fields(runtime.method = "chain_runtime.start_follow"))] + async fn start_follow( + &self, + local_follow_id: String, + request: RemoteChainHeadFollowRequest, + sender: Option>, + ) -> Result<(), RuntimeFailure> { + let connection = self + .connection_for(FOLLOW_METHOD, &request.genesis_hash) + .await?; + // Record this subscriber's sender before kicking off (or joining) the + // single-flight setup so events route to it regardless of which caller + // wins the setup. + connection.register_follow_intent(&local_follow_id, request.with_runtime, sender); + connection + .ensure_remote_follow(local_follow_id, request.with_runtime) + .await?; + Ok(()) + } + + #[instrument(skip_all, fields(runtime.method = "chain_runtime.ensure_follow_context", method = method))] + async fn ensure_follow_context( + &self, + method: &'static str, + connection: &Arc, + local_follow_id: String, + with_runtime: bool, + ) -> Result { + let remote_follow_id = connection + .require_remote_follow(method, local_follow_id.clone()) + .await?; + if with_runtime && !connection.follow_with_runtime(&local_follow_id) { + return Err(RuntimeFailure::host_failure( + method, + "follow subscription was created without runtime metadata", + )); + } + Ok(remote_follow_id) + } + + #[instrument(skip_all, fields(runtime.method = "chain_runtime.cleanup_follow"))] + fn cleanup_follow(&self, genesis_hash: &[u8], local_follow_id: &str) { + let key = encode_hex(genesis_hash); + let Some(connection) = self.connections.lock().unwrap().get(&key).cloned() else { + return; + }; + connection.unfollow(local_follow_id); + } +} + +/// One delivery on the local follow stream. `Interrupt` signals an +/// abnormal close (connection dropped, follow setup failed); it produces no +/// item but ends the stream. +enum FollowSignal { + Item(RemoteChainHeadFollowItem), + Interrupt, +} + +struct ChainConnection { + rpc_client: HostRpcClient, + methods: ChainHeadRpcMethods, + spawner: Spawner, + follows: Mutex>, + follow_setups: Mutex>, +} + +impl ChainConnection { + fn new(rpc: Arc, spawner: Spawner) -> Arc { + let rpc_client = HostRpcClient::new(rpc, spawner.clone()); + let methods = ChainHeadRpcMethods::new(RpcClient::new(rpc_client.clone())); + Arc::new(Self { + rpc_client, + methods, + spawner, + follows: Mutex::new(HashMap::new()), + follow_setups: Mutex::new(HashMap::new()), + }) + } + + fn is_closed(&self) -> bool { + self.rpc_client.is_closed() + } + + fn follow_with_runtime(&self, local_follow_id: &str) -> bool { + self.follows + .lock() + .unwrap() + .get(local_follow_id) + .is_some_and(|follow| follow.with_runtime) + } + + fn remote_follow_id(&self, local_follow_id: &str) -> Option { + self.follows + .lock() + .unwrap() + .get(local_follow_id) + .and_then(|follow| follow.remote_subscription_id.clone()) + } + + /// Record intent to follow `local_follow_id`, attaching `sender` for a + /// follow subscriber. Idempotent: an existing follow keeps its + /// `with_runtime` flag and remote id; only the sender is (re)attached. + fn register_follow_intent( + &self, + local_follow_id: &str, + with_runtime: bool, + sender: Option>, + ) { + let mut follows = self.follows.lock().unwrap(); + match follows.get_mut(local_follow_id) { + Some(follow) => { + if sender.is_some() { + follow.sender = sender; + } + } + None => { + follows.insert( + local_follow_id.to_string(), + FollowState { + with_runtime, + remote_subscription_id: None, + abort: None, + sender, + }, + ); + } + } + } + + /// Issue `chainHead_v1_follow` exactly once per local follow id and return + /// the remote subscription id. Concurrent callers for the same id share + /// one in-flight setup instead of each opening a duplicate remote + /// subscription that would then leak. + #[instrument(skip_all, fields(runtime.method = "chain_connection.ensure_remote_follow"))] + async fn ensure_remote_follow( + self: &Arc, + local_follow_id: String, + with_runtime: bool, + ) -> Result { + if let Some(remote_follow_id) = self.remote_follow_id(&local_follow_id) { + return Ok(remote_follow_id); + } + + let setup = { + let mut setups = self.follow_setups.lock().unwrap(); + if let Some(existing) = setups.get(&local_follow_id) { + existing.clone() + } else { + let connection = self.clone(); + let id = local_follow_id.clone(); + let setup: FollowSetup = + async move { connection.run_follow_setup(id, with_runtime).await } + .boxed() + .shared(); + setups.insert(local_follow_id.clone(), setup.clone()); + setup + } + }; + + let result = setup.await; + // On failure, drop the cached setup so a later re-subscribe can retry. + // On success the established follow short-circuits the fast path above, + // and `remove_follow` clears the entry at teardown. + if result.is_err() { + self.follow_setups.lock().unwrap().remove(&local_follow_id); + } + result + } + + /// Return the remote follow id for an already-created local follow. + /// + /// Follow-bound request methods must not create remote follows themselves: + /// the local follow stream owns cleanup, so only `follow_head_subscribe` + /// may establish the remote subscription. + #[instrument(skip_all, fields(runtime.method = "chain_connection.require_remote_follow"))] + async fn require_remote_follow( + self: &Arc, + method: &'static str, + local_follow_id: String, + ) -> Result { + if let Some(remote_follow_id) = self.remote_follow_id(&local_follow_id) { + return Ok(remote_follow_id); + } + + let setup = { + let follows = self.follows.lock().unwrap(); + if !follows.contains_key(&local_follow_id) { + return Err(RuntimeFailure::host_failure( + method, + format!("unknown follow subscription id {local_follow_id:?}"), + )); + } + self.follow_setups + .lock() + .unwrap() + .get(&local_follow_id) + .cloned() + }; + + match setup { + Some(setup) => setup.await.map_err(|failure| failure.reclassify(method)), + None => Err(RuntimeFailure::host_failure( + method, + format!("follow subscription {local_follow_id:?} is not established"), + )), + } + } + + /// Body of the single-flight follow setup: ensure the `FollowState` + /// exists, issue `chainHead_v1_follow`, and record the remote id. + #[instrument(skip_all, fields(runtime.method = "chain_connection.run_follow_setup"))] + async fn run_follow_setup( + self: Arc, + local_follow_id: String, + with_runtime: bool, + ) -> Result { + self.follows + .lock() + .unwrap() + .entry(local_follow_id.clone()) + .or_insert_with(|| FollowState { + with_runtime, + remote_subscription_id: None, + abort: None, + sender: None, + }); + + let mut follow = self + .methods + .chainhead_v1_follow(with_runtime) + .await + .map_err(|err| { + self.remove_follow(&local_follow_id); + rpc_failure(FOLLOW_METHOD, err) + })?; + let remote_follow_id = follow + .subscription_id() + .ok_or_else(|| { + RuntimeFailure::host_failure(FOLLOW_METHOD, "missing follow subscription id") + })? + .to_string(); + + let (abort, abort_registration) = AbortHandle::new_pair(); + let connection = self.clone(); + let pump_follow_id = local_follow_id.clone(); + let pump = async move { + while let Some(item) = follow.next().await { + match item { + Ok(event) => match map_follow_event(event) { + Ok(item) => { + let is_stop = matches!(item, RemoteChainHeadFollowItem::Stop); + connection.deliver_follow_event(&pump_follow_id, item, false); + if is_stop { + break; + } + } + Err(_) => { + connection.interrupt_follow(&pump_follow_id, false); + break; + } + }, + Err(_) => { + connection.interrupt_follow(&pump_follow_id, false); + break; + } + } + } + connection.remove_follow_without_abort(&pump_follow_id); + }; + + if !self.attach_remote_follow(&local_follow_id, remote_follow_id.clone(), abort) { + return Err(RuntimeFailure::unavailable(FOLLOW_METHOD)); + } + + (self.spawner)(Abortable::new(pump, abort_registration).map(|_| ()).boxed()); + Ok(remote_follow_id) + } + + fn attach_remote_follow( + &self, + local_follow_id: &str, + remote_follow_id: String, + abort: AbortHandle, + ) -> bool { + let mut follows = self.follows.lock().unwrap(); + let Some(follow) = follows.get_mut(local_follow_id) else { + return false; + }; + follow.remote_subscription_id = Some(remote_follow_id); + follow.abort = Some(abort); + true + } + + fn remove_follow(&self, local_follow_id: &str) { + self.follow_setups.lock().unwrap().remove(local_follow_id); + if let Some(mut follow) = self.follows.lock().unwrap().remove(local_follow_id) + && let Some(abort) = follow.abort.take() + { + abort.abort(); + } + } + + fn remove_follow_without_abort(&self, local_follow_id: &str) { + self.follow_setups.lock().unwrap().remove(local_follow_id); + self.follows.lock().unwrap().remove(local_follow_id); + } + + fn unfollow(&self, local_follow_id: &str) { + self.remove_follow(local_follow_id); + } + + fn deliver_follow_event( + &self, + local_follow_id: &str, + event: RemoteChainHeadFollowItem, + abort_on_stop: bool, + ) { + let sender = self + .follows + .lock() + .unwrap() + .get(local_follow_id) + .and_then(|follow| follow.sender.clone()); + let is_stop = matches!(event, RemoteChainHeadFollowItem::Stop); + if let Some(sender) = sender { + let _ = sender.unbounded_send(FollowSignal::Item(event)); + } + if is_stop { + if abort_on_stop { + self.remove_follow(local_follow_id); + } else { + self.remove_follow_without_abort(local_follow_id); + } + } + } + + fn interrupt_follow(&self, local_follow_id: &str, abort: bool) { + let sender = self + .follows + .lock() + .unwrap() + .get(local_follow_id) + .and_then(|follow| follow.sender.clone()); + if let Some(sender) = sender { + let _ = sender.unbounded_send(FollowSignal::Interrupt); + } + if abort { + self.remove_follow(local_follow_id); + } else { + self.remove_follow_without_abort(local_follow_id); + } + } +} + +struct FollowState { + with_runtime: bool, + remote_subscription_id: Option, + abort: Option, + sender: Option>, +} + +/// Subscription wrapper that runs an `on_drop` cleanup when the stream is +/// dropped. Used by `remote_chain_head_follow` to send `chainHead_v1_unfollow` +/// when the local follow stream is dropped. +struct ManagedSubscription { + inner: BoxStream<'static, T>, + on_drop: Option>, +} + +impl ManagedSubscription { + fn new(inner: BoxStream<'static, T>, on_drop: Option>) -> Self { + Self { inner, on_drop } + } +} + +impl Drop for ManagedSubscription { + fn drop(&mut self) { + if let Some(on_drop) = self.on_drop.take() { + on_drop(); + } + } +} + +impl Stream for ManagedSubscription { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + this.inner.as_mut().poll_next(cx) + } +} + +fn operation_started_result( + response: subxt_chain::MethodResponse, +) -> Result { + match response { + subxt_chain::MethodResponse::Started(started) => Ok(OperationStartedResult::Started { + operation_id: started.operation_id, + }), + subxt_chain::MethodResponse::LimitReached => Ok(OperationStartedResult::LimitReached), + } +} + +fn map_follow_event( + event: subxt_chain::FollowEvent, +) -> Result { + match event { + subxt_chain::FollowEvent::Initialized(event) => { + Ok(RemoteChainHeadFollowItem::Initialized { + finalized_block_hashes: event + .finalized_block_hashes + .into_iter() + .map(hash_to_bytes) + .collect(), + finalized_block_runtime: event + .finalized_block_runtime + .map(map_runtime_event) + .transpose()?, + }) + } + subxt_chain::FollowEvent::NewBlock(event) => Ok(RemoteChainHeadFollowItem::NewBlock { + block_hash: hash_to_bytes(event.block_hash), + parent_block_hash: hash_to_bytes(event.parent_block_hash), + new_runtime: event.new_runtime.map(map_runtime_event).transpose()?, + }), + subxt_chain::FollowEvent::BestBlockChanged(event) => { + Ok(RemoteChainHeadFollowItem::BestBlockChanged { + best_block_hash: hash_to_bytes(event.best_block_hash), + }) + } + subxt_chain::FollowEvent::Finalized(event) => Ok(RemoteChainHeadFollowItem::Finalized { + finalized_block_hashes: event + .finalized_block_hashes + .into_iter() + .map(hash_to_bytes) + .collect(), + pruned_block_hashes: event + .pruned_block_hashes + .into_iter() + .map(hash_to_bytes) + .collect(), + }), + subxt_chain::FollowEvent::OperationBodyDone(event) => { + Ok(RemoteChainHeadFollowItem::OperationBodyDone { + operation_id: event.operation_id, + value: event.value.into_iter().map(|bytes| bytes.0).collect(), + }) + } + subxt_chain::FollowEvent::OperationCallDone(event) => { + Ok(RemoteChainHeadFollowItem::OperationCallDone { + operation_id: event.operation_id, + output: event.output.0, + }) + } + subxt_chain::FollowEvent::OperationStorageItems(event) => { + Ok(RemoteChainHeadFollowItem::OperationStorageItems { + operation_id: event.operation_id, + items: event + .items + .into_iter() + .map(map_storage_result) + .collect::, _>>()?, + }) + } + subxt_chain::FollowEvent::OperationStorageDone(event) => { + Ok(RemoteChainHeadFollowItem::OperationStorageDone { + operation_id: event.operation_id, + }) + } + subxt_chain::FollowEvent::OperationWaitingForContinue(event) => { + Ok(RemoteChainHeadFollowItem::OperationWaitingForContinue { + operation_id: event.operation_id, + }) + } + subxt_chain::FollowEvent::OperationInaccessible(event) => { + Ok(RemoteChainHeadFollowItem::OperationInaccessible { + operation_id: event.operation_id, + }) + } + subxt_chain::FollowEvent::OperationError(event) => { + Ok(RemoteChainHeadFollowItem::OperationError { + operation_id: event.operation_id, + error: event.error, + }) + } + subxt_chain::FollowEvent::Stop => Ok(RemoteChainHeadFollowItem::Stop), + } +} + +fn map_runtime_event(event: subxt_chain::RuntimeEvent) -> Result { + match event { + subxt_chain::RuntimeEvent::Valid(event) => { + let mut apis = event + .spec + .apis + .into_iter() + .map(|(name, version)| RuntimeApi { name, version }) + .collect::>(); + apis.sort_by(|left, right| left.name.cmp(&right.name)); + Ok(RuntimeType::Valid(RuntimeSpec { + spec_name: event.spec.spec_name, + impl_name: event.spec.impl_name, + spec_version: event.spec.spec_version, + impl_version: event.spec.impl_version, + transaction_version: Some(event.spec.transaction_version), + apis, + })) + } + subxt_chain::RuntimeEvent::Invalid(event) => { + Ok(RuntimeType::Invalid { error: event.error }) + } + } +} + +fn map_storage_query_item(item: &StorageQueryItem) -> subxt_chain::StorageQuery<&[u8]> { + subxt_chain::StorageQuery { + key: item.key.as_slice(), + query_type: match item.query_type { + StorageQueryType::Value => subxt_chain::StorageQueryType::Value, + StorageQueryType::Hash => subxt_chain::StorageQueryType::Hash, + StorageQueryType::ClosestDescendantMerkleValue => { + subxt_chain::StorageQueryType::ClosestDescendantMerkleValue + } + StorageQueryType::DescendantsValues => subxt_chain::StorageQueryType::DescendantsValues, + StorageQueryType::DescendantsHashes => subxt_chain::StorageQueryType::DescendantsHashes, + }, + } +} + +fn map_storage_result( + item: subxt_chain::StorageResult, +) -> Result { + let mut result = StorageResultItem { + key: item.key.0, + value: None, + hash: None, + closest_descendant_merkle_value: None, + }; + match item.result { + subxt_chain::StorageResultType::Value(value) => result.value = Some(value.0), + subxt_chain::StorageResultType::Hash(hash) => result.hash = Some(hash.0), + subxt_chain::StorageResultType::ClosestDescendantMerkleValue(value) => { + result.closest_descendant_merkle_value = Some(value.0); + } + } + Ok(result) +} + +fn hash_from_bytes(method: &'static str, bytes: &[u8]) -> Result { + if bytes.len() != 32 { + return Err(RuntimeFailure::host_failure( + method, + format!("expected 32-byte hash, got {}", bytes.len()), + )); + } + Ok(H256::from_slice(bytes)) +} + +fn hash_to_bytes(hash: H256) -> Vec { + hash.as_bytes().to_vec() +} + +fn rpc_failure(method: &'static str, error: SubxtRpcError) -> RuntimeFailure { + match error { + SubxtRpcError::Client(_) | SubxtRpcError::DisconnectedWillReconnect(_) => { + RuntimeFailure::unavailable(method) + } + error => RuntimeFailure::host_failure(method, error.to_string()), + } +} + +/// Encode a byte slice as a `0x`-prefixed lowercase hex string. +pub(crate) fn encode_hex(value: &[u8]) -> String { + format!("0x{}", hex::encode(value)) +} + +#[cfg(test)] +fn decode_hex(value: &str) -> Result, String> { + hex::decode(value.strip_prefix("0x").unwrap_or(value)).map_err(|_| "invalid hex".to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + use futures::channel::mpsc as fut_mpsc; + use futures::stream::BoxStream; + use std::sync::atomic::{AtomicUsize, Ordering}; + + fn spawner_for_tests() -> Spawner { + #[cfg(not(target_arch = "wasm32"))] + { + crate::subscription::thread_per_subscription_spawner() + } + #[cfg(target_arch = "wasm32")] + { + Arc::new(futures::executor::block_on) + } + } + + /// Provider that echoes a canned response for every request it sees, + /// driven by a `respond` closure. The closure receives each json-rpc + /// request string and returns the response string the test wants the + /// server to deliver. Keeps the response loop synchronized with the + /// request stream so there is no race between `send` and the response + /// loop draining frames before pending requests have registered. + type Responder = Arc Option + Send + Sync>; + + struct ScriptedProvider { + respond: Responder, + sent: Arc>>, + sender: Arc>>>, + receiver: Arc>>>, + connect_calls: Arc, + } + + impl ScriptedProvider { + fn new(respond: F) -> Self + where + F: Fn(&str) -> Option + Send + Sync + 'static, + { + let (tx, rx) = fut_mpsc::unbounded(); + Self { + respond: Arc::new(respond), + sent: Arc::new(Mutex::new(Vec::new())), + sender: Arc::new(Mutex::new(Some(tx))), + receiver: Arc::new(Mutex::new(Some(rx))), + connect_calls: Arc::new(AtomicUsize::new(0)), + } + } + } + + struct ScriptedConnection { + respond: Responder, + sent: Arc>>, + sender: Arc>>>, + receiver: Mutex>>, + } + + impl JsonRpcConnection for ScriptedConnection { + fn send(&self, request: String) { + self.sent.lock().unwrap().push(request.clone()); + if let Some(response) = (self.respond)(&request) + && let Some(sender) = self.sender.lock().unwrap().as_ref() + { + let _ = sender.unbounded_send(response); + } + } + fn responses(&self) -> BoxStream<'static, String> { + let rx = self + .receiver + .lock() + .unwrap() + .take() + .expect("ScriptedConnection::responses called twice"); + rx.boxed() + } + + fn close(&self) { + self.sender.lock().unwrap().take(); + } + } + + #[async_trait] + impl RuntimeChainProvider for ScriptedProvider { + async fn connect( + &self, + _genesis_hash: Vec, + ) -> Result, RuntimeFailure> { + self.connect_calls.fetch_add(1, Ordering::SeqCst); + let receiver = self.receiver.lock().unwrap().take(); + Ok(Arc::new(ScriptedConnection { + respond: self.respond.clone(), + sent: self.sent.clone(), + sender: self.sender.clone(), + receiver: Mutex::new(receiver), + })) + } + } + + /// Clone of the scripted notification sender, used by tests to push + /// asynchronous frames (e.g. follow events) into the response stream. + fn notification_sender(provider: &ScriptedProvider) -> fut_mpsc::UnboundedSender { + provider + .sender + .lock() + .unwrap() + .as_ref() + .expect("notification sender available") + .clone() + } + + #[test] + fn unavailable_provider_surfaces_failure() { + let provider = Arc::new(UnavailableChainProvider); + let result = futures::executor::block_on(provider.connect(vec![0u8; 32])); + let err = match result { + Ok(_) => panic!("expected failure"), + Err(err) => err, + }; + assert_eq!(err.kind(), RuntimeFailureKind::Unavailable); + assert_eq!(err.method(), "remote_chain_connect"); + } + + /// Find the json-rpc request id of the just-sent frame so the scripted + /// responder can mirror it back to the dispatcher. + fn extract_id(request: &str) -> Option { + let value: Value = serde_json::from_str(request).ok()?; + value.get("id")?.as_str().map(ToString::to_string) + } + + fn wait_for_sent( + provider: &ScriptedProvider, + predicate: impl Fn(&[String]) -> bool, + ) -> Vec { + for _ in 0..500 { + let sent = provider.sent.lock().unwrap().clone(); + if predicate(&sent) { + return sent; + } + std::thread::sleep(std::time::Duration::from_millis(10)); + } + provider.sent.lock().unwrap().clone() + } + + #[test] + fn header_request_reuses_existing_follow() { + let provider = Arc::new(ScriptedProvider::new(|request| { + let id = extract_id(request).unwrap(); + if request.contains("chainHead_v1_follow") { + Some(format!( + r#"{{"jsonrpc":"2.0","id":"{id}","result":"REMOTE-FOLLOW"}}"# + )) + } else if request.contains("chainHead_v1_header") { + Some(format!( + r#"{{"jsonrpc":"2.0","id":"{id}","result":"0xdeadbeef"}}"# + )) + } else { + None + } + })); + let runtime = ChainRuntime::new(provider.clone(), spawner_for_tests()); + let _follow_stream = runtime.remote_chain_head_follow( + "local-follow".to_string(), + RemoteChainHeadFollowRequest { + genesis_hash: vec![0u8; 32], + with_runtime: false, + }, + ); + let sent = wait_for_sent(&provider, |sent| { + sent.iter() + .any(|request| request.contains("chainHead_v1_follow")) + }); + assert!( + sent.iter() + .any(|request| request.contains("chainHead_v1_follow")), + "follow setup did not start; sent: {sent:?}", + ); + + let response = futures::executor::block_on(runtime.remote_chain_head_header( + RemoteChainHeadHeaderRequest { + genesis_hash: vec![0u8; 32], + follow_subscription_id: "local-follow".to_string(), + hash: vec![1u8; 32], + }, + )) + .expect("ok response"); + assert_eq!(response.header, Some(vec![0xde, 0xad, 0xbe, 0xef])); + assert_eq!(provider.connect_calls.load(Ordering::SeqCst), 1); + let sent = provider.sent.lock().unwrap().clone(); + assert_eq!(sent.len(), 2); + assert!(sent[0].contains("chainHead_v1_follow")); + assert!(sent[1].contains("chainHead_v1_header")); + } + + #[test] + fn header_request_rejects_unknown_follow_id_without_opening_follow() { + let provider = Arc::new(ScriptedProvider::new(|request| { + let id = extract_id(request).unwrap(); + if request.contains("chainHead_v1_follow") { + Some(format!( + r#"{{"jsonrpc":"2.0","id":"{id}","result":"REMOTE-FOLLOW"}}"# + )) + } else if request.contains("chainHead_v1_header") { + Some(format!( + r#"{{"jsonrpc":"2.0","id":"{id}","result":"0xdeadbeef"}}"# + )) + } else { + None + } + })); + let runtime = ChainRuntime::new(provider.clone(), spawner_for_tests()); + + let err = futures::executor::block_on(runtime.remote_chain_head_header( + RemoteChainHeadHeaderRequest { + genesis_hash: vec![0u8; 32], + follow_subscription_id: "missing-follow".to_string(), + hash: vec![1u8; 32], + }, + )) + .expect_err("unknown follow id should fail"); + + assert_eq!(err.kind(), RuntimeFailureKind::HostFailure); + assert!( + err.reason().contains("unknown follow subscription id"), + "unexpected error: {}", + err.reason(), + ); + assert!(provider.sent.lock().unwrap().is_empty()); + } + + /// Two concurrent calls for the same chain must share one provider + /// `connect` instead of racing the first connection and orphaning the + /// loser. + #[test] + fn concurrent_connection_for_shares_one_connect() { + struct SlowConnectProvider { + inner: ScriptedProvider, + } + + #[async_trait] + impl RuntimeChainProvider for SlowConnectProvider { + async fn connect( + &self, + genesis_hash: Vec, + ) -> Result, RuntimeFailure> { + futures_timer::Delay::new(std::time::Duration::from_millis(50)).await; + self.inner.connect(genesis_hash).await + } + } + + let provider = Arc::new(SlowConnectProvider { + inner: ScriptedProvider::new(|request| { + let id = extract_id(request).unwrap(); + if request.contains("chainSpec_v1_chainName") { + Some(format!( + r#"{{"jsonrpc":"2.0","id":"{id}","result":"Polkadot"}}"# + )) + } else { + None + } + }), + }); + let runtime = ChainRuntime::new(provider.clone(), spawner_for_tests()); + + let (first, second) = futures::executor::block_on(futures::future::join( + runtime.remote_chain_spec_chain_name(vec![0u8; 32]), + runtime.remote_chain_spec_chain_name(vec![0u8; 32]), + )); + + assert_eq!(first.unwrap().chain_name, "Polkadot"); + assert_eq!(second.unwrap().chain_name, "Polkadot"); + assert_eq!(provider.inner.connect_calls.load(Ordering::SeqCst), 1); + } + + #[test] + fn unknown_genesis_chain_spec_propagates_failure() { + let provider = Arc::new(UnavailableChainProvider); + let runtime = ChainRuntime::new(provider, spawner_for_tests()); + let err = match futures::executor::block_on( + runtime.remote_chain_spec_chain_name(vec![0u8; 32]), + ) { + Ok(_) => panic!("expected failure"), + Err(err) => err, + }; + assert_eq!(err.kind(), RuntimeFailureKind::Unavailable); + assert_eq!(err.method(), "remote_chain_spec_chain_name"); + } + + #[test] + fn json_rpc_error_becomes_host_failure() { + let provider = Arc::new(ScriptedProvider::new(|request| { + let id = extract_id(request).unwrap(); + Some(format!( + r#"{{"jsonrpc":"2.0","id":"{id}","error":{{"code":-32601,"message":"method not found"}}}}"# + )) + })); + let runtime = ChainRuntime::new(provider, spawner_for_tests()); + let err = match futures::executor::block_on( + runtime.remote_chain_spec_chain_name(vec![0u8; 32]), + ) { + Ok(_) => panic!("expected failure"), + Err(err) => err, + }; + assert_eq!(err.kind(), RuntimeFailureKind::HostFailure); + assert!( + err.reason().contains("method not found"), + "unexpected reason: {}", + err.reason() + ); + } + + #[test] + fn follow_event_initialized_translates_to_v01_item() { + // Answer `chainHead_v1_follow` through the synchronized responder so + // the ack cannot reach the response loop before the pending request + // is registered. + let provider = Arc::new(ScriptedProvider::new(|request| { + let id = extract_id(request).unwrap(); + if request.contains("chainHead_v1_follow") { + Some(format!( + r#"{{"jsonrpc":"2.0","id":"{id}","result":"REMOTE-FOLLOW"}}"# + )) + } else { + None + } + })); + let runtime = ChainRuntime::new(provider.clone(), spawner_for_tests()); + + let mut stream = runtime.remote_chain_head_follow( + "local-follow".to_string(), + RemoteChainHeadFollowRequest { + genesis_hash: vec![0u8; 32], + with_runtime: false, + }, + ); + + // Push follow events keyed by remote subscription id. Events that + // land before the follow ack are buffered by remote id and replayed + // once the follow is established. + let tx = notification_sender(&provider); + tx.unbounded_send( + r#"{"jsonrpc":"2.0","method":"chainHead_v1_followEvent","params":{"subscription":"REMOTE-FOLLOW","result":{"event":"initialized","finalizedBlockHashes":["0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"]}}}"# + .to_string(), + ).unwrap(); + tx.unbounded_send( + r#"{"jsonrpc":"2.0","method":"chainHead_v1_followEvent","params":{"subscription":"REMOTE-FOLLOW","result":{"event":"stop"}}}"# + .to_string(), + ).unwrap(); + + let items: Vec<_> = futures::executor::block_on(async { + let mut out = Vec::new(); + while let Some(item) = stream.next().await { + let is_stop = matches!(item, RemoteChainHeadFollowItem::Stop); + out.push(item); + if is_stop { + break; + } + } + out + }); + + match &items[0] { + RemoteChainHeadFollowItem::Initialized { + finalized_block_hashes, + finalized_block_runtime, + } => { + assert_eq!(finalized_block_hashes, &vec![vec![0xaa; 32]]); + assert!(finalized_block_runtime.is_none()); + } + other => panic!("expected Initialized, got {other:?}"), + } + assert!(matches!(items[1], RemoteChainHeadFollowItem::Stop)); + } + + #[cfg_attr(target_arch = "wasm32", ignore)] + #[test] + fn drop_follow_stream_sends_unfollow() { + let provider = Arc::new(ScriptedProvider::new(|request| { + let id = extract_id(request).unwrap(); + if request.contains("chainHead_v1_follow") { + Some(format!( + r#"{{"jsonrpc":"2.0","id":"{id}","result":"REMOTE-FOLLOW"}}"# + )) + } else { + None + } + })); + let runtime = ChainRuntime::new(provider.clone(), spawner_for_tests()); + let sent = provider.sent.clone(); + + let stream = runtime.remote_chain_head_follow( + "local-follow".to_string(), + RemoteChainHeadFollowRequest { + genesis_hash: vec![0u8; 32], + with_runtime: false, + }, + ); + + // Wait until the follow setup roundtrips and lands in `sent`. + // Generous timeout so the test stays robust under loaded CI runners + // where the spawner can be slow to schedule the request task. + for _ in 0..500 { + if !sent.lock().unwrap().is_empty() { + break; + } + std::thread::sleep(std::time::Duration::from_millis(10)); + } + + drop(stream); + + // Wait for the cleanup task to run and emit the unfollow request. + for _ in 0..500 { + if sent.lock().unwrap().len() >= 2 { + break; + } + std::thread::sleep(std::time::Duration::from_millis(10)); + } + + let messages = sent.lock().unwrap().clone(); + assert!( + messages.iter().any(|m| m.contains("chainHead_v1_unfollow")), + "unfollow not sent; messages: {messages:?}", + ); + } + + #[test] + fn encode_hex_round_trip() { + let bytes = vec![0x00u8, 0x12, 0xab, 0xff]; + let s = encode_hex(&bytes); + assert_eq!(s, "0x0012abff"); + assert_eq!(decode_hex(&s).unwrap(), bytes); + } + + #[test] + fn parse_runtime_type_valid_sorts_apis() { + let runtime_type = map_runtime_event(subxt_chain::RuntimeEvent::Valid( + subxt_chain::RuntimeVersionEvent { + spec: subxt_chain::RuntimeSpec { + spec_name: "polkadot".to_string(), + impl_name: "parity-polkadot".to_string(), + spec_version: 1000, + impl_version: 1, + transaction_version: 24, + apis: HashMap::from([("0xbeef".to_string(), 2), ("0xbabe".to_string(), 4)]), + }, + }, + )) + .unwrap(); + match runtime_type { + RuntimeType::Valid(spec) => { + assert_eq!(spec.apis.len(), 2); + assert_eq!(spec.apis[0].name, "0xbabe"); + assert_eq!(spec.apis[1].name, "0xbeef"); + assert_eq!(spec.transaction_version, Some(24)); + } + other => panic!("expected Valid, got {other:?}"), + } + } +} diff --git a/rust/crates/truapi-server/src/dispatcher.rs b/rust/crates/truapi-server/src/dispatcher.rs new file mode 100644 index 00000000..da27c8c0 --- /dev/null +++ b/rust/crates/truapi-server/src/dispatcher.rs @@ -0,0 +1,272 @@ +//! Request dispatcher. +//! +//! Routes incoming frames to the appropriate trait method based on the +//! numeric wire discriminant. The handler set is registered by the +//! auto-generated [`crate::generated::dispatcher::register`] function; this +//! module provides the framework that owns the registration tables and the +//! routing logic. + +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use futures::future::LocalBoxFuture; +use tracing::instrument; + +use crate::frame::{Payload, ProtocolMessage}; +use crate::generated::wire_table::{RequestFrameIds, SubscriptionFrameIds}; +use crate::subscription::{Spawner, SubscriptionManager, SubscriptionStream}; +use crate::transport::Transport; + +/// A handler for a request-response method. The returned future is not +/// required to be `Send` because the truapi trait uses `async fn`, whose +/// auto-Send-ness is not guaranteed. The `request_id` is the per-frame +/// identifier; handlers thread it into the `CallContext` so trait methods +/// can correlate logs/cancellation with the originating request. On the +/// error path handlers return the complete SCALE-encoded response payload. +pub type RequestHandler = + Arc) -> LocalBoxFuture<'static, Result, Vec>> + Send + Sync>; + +/// A handler for a subscription method. On the error path the handler returns +/// the complete SCALE-encoded `_interrupt` payload. +pub type SubscriptionHandler = Arc< + dyn Fn(String, Vec) -> LocalBoxFuture<'static, Result>> + + Send + + Sync, +>; + +/// A registered request handler plus the discriminants it replies on. +pub struct RequestEntry { + ids: RequestFrameIds, + handler: RequestHandler, +} + +/// A registered subscription handler plus the discriminants its frames carry. +pub struct SubscriptionEntry { + ids: SubscriptionFrameIds, + handler: SubscriptionHandler, +} + +/// Routes incoming protocol messages to registered handlers, keyed on the +/// numeric wire discriminant. +pub struct Dispatcher { + by_request: HashMap, + by_start: HashMap, + stop_ids: HashSet, + subscriptions: SubscriptionManager, +} + +impl Dispatcher { + /// Construct a dispatcher whose subscriptions are driven on `spawner`. + pub fn new(spawner: Spawner) -> Self { + Self { + by_request: HashMap::new(), + by_start: HashMap::new(), + stop_ids: HashSet::new(), + subscriptions: SubscriptionManager::new(spawner), + } + } + + /// Register a request-response handler, keyed on `ids.request_id`. Returns + /// the previously registered entry if any; callers (the generated + /// `dispatcher::register`) should treat `Some` as a programming error + /// since each request id must own exactly one handler. + pub fn on_request(&mut self, ids: RequestFrameIds, handler: F) -> Option + where + F: Fn(String, Vec) -> LocalBoxFuture<'static, Result, Vec>> + + Send + + Sync + + 'static, + { + self.by_request.insert( + ids.request_id, + RequestEntry { + ids, + handler: Arc::new(handler), + }, + ) + } + + /// Register a subscription handler, keyed on `ids.start_id`, and record + /// `ids.stop_id` so a matching `_stop` frame tears the subscription down. + /// Returns the previously registered entry if any. + pub fn on_subscription( + &mut self, + ids: SubscriptionFrameIds, + handler: F, + ) -> Option + where + F: Fn(String, Vec) -> LocalBoxFuture<'static, Result>> + + Send + + Sync + + 'static, + { + self.stop_ids.insert(ids.stop_id); + self.by_start.insert( + ids.start_id, + SubscriptionEntry { + ids, + handler: Arc::new(handler), + }, + ) + } + + /// Process an incoming protocol message, sending any responses or + /// subscription frames through `transport`. A discriminant with no + /// registered handler is dropped. + #[instrument(skip_all, fields(runtime.method = "dispatcher.dispatch"))] + pub async fn dispatch(&self, message: ProtocolMessage, transport: Arc) { + let id = message.payload.id; + + if let Some(entry) = self.by_request.get(&id) { + let request_id = message.request_id.clone(); + let value = (entry.handler)(request_id, message.payload.value) + .await + .unwrap_or_else(|value| value); + transport.send(ProtocolMessage { + request_id: message.request_id, + payload: Payload { + id: entry.ids.response_id, + value, + }, + }); + } else if let Some(entry) = self.by_start.get(&id) { + // Reserve the slot before awaiting the handler so a `_stop` + // arriving while the handler resolves cancels the pending + // subscription instead of racing the registration. + let token = self.subscriptions.reserve(message.request_id.clone()); + let request_id = message.request_id.clone(); + match (entry.handler)(request_id, message.payload.value).await { + Ok(stream) => { + self.subscriptions.activate( + token, + entry.ids.receive_id, + entry.ids.interrupt_id, + stream, + transport, + ); + } + Err(err_bytes) => { + self.subscriptions.cancel_reservation(token); + transport.send(ProtocolMessage { + request_id: message.request_id, + payload: Payload { + id: entry.ids.interrupt_id, + value: err_bytes, + }, + }); + } + } + } else if self.stop_ids.contains(&id) { + self.subscriptions.handle_stop(&message.request_id); + } + // Unknown discriminant: drop. Response / receive / interrupt frames are + // handled by the client side and never registered here. + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + fn test_spawner() -> Spawner { + #[cfg(not(target_arch = "wasm32"))] + { + crate::subscription::thread_per_subscription_spawner() + } + #[cfg(target_arch = "wasm32")] + { + Arc::new(futures::executor::block_on) + } + } + + #[derive(Default)] + struct RecordingTransport { + sent: Mutex>, + } + + impl RecordingTransport { + fn sent(&self) -> Vec { + self.sent.lock().unwrap().clone() + } + } + + impl Transport for RecordingTransport { + fn send(&self, message: ProtocolMessage) { + self.sent.lock().unwrap().push(message); + } + fn on_message( + &self, + _handler: Box, + ) -> Box { + Box::new(|| {}) + } + } + + fn make_frame(id: u8, value: Vec) -> ProtocolMessage { + ProtocolMessage { + request_id: "p:1".into(), + payload: Payload { id, value }, + } + } + + /// A frame whose discriminant has no registered handler is dropped: no + /// response, no interrupt. (In production `register` registers every wire + /// method, so this only happens for malformed or client-bound ids.) + #[test] + fn dispatch_unregistered_id_sends_nothing() { + let dispatcher = Dispatcher::new(test_spawner()); + let transport = Arc::new(RecordingTransport::default()); + let transport_dyn: Arc = transport.clone(); + let frame = make_frame(250, Vec::new()); + futures::executor::block_on(dispatcher.dispatch(frame, transport_dyn)); + assert!( + transport.sent().is_empty(), + "an unregistered discriminant must produce no frame" + ); + } + + /// A handler error already owns the complete response payload. The + /// dispatcher only routes it to the registered response id. + #[test] + fn dispatch_request_handler_error_emits_response_payload() { + let mut dispatcher = Dispatcher::new(test_spawner()); + let ids = RequestFrameIds { + request_id: 200, + response_id: 201, + }; + dispatcher.on_request(ids, |_request_id, _bytes| { + Box::pin(async move { Err(vec![9, 8, 7]) }) + }); + let transport = Arc::new(RecordingTransport::default()); + let frame = make_frame(200, Vec::new()); + futures::executor::block_on(dispatcher.dispatch(frame, transport.clone())); + let sent = transport.sent(); + assert_eq!(sent.len(), 1, "exactly one response expected"); + assert_eq!(sent[0].payload.id, 201); + assert_eq!(sent[0].payload.value, vec![9, 8, 7]); + } + + /// Registering two handlers under the same key must not silently + /// overwrite. The contract chosen here is "loud": `on_request` + /// returns the previous handler, so callers can detect collisions. + #[test] + fn register_request_twice_returns_previous_handler() { + let mut dispatcher = Dispatcher::new(test_spawner()); + let ids = RequestFrameIds { + request_id: 200, + response_id: 201, + }; + let prev = dispatcher.on_request(ids, |_request_id, _bytes| { + Box::pin(async move { Ok(Vec::new()) }) + }); + assert!(prev.is_none(), "first registration has no predecessor"); + let prev = dispatcher.on_request(ids, |_request_id, _bytes| { + Box::pin(async move { Ok(Vec::new()) }) + }); + assert!( + prev.is_some(), + "second registration must return the previous handler" + ); + } +} diff --git a/rust/crates/truapi-server/src/frame.rs b/rust/crates/truapi-server/src/frame.rs new file mode 100644 index 00000000..c324200a --- /dev/null +++ b/rust/crates/truapi-server/src/frame.rs @@ -0,0 +1,433 @@ +//! Wire protocol frame types. +//! +//! Every message on the wire is a `ProtocolMessage` containing a `requestId` +//! and a `payload`. On the wire the envelope is: +//! +//! ```text +//! [requestId: SCALE str][discriminant: u8][payload bytes...] +//! ``` +//! +//! The discriminant maps to a method/kind slot via the auto-generated +//! [`crate::generated::wire_table::WIRE_TABLE`]. Method ordering is part of +//! the wire protocol; only ever append to the table. The payload bytes are +//! the SCALE-encoded inner value, inlined without a length prefix. +//! +//! In-memory we keep the numeric id directly so dispatch does not need to +//! reconstruct string action tags on every frame. + +use parity_scale_codec::{Decode, Encode, Error as CodecError, Input, Output}; + +use crate::generated::wire_table::{RequestFrameIds, SubscriptionFrameIds, WIRE_TABLE, WireKind}; + +/// Top-level wire message. Encoded as `[requestId][discriminant][bytes]`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ProtocolMessage { + /// Per-message identifier carried by both halves of a request/response. + pub request_id: String, + /// Tagged payload describing the frame kind and SCALE bytes. + pub payload: Payload, +} + +/// Encode `Versioned>` from a versioned success wrapper. +/// +/// TODO(shared-core-wire): once all hosts use the shared Rust core/generated +/// client stack, remove this dispatcher compatibility rewrite and encode the +/// trait return shape directly: `Result>`. +pub fn encode_versioned_ok_payload(value: T) -> Vec { + encode_versioned_result_payload(value, 0) +} + +/// Encode `Versioned>` for methods whose success type is unit. +pub fn encode_versioned_unit_ok_payload(version: u8) -> Vec { + vec![version_index(version), 0] +} + +/// Encode `Versioned>` from an ordinary error value. +pub fn encode_versioned_err_payload(value: T, version: u8) -> Vec { + let encoded = value.encode(); + let mut out = Vec::with_capacity(encoded.len() + 2); + out.push(version_index(version)); + out.push(1); + out.extend_from_slice(&encoded); + out +} + +/// Encode `Result<(), _>` for unversioned methods whose success type is unit. +pub fn encode_raw_unit_ok_payload() -> Vec { + Ok::<(), ()>(()).encode() +} + +/// Encode `Result<(), Err>` for unversioned methods from an ordinary error value. +pub fn encode_raw_err_payload(value: T) -> Vec { + Err::<(), T>(value).encode() +} + +/// Encode a versioned subscription interrupt payload from an ordinary error. +pub fn encode_versioned_interrupt_payload(value: T, version: u8) -> Vec { + let encoded = value.encode(); + let mut out = Vec::with_capacity(encoded.len() + 1); + out.push(version_index(version)); + out.extend_from_slice(&encoded); + out +} + +impl Encode for ProtocolMessage { + fn encode_to(&self, dest: &mut T) { + self.request_id.encode_to(dest); + self.payload.id.encode_to(dest); + // Payload bytes are inlined; the receiver reads "until end of frame" + // because each transport frame is one ProtocolMessage. This matches + // the public versioned enum transport shape (variant payload encoded + // inline, no length prefix), and constrains us to slice-shaped + // `Input`s on the decode side. + dest.write(&self.payload.value); + } +} + +// Callers must hand `Decode` a slice-shaped `Input`; streaming inputs cannot +// decode this envelope because the payload has no length prefix. +impl Decode for ProtocolMessage { + fn decode(input: &mut I) -> Result { + let request_id = String::decode(input)?; + let id = u8::decode(input)?; + // Unknown ids are accepted here; routing is deferred to dispatch, + // which drops frames with no registered handler. + let remaining = input + .remaining_len()? + .ok_or_else(|| CodecError::from("frame input must report remaining length"))?; + let mut value = vec![0u8; remaining]; + input.read(&mut value)?; + Ok(ProtocolMessage { + request_id, + payload: Payload { id, value }, + }) + } +} + +/// Tagged payload. The `id` is the wire discriminant from +/// [`crate::generated::wire_table::WIRE_TABLE`], identifying the frame's method +/// and kind (request/response/start/stop/interrupt/receive). +/// +/// Note: `Payload` does not derive `Encode`/`Decode` directly; the wire +/// representation lives on [`ProtocolMessage`]. `Payload` is kept as a plain +/// data type for in-memory dispatch (key on `id`, value bytes already +/// SCALE-encoded by the call site). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Payload { + /// Wire discriminant identifying the frame's method and kind. + pub id: u8, + /// SCALE-encoded inner value bytes. + pub value: Vec, +} + +/// Request discriminants for a request method, by name. Walks the generated +/// [`WIRE_TABLE`]; intended for tests and embedders that route by method +/// string rather than holding the generated const. +pub fn request_ids(method: &str) -> Option { + WIRE_TABLE + .iter() + .find_map(|entry| match (&entry.kind, entry.method == method) { + (WireKind::Request(ids), true) => Some(*ids), + _ => None, + }) +} + +/// Subscription discriminants for a subscription method, by name. Walks the +/// generated [`WIRE_TABLE`]. +pub fn subscription_ids(method: &str) -> Option { + WIRE_TABLE + .iter() + .find_map(|entry| match (&entry.kind, entry.method == method) { + (WireKind::Subscription(ids), true) => Some(*ids), + _ => None, + }) +} + +/// Unique ID generator with a prefix. +pub struct IdFactory { + prefix: String, + counter: u64, +} + +impl IdFactory { + /// Build a factory that mints IDs of the form `{prefix}{counter}`. + pub fn new(prefix: impl Into) -> Self { + Self { + prefix: prefix.into(), + counter: 0, + } + } + + /// Return the next ID, monotonically increasing from 1. + pub fn next_id(&mut self) -> String { + self.counter += 1; + format!("{}{}", self.prefix, self.counter) + } +} + +fn encode_versioned_result_payload(value: T, result_index: u8) -> Vec { + let encoded = value.encode(); + let Some((&version_index, inner)) = encoded.split_first() else { + return vec![result_index]; + }; + let mut out = Vec::with_capacity(encoded.len() + 1); + out.push(version_index); + out.push(result_index); + out.extend_from_slice(inner); + out +} + +fn version_index(version: u8) -> u8 { + version.saturating_sub(1) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Encode)] + enum TestVersioned { + V1(T), + } + + fn build(id: u8, value: Vec) -> ProtocolMessage { + ProtocolMessage { + request_id: "p:1".to_string(), + payload: Payload { id, value }, + } + } + + fn expected_wire(id: u8, value: &[u8]) -> Vec { + let mut out = Vec::new(); + "p:1".to_string().encode_to(&mut out); + out.push(id); + out.extend_from_slice(value); + out + } + + #[test] + fn handshake_request_encodes_with_discriminant_zero() { + // SCALE-encoded HostHandshakeRequest::V1(1u8) = [0u8 variant][1u8 codec_version] + let inner: Vec = vec![0x00, 0x01]; + let msg = build(0, inner.clone()); + assert_eq!(msg.encode(), expected_wire(0, &inner)); + } + + #[test] + fn get_account_request_encodes_with_discriminant_22() { + let mut inner = vec![0x00]; // V1 variant + "foo".to_string().encode_to(&mut inner); + 0u32.encode_to(&mut inner); + let msg = build(22, inner.clone()); + assert_eq!(msg.encode(), expected_wire(22, &inner)); + } + + #[test] + fn round_trip_preserves_id_and_value() { + let inner: Vec = vec![0x00, 0x42, 0xab, 0xcd]; + let msg = build(12, inner.clone()); + let decoded = ProtocolMessage::decode(&mut &msg.encode()[..]).expect("decode"); + assert_eq!(decoded, msg); + } + + /// An unknown discriminant is no longer rejected at decode; routing is + /// deferred to dispatch (which drops frames with no registered handler). + #[test] + fn unknown_discriminant_decodes_ok() { + let mut bytes = Vec::new(); + "p:1".to_string().encode_to(&mut bytes); + bytes.push(250); // far outside the populated range + bytes.extend_from_slice(&[0xaa, 0xbb]); + let decoded = ProtocolMessage::decode(&mut &bytes[..]).expect("unknown id must decode"); + assert_eq!(decoded.payload.id, 250); + assert_eq!(decoded.payload.value, vec![0xaa, 0xbb]); + } + + /// All four subscription phases round-trip through the codec. Catches a + /// regression where `Decode` mishandles a frame whose payload is empty for + /// `_stop` / `_interrupt` (no inner data) but non-empty for `_start` / + /// `_receive`. The ids are the `account_connection_status_subscribe` + /// quartet (18..=21). + #[test] + fn subscription_phases_round_trip_through_codec() { + let cases: &[(u8, Vec)] = &[ + (18, vec![0x00, 0xaa]), // start + (19, Vec::new()), // stop + (20, Vec::new()), // interrupt + (21, vec![0x01, 0x02, 0x03, 0x04]), // receive + ]; + for (id, value) in cases { + let msg = build(*id, value.clone()); + let bytes = msg.encode(); + assert_eq!( + bytes, + expected_wire(*id, value), + "encode mismatch for id {id}" + ); + let decoded = ProtocolMessage::decode(&mut &bytes[..]).expect("decode"); + assert_eq!(decoded, msg, "round-trip mismatch for id {id}"); + } + } + + /// `request_ids` / `subscription_ids` resolve a method name to its + /// generated discriminants without going through the codec. + #[test] + fn id_helpers_resolve_known_methods() { + let handshake = request_ids("system_handshake").expect("known request method"); + assert_eq!(handshake.request_id, 0); + assert_eq!(handshake.response_id, 1); + + let get_account = request_ids("account_get_account").expect("known request method"); + assert_eq!(get_account.request_id, 22); + + let sub = + subscription_ids("account_connection_status_subscribe").expect("known subscription"); + assert_eq!(sub.start_id, 18); + assert_eq!(sub.stop_id, 19); + assert_eq!(sub.interrupt_id, 20); + assert_eq!(sub.receive_id, 21); + + // A request method is not a subscription and vice versa. + assert!(subscription_ids("system_handshake").is_none()); + assert!(request_ids("account_connection_status_subscribe").is_none()); + assert!(request_ids("not_a_method").is_none()); + } + + /// Genuine zero-byte payload (e.g. unit-typed response). `Decode` must + /// handle `remaining_len == 0` without erroring or reading past EOF. + #[test] + fn empty_payload_round_trips() { + // local_storage_clear_response = 17. + let msg = build(17, Vec::new()); + let bytes = msg.encode(); + // [SCALE compact-len 0x0c][p][:][1][u8 17] = 4 + 1 = 5 bytes total + assert_eq!(bytes.len(), 5); + let decoded = ProtocolMessage::decode(&mut &bytes[..]).expect("decode"); + assert_eq!(decoded, msg); + } + + /// Compact-len mode 1 kicks in for strings with length 64..=16383. Make + /// sure the codec handles a long requestId without truncation. + #[test] + fn long_request_id_round_trips() { + let long_id: String = "x".repeat(200); + let msg = ProtocolMessage { + request_id: long_id, + payload: Payload { + id: 22, + value: vec![0x00, 0xab, 0xcd], + }, + }; + let decoded = ProtocolMessage::decode(&mut &msg.encode()[..]).expect("decode"); + assert_eq!(decoded, msg); + } + + /// Truncated frames must surface a `CodecError`, not panic. + #[test] + fn truncated_frames_error_cleanly() { + // Empty buffer. + assert!(ProtocolMessage::decode(&mut &[][..]).is_err()); + // Just the requestId, no discriminant byte. + let mut only_request_id = Vec::new(); + "p:1".to_string().encode_to(&mut only_request_id); + assert!(ProtocolMessage::decode(&mut &only_request_id[..]).is_err()); + // RequestId header claims length=200 but the buffer is far shorter. + let truncated_str_header = [200u8 << 2, 0x61, 0x62, 0x63]; + assert!(ProtocolMessage::decode(&mut &truncated_str_header[..]).is_err()); + } + + /// Empty requestId (zero-length string) is a valid SCALE-encoded `str` + /// (compact-len 0, no body). The codec must round-trip it without + /// confusing length-0 with EOF. + #[test] + fn empty_request_id_round_trips() { + let msg = ProtocolMessage { + request_id: String::new(), + payload: Payload { + id: 22, + value: vec![0x00, 0x01, 0x02], + }, + }; + let bytes = msg.encode(); + // [SCALE compact-len 0 = 0x00][discriminant][payload] + assert_eq!(bytes[0], 0x00); + let decoded = ProtocolMessage::decode(&mut &bytes[..]).expect("decode"); + assert_eq!(decoded, msg); + } + + /// Unicode characters round-trip through SCALE string encoding. + #[test] + fn unicode_request_id_round_trips() { + let msg = ProtocolMessage { + request_id: "héllo-世界-🦀".to_string(), + payload: Payload { + id: 22, + value: vec![0x00, 0x01], + }, + }; + let decoded = ProtocolMessage::decode(&mut &msg.encode()[..]).expect("decode"); + assert_eq!(decoded, msg); + } + + /// Large payload (>64KiB) round-trips. Catches buffer-size assumptions + /// in the inline-payload read path. + #[test] + fn large_payload_round_trips() { + let big = vec![0xa5u8; 100 * 1024]; + let msg = build(22, big); + let decoded = ProtocolMessage::decode(&mut &msg.encode()[..]).expect("decode"); + assert_eq!(decoded, msg); + } + + #[test] + fn encode_versioned_unit_ok_payload_wraps_unit_success() { + assert_eq!(encode_versioned_unit_ok_payload(1), vec![0u8, 0u8]); + assert_eq!(encode_versioned_unit_ok_payload(0), vec![0u8, 0u8]); + } + + #[test] + fn encode_versioned_ok_payload_wraps_success_values() { + let mut expected = vec![0u8, 0u8]; + 7u32.encode_to(&mut expected); + assert_eq!( + encode_versioned_ok_payload(TestVersioned::V1(7u32)), + expected + ); + } + + #[test] + fn encode_versioned_err_payload_wraps_error_values() { + let mut expected = vec![0u8, 1u8]; + 9u32.encode_to(&mut expected); + assert_eq!(encode_versioned_err_payload(9u32, 1), expected); + } + + #[test] + fn encode_versioned_interrupt_payload_wraps_error_values() { + let mut expected = vec![1u8]; + 9u32.encode_to(&mut expected); + assert_eq!(encode_versioned_interrupt_payload(9u32, 2), expected); + } + + /// IdFactory mints monotonically increasing ids prefixed with the + /// configured string. + #[test] + fn id_factory_minted_ids_are_unique_and_monotonic() { + let mut factory = IdFactory::new("p:"); + assert_eq!(factory.next_id(), "p:1"); + assert_eq!(factory.next_id(), "p:2"); + assert_eq!(factory.next_id(), "p:3"); + } + + /// Two distinct factories each maintain their own counter; minting from + /// one does not advance the other. + #[test] + fn two_factories_dont_share_state() { + let mut a = IdFactory::new("a:"); + let mut b = IdFactory::new("b:"); + assert_eq!(a.next_id(), "a:1"); + assert_eq!(b.next_id(), "b:1"); + assert_eq!(a.next_id(), "a:2"); + assert_eq!(b.next_id(), "b:2"); + } +} diff --git a/rust/crates/truapi-server/src/host_rpc_client.rs b/rust/crates/truapi-server/src/host_rpc_client.rs new file mode 100644 index 00000000..88229e03 --- /dev/null +++ b/rust/crates/truapi-server/src/host_rpc_client.rs @@ -0,0 +1,587 @@ +//! `subxt-rpcs` client adapter for host-provided JSON-RPC pipes. +//! +//! The platform owns the physical chain connection. This module owns only the +//! generic JSON-RPC mechanics needed to expose that pipe as a +//! [`subxt_rpcs::RpcClientT`]: request correlation, subscription routing, and +//! best-effort unsubscribe on subscription drop. + +#![allow(dead_code)] + +use core::fmt; +use core::mem; +use core::pin::Pin; +use core::task::{Context, Poll}; +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; + +use futures::channel::{mpsc, oneshot}; +use futures::{FutureExt, pin_mut}; +use futures::{Stream, StreamExt}; +use serde::Serialize; +use serde_json::value::RawValue; +use subxt_rpcs::client::{RawRpcFuture, RawRpcSubscription, RpcClientT}; +use subxt_rpcs::{Error as RpcError, UserError}; +use tracing::instrument; +use truapi_platform::JsonRpcConnection; + +use crate::subscription::Spawner; + +const MAX_BUFFERED_SUBSCRIPTIONS: usize = 64; +const MAX_BUFFERED_ITEMS_PER_SUBSCRIPTION: usize = 256; + +/// JSON-RPC client backed by a host-owned [`JsonRpcConnection`]. +pub(crate) struct HostRpcClient { + inner: Arc, +} + +struct HostRpcClientInner { + connection: Arc, + request_ids: AtomicU64, + user_handles: AtomicUsize, + closed: AtomicBool, + stop_response_loop: Mutex>>, + pending: Mutex>, + subscriptions: Mutex>, + buffered_subscription_items: Mutex>>>, +} + +struct HostRpcClientLease { + inner: Arc, +} + +struct PendingRequest { + tx: oneshot::Sender, RpcError>>, +} + +#[derive(Clone)] +struct SubscriptionSink { + tx: mpsc::UnboundedSender, RpcError>>, +} + +#[derive(Debug)] +struct HostRpcClientError(String); + +impl fmt::Display for HostRpcClientError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} + +impl std::error::Error for HostRpcClientError {} + +#[derive(Serialize)] +struct JsonRpcRequest<'a> { + jsonrpc: &'static str, + id: &'a str, + method: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + params: Option<&'a RawValue>, +} + +impl HostRpcClient { + /// Wrap `connection` and start the response pump on `spawner`. + pub(crate) fn new(connection: Arc, spawner: Spawner) -> Self { + let (stop_response_tx, stop_response_rx) = oneshot::channel(); + let client = Self { + inner: Arc::new(HostRpcClientInner { + connection, + request_ids: AtomicU64::new(1), + user_handles: AtomicUsize::new(1), + closed: AtomicBool::new(false), + stop_response_loop: Mutex::new(Some(stop_response_tx)), + pending: Mutex::new(HashMap::new()), + subscriptions: Mutex::new(HashMap::new()), + buffered_subscription_items: Mutex::new(HashMap::new()), + }), + }; + client.spawn_response_loop(spawner, stop_response_rx); + client + } + + /// Whether the underlying response stream has ended or failed. + pub(crate) fn is_closed(&self) -> bool { + self.inner.closed.load(Ordering::Relaxed) + } + + /// Send a JSON-RPC request without waiting for its response. + /// + /// Used by best-effort notifications where the caller must not block on + /// the remote endpoint acknowledging the request. + pub(crate) fn send_fire_and_forget( + &self, + method: &str, + params: Option>, + ) -> Result<(), RpcError> { + if self.inner.closed.load(Ordering::Relaxed) { + return Err(client_error("json-rpc connection is closed")); + } + let id = self.inner.next_request_id(); + self.inner.send_request(&id, method, params.as_deref()) + } + + fn spawn_response_loop(&self, spawner: Spawner, stop_rx: oneshot::Receiver<()>) { + let inner = self.inner.clone(); + let fut = async move { + let mut responses = inner.connection.responses(); + let stop = stop_rx.fuse(); + pin_mut!(stop); + loop { + futures::select! { + _ = stop => return, + frame = responses.next().fuse() => match frame { + Some(frame) => { + if let Err(error) = inner.handle_frame(&frame) { + inner.close_with_error(error); + return; + } + } + None => { + inner.close_with_error(client_error("json-rpc response stream ended")); + return; + } + } + } + } + }; + (spawner)(fut.boxed()); + } +} + +impl Clone for HostRpcClient { + fn clone(&self) -> Self { + self.inner.retain_user_handle(); + Self { + inner: self.inner.clone(), + } + } +} + +impl Drop for HostRpcClient { + fn drop(&mut self) { + self.inner.release_user_handle(); + } +} + +impl HostRpcClientInner { + fn retain_user_handle(&self) { + self.user_handles.fetch_add(1, Ordering::Relaxed); + } + + fn acquire_lease(self: &Arc) -> HostRpcClientLease { + self.retain_user_handle(); + HostRpcClientLease { + inner: self.clone(), + } + } + + fn release_user_handle(&self) { + let previous = self.user_handles.fetch_sub(1, Ordering::AcqRel); + debug_assert!(previous > 0, "host rpc client handle count underflow"); + if previous == 1 { + self.close_with_error(client_error("json-rpc client dropped")); + } + } + + fn next_request_id(&self) -> String { + format!( + "truapi:{}", + self.request_ids.fetch_add(1, Ordering::Relaxed) + ) + } + + fn send_request( + &self, + id: &str, + method: &str, + params: Option<&RawValue>, + ) -> Result<(), RpcError> { + let request = JsonRpcRequest { + jsonrpc: "2.0", + id, + method, + params, + }; + let encoded = serde_json::to_string(&request).map_err(RpcError::Serialization)?; + self.connection.send(encoded); + Ok(()) + } + + async fn request( + &self, + method: &str, + params: Option>, + ) -> Result, RpcError> { + let id = self.next_request_id(); + let (tx, rx) = oneshot::channel(); + { + let mut pending = self.pending.lock().unwrap(); + if self.closed.load(Ordering::Relaxed) { + return Err(client_error("json-rpc connection is closed")); + } + pending.insert(id.clone(), PendingRequest { tx }); + } + + if let Err(error) = self.send_request(&id, method, params.as_deref()) { + self.pending.lock().unwrap().remove(&id); + return Err(error); + } + + rx.await + .map_err(|_| client_error("json-rpc request was cancelled"))? + } + + async fn subscribe( + self: Arc, + method: &str, + params: Option>, + unsubscribe_method: &str, + lease: HostRpcClientLease, + ) -> Result { + let raw_id = self.request(method, params).await?; + let subscription_id = subscription_id_from_raw(raw_id.as_ref())?; + let (tx, rx) = mpsc::unbounded(); + { + let mut subscriptions = self.subscriptions.lock().unwrap(); + if self.closed.load(Ordering::Relaxed) { + return Err(client_error("json-rpc connection is closed")); + } + subscriptions.insert(subscription_id.clone(), SubscriptionSink { tx: tx.clone() }); + } + + let buffered = self + .buffered_subscription_items + .lock() + .unwrap() + .remove(&subscription_id) + .unwrap_or_default(); + for item in buffered { + let _ = tx.unbounded_send(Ok(item)); + } + + let stream = SubscriptionStream { + inner: rx, + client: self, + _lease: lease, + subscription_id: subscription_id.clone(), + unsubscribe_method: unsubscribe_method.to_string(), + closed: false, + }; + Ok(RawRpcSubscription { + stream: Box::pin(stream), + id: Some(subscription_id), + }) + } + + fn unsubscribe(&self, subscription_id: &str, unsubscribe_method: &str) { + self.subscriptions.lock().unwrap().remove(subscription_id); + if self.closed.load(Ordering::Relaxed) { + return; + } + let id = self.next_request_id(); + let params = RawValue::from_string(format!( + "[{}]", + serde_json::to_string(subscription_id).unwrap_or_else(|_| "\"\"".to_string()) + )); + if let Ok(params) = params { + let _ = self.send_request(&id, unsubscribe_method, Some(params.as_ref())); + } + } + + #[instrument(skip_all, fields(runtime.method = "host_rpc_client.handle_frame"))] + fn handle_frame(&self, frame: &str) -> Result<(), RpcError> { + let value: serde_json::Value = + serde_json::from_str(frame).map_err(RpcError::Deserialization)?; + + if value.get("method").is_some() && value.get("params").is_some() { + self.handle_notification(&value)?; + return Ok(()); + } + + let Some(request_id) = value.get("id").and_then(json_id) else { + return Ok(()); + }; + let Some(pending) = self.pending.lock().unwrap().remove(&request_id) else { + return Ok(()); + }; + + if let Some(result) = value.get("result") { + let raw = raw_value_from_json(result)?; + let _ = pending.tx.send(Ok(raw)); + return Ok(()); + } + + if let Some(error) = value.get("error") { + let _ = pending.tx.send(Err(user_error_from_json(error))); + return Ok(()); + } + + let _ = pending.tx.send(Err(client_error( + "json-rpc response missing result and error", + ))); + Ok(()) + } + + fn handle_notification(&self, value: &serde_json::Value) -> Result<(), RpcError> { + let Some(params) = value.get("params") else { + return Ok(()); + }; + let Some(subscription_id) = params.get("subscription").and_then(json_id) else { + return Ok(()); + }; + let Some(result) = params.get("result") else { + return Ok(()); + }; + let raw = raw_value_from_json(result)?; + let sink = self + .subscriptions + .lock() + .unwrap() + .get(&subscription_id) + .cloned(); + match sink { + Some(sink) => { + let _ = sink.tx.unbounded_send(Ok(raw)); + } + None => self.buffer_subscription_item(subscription_id, raw), + } + Ok(()) + } + + fn buffer_subscription_item(&self, subscription_id: String, item: Box) { + let mut buffered = self.buffered_subscription_items.lock().unwrap(); + let known = buffered.contains_key(&subscription_id); + if !known && buffered.len() >= MAX_BUFFERED_SUBSCRIPTIONS { + return; + } + let items = buffered.entry(subscription_id).or_default(); + if items.len() >= MAX_BUFFERED_ITEMS_PER_SUBSCRIPTION { + return; + } + items.push(item); + } + + fn close_with_error(&self, error: RpcError) { + if self.closed.swap(true, Ordering::AcqRel) { + return; + } + if let Some(stop) = self.stop_response_loop.lock().unwrap().take() { + let _ = stop.send(()); + } + self.connection.close(); + + let pending = { + let mut pending = self.pending.lock().unwrap(); + mem::take(&mut *pending) + }; + for (_, pending) in pending { + let _ = pending.tx.send(Err(client_error(format!( + "json-rpc connection closed: {error}" + )))); + } + + let subscriptions = mem::take(&mut *self.subscriptions.lock().unwrap()); + for (_, sink) in subscriptions { + let _ = sink.tx.unbounded_send(Err(client_error(format!( + "json-rpc connection closed: {error}" + )))); + } + self.buffered_subscription_items.lock().unwrap().clear(); + } +} + +impl Drop for HostRpcClientLease { + fn drop(&mut self) { + self.inner.release_user_handle(); + } +} + +impl RpcClientT for HostRpcClient { + fn request_raw<'a>( + &'a self, + method: &'a str, + params: Option>, + ) -> RawRpcFuture<'a, Box> { + Box::pin(async move { self.inner.request(method, params).await }) + } + + fn subscribe_raw<'a>( + &'a self, + sub: &'a str, + params: Option>, + unsub: &'a str, + ) -> RawRpcFuture<'a, RawRpcSubscription> { + let lease = self.inner.acquire_lease(); + Box::pin(async move { + self.inner + .clone() + .subscribe(sub, params, unsub, lease) + .await + }) + } +} + +struct SubscriptionStream { + inner: mpsc::UnboundedReceiver, RpcError>>, + client: Arc, + _lease: HostRpcClientLease, + subscription_id: String, + unsubscribe_method: String, + closed: bool, +} + +impl Drop for SubscriptionStream { + fn drop(&mut self) { + if !self.closed { + self.closed = true; + self.client + .unsubscribe(&self.subscription_id, &self.unsubscribe_method); + } + } +} + +impl Stream for SubscriptionStream { + type Item = Result, RpcError>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + match Pin::new(&mut this.inner).poll_next(cx) { + Poll::Ready(None) => { + this.closed = true; + Poll::Ready(None) + } + other => other, + } + } +} + +fn raw_value_from_json(value: &serde_json::Value) -> Result, RpcError> { + RawValue::from_string(value.to_string()).map_err(RpcError::Deserialization) +} + +fn subscription_id_from_raw(raw: &RawValue) -> Result { + let value: serde_json::Value = + serde_json::from_str(raw.get()).map_err(RpcError::Deserialization)?; + json_id(&value).ok_or_else(|| client_error("json-rpc subscription id is not a string")) +} + +fn json_id(value: &serde_json::Value) -> Option { + match value { + serde_json::Value::String(value) => Some(value.clone()), + serde_json::Value::Number(value) => Some(value.to_string()), + _ => None, + } +} + +fn user_error_from_json(value: &serde_json::Value) -> RpcError { + match serde_json::from_value::(value.clone()) { + Ok(error) => RpcError::User(error), + Err(error) => RpcError::Deserialization(error), + } +} + +fn client_error(reason: impl Into) -> RpcError { + RpcError::Client(Box::new(HostRpcClientError(reason.into()))) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::AtomicUsize; + + use futures::executor::block_on; + use futures::stream::BoxStream; + use serde_json::{Value, json}; + use subxt_rpcs::RpcClient; + use subxt_rpcs::client::rpc_params; + + use crate::subscription::thread_per_subscription_spawner; + + struct TrackingConnection { + sender: Mutex>>, + receiver: Mutex>>, + close_count: AtomicUsize, + } + + impl TrackingConnection { + fn new() -> Arc { + let (tx, rx) = mpsc::unbounded(); + Arc::new(Self { + sender: Mutex::new(Some(tx)), + receiver: Mutex::new(Some(rx)), + close_count: AtomicUsize::new(0), + }) + } + + fn close_count(&self) -> usize { + self.close_count.load(Ordering::SeqCst) + } + } + + impl JsonRpcConnection for TrackingConnection { + fn send(&self, request: String) { + let Ok(value) = serde_json::from_str::(&request) else { + return; + }; + let Some(id) = value.get("id").cloned() else { + return; + }; + if value.get("method").and_then(Value::as_str) == Some("sub") { + let response = json!({ + "jsonrpc": "2.0", + "id": id, + "result": "sub-1", + }); + if let Some(sender) = self.sender.lock().unwrap().as_ref() { + let _ = sender.unbounded_send(response.to_string()); + } + } + } + + fn responses(&self) -> BoxStream<'static, String> { + self.receiver + .lock() + .unwrap() + .take() + .expect("responses called twice") + .boxed() + } + + fn close(&self) { + self.close_count.fetch_add(1, Ordering::SeqCst); + self.sender.lock().unwrap().take(); + } + } + + #[test] + fn dropping_one_shot_client_closes_connection_lease() { + let connection = TrackingConnection::new(); + let spawner: Spawner = Arc::new(|_| {}); + + { + let client = HostRpcClient::new(connection.clone(), spawner); + client + .send_fire_and_forget("statement_submit", None) + .unwrap(); + } + + assert_eq!(connection.close_count(), 1); + } + + #[test] + fn subscription_stream_holds_connection_lease_until_dropped() { + let connection = TrackingConnection::new(); + let client = HostRpcClient::new(connection.clone(), thread_per_subscription_spawner()); + let rpc_client = RpcClient::new(client.clone()); + + let subscription = block_on(rpc_client.subscribe::("sub", rpc_params![], "unsub")) + .expect("subscription should start"); + + drop(rpc_client); + drop(client); + assert_eq!(connection.close_count(), 0); + + drop(subscription); + assert_eq!(connection.close_count(), 1); + } +} diff --git a/rust/crates/truapi-server/src/lib.rs b/rust/crates/truapi-server/src/lib.rs index 8b56ec56..9708f559 100644 --- a/rust/crates/truapi-server/src/lib.rs +++ b/rust/crates/truapi-server/src/lib.rs @@ -1,9 +1,17 @@ //! TrUAPI server runtime support. //! -//! This layer contains host-agnostic logic shared by the runtime and target -//! adapters. Wire dispatch and platform runtime wiring are added by later stack -//! layers. +//! This layer contains host-agnostic logic, wire-frame dispatch, and chain +//! JSON-RPC mechanics shared by the runtime and target adapters. Platform +//! runtime wiring is added by a later stack layer. #![forbid(unsafe_code)] +pub(crate) mod chain_runtime; +pub(crate) mod dispatcher; +pub mod frame; pub mod host_logic; +pub(crate) mod host_rpc_client; +pub mod subscription; +pub mod transport; + +pub mod generated; diff --git a/rust/crates/truapi-server/src/subscription.rs b/rust/crates/truapi-server/src/subscription.rs new file mode 100644 index 00000000..8b24fd98 --- /dev/null +++ b/rust/crates/truapi-server/src/subscription.rs @@ -0,0 +1,495 @@ +//! Subscription lifecycle management. +//! +//! Tracks active subscriptions (start/receive/stop/interrupt) and handles +//! cleanup when either side terminates. Each registered subscription drives +//! its stream on a caller-supplied [`Spawner`]; the manager itself never +//! creates threads or runtimes. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; + +use futures::StreamExt; +use futures::future::{BoxFuture, Either, select}; +use futures::stream::BoxStream; +use parity_scale_codec::Encode; + +use crate::frame::{Payload, ProtocolMessage}; +use crate::transport::Transport; + +type StopFn = Box; + +/// Spawns a subscription-driving future onto the caller's runtime. The +/// future is `Send` because the inner [`SubscriptionStream`] is a +/// `BoxStream<'static, _>` and every captured value the manager threads +/// through it is also `Send`. Each platform bridge supplies an +/// implementation that hands the future to the runtime driving its +/// transport (tokio `LocalSet`, `wasm_bindgen_futures::spawn_local`, ...). +pub type Spawner = Arc) + Send + Sync>; + +/// Convenience spawner for tests and embedders that don't yet wire a +/// real runtime: starts a fresh OS thread per subscription and drives the +/// future with `futures::executor::block_on`. Not available on wasm32 since +/// the platform has no threads. +#[cfg(not(target_arch = "wasm32"))] +pub fn thread_per_subscription_spawner() -> Spawner { + Arc::new(|fut: BoxFuture<'static, ()>| { + std::thread::spawn(move || futures::executor::block_on(fut)); + }) +} + +/// One yielded value of a subscription stream after SCALE-encoding. +pub enum SubscriptionOutput { + /// A regular subscription item to deliver as a `_receive` frame. + Item(Vec), + /// Stream-initiated termination delivered as an `_interrupt` frame. + Interrupt(Vec), +} + +/// Boxed stream of [`SubscriptionOutput`] consumed by the dispatcher. +pub type SubscriptionStream = BoxStream<'static, SubscriptionOutput>; + +/// Wrap a host-side stream of typed items into the SCALE-encoded +/// [`SubscriptionStream`] that the dispatcher delivers to the transport. +/// +/// `Item` is the versioned wrapper for each emitted value (e.g. +/// `versioned::account::HostAccountConnectionStatusSubscribeItem`). The +/// generated dispatcher calls this with the second type parameter inferred +/// from the host trait return. +pub fn subscription_stream(stream: S) -> SubscriptionStream +where + Item: Encode + 'static, + S: futures::Stream + Send + 'static, +{ + Box::pin(stream.map(|item| SubscriptionOutput::Item(item.encode()))) +} + +/// Generation-stamped slot tracking the lifecycle of one subscription id. +/// `request_id` is client-controlled and may be reused or raced against a +/// `_stop`, so each reservation carries a monotonic generation and only the +/// owner of the current generation may transition or remove the slot. +enum Slot { + /// Reserved by the dispatcher before its `_start` handler resolved. + /// `cancelled` flips to `true` if a `_stop` arrives in that window so + /// activation aborts instead of leaking an unstoppable stream. + Pending { generation: u64, cancelled: bool }, + /// A live subscription with its cancellation handle. + Live { generation: u64, cancel: StopFn }, +} + +/// Handle returned by [`SubscriptionManager::reserve`] and presented back to +/// [`SubscriptionManager::activate`]. Ties an activation to the exact +/// reservation it belongs to so a superseding `_start` for the same id +/// cannot be activated by a stale handler. +pub struct ReservationToken { + request_id: String, + generation: u64, +} + +/// Manages active subscriptions on the server side. +pub struct SubscriptionManager { + active: Arc>>, + next_generation: Arc, + spawner: Spawner, +} + +impl SubscriptionManager { + /// Create an empty manager driven by `spawner`. + pub fn new(spawner: Spawner) -> Self { + Self { + active: Arc::new(Mutex::new(HashMap::new())), + next_generation: Arc::new(AtomicU64::new(0)), + spawner, + } + } + + /// Reserve the slot for `request_id` before its subscription stream is + /// available. Any live subscription already under that id is stopped and + /// replaced (re-subscribe semantics). A `_stop` arriving before + /// [`activate`](Self::activate) flips the reservation to cancelled. + pub fn reserve(&self, request_id: String) -> ReservationToken { + let generation = self.next_generation.fetch_add(1, Ordering::Relaxed); + let mut active = self.active.lock().unwrap(); + if let Some(Slot::Live { cancel, .. }) = active.insert( + request_id.clone(), + Slot::Pending { + generation, + cancelled: false, + }, + ) { + cancel(); + } + ReservationToken { + request_id, + generation, + } + } + + /// Drop a reservation whose `_start` handler failed before producing a + /// stream. No-op if the slot was superseded by a newer reservation. + pub fn cancel_reservation(&self, token: ReservationToken) { + let mut active = self.active.lock().unwrap(); + let owned = matches!( + active.get(&token.request_id), + Some(Slot::Pending { generation, .. }) if *generation == token.generation + ); + if owned { + active.remove(&token.request_id); + } + } + + /// Activate a reserved subscription with its stream, forwarding stream + /// items as `_receive` frames until the stream ends or `_stop` is + /// received. No-ops without starting the stream if the reservation was + /// cancelled by a `_stop` or superseded by a newer reservation for the + /// same id. + pub fn activate( + &self, + token: ReservationToken, + receive_id: u8, + interrupt_id: u8, + mut stream: SubscriptionStream, + transport: Arc, + ) { + let ReservationToken { + request_id, + generation, + } = token; + let rid = request_id.clone(); + let stream_transport = transport.clone(); + + // Cancellation channel. + let (cancel_tx, cancel_rx) = futures::channel::oneshot::channel::<()>(); + + // Transition the reserved slot to live, unless a `_stop` cancelled it + // or a newer reservation superseded it while the handler resolved. + { + let mut active = self.active.lock().unwrap(); + match active.get(&request_id) { + Some(Slot::Pending { + generation: g, + cancelled, + }) if *g == generation => { + if *cancelled { + active.remove(&request_id); + return; + } + } + _ => return, + } + active.insert( + request_id.clone(), + Slot::Live { + generation, + cancel: Box::new(move || { + let _ = cancel_tx.send(()); + }), + }, + ); + } + + let active = self.active.clone(); + + let future: BoxFuture<'static, ()> = Box::pin(async move { + let completed = { + let mut cancel_rx = cancel_rx; + loop { + match select(cancel_rx, stream.next()).await { + Either::Left((_cancelled, _next)) => break false, + Either::Right((item, next_cancel_rx)) => { + cancel_rx = next_cancel_rx; + match item { + Some(SubscriptionOutput::Item(value)) => { + stream_transport.send(ProtocolMessage { + request_id: rid.clone(), + payload: Payload { + id: receive_id, + value, + }, + }) + } + Some(SubscriptionOutput::Interrupt(value)) => { + stream_transport.send(ProtocolMessage { + request_id: rid.clone(), + payload: Payload { + id: interrupt_id, + value, + }, + }); + break false; + } + None => break true, + } + } + } + } + }; + + // Only remove the slot if it still holds THIS generation; a + // superseding reservation owns its own cleanup. + let removed = { + let mut active = active.lock().unwrap(); + let owned = matches!( + active.get(&request_id), + Some(Slot::Live { generation: g, .. }) if *g == generation + ); + if owned { + active.remove(&request_id); + } + owned + }; + + if completed && removed { + transport.send(ProtocolMessage { + request_id, + payload: Payload { + id: interrupt_id, + value: Vec::new(), + }, + }); + } + }); + + (self.spawner)(future); + } + + /// Convenience for callers that already hold the stream with no async gap + /// between reservation and activation (tests and synchronous embedders). + pub fn register( + &self, + request_id: String, + receive_id: u8, + interrupt_id: u8, + stream: SubscriptionStream, + transport: Arc, + ) { + let token = self.reserve(request_id); + self.activate(token, receive_id, interrupt_id, stream, transport); + } + + /// Handle a `_stop` frame from the product side. Cancels a live + /// subscription, or marks a still-pending reservation cancelled so its + /// in-flight activation aborts rather than leaking an unstoppable stream. + pub fn handle_stop(&self, request_id: &str) { + let mut active = self.active.lock().unwrap(); + match active.get_mut(request_id) { + Some(Slot::Pending { cancelled, .. }) => { + *cancelled = true; + } + Some(Slot::Live { .. }) => { + if let Some(Slot::Live { cancel, .. }) = active.remove(request_id) { + cancel(); + } + } + None => {} + } + } +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tests { + use super::*; + use futures::stream; + use std::sync::atomic::{AtomicUsize, Ordering}; + + /// Transport that records every frame and notifies waiters when it + /// reaches a target count. Used to wait for the subscription's + /// background thread to drain a known number of frames. + struct RecordingTransport { + sent: Mutex>, + cvar: std::sync::Condvar, + } + + impl RecordingTransport { + fn new() -> Self { + Self { + sent: Mutex::new(Vec::new()), + cvar: std::sync::Condvar::new(), + } + } + fn sent(&self) -> Vec { + self.sent.lock().unwrap().clone() + } + /// Wait until at least `count` frames have been recorded, or + /// `timeout` elapses. Returns the number of frames recorded at + /// wake-up time. + fn wait_for(&self, count: usize, timeout: std::time::Duration) -> usize { + let mut guard = self.sent.lock().unwrap(); + let deadline = std::time::Instant::now() + timeout; + while guard.len() < count { + let now = std::time::Instant::now(); + if now >= deadline { + break; + } + let (new_guard, _) = self.cvar.wait_timeout(guard, deadline - now).unwrap(); + guard = new_guard; + } + guard.len() + } + } + + impl Transport for RecordingTransport { + fn send(&self, message: ProtocolMessage) { + self.sent.lock().unwrap().push(message); + self.cvar.notify_all(); + } + fn on_message( + &self, + _handler: Box, + ) -> Box { + Box::new(|| {}) + } + } + + fn dummy_stream(items: Vec>) -> SubscriptionStream { + Box::pin(stream::iter( + items.into_iter().map(SubscriptionOutput::Item), + )) + } + + /// Register a never-ending stream then immediately stop it. The + /// stream's first poll must observe cancellation and exit without + /// having pushed any frame. + #[test] + fn register_then_stop_emits_no_extra_frames() { + let transport_typed = Arc::new(RecordingTransport::new()); + let transport_dyn: Arc = transport_typed.clone(); + let manager = SubscriptionManager::new(thread_per_subscription_spawner()); + let slow_stream: SubscriptionStream = Box::pin(stream::pending()); + manager.register("p:1".to_string(), 99, 98, slow_stream, transport_dyn); + manager.handle_stop("p:1"); + // Give the worker thread a beat to observe the cancel. + std::thread::sleep(std::time::Duration::from_millis(50)); + assert!( + transport_typed.sent().is_empty(), + "stopped subscription must not push any frame" + ); + } + + /// A stream that yields 2 items then ends naturally must produce 2 + /// `_receive` frames followed by one `_interrupt` frame. + #[test] + fn register_completion_emits_interrupt() { + let transport_typed = Arc::new(RecordingTransport::new()); + let transport_dyn: Arc = transport_typed.clone(); + let manager = SubscriptionManager::new(thread_per_subscription_spawner()); + let items = dummy_stream(vec![vec![0xaa], vec![0xbb]]); + manager.register("p:1".to_string(), 99, 98, items, transport_dyn); + let observed = transport_typed.wait_for(3, std::time::Duration::from_secs(2)); + assert_eq!(observed, 3, "expected 2 receive frames + 1 interrupt"); + let frames = transport_typed.sent(); + assert_eq!(frames[0].payload.id, 99); + assert_eq!(frames[0].payload.value, vec![0xaa]); + assert_eq!(frames[1].payload.id, 99); + assert_eq!(frames[1].payload.value, vec![0xbb]); + assert_eq!(frames[2].payload.id, 98); + assert_eq!(frames[2].payload.value, Vec::::new()); + } + + /// Calling `handle_stop` twice on the same request id must be a + /// no-op the second time around (the entry has already been removed, + /// no panic, no extra frames). + #[test] + fn double_stop_is_idempotent() { + let transport_typed = Arc::new(RecordingTransport::new()); + let transport_dyn: Arc = transport_typed.clone(); + let manager = SubscriptionManager::new(thread_per_subscription_spawner()); + let slow_stream: SubscriptionStream = Box::pin(stream::pending()); + manager.register("p:1".to_string(), 99, 98, slow_stream, transport_dyn); + manager.handle_stop("p:1"); + // Second call must not panic and must not emit any frame. + manager.handle_stop("p:1"); + std::thread::sleep(std::time::Duration::from_millis(50)); + assert!( + transport_typed.sent().is_empty(), + "double-stop must not emit any frame" + ); + } + + /// The manager must drive subscriptions through the injected spawner, + /// not by reaching out to `std::thread::spawn` itself. The counter + /// inside the test spawner is the proof. + #[test] + fn subscription_uses_provided_spawner_not_native_thread() { + let invocations = Arc::new(AtomicUsize::new(0)); + let invocations_for_spawner = invocations.clone(); + let spawner: Spawner = Arc::new(move |fut: BoxFuture<'static, ()>| { + invocations_for_spawner.fetch_add(1, Ordering::SeqCst); + std::thread::spawn(move || futures::executor::block_on(fut)); + }); + + let transport_typed = Arc::new(RecordingTransport::new()); + let transport_dyn: Arc = transport_typed.clone(); + let manager = SubscriptionManager::new(spawner); + let items = dummy_stream(vec![vec![0xcc]]); + manager.register("p:1".to_string(), 99, 98, items, transport_dyn); + + // Wait for the worker future to drain to completion so we know + // the spawner closure ran on this path. + let _ = transport_typed.wait_for(2, std::time::Duration::from_secs(2)); + assert_eq!( + invocations.load(Ordering::SeqCst), + 1, + "spawner must be invoked exactly once per register", + ); + } + + /// A `_stop` arriving before `activate` (the stop-before-register race on + /// non-serialized transports) must abort the subscription: no `_receive` + /// frames are emitted even though the stream had items to yield. + #[test] + fn stop_before_activate_aborts_subscription() { + let transport_typed = Arc::new(RecordingTransport::new()); + let transport_dyn: Arc = transport_typed.clone(); + let manager = SubscriptionManager::new(thread_per_subscription_spawner()); + let token = manager.reserve("p:1".to_string()); + manager.handle_stop("p:1"); + let items = dummy_stream(vec![vec![0x01], vec![0x02]]); + manager.activate(token, 99, 98, items, transport_dyn); + std::thread::sleep(std::time::Duration::from_millis(50)); + assert!( + transport_typed.sent().is_empty(), + "a stop before activate must abort the subscription" + ); + } + + /// Re-using a live request id (the duplicate-`_start` case) supersedes the + /// previous subscription rather than leaking it: the first stream is + /// stopped, only the second runs, and the superseded stream leaves no + /// frames behind. + #[test] + fn duplicate_start_supersedes_previous_without_leak() { + let transport_typed = Arc::new(RecordingTransport::new()); + let transport_dyn: Arc = transport_typed.clone(); + let manager = SubscriptionManager::new(thread_per_subscription_spawner()); + + // First subscription never yields; the second reservation for the + // same id must stop it. + let pending: SubscriptionStream = Box::pin(stream::pending()); + manager.register("p:1".to_string(), 99, 98, pending, transport_dyn.clone()); + + // Second subscription yields one item then ends. + let items = dummy_stream(vec![vec![0xaa]]); + manager.register("p:1".to_string(), 99, 98, items, transport_dyn); + + // Exactly the second stream's frames appear: one receive + one + // completion interrupt. The first (pending) stream contributes none. + let observed = transport_typed.wait_for(2, std::time::Duration::from_secs(2)); + assert_eq!( + observed, 2, + "expected the second stream's receive + interrupt only" + ); + let frames = transport_typed.sent(); + assert_eq!(frames[0].payload.id, 99); + assert_eq!(frames[0].payload.value, vec![0xaa]); + assert_eq!(frames[1].payload.id, 98); + + manager.handle_stop("p:1"); + std::thread::sleep(std::time::Duration::from_millis(50)); + assert_eq!( + transport_typed.sent().len(), + 2, + "no leaked frames from the superseded stream" + ); + } +} diff --git a/rust/crates/truapi-server/src/transport.rs b/rust/crates/truapi-server/src/transport.rs new file mode 100644 index 00000000..ba58481f --- /dev/null +++ b/rust/crates/truapi-server/src/transport.rs @@ -0,0 +1,12 @@ +//! Transport abstraction over platform-specific IPC mechanisms. + +use crate::frame::ProtocolMessage; + +/// A raw message pipe. Platform-specific implementations provide this. +pub trait Transport: Send + Sync { + /// Send a protocol message to the other side. + fn send(&self, message: ProtocolMessage); + + /// Register a handler for incoming messages. Returns an unsubscribe handle. + fn on_message(&self, handler: Box) -> Box; +} diff --git a/rust/crates/truapi-server/tests/golden_frame.rs b/rust/crates/truapi-server/tests/golden_frame.rs new file mode 100644 index 00000000..94a75050 --- /dev/null +++ b/rust/crates/truapi-server/tests/golden_frame.rs @@ -0,0 +1,53 @@ +//! Binary golden-frame regression test. +//! +//! Loads `tests/snapshots/golden-account-get.bin` (the captured raw bytes +//! of an `account_get_account_request` frame) and asserts that +//! `ProtocolMessage::decode` produces the expected in-memory shape. +//! +//! The frame encodes: +//! requestId = "p:1" +//! payload = account_get_account_request, +//! inner = HostAccountGetRequest::V1(("foo", 0u32)) +//! +//! On the wire (14 bytes): +//! [0c 70 3a 31] requestId = compact-len(3) + "p:1" +//! [16] discriminant 22 = account_get_account_request +//! [00] versioned wrapper variant V1 +//! [0c 66 6f 6f] "foo" +//! [00 00 00 00] u32 = 0 +//! +//! If this test fails after a wire-protocol change, regenerate the file +//! deliberately and re-check the change against the wire table. + +use parity_scale_codec::{Decode, Encode}; +use truapi_server::frame::{Payload, ProtocolMessage}; +use truapi_server::generated::wire_table; + +const GOLDEN: &[u8] = include_bytes!("snapshots/golden-account-get.bin"); + +#[test] +fn golden_account_get_frame_decodes_to_expected_message() { + let decoded = ProtocolMessage::decode(&mut &GOLDEN[..]) + .expect("golden frame must decode with the current wire codec"); + + let mut expected_inner = Vec::new(); + expected_inner.push(0x00u8); // V1 variant + "foo".to_string().encode_to(&mut expected_inner); + 0u32.encode_to(&mut expected_inner); + + let expected = ProtocolMessage { + request_id: "p:1".to_string(), + payload: Payload { + id: wire_table::ACCOUNT_GET_ACCOUNT.request_id, + value: expected_inner, + }, + }; + assert_eq!(decoded, expected); +} + +#[test] +fn golden_account_get_frame_round_trips() { + // Encoding the in-memory shape must reproduce the on-disk bytes exactly. + let decoded = ProtocolMessage::decode(&mut &GOLDEN[..]).expect("decode"); + assert_eq!(decoded.encode(), GOLDEN); +} diff --git a/rust/crates/truapi-server/tests/snapshots/golden-account-get.bin b/rust/crates/truapi-server/tests/snapshots/golden-account-get.bin new file mode 100644 index 0000000000000000000000000000000000000000..c66be11b9bf19e8c751b7faa4996bf36cd7e90b4 GIT binary patch literal 14 Tcmd-nurd^5;7QBRX8-~K6a)fJ literal 0 HcmV?d00001 diff --git a/rust/crates/truapi-server/tests/wire_table_ts_parity.rs b/rust/crates/truapi-server/tests/wire_table_ts_parity.rs new file mode 100644 index 00000000..e349139c --- /dev/null +++ b/rust/crates/truapi-server/tests/wire_table_ts_parity.rs @@ -0,0 +1,228 @@ +//! Cross-language parity check: the Rust `WIRE_TABLE` and the TS +//! `wire-table.ts` must list the exact same `(method, request_id, response_id)` +//! tuples in the same order. A drift here means a product built against one +//! side will fail to decode frames produced by the other. +//! +//! Both files are auto-generated text artifacts of `truapi-codegen`; the +//! parser is a small line scanner so the test runs as part of `cargo test` +//! without any node/bun dependency. +//! +//! The TS file lives under `js/packages/truapi/src/generated/wire-table.ts` +//! and is `.gitignore`d (regenerated by `scripts/codegen.sh`). When the +//! generated file is absent, the test logs a skip notice and passes, unless +//! `TRUAPI_REQUIRE_GENERATED_TS=1` is set (CI sets it after running codegen), +//! in which case the missing file is a hard failure. + +use std::path::PathBuf; + +const RUST_TABLE: &str = include_str!("../src/generated/wire_table.rs"); + +#[derive(Debug, PartialEq, Eq)] +struct Row { + method: String, + request_or_start: u8, + response_or_receive: u8, + /// Subscription `_stop` / `_interrupt` ids; `None` for request methods. + stop: Option, + interrupt: Option, + is_subscription: bool, +} + +/// Parse a wire id. A malformed id is a hard failure, never a silent `0`: a +/// defensive fallback here would let a symmetric codegen-format change collapse +/// both tables to `0`s and pass the parity check while real drift slipped by. +fn parse_id(raw: &str, method: &str) -> u8 { + raw.trim_end_matches(',') + .trim() + .parse() + .unwrap_or_else(|_| panic!("unparseable wire id for `{method}`: {raw:?}")) +} + +fn parse_rust(src: &str) -> Vec { + // The Rust codegen emits one named `pub const FOO_BAR: RequestFrameIds = ...` + // (or `SubscriptionFrameIds`) per method. The const name is + // `SCREAMING_SNAKE_CASE` of the method name; we lowercase it to match the + // TS const names. This mirrors `parse_ts` below. + let mut out = Vec::new(); + let mut iter = src.lines(); + while let Some(line) = iter.next() { + let trimmed = line.trim(); + let Some(rest) = trimmed.strip_prefix("pub const ") else { + continue; + }; + let Some(colon) = rest.find(':') else { + continue; + }; + let is_subscription = rest.contains("SubscriptionFrameIds"); + // Skip non-id consts (e.g. `WIRE_TABLE: &[WireEntry]`). + if !is_subscription && !rest.contains("RequestFrameIds") { + continue; + } + let method = rest[..colon].trim().to_ascii_lowercase(); + let mut request_or_start = None; + let mut response_or_receive = None; + let mut stop = None; + let mut interrupt = None; + for inner in iter.by_ref() { + let t = inner.trim(); + if t.starts_with("};") { + break; + } + if let Some(rest) = t + .strip_prefix("request_id: ") + .or_else(|| t.strip_prefix("start_id: ")) + { + request_or_start = Some(parse_id(rest, &method)); + } + if let Some(rest) = t + .strip_prefix("response_id: ") + .or_else(|| t.strip_prefix("receive_id: ")) + { + response_or_receive = Some(parse_id(rest, &method)); + } + if let Some(rest) = t.strip_prefix("stop_id: ") { + stop = Some(parse_id(rest, &method)); + } + if let Some(rest) = t.strip_prefix("interrupt_id: ") { + interrupt = Some(parse_id(rest, &method)); + } + } + if let (Some(rs), Some(rr)) = (request_or_start, response_or_receive) { + out.push(Row { + method, + request_or_start: rs, + response_or_receive: rr, + stop, + interrupt, + is_subscription, + }); + } + } + out +} + +fn parse_ts(src: &str) -> Vec { + // The TS codegen emits one named `export const FOO_BAR = { ... }` per + // method. The const name is `SCREAMING_SNAKE_CASE` of the method name; + // we lowercase it to match the Rust `method:` strings. + let mut out = Vec::new(); + let mut iter = src.lines().peekable(); + while let Some(line) = iter.next() { + let trimmed = line.trim(); + let Some(rest) = trimmed.strip_prefix("export const ") else { + continue; + }; + let Some(name_end) = rest.find(|c: char| !(c.is_ascii_alphanumeric() || c == '_')) else { + continue; + }; + let method = rest[..name_end].to_ascii_lowercase(); + let mut request_or_start = None; + let mut response_or_receive = None; + let mut stop = None; + let mut interrupt = None; + let mut is_subscription = false; + for inner in iter.by_ref() { + let t = inner.trim(); + if t.starts_with("start:") || t.contains("SubscriptionFrameIds") { + is_subscription = true; + } + if let Some(rest) = t + .strip_prefix("request: ") + .or_else(|| t.strip_prefix("start: ")) + { + request_or_start = Some(parse_id(rest, &method)); + } + if let Some(rest) = t + .strip_prefix("response: ") + .or_else(|| t.strip_prefix("receive: ")) + { + response_or_receive = Some(parse_id(rest, &method)); + } + if let Some(rest) = t.strip_prefix("stop: ") { + stop = Some(parse_id(rest, &method)); + } + if let Some(rest) = t.strip_prefix("interrupt: ") { + interrupt = Some(parse_id(rest, &method)); + } + if t.starts_with("} as const") || t == "}" { + if let (Some(rs), Some(rr)) = (request_or_start, response_or_receive) { + out.push(Row { + method, + request_or_start: rs, + response_or_receive: rr, + stop, + interrupt, + is_subscription, + }); + } + break; + } + } + } + out +} + +#[test] +fn rust_and_ts_wire_tables_agree() { + let manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let ts_path = manifest + .join("../../../js/packages/truapi/src/generated/wire-table.ts") + .canonicalize(); + + let require_ts = std::env::var("TRUAPI_REQUIRE_GENERATED_TS").as_deref() == Ok("1"); + + let ts_path = match ts_path { + Ok(p) => p, + Err(_) => { + assert!( + !require_ts, + "TRUAPI_REQUIRE_GENERATED_TS=1 but wire-table.ts is missing; run scripts/codegen.sh" + ); + eprintln!( + "skipping wire-table parity check: TS wire-table.ts is not present \ + (run scripts/codegen.sh to generate it)" + ); + return; + } + }; + + let ts_src = match std::fs::read_to_string(&ts_path) { + Ok(s) => s, + Err(_) => { + assert!( + !require_ts, + "TRUAPI_REQUIRE_GENERATED_TS=1 but {} is unreadable", + ts_path.display() + ); + eprintln!( + "skipping wire-table parity check: could not read {}", + ts_path.display() + ); + return; + } + }; + + let rust_rows = parse_rust(RUST_TABLE); + let ts_rows = parse_ts(&ts_src); + // Lower bound pinned to the known table size so a parser/codegen regression + // that quietly shrinks both tables in lockstep cannot pass: `assert_eq!` + // alone is satisfied by two equal-but-truncated tables. + const MIN_EXPECTED_ROWS: usize = 60; + assert!( + rust_rows.len() >= MIN_EXPECTED_ROWS, + "rust parser produced {} entries (expected >= {MIN_EXPECTED_ROWS}); \ + wire_table.rs format may have changed", + rust_rows.len() + ); + assert!( + ts_rows.len() >= MIN_EXPECTED_ROWS, + "ts parser produced {} entries (expected >= {MIN_EXPECTED_ROWS}); \ + wire-table.ts format may have changed", + ts_rows.len() + ); + assert_eq!( + rust_rows, ts_rows, + "Rust WIRE_TABLE and TS wire-table.ts diverged. Regenerate both via \ + `scripts/codegen.sh` so the codegen pipeline produces them in lockstep.", + ); +}