Skip to content

Commit 9801a57

Browse files
committed
use real types for namespace and channel name
1 parent 9662761 commit 9801a57

3 files changed

Lines changed: 44 additions & 41 deletions

File tree

Cargo.lock

Lines changed: 7 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ edition = "2024"
77
axum = { version = "0.8" }
88
clap = { version = "4", features = ["derive", "env"] }
99
flume = "0.11"
10+
serde = { version = "1", features = ["derive"] }
1011
tokio = { version = "1", features = ["full"] }
1112
tokio-stream = { version = "0.1" }
1213
tower-http = { version = "0.6", features = [
@@ -20,7 +21,6 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
2021

2122
[dev-dependencies]
2223
reqwest = { version = "0.12", features = ["json"] }
23-
serde = { version = "1", features = ["derive"] }
2424
serde_json = "1"
2525

2626
[profile.release]

src/channel.rs

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,26 @@ use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
77
use axum::middleware::{self, Next};
88
use axum::response::{IntoResponse, Response};
99
use axum::routing::{get, post};
10+
use serde::{Deserialize, Serialize};
1011
use std::collections::HashMap;
1112
use std::sync::Arc;
1213
use tokio::sync::{Mutex, oneshot};
1314

14-
type Namespace = String;
15-
type ChannelName = String;
15+
/// the data the producer is sending to the consumer
16+
pub(crate) struct Payload {
17+
body_stream: BodyDataStream,
18+
headers: HeaderMap,
19+
drop_guard: DropGuard,
20+
}
21+
22+
#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)]
23+
pub(crate) struct Namespace(String);
24+
25+
#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)]
26+
pub(crate) struct ChannelName(String);
27+
1628
pub(crate) type ChannelClients = Mutex<
17-
HashMap<
18-
Namespace,
19-
HashMap<
20-
ChannelName,
21-
(
22-
flume::Sender<(BodyDataStream, HeaderMap, DropGuard)>,
23-
flume::Receiver<(BodyDataStream, HeaderMap, DropGuard)>,
24-
),
25-
>,
26-
>,
29+
HashMap<Namespace, HashMap<ChannelName, (flume::Sender<Payload>, flume::Receiver<Payload>)>>,
2730
>;
2831

2932
pub(crate) fn routes(state: Arc<AppState>) -> Router<Arc<AppState>> {
@@ -47,7 +50,7 @@ pub(crate) fn routes(state: Arc<AppState>) -> Router<Arc<AppState>> {
4750
}
4851

4952
async fn clean_up_unused_channels(
50-
Path((namespace, channel_name)): Path<(String, String)>,
53+
Path((namespace, channel_name)): Path<(Namespace, ChannelName)>,
5154
State(state): State<Arc<AppState>>,
5255
request: Request,
5356
next: Next,
@@ -82,7 +85,7 @@ async fn clean_up_unused_channels(
8285

8386
async fn list_all_namespaces(
8487
State(state): State<Arc<AppState>>,
85-
) -> axum::response::Result<axum::Json<Vec<String>>> {
88+
) -> axum::response::Result<axum::Json<Vec<Namespace>>> {
8689
let channel_clients = state.channel_clients.lock().await;
8790

8891
Ok(axum::Json(
@@ -91,9 +94,9 @@ async fn list_all_namespaces(
9194
}
9295

9396
async fn list_all_namespace_channels(
94-
Path(namespace): Path<String>,
97+
Path(namespace): Path<Namespace>,
9598
State(state): State<Arc<AppState>>,
96-
) -> axum::response::Result<axum::Json<Vec<String>>> {
99+
) -> axum::response::Result<axum::Json<Vec<ChannelName>>> {
97100
let channel_clients = state.channel_clients.lock().await;
98101

99102
let namespaced_channels = if let Some(channels) = channel_clients.get(&namespace) {
@@ -108,8 +111,8 @@ async fn list_all_namespace_channels(
108111
}
109112

110113
async fn broadcast_to_channel(
111-
request_headers: HeaderMap,
112-
Path((namespace, channel_name)): Path<(String, String)>,
114+
headers: HeaderMap,
115+
Path((namespace, channel_name)): Path<(Namespace, ChannelName)>,
113116
State(state): State<Arc<AppState>>,
114117
body: Body,
115118
) -> axum::response::Result<()> {
@@ -132,13 +135,17 @@ async fn broadcast_to_channel(
132135

133136
drop(channel_clients);
134137

135-
let request_body_stream = body.into_data_stream();
138+
let body_stream = body.into_data_stream();
136139

137140
let (drop_guard, drop_guard_rx) = DropGuard::new();
138141

139-
tx.send_async((request_body_stream, request_headers, drop_guard))
140-
.await
141-
.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
142+
tx.send_async(Payload {
143+
body_stream,
144+
headers,
145+
drop_guard,
146+
})
147+
.await
148+
.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
142149

143150
drop_guard_rx
144151
.await
@@ -148,7 +155,7 @@ async fn broadcast_to_channel(
148155
}
149156

150157
async fn subscribe_to_channel(
151-
Path((namespace, channel_name)): Path<(String, String)>,
158+
Path((namespace, channel_name)): Path<(Namespace, ChannelName)>,
152159
State(state): State<Arc<AppState>>,
153160
) -> axum::response::Result<impl IntoResponse> {
154161
let mut channel_clients = state.channel_clients.lock().await;
@@ -172,10 +179,13 @@ async fn subscribe_to_channel(
172179

173180
let rx = rx.into_recv_async();
174181

175-
let (request_body_stream, producer_request_headers, _drop_guard) =
176-
rx.await.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
182+
let Payload {
183+
body_stream,
184+
headers: producer_request_headers,
185+
drop_guard: _drop_guard,
186+
} = rx.await.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
177187

178-
let body = Body::from_stream(request_body_stream);
188+
let body = Body::from_stream(body_stream);
179189

180190
// we do this because by default, POSTs from curl are `x-www-form-urlencoded`
181191
let producer_content_type = producer_request_headers

0 commit comments

Comments
 (0)