Skip to content

Commit 31298e2

Browse files
authored
Merge pull request dmnd-pool#96 from Priceless-P/fix/handle-downstream-disconnect
handle downstream disconnection
2 parents 1d5a394 + d07b7f5 commit 31298e2

4 files changed

Lines changed: 90 additions & 21 deletions

File tree

src/translator/downstream/notify.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,13 @@ pub async fn start_notify(
123123
Downstream::send_message_downstream(downstream.clone(), message).await;
124124
}
125125
}
126-
// TODO here we want to be sure that on drop this is called
127-
let _ = Downstream::remove_downstream_hashrate_from_channel(&downstream);
128-
// TODO here we want to kill the tasks
129126
warn!(
130127
"Downstream: Shutting down sv1 downstream job notifier for {}",
131128
&host
132129
);
133130
})
134131
};
135-
TaskManager::add_notify(task_manager, handle.into())
132+
TaskManager::add_notify(task_manager, handle.into(), connection_id)
136133
.await
137134
.map_err(|_| Error::TranslatorTaskManagerFailed)
138135
}
@@ -166,7 +163,7 @@ async fn start_update(
166163
};
167164
}
168165
});
169-
TaskManager::add_update(task_manager, handle.into())
166+
TaskManager::add_update(task_manager, handle.into(), connection_id)
170167
.await
171168
.map_err(|_| Error::TranslatorTaskManagerFailed)
172169
}

src/translator/downstream/receive_from_downstream.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub async fn start_receive_downstream(
1616
mut recv_from_down: mpsc::Receiver<String>,
1717
connection_id: u32,
1818
) -> Result<(), Error<'static>> {
19+
let task_manager_clone = task_manager.clone();
1920
let handle = task::spawn(async move {
2021
while let Some(incoming) = recv_from_down.recv().await {
2122
let incoming: Result<json_rpc::Message, _> = serde_json::from_str(&incoming);
@@ -53,8 +54,18 @@ pub async fn start_receive_downstream(
5354
"Downstream: Shutting down sv1 downstream reader {}",
5455
connection_id
5556
);
57+
58+
if let Err(e) = Downstream::remove_downstream_hashrate_from_channel(&downstream) {
59+
error!("Failed to remove downstream hashrate from channel: {}", e)
60+
};
61+
if task_manager_clone
62+
.safe_lock(|tm| tm.abort_tasks_for_connection_id(connection_id))
63+
.is_err()
64+
{
65+
error!("TaskManager mutex poisoned")
66+
};
5667
});
57-
TaskManager::add_receive_downstream(task_manager, handle.into())
68+
TaskManager::add_receive_downstream(task_manager, handle.into(), connection_id)
5869
.await
5970
.map_err(|_| Error::TranslatorTaskManagerFailed)
6071
}

src/translator/downstream/send_to_downstream.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pub async fn start_send_to_downstream(
3333
connection_id
3434
);
3535
});
36-
TaskManager::add_send_downstream(task_manager, handle.into())
36+
TaskManager::add_send_downstream(task_manager, handle.into(), connection_id)
3737
.await
3838
.map_err(|_| Error::TranslatorTaskManagerFailed)
3939
}

src/translator/downstream/task_manager.rs

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::sync::Arc;
1+
use std::{collections::HashMap, sync::Arc};
22

33
use crate::shared::utils::AbortOnDrop;
44
use roles_logic_sv2::utils::Mutex;
@@ -16,17 +16,32 @@ enum Task {
1616
}
1717

1818
pub struct TaskManager {
19-
send_task: mpsc::Sender<Task>,
19+
send_task: mpsc::Sender<(Option<u32>, Task)>,
2020
abort: Option<AbortOnDrop>,
21+
tasks: Arc<Mutex<HashMap<Option<u32>, Vec<AbortOnDrop>>>>, // Track tasks by connection_id
2122
}
2223

