Skip to content
This repository was archived by the owner on Jun 1, 2026. It is now read-only.

Commit f44a12d

Browse files
committed
Refactor websockets; add device stream events
Introduce a unified websocket helper and refactor websocket handling, plus add device stream event support across services. Key changes: - New apps/api/src/websocket.rs: common Redis setup (setup_ws_resources) and PushInfo decoding (decode_push_message). - Refactor websocket_prices and websocket_stream: remove duplicated Redis setup, improve message handling, move configs out of Mutex where safe, better logging/error handling, and simplify stream loops. - Split price handling into a dedicated PriceHandler for websocket_stream to encapsulate asset tracking, subscription logic, and payload building. - Improve apps/api/src/main.rs: better CLI service parsing with strum iteration and unified Rocket launch path. - Add device stream feature: DeviceStreamPayload/DeviceStreamEvent types (crates/streamer), queue name and producer API for DeviceStreamEvents, daemon consumer DeviceStreamConsumer to publish device-scoped stream events to cacher. - Add CacherClient::publish helper to publish JSON messages to Redis channels. - Various small API/serialization and testkit cleanups. Why: removes duplication, isolates responsibilities (price handling, Redis setup), adds support for pushing device-targeted stream events, and improves robustness and logging for websocket components.
1 parent b96f865 commit f44a12d

19 files changed

Lines changed: 513 additions & 370 deletions

File tree

apps/api/src/main.rs

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ mod swap;
2222
mod transactions;
2323
mod wallets;
2424
mod webhooks;
25+
mod websocket;
2526
mod websocket_prices;
2627
mod websocket_stream;
2728

2829
use std::{str::FromStr, sync::Arc};
30+
use strum::IntoEnumIterator;
2931

