Skip to content

Commit b87bada

Browse files
committed
feat: support different underlying unbounded channels
- enable `crossfire-channel` feature to use crossfire's unbounded channels
1 parent 51eb073 commit b87bada

8 files changed

Lines changed: 197 additions & 8 deletions

File tree

Cargo.lock

Lines changed: 57 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ tracing-subscriber = { workspace = true }
6060

6161
[features]
6262
debugging = ["netmito/debugging"]
63+
crossfire-channel = ["netmito/crossfire-channel"]
6364

6465
# The profile that 'cargo dist' will build with
6566
[profile.dist]

netmito/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ axum-extra = { version = "0.10.1", features = ["typed-header", "query"] }
2323
base64 = "0.22.1"
2424
clap = { workspace = true }
2525
clap-repl = "0.3.2"
26+
crossfire = { version = "2.0.26", optional = true }
2627
derive_more = { version = "2.0.1", features = ["from"] }
2728
dirs = "6.0.0"
2829
figment = { version = "0.10", features = ["toml", "env"] }
@@ -81,3 +82,4 @@ async-once-cell = "0.5.4"
8182
[features]
8283
# default = ["debugging"]
8384
debugging = ["dep:http-body-util"]
85+
crossfire-channel = ["dep:crossfire"]

netmito/src/client/redis.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,45 @@ use redis::{
33
aio::{MultiplexedConnection, PubSub},
44
AsyncCommands, Commands, PubSubCommands, PushInfo,
55
};
6+
#[cfg(feature = "crossfire-channel")]
7+
use std::ops::Deref;
8+
#[cfg(not(feature = "crossfire-channel"))]
69
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
710
use uuid::Uuid;
811

912
use crate::entity::state::TaskExecState;
1013
pub use redis::{ControlFlow, Msg};
1114