2324
impl TaskManager {
2425
pub fn initialize() -> Arc<Mutex<Self>> {
25-
let (sender, mut receiver) = mpsc::channel(10);
26+
type TaskMessage = (Option<u32>, Task);
27+
28+
let (sender, mut receiver): (mpsc::Sender<TaskMessage>, mpsc::Receiver<TaskMessage>) =
29+
mpsc::channel(10);
30+
31+
let tasks = Arc::new(Mutex::new(HashMap::new()));
32+
let task_clone = tasks.clone();
2633
let handle = tokio::task::spawn(async move {
27-
let mut tasks = vec![];
28-
while let Some(task) = receiver.recv().await {
29-
tasks.push(task);
34+
while let Some((connection_id, task)) = receiver.recv().await {
35+
if task_clone
36+
.safe_lock(|tasks| {
37+
let tasks_list: &mut Vec<AbortOnDrop> =
38+
tasks.entry(connection_id).or_default();
39+
tasks_list.push(task.into());
40+
})
41+
.is_err()
42+
{
43+
tracing::error!("TasKManager Mutex Poisoned")
44+
};
3045
}
3146
warn!("Translator downstream task manager stopped, keep alive tasks");
3247
loop {
@@ -36,6 +51,7 @@ impl TaskManager {
3651
Arc::new(Mutex::new(Self {
3752
send_task: sender,
3853
abort: Some(handle.into()),
54+
tasks,
3955
}))
4056
}
4157

@@ -46,34 +62,44 @@ impl TaskManager {
4662
pub async fn add_receive_downstream(
4763
self_: Arc<Mutex<Self>>,
4864
abortable: AbortOnDrop,
65+
connection_id: u32,
4966
) -> Result<(), ()> {
5067
let send_task = self_.safe_lock(|s| s.send_task.clone()).unwrap();
5168
send_task
52-
.send(Task::ReceiveDownstream(abortable))
69+
.send((Some(connection_id), Task::ReceiveDownstream(abortable)))
5370
.await
5471
.map_err(|_| ())
5572
}
56-
pub async fn add_update(self_: Arc<Mutex<Self>>, abortable: AbortOnDrop) -> Result<(), ()> {
73+
pub async fn add_update(
74+
self_: Arc<Mutex<Self>>,
75+
abortable: AbortOnDrop,
76+
connection_id: u32,
77+
) -> Result<(), ()> {
5778
let send_task = self_.safe_lock(|s| s.send_task.clone()).unwrap();
5879
send_task
59-
.send(Task::Update(abortable))
80+
.send((Some(connection_id), Task::Update(abortable)))
6081
.await
6182
.map_err(|_| ())
6283
}
63-
pub async fn add_notify(self_: Arc<Mutex<Self>>, abortable: AbortOnDrop) -> Result<(), ()> {
84+
pub async fn add_notify(
85+
self_: Arc<Mutex<Self>>,
86+
abortable: AbortOnDrop,
87+
connection_id: u32,
88+
) -> Result<(), ()> {
6489
let send_task = self_.safe_lock(|s| s.send_task.clone()).unwrap();
6590
send_task
66-
.send(Task::Notify(abortable))
91+
.send((Some(connection_id), Task::Notify(abortable)))
6792
.await
6893
.map_err(|_| ())
6994
}
7095
pub async fn add_send_downstream(
7196
self_: Arc<Mutex<Self>>,
7297
abortable: AbortOnDrop,
98+
connection_id: u32,
7399
) -> Result<(), ()> {
74100
let send_task = self_.safe_lock(|s| s.send_task.clone()).unwrap();
75101
send_task
76-
.send(Task::SendDownstream(abortable))
102+
.send((Some(connection_id), Task::SendDownstream(abortable)))
77103
.await
78104
.map_err(|_| ())
79105
}
@@ -83,7 +109,7 @@ impl TaskManager {
83109
) -> Result<(), ()> {
84110
let send_task = self_.safe_lock(|s| s.send_task.clone()).unwrap();
85111
send_task
86-
.send(Task::AcceptConnection(abortable))
112+
.send((None, Task::AcceptConnection(abortable)))
87113
.await
88114
.map_err(|_| ())
89115
}
@@ -94,8 +120,43 @@ impl TaskManager {
94120
) -> Result<(), ()> {
95121
let send_task = self_.safe_lock(|s| s.send_task.clone()).unwrap();
96122
send_task
97-
.send(Task::SharesMonitor(abortable))
123+
.send((None, Task::SharesMonitor(abortable)))
98124
.await
99125
.map_err(|_| ())
100126
}
127+
128+
/// Kills all tasks for a given `connection_id` and removes them from TaskManager.
129+
pub fn abort_tasks_for_connection_id(&mut self, connection_id: u32) {
130+
if self
131+
.tasks
132+
.safe_lock(|tasks| {
133+
if let Some(handles) = tasks.remove(&Some(connection_id)) {
134+
for handle in handles {
135+
drop(handle);
136+
}
137+
}
138+
})
139+
.is_err()
140+
{
141+
tracing::error!("TasKManager Mutex Poisoned")
142+
};
143+
tracing::info!(
144+
"Aborted all tasks for downstream connection ID {}",
145+
connection_id
146+
);
147+
}
148+
}
149+
150+
/// Converts a `Task` into its `AbortHandle` for task management.
151+
impl From<Task> for AbortOnDrop {
152+
fn from(task: Task) -> Self {
153+
match task {
154+
Task::AcceptConnection(handle) => handle,
155+
Task::ReceiveDownstream(handle) => handle,
156+
Task::SendDownstream(handle) => handle,
157+
Task::Notify(handle) => handle,
158+
Task::Update(handle) => handle,
159+
Task::SharesMonitor(handle) => handle,
160+
}
161+
}
101162
}

0 commit comments

Comments
 (0)