3032
use ::fiat::FiatClient;
3133
use ::nft::{NFTClient, NFTProviderConfig};
@@ -268,7 +270,7 @@ async fn rocket_ws_prices(settings: Settings) -> Rocket<Build> {
268270
};
269271
rocket::build()
270272
.manage(Arc::new(Mutex::new(price_client)))
271-
.manage(Arc::new(Mutex::new(price_observer_config)))
273+
.manage(Arc::new(price_observer_config))
272274
.mount("/", routes![websocket_prices::ws_health])
273275
.mount("/v1/ws", routes![websocket_prices::ws_prices])
274276
.register("/", catchers![catchers::default_catcher])
@@ -292,7 +294,7 @@ async fn rocket_ws_stream(settings: Settings) -> Rocket<Build> {
292294
.manage(auth_config)
293295
.manage(database)
294296
.manage(Arc::new(Mutex::new(price_client)))
295-
.manage(Arc::new(Mutex::new(stream_observer_config)))
297+
.manage(Arc::new(stream_observer_config))
296298
.mount("/v2/devices", routes![websocket_stream::ws_stream])
297299
.mount("/", routes![websocket_stream::ws_health])
298300
.register("/", catchers![catchers::default_catcher])
@@ -302,25 +304,22 @@ async fn rocket_ws_stream(settings: Settings) -> Rocket<Build> {
302304
async fn main() {
303305
let settings = Settings::new().unwrap();
304306

305-
let service = std::env::args().nth(1).unwrap_or_default();
306-
let service = APIService::from_str(service.as_str()).ok().unwrap_or(APIService::Api);
307+
let service = match std::env::args().nth(1) {
308+
Some(arg) => APIService::from_str(&arg).unwrap_or_else(|_| {
309+
let services: Vec<_> = APIService::iter().map(|s| format!("api {}", s.as_ref())).collect();
310+
panic!("unknown service: {arg}\nAvailable:\n {}", services.join("\n "))
311+
}),
312+
None => APIService::Api,
313+
};
307314

308315
println!("api start service: {}", service.as_ref());
309316

310-
match service {
311-
APIService::WebsocketPrices => {
312-
let rocket_api = rocket_ws_prices(settings.clone()).await;
313-
rocket_api.launch().await.expect("Failed to launch Rocket");
314-
}
315-
APIService::WebsocketStream => {
316-
let rocket_api = rocket_ws_stream(settings.clone()).await;
317-
rocket_api.launch().await.expect("Failed to launch Rocket");
318-
}
319-
APIService::Api => {
320-
let rocket_api = rocket_api(settings.clone()).await;
321-
rocket_api.launch().await.expect("Failed to launch Rocket");
322-
}
323-
}
317+
let rocket = match service {
318+
APIService::Api => rocket_api(settings).await,
319+
APIService::WebsocketPrices => rocket_ws_prices(settings).await,
320+
APIService::WebsocketStream => rocket_ws_stream(settings).await,
321+
};
322+
rocket.launch().await.expect("Failed to launch Rocket");
324323
}
325324

326325
#[cfg(test)]

apps/api/src/websocket.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
use std::error::Error;
2+
3+
use redis::aio::MultiplexedConnection;
4+
use redis::{PushInfo, PushKind};
5+
use rocket_ws::stream::DuplexStream;
6+
use tokio::sync::mpsc::UnboundedReceiver;
7+
8+
pub fn decode_push_message(message: &PushInfo) -> Option<(&str, &[u8])> {
9+
match (&message.kind, message.data.as_slice()) {
10+
(PushKind::Message, [redis::Value::BulkString(channel), redis::Value::BulkString(value)]) => Some((std::str::from_utf8(channel).ok()?, value)),
11+
_ => None,
12+
}
13+
}
14+
15+
pub async fn setup_ws_resources(redis_url: &str, stream: DuplexStream) -> Result<(DuplexStream, MultiplexedConnection, UnboundedReceiver<PushInfo>), Box<dyn Error + Send + Sync>> {
16+
let client = redis::Client::open(redis_url)?;
17+
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
18+
let config = redis::AsyncConnectionConfig::new().set_push_sender(tx);
19+
let redis_connection = client.get_multiplexed_async_connection_with_config(&config).await?;
20+
Ok((stream, redis_connection, rx))
21+
}

apps/api/src/websocket_prices/client.rs

Lines changed: 43 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,23 @@ use std::sync::Arc;
44

55
use pricer::PriceClient;
66
use primitives::{AssetId, AssetPrice, AssetPriceInfo, WebSocketPriceAction, WebSocketPriceActionType, WebSocketPricePayload, asset::AssetHashSetExt};
7+
use redis::PushInfo;
78
use redis::aio::MultiplexedConnection;
8-
use redis::{PushInfo, PushKind};
99
use rocket::futures::SinkExt;
1010
use rocket::serde::json::serde_json;
1111
use rocket::tokio::sync::Mutex;
1212
use rocket_ws::Message;
1313
use rocket_ws::stream::DuplexStream;
1414

15+
use crate::websocket::decode_push_message;
16+
1517
pub struct PriceObserverConfig {
1618
pub redis_url: String,
1719
}
1820

1921
pub struct PriceObserverClient {
20-
pub price_client: Arc<Mutex<PriceClient>>,
21-
pub assets: HashSet<AssetId>,
22+
price_client: Arc<Mutex<PriceClient>>,
23+
assets: HashSet<AssetId>,
2224
prices_to_publish: HashMap<String, AssetPrice>,
2325
interval: rocket::tokio::time::Interval,
2426
}
@@ -37,105 +39,82 @@ impl PriceObserverClient {
3739
self.interval.tick().await;
3840
}
3941

40-
pub fn get_asset_ids(&self) -> Vec<String> {
41-
self.assets.iter().map(|id| id.to_string()).collect()
42-
}
43-
44-
pub fn add_price_to_publish(&mut self, price: AssetPrice) {
45-
self.prices_to_publish.insert(price.asset_id.to_string(), price);
46-
}
47-
48-
pub fn clear_prices_to_publish(&mut self) {
49-
self.prices_to_publish.clear();
42+
pub fn take_prices(&mut self) -> Vec<AssetPrice> {
43+
self.prices_to_publish.drain().map(|(_, v)| v).collect()
5044
}
5145

52-
pub fn get_prices_to_publish(&self) -> Vec<AssetPrice> {
53-
self.prices_to_publish.values().cloned().collect()
46+
fn get_channel_ids(&self) -> Vec<String> {
47+
self.assets.iter().map(|id| id.to_string()).collect()
5448
}
5549

56-
pub async fn fetch_payload_data(&mut self, fetch_rates: bool) -> Result<WebSocketPricePayload, Box<dyn Error + Send + Sync>> {
57-
let price_client_clone_prices = Arc::clone(&self.price_client);
58-
let assets_clone_prices = self.assets.clone();
59-
let prices = price_client_clone_prices
60-
.lock()
61-
.await
62-
.get_cache_prices(assets_clone_prices.ids())
50+
async fn fetch_payload(&self, fetch_rates: bool) -> Result<WebSocketPricePayload, Box<dyn Error + Send + Sync>> {
51+
let client = self.price_client.lock().await;
52+
let prices = client
53+
.get_cache_prices(self.assets.ids())
6354
.await?
6455
.into_iter()
6556
.map(|x| x.as_asset_price_primitive())
6657
.collect();
67-
68-
if fetch_rates {
69-
let rates = self.price_client.lock().await.get_cache_fiat_rates().await?;
70-
Ok(WebSocketPricePayload { prices, rates })
71-
} else {
72-
Ok(WebSocketPricePayload { prices, rates: vec![] })
73-
}
74-
}
75-
76-
pub async fn build_and_send_payload(&mut self, stream: &mut DuplexStream, payload: WebSocketPricePayload) -> Result<(), Box<dyn Error + Send + Sync>> {
77-
let text = serde_json::to_string(&payload)?;
78-
let item = Message::Text(text);
79-
Ok(stream.send(item).await?)
58+
let rates = if fetch_rates { client.get_cache_fiat_rates().await? } else { vec![] };
59+
Ok(WebSocketPricePayload { prices, rates })
8060
}
8161

8262
pub async fn handle_ws_message(
8363
&mut self,
84-
message: rocket_ws::Message,
64+
message: Message,
8565
redis_connection: &mut MultiplexedConnection,
8666
stream: &mut DuplexStream,
8767
) -> Result<(), Box<dyn Error + Send + Sync>> {
8868
match message {
8969
Message::Binary(data) => self.handle_message_payload(data, redis_connection, stream).await,
9070
Message::Text(text) => self.handle_message_payload(text.into_bytes(), redis_connection, stream).await.or(Ok(())),
9171
Message::Ping(data) => Ok(stream.send(Message::Pong(data)).await?),
92-
Message::Close(_) => {
93-
println!("Client closed connection gracefully");
94-
Ok(())
95-
}
72+
Message::Close(_) => Ok(()),
9673
Message::Pong(_) | Message::Frame(_) => Ok(()),
9774
}
9875
}
9976

100-
pub async fn handle_message_payload(
101-
&mut self,
102-
data: Vec<u8>,
103-
redis_connection: &mut MultiplexedConnection,
104-
stream: &mut DuplexStream,
105-
) -> Result<(), Box<dyn Error + Send + Sync>> {
77+
async fn handle_message_payload(&mut self, data: Vec<u8>, redis_connection: &mut MultiplexedConnection, stream: &mut DuplexStream) -> Result<(), Box<dyn Error + Send + Sync>> {
10678
let action = serde_json::from_slice::<WebSocketPriceAction>(&data)?;
10779
let new_assets: HashSet<AssetId> = action.assets.iter().cloned().collect();
10880

109-
match action.action {
81+
let needs_rates = match action.action {
11082
WebSocketPriceActionType::Subscribe => {
83+
let old_channels = self.get_channel_ids();
11184
self.assets.clear();
85+
self.prices_to_publish.clear();
86+
if !old_channels.is_empty() {
87+
redis_connection.unsubscribe(old_channels).await?;
88+
}
11289
self.assets.extend(new_assets);
90+
true
11391
}
11492
WebSocketPriceActionType::Add => {
11593
self.assets.extend(new_assets);
94+
false
11695
}
117-
}
96+
};
11897

119-
let asset_ids = self.assets.ids();
120-
let _ = self.price_client.lock().await.track_observed_assets(&asset_ids).await;
98+
let _ = self.price_client.lock().await.track_observed_assets(&self.assets.ids()).await;
12199

122-
let needs_rates = action.action == WebSocketPriceActionType::Subscribe;
123-
let payload = self.fetch_payload_data(needs_rates).await?;
100+
let payload = self.fetch_payload(needs_rates).await?;
101+
self.send_payload(stream, payload).await?;
124102

125-
self.build_and_send_payload(stream, payload).await?;
103+
redis_connection.subscribe(self.get_channel_ids()).await?;
104+
Ok(())
105+
}
126106

127-
Ok(redis_connection.subscribe(self.get_asset_ids()).await?)
107+
pub async fn send_payload(&self, stream: &mut DuplexStream, payload: WebSocketPricePayload) -> Result<(), Box<dyn Error + Send + Sync>> {
108+
let text = serde_json::to_string(&payload)?;
109+
Ok(stream.send(Message::Text(text)).await?)
128110
}
129111

130-
pub fn handle_redis_message(&mut self, message: &PushInfo) -> Result<(), String> {
131-
match (message.kind.clone(), message.data.last()) {
132-
(PushKind::Message, Some(redis::Value::BulkString(value))) => {
133-
let asset_price_info = serde_json::from_slice::<AssetPriceInfo>(value).map_err(|e| format!("Failed to deserialize AssetPrice: {e}"))?;
134-
let asset_price = asset_price_info.as_asset_price_primitive();
135-
self.add_price_to_publish(asset_price);
136-
Ok(())
137-
}
138-
_ => Ok(()),
139-
}
112+
pub fn handle_redis_message(&mut self, message: &PushInfo) -> Result<(), Box<dyn Error + Send + Sync>> {
113+
let Some((_, value)) = decode_push_message(message) else {
114+
return Ok(());
115+
};
116+
let info = serde_json::from_slice::<AssetPriceInfo>(value)?;
117+
self.prices_to_publish.insert(info.asset_id.to_string(), info.as_asset_price_primitive());
118+
Ok(())
140119
}
141120
}

apps/api/src/websocket_prices/mod.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,17 @@ use rocket_ws::{Channel, WebSocket};
99
mod client;
1010
mod stream;
1111

12-
pub use client::{PriceObserverClient, PriceObserverConfig};
13-
use stream::Stream;
12+
pub use client::PriceObserverConfig;
1413

1514
#[rocket::get("/prices")]
16-
pub async fn ws_prices(ws: WebSocket, price_client: &State<Arc<Mutex<PriceClient>>>, config: &State<Arc<Mutex<PriceObserverConfig>>>) -> Channel<'static> {
15+
pub async fn ws_prices(ws: WebSocket, price_client: &State<Arc<Mutex<PriceClient>>>, config: &State<Arc<PriceObserverConfig>>) -> Channel<'static> {
1716
let price_client = price_client.inner().clone();
18-
let redis_url = config.lock().await.redis_url.clone();
17+
let redis_url = config.redis_url.clone();
1918

20-
ws.channel(move |stream| {
19+
ws.channel(move |ws_stream| {
2120
Box::pin(async move {
22-
let mut observer = PriceObserverClient::new(price_client.clone());
23-
Stream::new_stream(&redis_url, &mut observer, stream).await;
21+
let mut observer = client::PriceObserverClient::new(price_client);
22+
stream::new_stream(&redis_url, &mut observer, ws_stream).await;
2423
Ok::<(), rocket_ws::result::Error>(())
2524
})
2625
})

0 commit comments

Comments
 (0)