15+
#[cfg(feature = "crossfire-channel")]
16+
pub struct UnboundedReceiver<T>(crossfire::AsyncRx<T>);
17+
18+
#[cfg(feature = "crossfire-channel")]
19+
impl redis::aio::AsyncPushSender for UnboundedSender<PushInfo> {
20+
fn send(&self, info: PushInfo) -> std::result::Result<(), redis::aio::SendError> {
21+
self.0.send(info).map_err(|_| redis::aio::SendError)
22+
}
23+
}
24+
25+
#[cfg(feature = "crossfire-channel")]
26+
impl<T> Deref for UnboundedReceiver<T> {
27+
type Target = crossfire::AsyncRx<T>;
28+
fn deref(&self) -> &Self::Target {
29+
&self.0
30+
}
31+
}
32+
33+
#[cfg(feature = "crossfire-channel")]
34+
#[derive(Clone)]
35+
pub struct UnboundedSender<T>(crossfire::MTx<T>);
36+
37+
#[cfg(feature = "crossfire-channel")]
38+
impl<T> Deref for UnboundedSender<T> {
39+
type Target = crossfire::MTx<T>;
40+
fn deref(&self) -> &Self::Target {
41+
&self.0
42+
}
43+
}
44+
1245
#[self_referencing]
1346
pub struct MitoRedisPubSubClient {
1447
pub client: redis::Client,
@@ -167,7 +200,13 @@ impl MitoAsyncRedisClient {
167200
}
168201

169202
pub async fn get_resp3_pubsub(&mut self) -> crate::error::Result<AsyncPubSub> {
203+
#[cfg(not(feature = "crossfire-channel"))]
170204
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
205+
#[cfg(feature = "crossfire-channel")]
206+
let (tx, rx) = {
207+
let (tx, rx) = crossfire::mpsc::unbounded_async();
208+
(UnboundedSender(tx), UnboundedReceiver(rx))
209+
};
171210
let config = redis::AsyncConnectionConfig::new().set_push_sender(tx.clone());
172211
let con = self
173212
.client

netmito/src/config/coordinator.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use redis::{acl::Rule, AsyncCommands};
1717
use sea_orm::DatabaseConnection;
1818
use serde::{Deserialize, Serialize};
1919
use time::Duration;
20+
#[cfg(not(feature = "crossfire-channel"))]
2021
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
2122
use tokio_util::sync::CancellationToken;
2223
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer};
@@ -197,7 +198,8 @@ impl CoordinatorConfig {
197198
pub fn build_worker_task_queue(
198199
&self,
199200
cancel_token: CancellationToken,
200-
rx: UnboundedReceiver<TaskDispatcherOp>,
201+
#[cfg(not(feature = "crossfire-channel"))] rx: UnboundedReceiver<TaskDispatcherOp>,
202+
#[cfg(feature = "crossfire-channel")] rx: crossfire::AsyncRx<TaskDispatcherOp>,
201203
) -> TaskDispatcher {
202204
TaskDispatcher::new(cancel_token, rx)
203205
}
@@ -206,7 +208,8 @@ impl CoordinatorConfig {
206208
&self,
207209
cancel_token: CancellationToken,
208210
pool: InfraPool,
209-
rx: UnboundedReceiver<HeartbeatOp>,
211+
#[cfg(not(feature = "crossfire-channel"))] rx: UnboundedReceiver<HeartbeatOp>,
212+
#[cfg(feature = "crossfire-channel")] rx: crossfire::AsyncRx<HeartbeatOp>,
210213
) -> HeartbeatQueue {
211214
HeartbeatQueue::new(cancel_token, self.heartbeat_timeout, pool, rx)
212215
}
@@ -273,8 +276,18 @@ impl CoordinatorConfig {
273276

274277
pub async fn build_infra_pool(
275278
&self,
276-
worker_task_queue_tx: UnboundedSender<TaskDispatcherOp>,
277-
worker_heartbeat_queue_tx: UnboundedSender<HeartbeatOp>,
279+
#[cfg(not(feature = "crossfire-channel"))] worker_task_queue_tx: UnboundedSender<
280+
TaskDispatcherOp,
281+
>,
282+
#[cfg(feature = "crossfire-channel")] worker_task_queue_tx: crossfire::MTx<
283+
TaskDispatcherOp,
284+
>,
285+
#[cfg(not(feature = "crossfire-channel"))] worker_heartbeat_queue_tx: UnboundedSender<
286+
HeartbeatOp,
287+
>,
288+
#[cfg(feature = "crossfire-channel")] worker_heartbeat_queue_tx: crossfire::MTx<
289+
HeartbeatOp,
290+
>,
278291
) -> crate::error::Result<InfraPool> {
279292
let db = sea_orm::Database::connect(&self.db_url).await?;
280293
let credential = Credentials::new(
@@ -399,8 +412,14 @@ impl CoordinatorConfig {
399412
pub struct InfraPool {
400413
pub db: DatabaseConnection,
401414
pub s3: S3Client,
415+
#[cfg(not(feature = "crossfire-channel"))]
402416
pub worker_task_queue_tx: UnboundedSender<TaskDispatcherOp>,
417+
#[cfg(feature = "crossfire-channel")]
418+
pub worker_task_queue_tx: crossfire::MTx<TaskDispatcherOp>,
419+
#[cfg(not(feature = "crossfire-channel"))]
403420
pub worker_heartbeat_queue_tx: UnboundedSender<HeartbeatOp>,
421+
#[cfg(feature = "crossfire-channel")]
422+
pub worker_heartbeat_queue_tx: crossfire::MTx<HeartbeatOp>,
404423
}
405424

406425
#[derive(Debug)]

netmito/src/coordinator.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,16 @@ impl MitoCoordinator {
8585
.map_err(|_| crate::error::Error::Custom("set shutdown secret failed".to_string()))?;
8686
let cancel_token = CancellationToken::new();
8787

88+
#[cfg(not(feature = "crossfire-channel"))]
8889
let (worker_task_queue_tx, worker_task_queue_rx) = tokio::sync::mpsc::unbounded_channel();
90+
#[cfg(feature = "crossfire-channel")]
91+
let (worker_task_queue_tx, worker_task_queue_rx) = crossfire::mpsc::unbounded_async();
92+
#[cfg(not(feature = "crossfire-channel"))]
8993
let (worker_heartbeat_queue_tx, worker_heartbeat_queue_rx) =
9094
tokio::sync::mpsc::unbounded_channel();
95+
#[cfg(feature = "crossfire-channel")]
96+
let (worker_heartbeat_queue_tx, worker_heartbeat_queue_rx) =
97+
crossfire::mpsc::unbounded_async();
9198

9299
// Setup worker task queue
93100
let worker_task_queue =

netmito/src/service/worker/heartbeat.rs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use std::{cmp::Reverse, time::Duration};
22

33
use priority_queue::PriorityQueue;
44
use sea_orm::prelude::*;
5-
use tokio::{sync::mpsc::UnboundedReceiver, time::Instant};
5+
#[cfg(not(feature = "crossfire-channel"))]
6+
use tokio::sync::mpsc::UnboundedReceiver;
7+
use tokio::time::Instant;
68
use tokio_util::sync::CancellationToken;
79

810
use crate::{config::InfraPool, entity::workers as Worker};
@@ -14,7 +16,10 @@ pub struct HeartbeatQueue {
1416
cancel_token: CancellationToken,
1517
heartbeat_timeout: Duration,
1618
pool: InfraPool,
19+
#[cfg(not(feature = "crossfire-channel"))]
1720
rx: UnboundedReceiver<HeartbeatOp>,
21+
#[cfg(feature = "crossfire-channel")]
22+
rx: crossfire::AsyncRx<HeartbeatOp>,
1823
}
1924

2025
pub enum HeartbeatOp {
@@ -27,7 +32,8 @@ impl HeartbeatQueue {
2732
cancel_token: CancellationToken,
2833
heartbeat_timeout: Duration,
2934
pool: InfraPool,
30-
rx: UnboundedReceiver<HeartbeatOp>,
35+
#[cfg(not(feature = "crossfire-channel"))] rx: UnboundedReceiver<HeartbeatOp>,
36+
#[cfg(feature = "crossfire-channel")] rx: crossfire::AsyncRx<HeartbeatOp>,
3137
) -> Self {
3238
Self {
3339
workers: PriorityQueue::new(),
@@ -75,6 +81,7 @@ impl HeartbeatQueue {
7581

7682
pub async fn run(&mut self) {
7783
let mut timeout_duration = self.heartbeat_timeout;
84+
#[cfg(not(feature = "crossfire-channel"))]
7885
loop {
7986
tokio::select! {
8087
biased;
@@ -108,5 +115,39 @@ impl HeartbeatQueue {
108115
}
109116
}
110117
}
118+
#[cfg(feature = "crossfire-channel")]
119+
loop {
120+
tokio::select! {
121+
biased;
122+
_ = self.cancel_token.cancelled() => {
123+
break;
124+
}
125+
op = self.rx.recv() => match op.ok() {
126+
None => {
127+
break;
128+
}
129+
Some(op) => {
130+
self.handle_op(op);
131+
timeout_duration = self
132+
.workers
133+
.peek()
134+
.map(|(_, r)| r.0 - Instant::now())
135+
.unwrap_or(self.heartbeat_timeout);
136+
}
137+
},
138+
_ = tokio::time::sleep(timeout_duration) => {
139+
if let Err(e) = self.handle_timeout().await {
140+
tracing::error!("handle timeout failed: {:?}", e);
141+
self.cancel_token.cancel();
142+
break;
143+
}
144+
timeout_duration = self
145+
.workers
146+
.peek()
147+
.map(|(_, r)| r.0 - Instant::now())
148+
.unwrap_or(self.heartbeat_timeout);
149+
}
150+
}
151+
}
111152
}
112153
}

netmito/src/service/worker/queue.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use std::collections::HashMap;
22

33
use priority_queue::PriorityQueue;
44

5-
use tokio::sync::{mpsc::UnboundedReceiver, oneshot::Sender};
5+
#[cfg(not(feature = "crossfire-channel"))]
6+
use tokio::sync::mpsc::UnboundedReceiver;
7+
use tokio::sync::oneshot::Sender;
68
use tokio_util::sync::CancellationToken;
79

810
// MARK: TaskDispatcher
@@ -12,7 +14,10 @@ pub struct TaskDispatcher {
1214
/// Every task is represented by a tuple of (task id, priority).
1315
pub workers: HashMap<i64, PriorityQueue<i64, i32>>,
1416
cancel_token: CancellationToken,
17+
#[cfg(not(feature = "crossfire-channel"))]
1518
rx: UnboundedReceiver<TaskDispatcherOp>,
19+
#[cfg(feature = "crossfire-channel")]
20+
rx: crossfire::AsyncRx<TaskDispatcherOp>,
1621
}
1722

1823
pub enum TaskDispatcherOp {
@@ -28,7 +33,11 @@ pub enum TaskDispatcherOp {
2833
}
2934

3035
impl TaskDispatcher {
31-
pub fn new(cancel_token: CancellationToken, rx: UnboundedReceiver<TaskDispatcherOp>) -> Self {
36+
pub fn new(
37+
cancel_token: CancellationToken,
38+
#[cfg(not(feature = "crossfire-channel"))] rx: UnboundedReceiver<TaskDispatcherOp>,
39+
#[cfg(feature = "crossfire-channel")] rx: crossfire::AsyncRx<TaskDispatcherOp>,
40+
) -> Self {
3241
Self {
3342
workers: HashMap::new(),
3443
cancel_token,
@@ -138,6 +147,7 @@ impl TaskDispatcher {
138147
}
139148

140149
pub async fn run(&mut self) {
150+
#[cfg(not(feature = "crossfire-channel"))]
141151
loop {
142152
tokio::select! {
143153
biased;
@@ -150,5 +160,18 @@ impl TaskDispatcher {
150160
}
151161
}
152162
}
163+
#[cfg(feature = "crossfire-channel")]
164+
loop {
165+
tokio::select! {
166+
biased;
167+
_ = self.cancel_token.cancelled() => {
168+
break;
169+
}
170+
op = self.rx.recv() => if self.handle_op(op.ok()) {
171+
self.cancel_token.cancel();
172+
break;
173+
}
174+
}
175+
}
153176
}
154177
}

0 commit comments

Comments
 (0)