Skip to content

Commit 2df3f69

Browse files
author
fi3
committed
Fix proxy restart
d07b7f5 introduced a way to kill specific tasks that handle downstream connections: so when a downstream will disconnect we make sure to kill all the tasks related with that downstream. In doing that it saved the task aborter in a field of the TaskManager making impossible to kill the tasks when the aborter was called somewhere else. Now we make sure that all the aborter live only under tasks that get killed when the main aborter of the TaskManager is called.
1 parent f98d583 commit 2df3f69

5 files changed

Lines changed: 109 additions & 79 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "demand-cli"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
edition = "2021"
55

66
[dependencies]

src/shared/utils.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,34 @@ use tokio::task::JoinHandle;
66

77
#[derive(Debug)]
88
pub struct AbortOnDrop {
9-
abort_handle: AbortHandle,
9+
abort_handle: Vec<AbortHandle>,
1010
}
1111

1212
impl AbortOnDrop {
1313
pub fn new<T: Send + 'static>(handle: JoinHandle<T>) -> Self {
14-
let abort_handle = handle.abort_handle();
14+
let abort_handle = vec![handle.abort_handle()];
1515
Self { abort_handle }
1616
}
1717

1818
pub fn is_finished(&self) -> bool {
19-
self.abort_handle.is_finished()
19+
for task in &self.abort_handle {
20+
if !task.is_finished() {
21+
return false;
22+
}
23+
}
24+
true
25+
}
26+
27+
pub fn add_task<T: Send + 'static>(&mut self, handle: JoinHandle<T>) {
28+
self.abort_handle.push(handle.abort_handle());
2029
}
2130
}
2231

2332
impl core::ops::Drop for AbortOnDrop {
2433
fn drop(&mut self) {
25-
self.abort_handle.abort()
34+
for task in &self.abort_handle {
35+
task.abort();
36+
}
2637
}
2738
}
2839

