1- use std:: sync:: Arc ;
1+ use std:: { collections :: HashMap , sync:: Arc } ;
22
33use crate :: shared:: utils:: AbortOnDrop ;
44use roles_logic_sv2:: utils:: Mutex ;
@@ -16,17 +16,32 @@ enum Task {
1616}
1717
1818pub struct TaskManager {
19- send_task : mpsc:: Sender < Task > ,
19+ send_task : mpsc:: Sender < ( Option < u32 > , Task ) > ,
2020 abort : Option < AbortOnDrop > ,
21+ tasks : Arc < Mutex < HashMap < Option < u32 > , Vec < AbortOnDrop > > > > , // Track tasks by connection_id
2122}
2223
2324impl TaskManager {
2425 pub fn initialize ( ) -> Arc < Mutex < Self > > {
25- let ( sender, mut receiver) = mpsc:: channel ( 10 ) ;
26+ type TaskMessage = ( Option < u32 > , Task ) ;
27+
28+ let ( sender, mut receiver) : ( mpsc:: Sender < TaskMessage > , mpsc:: Receiver < TaskMessage > ) =
29+ mpsc:: channel ( 10 ) ;
30+
31+ let tasks = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
32+ let task_clone = tasks. clone ( ) ;
2633 let handle = tokio:: task:: spawn ( async move {
27- let mut tasks = vec ! [ ] ;
28- while let Some ( task) = receiver. recv ( ) . await {
29- tasks. push ( task) ;
34+ while let Some ( ( connection_id, task) ) = receiver. recv ( ) . await {
35+ if task_clone
36+ . safe_lock ( |tasks| {
37+ let tasks_list: & mut Vec < AbortOnDrop > =
38+ tasks. entry ( connection_id) . or_default ( ) ;
39+ tasks_list. push ( task. into ( ) ) ;
40+ } )
41+ . is_err ( )
42+ {
43+ tracing:: error!( "TasKManager Mutex Poisoned" )
44+ } ;
3045 }
3146 warn ! ( "Translator downstream task manager stopped, keep alive tasks" ) ;
3247 loop {
@@ -36,6 +51,7 @@ impl TaskManager {
3651 Arc :: new ( Mutex :: new ( Self {
3752 send_task : sender,
3853 abort : Some ( handle. into ( ) ) ,
54+ tasks,
3955 } ) )
4056 }
4157
@@ -46,34 +62,44 @@ impl TaskManager {
4662 pub async fn add_receive_downstream (
4763 self_ : Arc < Mutex < Self > > ,
4864 abortable : AbortOnDrop ,
65+ connection_id : u32 ,
4966 ) -> Result < ( ) , ( ) > {
5067 let send_task = self_. safe_lock ( |s| s. send_task . clone ( ) ) . unwrap ( ) ;
5168 send_task
52- . send ( Task :: ReceiveDownstream ( abortable) )
69+ . send ( ( Some ( connection_id ) , Task :: ReceiveDownstream ( abortable) ) )
5370 . await
5471 . map_err ( |_| ( ) )
5572 }
56- pub async fn add_update ( self_ : Arc < Mutex < Self > > , abortable : AbortOnDrop ) -> Result < ( ) , ( ) > {
73+ pub async fn add_update (
74+ self_ : Arc < Mutex < Self > > ,
75+ abortable : AbortOnDrop ,
76+ connection_id : u32 ,
77+ ) -> Result < ( ) , ( ) > {
5778 let send_task = self_. safe_lock ( |s| s. send_task . clone ( ) ) . unwrap ( ) ;
5879 send_task
59- . send ( Task :: Update ( abortable) )
80+ . send ( ( Some ( connection_id ) , Task :: Update ( abortable) ) )
6081 . await
6182 . map_err ( |_| ( ) )
6283 }
63- pub async fn add_notify ( self_ : Arc < Mutex < Self > > , abortable : AbortOnDrop ) -> Result < ( ) , ( ) > {
84+ pub async fn add_notify (
85+ self_ : Arc < Mutex < Self > > ,
86+ abortable : AbortOnDrop ,
87+ connection_id : u32 ,
88+ ) -> Result < ( ) , ( ) > {
6489 let send_task = self_. safe_lock ( |s| s. send_task . clone ( ) ) . unwrap ( ) ;
6590 send_task
66- . send ( Task :: Notify ( abortable) )
91+ . send ( ( Some ( connection_id ) , Task :: Notify ( abortable) ) )
6792 . await
6893 . map_err ( |_| ( ) )
6994 }
7095 pub async fn add_send_downstream (
7196 self_ : Arc < Mutex < Self > > ,
7297 abortable : AbortOnDrop ,
98+ connection_id : u32 ,
7399 ) -> Result < ( ) , ( ) > {
74100 let send_task = self_. safe_lock ( |s| s. send_task . clone ( ) ) . unwrap ( ) ;
75101 send_task
76- . send ( Task :: SendDownstream ( abortable) )
102+ . send ( ( Some ( connection_id ) , Task :: SendDownstream ( abortable) ) )
77103 . await
78104 . map_err ( |_| ( ) )
79105 }
@@ -83,7 +109,7 @@ impl TaskManager {
83109 ) -> Result < ( ) , ( ) > {
84110 let send_task = self_. safe_lock ( |s| s. send_task . clone ( ) ) . unwrap ( ) ;
85111 send_task
86- . send ( Task :: AcceptConnection ( abortable) )
112+ . send ( ( None , Task :: AcceptConnection ( abortable) ) )
87113 . await
88114 . map_err ( |_| ( ) )
89115 }
@@ -94,8 +120,43 @@ impl TaskManager {
94120 ) -> Result < ( ) , ( ) > {
95121 let send_task = self_. safe_lock ( |s| s. send_task . clone ( ) ) . unwrap ( ) ;
96122 send_task
97- . send ( Task :: SharesMonitor ( abortable) )
123+ . send ( ( None , Task :: SharesMonitor ( abortable) ) )
98124 . await
99125 . map_err ( |_| ( ) )
100126 }
127+
128+ /// Kills all tasks for a given `connection_id` and removes them from TaskManager.
129+ pub fn abort_tasks_for_connection_id ( & mut self , connection_id : u32 ) {
130+ if self
131+ . tasks
132+ . safe_lock ( |tasks| {
133+ if let Some ( handles) = tasks. remove ( & Some ( connection_id) ) {
134+ for handle in handles {
135+ drop ( handle) ;
136+ }
137+ }
138+ } )
139+ . is_err ( )
140+ {
141+ tracing:: error!( "TasKManager Mutex Poisoned" )
142+ } ;
143+ tracing:: info!(
144+ "Aborted all tasks for downstream connection ID {}" ,
145+ connection_id
146+ ) ;
147+ }
148+ }
149+
150+ /// Converts a `Task` into its `AbortHandle` for task management.
151+ impl From < Task > for AbortOnDrop {
152+ fn from ( task : Task ) -> Self {
153+ match task {
154+ Task :: AcceptConnection ( handle) => handle,
155+ Task :: ReceiveDownstream ( handle) => handle,
156+ Task :: SendDownstream ( handle) => handle,
157+ Task :: Notify ( handle) => handle,
158+ Task :: Update ( handle) => handle,
159+ Task :: SharesMonitor ( handle) => handle,
160+ }
161+ }
101162}
0 commit comments