11use alloy:: primitives:: Address ;
22use anyhow:: Error ;
3- use log:: info;
3+ use log:: { error, info, warn} ;
4+ use rand:: seq:: IndexedRandom ;
45use redis:: Commands ;
56use serde:: { Deserialize , Serialize } ;
67use shared:: models:: task:: Task ;
@@ -18,6 +19,7 @@ use super::StatusUpdatePlugin;
1819
1920const GROUP_KEY_PREFIX : & str = "node_group:" ;
2021const NODE_GROUP_MAP_KEY : & str = "node_to_group" ;
22+ const GROUP_TASK_KEY_PREFIX : & str = "group_task:" ;
2123
2224#[ derive( Debug , Serialize , Deserialize , PartialEq , Eq , Clone ) ]
2325pub struct NodeGroup {
@@ -58,6 +60,7 @@ impl NodeGroupsPlugin {
5860 fn get_group_key ( group_id : & str ) -> String {
5961 format ! ( "{}{}" , GROUP_KEY_PREFIX , group_id)
6062 }
63+
6164 fn try_form_new_group ( & self , node_addr : Option < & str > ) -> Result < Option < NodeGroup > , Error > {
6265 let mut conn = self . store . client . get_connection ( ) ?;
6366
@@ -199,6 +202,26 @@ impl NodeGroupsPlugin {
199202
200203 Ok ( None )
201204 }
205+ fn get_current_group_task ( & self , group_id : & str ) -> Result < Option < Task > , Error > {
206+ let mut conn = self . store . client . get_connection ( ) ?;
207+ let task_key = format ! ( "{}{}" , GROUP_TASK_KEY_PREFIX , group_id) ;
208+ let task_id: Option < String > = conn. get ( & task_key) ?;
209+
210+ if let Some ( task_id) = task_id {
211+ if let Some ( task) = self . store_context . task_store . get_task ( & task_id) {
212+ return Ok ( Some ( task) ) ;
213+ }
214+ warn ! ( "Task id set but task not found" ) ;
215+ }
216+ Ok ( None )
217+ }
218+
219+ fn assign_task_to_group ( & self , group_id : & str , task_id : & str ) -> Result < bool , Error > {
220+ let mut conn = self . store . client . get_connection ( ) ?;
221+ let task_key = format ! ( "{}{}" , GROUP_TASK_KEY_PREFIX , group_id) ;
222+ let result: bool = conn. set_nx :: < _ , _ , bool > ( & task_key, task_id) ?;
223+ Ok ( result)
224+ }
202225}
203226
204227impl Plugin for NodeGroupsPlugin { }
@@ -277,9 +300,39 @@ impl SchedulerPlugin for NodeGroupsPlugin {
277300 . position ( |n| n == & node_address. to_string ( ) )
278301 . unwrap ( ) ;
279302
280- let mut final_tasks: Vec < Task > = Vec :: new ( ) ;
281- for task in tasks {
282- let mut task_clone = task. clone ( ) ;
303+ let mut current_task: Option < Task > = None ;
304+ match self . get_current_group_task ( & group. id ) {
305+ Ok ( Some ( task) ) => {
306+ current_task = Some ( task) ;
307+ }
308+ Ok ( None ) => {
309+ if tasks. is_empty ( ) {
310+ return vec ! [ ] ;
311+ }
312+ if let Some ( new_task) = tasks. choose ( & mut rand:: rng ( ) ) {
313+ let task_id = new_task. id . to_string ( ) ;
314+ match self . assign_task_to_group ( & group. id , & task_id) {
315+ Ok ( true ) => {
316+ // Successfully assigned the task
317+ current_task = Some ( new_task. clone ( ) ) ;
318+ }
319+ Ok ( false ) => {
320+ // Another node already assigned a task, try to get it
321+ if let Ok ( Some ( task) ) = self . get_current_group_task ( & group. id ) {
322+ current_task = Some ( task) ;
323+ }
324+ }
325+ Err ( e) => {
326+ error ! ( "Failed to assign task to group: {}" , e) ;
327+ }
328+ }
329+ }
330+ }
331+ _ => { }
332+ }
333+
334+ if let Some ( t) = current_task {
335+ let mut task_clone = t. clone ( ) ;
283336
284337 let next_node_idx = ( node_group_index + 1 ) % group. nodes . len ( ) ;
285338 let next_node_addr = group. nodes . iter ( ) . nth ( next_node_idx) . unwrap ( ) ;
@@ -317,17 +370,8 @@ impl SchedulerPlugin for NodeGroupsPlugin {
317370 } )
318371 . collect :: < Vec < String > > ( )
319372 } ) ;
320-
321- final_tasks. push ( task_clone) ;
373+ return vec ! [ task_clone] ;
322374 }
323-
324- info ! (
325- "Returning {} tasks for node {} in group {}" ,
326- final_tasks. len( ) ,
327- node_address,
328- group. id
329- ) ;
330- return final_tasks;
331375 }
332376 info ! (
333377 "Node {} is not in a group, skipping all tasks" ,
@@ -345,6 +389,7 @@ mod tests {
345389 use alloy:: primitives:: Address ;
346390 use shared:: models:: task:: TaskState ;
347391 use std:: { collections:: HashMap , str:: FromStr , sync:: Arc } ;
392+
348393 use uuid:: Uuid ;
349394
350395 fn create_test_node ( addr : & str , status : NodeStatus ) -> OrchestratorNode {
@@ -468,37 +513,87 @@ mod tests {
468513 created_at : 0 ,
469514 updated_at : None ,
470515 } ;
516+ plugin. store_context . task_store . add_task ( task1. clone ( ) ) ;
517+
518+ let mut task2 = task1. clone ( ) ;
519+ task2. id = Uuid :: new_v4 ( ) ;
520+ plugin. store_context . task_store . add_task ( task2. clone ( ) ) ;
471521
472- let tasks = vec ! [ task1] ;
522+ let mut task3 = task1. clone ( ) ;
523+ task3. id = Uuid :: new_v4 ( ) ;
524+ plugin. store_context . task_store . add_task ( task3. clone ( ) ) ;
525+
526+ let tasks = vec ! [ task1, task2, task3] ;
473527
474528 let filtered_tasks = plugin. filter_tasks ( & tasks, & node1. address ) ;
475529 assert_eq ! ( filtered_tasks. len( ) , 0 ) ;
476530
477531 let _ = plugin
478532 . handle_status_change ( & node1, & NodeStatus :: Healthy )
479533 . await ;
534+ let mut tasks_clone = tasks. clone ( ) ;
535+ tasks_clone. reverse ( ) ;
536+ assert_ne ! ( tasks_clone[ 0 ] . id, tasks[ 0 ] . id) ;
480537
481- let filtered_tasks = plugin. filter_tasks ( & tasks, & node1. address ) ;
538+ let ( filtered_tasks_1, filtered_tasks_2) = tokio:: join!(
539+ async { plugin. filter_tasks( & tasks, & node1. address) } ,
540+ async { plugin. filter_tasks( & tasks_clone, & node2. address) }
541+ ) ;
482542
483543 // Check both nodes get assigned valid and different indexes
484- assert_eq ! ( filtered_tasks. len( ) , 1 ) ;
485- let task = & filtered_tasks[ 0 ] ;
486- let env_vars = task. env_vars . as_ref ( ) . unwrap ( ) ;
487- assert_eq ! ( env_vars. get( "GROUP_INDEX" ) . unwrap( ) , "0" ) ;
488- assert_eq ! ( env_vars. get( "RANK" ) . unwrap( ) , "0" ) ;
489- assert_eq ! ( env_vars. get( "WORLD_SIZE" ) . unwrap( ) , "2" ) ;
490- assert_eq ! ( task. args. as_ref( ) . unwrap( ) [ 3 ] , "model/Qwen3-14B-0.2" ) ;
491- assert_ne ! ( env_vars. get( "GROUP_ID" ) . unwrap( ) , "${GROUP_ID}" ) ;
544+ // Also ensure both nodes get the same task
545+ assert_eq ! ( filtered_tasks_1. len( ) , 1 ) ;
546+ let task_node_1 = & filtered_tasks_1[ 0 ] ;
547+ let env_vars_1 = task_node_1. env_vars . as_ref ( ) . unwrap ( ) ;
548+ assert_eq ! ( env_vars_1. get( "GROUP_INDEX" ) . unwrap( ) , "0" ) ;
549+ assert_eq ! ( env_vars_1. get( "RANK" ) . unwrap( ) , "0" ) ;
550+ assert_eq ! ( env_vars_1. get( "WORLD_SIZE" ) . unwrap( ) , "2" ) ;
551+ assert_eq ! ( task_node_1. args. as_ref( ) . unwrap( ) [ 3 ] , "model/Qwen3-14B-0.2" ) ;
552+ assert_ne ! ( env_vars_1. get( "GROUP_ID" ) . unwrap( ) , "${GROUP_ID}" ) ;
553+
554+ assert_eq ! ( filtered_tasks_2. len( ) , 1 ) ;
555+ let task_node_2 = & filtered_tasks_2[ 0 ] ;
556+ let env_vars_2 = task_node_2. env_vars . as_ref ( ) . unwrap ( ) ;
557+ assert_eq ! ( env_vars_2. get( "GROUP_INDEX" ) . unwrap( ) , "1" ) ;
558+ assert_eq ! ( env_vars_2. get( "RANK" ) . unwrap( ) , "1" ) ;
559+ assert_eq ! ( env_vars_2. get( "WORLD_SIZE" ) . unwrap( ) , "2" ) ;
560+ assert_eq ! ( task_node_2. args. as_ref( ) . unwrap( ) [ 3 ] , "model/Qwen3-14B-1.2" ) ;
561+ assert_ne ! ( env_vars_2. get( "GROUP_ID" ) . unwrap( ) , "${GROUP_ID}" ) ;
562+
563+ assert_eq ! ( task_node_1. id, task_node_2. id) ;
564+ }
565+
566+ #[ tokio:: test]
567+ async fn test_group_scheduling_without_tasks ( ) {
568+ let store: Arc < RedisStore > = Arc :: new ( RedisStore :: new_test ( ) ) ;
569+ let context_store = store. clone ( ) ;
570+ let store_context = Arc :: new ( StoreContext :: new ( context_store) ) ;
571+
572+ let plugin = NodeGroupsPlugin :: new ( 2 , 5 , store. clone ( ) , store_context) ;
573+ let node1 = create_test_node (
574+ "0x1234567890123456789012345678901234567890" ,
575+ NodeStatus :: Healthy ,
576+ ) ;
577+ plugin. store_context . node_store . add_node ( node1. clone ( ) ) ;
578+ let node2 = create_test_node (
579+ "0x2234567890123456789012345678901234567890" ,
580+ NodeStatus :: Healthy ,
581+ ) ;
582+ plugin. store_context . node_store . add_node ( node2. clone ( ) ) ;
583+ let tasks = vec ! [ ] ;
584+
585+ let filtered_tasks = plugin. filter_tasks ( & tasks, & node1. address ) ;
586+ assert_eq ! ( filtered_tasks. len( ) , 0 ) ;
587+
588+ let _ = plugin
589+ . handle_status_change ( & node1, & NodeStatus :: Healthy )
590+ . await ;
591+
592+ let filtered_tasks = plugin. filter_tasks ( & tasks, & node1. address ) ;
593+ assert_eq ! ( filtered_tasks. len( ) , 0 ) ;
492594
493595 let filtered_tasks = plugin. filter_tasks ( & tasks, & node2. address ) ;
494- assert_eq ! ( filtered_tasks. len( ) , 1 ) ;
495- let task = & filtered_tasks[ 0 ] ;
496- let env_vars = task. env_vars . as_ref ( ) . unwrap ( ) ;
497- assert_eq ! ( env_vars. get( "GROUP_INDEX" ) . unwrap( ) , "1" ) ;
498- assert_eq ! ( env_vars. get( "RANK" ) . unwrap( ) , "1" ) ;
499- assert_eq ! ( env_vars. get( "WORLD_SIZE" ) . unwrap( ) , "2" ) ;
500- assert_eq ! ( task. args. as_ref( ) . unwrap( ) [ 3 ] , "model/Qwen3-14B-1.2" ) ;
501- assert_ne ! ( env_vars. get( "GROUP_ID" ) . unwrap( ) , "${GROUP_ID}" ) ;
596+ assert_eq ! ( filtered_tasks. len( ) , 0 ) ;
502597 }
503598
504599 #[ tokio:: test]
@@ -556,6 +651,7 @@ mod tests {
556651 created_at : 0 ,
557652 updated_at : None ,
558653 } ;
654+ plugin. store_context . task_store . add_task ( task. clone ( ) ) ;
559655
560656 let tasks = vec ! [ task] ;
561657
@@ -855,8 +951,6 @@ mod tests {
855951 let _ = plugin
856952 . handle_status_change ( & node2, & NodeStatus :: Healthy )
857953 . await ;
858- let nodes = plugin. store_context . node_store . get_nodes ( ) ;
859- println ! ( "nodes {:?}" , nodes) ;
860954
861955 let node_2_group_id: Option < String > = conn
862956 . hget ( NODE_GROUP_MAP_KEY , node2. address . to_string ( ) )
0 commit comments