src/translator/downstream/receive_from_downstream.rs

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,55 +16,65 @@ pub async fn start_receive_downstream(
1616
mut recv_from_down: mpsc::Receiver<String>,
1717
connection_id: u32,
1818
) -> Result<(), Error<'static>> {
19-
let task_manager_clone = task_manager.clone();
20-
let handle = task::spawn(async move {
21-
while let Some(incoming) = recv_from_down.recv().await {
22-
let incoming: Result<json_rpc::Message, _> = serde_json::from_str(&incoming);
23-
if let Ok(incoming) = incoming {
24-
// if message is Submit Shares update difficulty management
25-
if let sv1_api::Message::StandardRequest(standard_req) = incoming.clone() {
26-
if let Ok(Submit { .. }) = standard_req.try_into() {
27-
if let Err(e) = Downstream::save_share(downstream.clone()) {
28-
error!("{}", e);
29-
break;
19+
let handle = {
20+
let task_manager = task_manager.clone();
21+
task::spawn(async move {
22+
while let Some(incoming) = recv_from_down.recv().await {
23+
let incoming: Result<json_rpc::Message, _> = serde_json::from_str(&incoming);
24+
if let Ok(incoming) = incoming {
25+
// if message is Submit Shares update difficulty management
26+
if let sv1_api::Message::StandardRequest(standard_req) = incoming.clone() {
27+
if let Ok(Submit { .. }) = standard_req.try_into() {
28+
if let Err(e) = Downstream::save_share(downstream.clone()) {
29+
error!("{}", e);
30+
break;
31+
}
3032
}
3133
}
32-
}
3334

34-
if let Err(error) =
35-
Downstream::handle_incoming_sv1(downstream.clone(), incoming).await
36-
{
37-
error!("Failed to handle incoming sv1 msg: {:?}", error);
38-
ProxyState::update_downstream_state(DownstreamType::TranslatorDownstream);
39-
};
40-
} else {
41-
// Message received could not be converted to rpc message
42-
error!(
43-
"{}",
44-
Error::V1Protocol(Box::new(sv1_api::error::Error::InvalidJsonRpcMessageKind))
45-
);
46-
return;
35+
if let Err(error) =
36+
Downstream::handle_incoming_sv1(downstream.clone(), incoming).await
37+
{
38+
error!("Failed to handle incoming sv1 msg: {:?}", error);
39+
ProxyState::update_downstream_state(DownstreamType::TranslatorDownstream);
40+
};
41+
} else {
42+
// Message received could not be converted to rpc message
43+
error!(
44+
"{}",
45+
Error::V1Protocol(Box::new(
46+
sv1_api::error::Error::InvalidJsonRpcMessageKind
47+
))
48+
);
49+
return;
50+
}
4751
}
48-
}
49-
if let Ok(stats_sender) = downstream.safe_lock(|d| d.stats_sender.clone()) {
50-
stats_sender.remove_stats(connection_id);
51-
}
52-
// No message to receive
53-
warn!(
54-
"Downstream: Shutting down sv1 downstream reader {}",
55-
connection_id
56-
);
52+
if let Ok(stats_sender) = downstream.safe_lock(|d| d.stats_sender.clone()) {
53+
stats_sender.remove_stats(connection_id);
54+
}
55+
// No message to receive
56+
warn!(
57+
"Downstream: Shutting down sv1 downstream reader {}",
58+
connection_id
59+
);
5760

58-
if let Err(e) = Downstream::remove_downstream_hashrate_from_channel(&downstream) {
59-
error!("Failed to remove downstream hashrate from channel: {}", e)
60-
};
61-
if task_manager_clone
62-
.safe_lock(|tm| tm.abort_tasks_for_connection_id(connection_id))
63-
.is_err()
64-
{
65-
error!("TaskManager mutex poisoned")
66-
};
67-
});
61+
if let Err(e) = Downstream::remove_downstream_hashrate_from_channel(&downstream) {
62+
error!("Failed to remove downstream hashrate from channel: {}", e)
63+
};
64+
// Apparently there is no way to make the compiler happy without unwrapping here. But
65+
// is not an issue since:
66+
// 1. the mutex should never get poisioned and if it does will be very very rare
67+
// 2. restarting the process after the unwrapping or restarting the all the tasks from
68+
// inside the process (that is what we should do here) is almost the same thing
69+
let send_kill_signal = task_manager
70+
.safe_lock(|tm| tm.send_kill_signal.clone())
71+
.unwrap();
72+
if send_kill_signal.send(connection_id).await.is_err() {
73+
error!("Proxy can not abort downstreams tasks");
74+
ProxyState::update_inconsistency(Some(1));
75+
}
76+
})
77+
};
6878
TaskManager::add_receive_downstream(task_manager, handle.into(), connection_id)
6979
.await
7080
.map_err(|_| Error::TranslatorTaskManagerFailed)

src/translator/downstream/task_manager.rs

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{collections::HashMap, sync::Arc};
22

3-
use crate::shared::utils::AbortOnDrop;
3+
use crate::{proxy_state::ProxyState, shared::utils::AbortOnDrop};
44
use roles_logic_sv2::utils::Mutex;
55
use tokio::sync::mpsc;
66
use tracing::warn;
@@ -15,23 +15,31 @@ enum Task {
1515
SharesMonitor(AbortOnDrop),
1616
}
1717

18+
type TaskMessage = (Option<u32>, Task);
19+
1820
pub struct TaskManager {
19-
send_task: mpsc::Sender<(Option<u32>, Task)>,
21+
send_task: mpsc::Sender<TaskMessage>,
2022
abort: Option<AbortOnDrop>,
21-
tasks: Arc<Mutex<HashMap<Option<u32>, Vec<AbortOnDrop>>>>, // Track tasks by connection_id
23+
pub send_kill_signal: mpsc::Sender<u32>,
2224
}
2325

2426
impl TaskManager {
2527
pub fn initialize() -> Arc<Mutex<Self>> {
26-
type TaskMessage = (Option<u32>, Task);
27-
2828
let (sender, mut receiver): (mpsc::Sender<TaskMessage>, mpsc::Receiver<TaskMessage>) =
2929
mpsc::channel(10);
30+
let (send_kill_signal, mut receiver_kill_signal) = mpsc::channel(10);
3031

3132
let tasks = Arc::new(Mutex::new(HashMap::new()));
3233
let task_clone = tasks.clone();
3334
let handle = tokio::task::spawn(async move {
3435
while let Some((connection_id, task)) = receiver.recv().await {
36+
// The tasks map is used to save task related to downstream managment, some of them
37+
// are "global" in the sense that live for all the life of the transalator (like
38+
// the task that create new downstreams when a downstream connect) others are
39+
// specific to a downstream like the one that receive messages from it. Specific
40+
// task have an id that is the connnection id, "global" ones do not have one; for
41+
// that TaskMessage is an (Option<u32>, Task) where u32 is the connection id.
42+
// "Global" tasks are saved in the map under the None key.
3543
if task_clone
3644
.safe_lock(|tasks| {
3745
let tasks_list: &mut Vec<AbortOnDrop> =
@@ -48,10 +56,33 @@ impl TaskManager {
4856
tokio::time::sleep(std::time::Duration::from_secs(1000)).await;
4957
}
5058
});
59+
let kill_tasks = tokio::task::spawn(async move {
60+
while let Some(connection_id) = receiver_kill_signal.recv().await {
61+
if tasks
62+
.safe_lock(|tasks| {
63+
if let Some(handles) = tasks.remove(&Some(connection_id)) {
64+
for handle in handles {
65+
drop(handle);
66+
}
67+
}
68+
})
69+
.is_err()
70+
{
71+
tracing::error!("TasKManager Mutex Poisoned");
72+
ProxyState::update_inconsistency(Some(1));
73+
};
74+
tracing::info!(
75+
"Aborted all tasks for downstream connection ID {}",
76+
connection_id
77+
);
78+
}
79+
});
80+
let mut aborter: AbortOnDrop = handle.into();
81+
aborter.add_task(kill_tasks);
5182
Arc::new(Mutex::new(Self {
5283
send_task: sender,
53-
abort: Some(handle.into()),
54-
tasks,
84+
abort: Some(aborter),
85+
send_kill_signal,
5586
}))
5687
}
5788

@@ -124,29 +155,7 @@ impl TaskManager {
124155
.await
125156
.map_err(|_| ())
126157
}
127-
128-
/// Kills all tasks for a given `connection_id` and removes them from TaskManager.
129-
pub fn abort_tasks_for_connection_id(&mut self, connection_id: u32) {
130-
if self
131-
.tasks
132-
.safe_lock(|tasks| {
133-
if let Some(handles) = tasks.remove(&Some(connection_id)) {
134-
for handle in handles {
135-
drop(handle);
136-
}
137-
}
138-
})
139-
.is_err()
140-
{
141-
tracing::error!("TasKManager Mutex Poisoned")
142-
};
143-
tracing::info!(
144-
"Aborted all tasks for downstream connection ID {}",
145-
connection_id
146-
);
147-
}
148158
}
149-
150159
/// Converts a `Task` into its `AbortHandle` for task management.
151160
impl From<Task> for AbortOnDrop {
152161
fn from(task: Task) -> Self {

0 commit comments

Comments
 (0)