Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.

Commit 226f710

Browse files
authored
Merge pull request #333 from PrimeIntellect-ai/improvement/single-task-to-group
Improvement: Ensure a group is only working on a single task
1 parent 402cb46 commit 226f710

4 files changed

Lines changed: 138 additions & 38 deletions

File tree

.github/workflows/checks.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
- name: Install Foundry
4444
uses: foundry-rs/foundry-toolchain@v1
4545
with:
46-
version: nightly
46+
version: v1.1.0
4747

4848
- name: Install Redis binary
4949
run: |
@@ -57,4 +57,4 @@ jobs:
5757
if: success() || failure()
5858
run: |
5959
redis-server --version
60-
cargo test -- --nocapture
60+
RUST_BACKTRACE=1 cargo test -- --nocapture

crates/orchestrator/src/status_update/plugins/node_groups.rs

Lines changed: 128 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use alloy::primitives::Address;
22
use anyhow::Error;
3-
use log::info;
3+
use log::{error, info, warn};
4+
use rand::seq::IndexedRandom;
45
use redis::Commands;
56
use serde::{Deserialize, Serialize};
67
use shared::models::task::Task;
@@ -18,6 +19,7 @@ use super::StatusUpdatePlugin;
1819

1920
const GROUP_KEY_PREFIX: &str = "node_group:";
2021
const NODE_GROUP_MAP_KEY: &str = "node_to_group";
22+
const GROUP_TASK_KEY_PREFIX: &str = "group_task:";
2123

2224
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
2325
pub struct NodeGroup {
@@ -58,6 +60,7 @@ impl NodeGroupsPlugin {
5860
fn get_group_key(group_id: &str) -> String {
5961
format!("{}{}", GROUP_KEY_PREFIX, group_id)
6062
}
63+
6164
fn try_form_new_group(&self, node_addr: Option<&str>) -> Result<Option<NodeGroup>, Error> {
6265
let mut conn = self.store.client.get_connection()?;
6366

@@ -199,6 +202,26 @@ impl NodeGroupsPlugin {
199202

200203
Ok(None)
201204
}
205+
fn get_current_group_task(&self, group_id: &str) -> Result<Option<Task>, Error> {
206+
let mut conn = self.store.client.get_connection()?;
207+
let task_key = format!("{}{}", GROUP_TASK_KEY_PREFIX, group_id);
208+
let task_id: Option<String> = conn.get(&task_key)?;
209+
210+
if let Some(task_id) = task_id {
211+
if let Some(task) = self.store_context.task_store.get_task(&task_id) {
212+
return Ok(Some(task));
213+
}
214+
warn!("Task id set but task not found");
215+
}
216+
Ok(None)
217+
}
218+
219+
fn assign_task_to_group(&self, group_id: &str, task_id: &str) -> Result<bool, Error> {
220+
let mut conn = self.store.client.get_connection()?;
221+
let task_key = format!("{}{}", GROUP_TASK_KEY_PREFIX, group_id);
222+
let result: bool = conn.set_nx::<_, _, bool>(&task_key, task_id)?;
223+
Ok(result)
224+
}
202225
}
203226

204227
impl Plugin for NodeGroupsPlugin {}
@@ -277,9 +300,39 @@ impl SchedulerPlugin for NodeGroupsPlugin {
277300
.position(|n| n == &node_address.to_string())
278301
.unwrap();
279302

280-
let mut final_tasks: Vec<Task> = Vec::new();
281-
for task in tasks {
282-
let mut task_clone = task.clone();
303+
let mut current_task: Option<Task> = None;
304+
match self.get_current_group_task(&group.id) {
305+
Ok(Some(task)) => {
306+
current_task = Some(task);
307+
}
308+
Ok(None) => {
309+
if tasks.is_empty() {
310+
return vec![];
311+
}
312+
if let Some(new_task) = tasks.choose(&mut rand::rng()) {
313+
let task_id = new_task.id.to_string();
314+
match self.assign_task_to_group(&group.id, &task_id) {
315+
Ok(true) => {
316+
// Successfully assigned the task
317+
current_task = Some(new_task.clone());
318+
}
319+
Ok(false) => {
320+
// Another node already assigned a task, try to get it
321+
if let Ok(Some(task)) = self.get_current_group_task(&group.id) {
322+
current_task = Some(task);
323+
}
324+
}
325+
Err(e) => {
326+
error!("Failed to assign task to group: {}", e);
327+
}
328+
}
329+
}
330+
}
331+
_ => {}
332+
}
333+
334+
if let Some(t) = current_task {
335+
let mut task_clone = t.clone();
283336

284337
let next_node_idx = (node_group_index + 1) % group.nodes.len();
285338
let next_node_addr = group.nodes.iter().nth(next_node_idx).unwrap();
@@ -317,17 +370,8 @@ impl SchedulerPlugin for NodeGroupsPlugin {
317370
})
318371
.collect::<Vec<String>>()
319372
});
320-
321-
final_tasks.push(task_clone);
373+
return vec![task_clone];
322374
}
323-
324-
info!(
325-
"Returning {} tasks for node {} in group {}",
326-
final_tasks.len(),
327-
node_address,
328-
group.id
329-
);
330-
return final_tasks;
331375
}
332376
info!(
333377
"Node {} is not in a group, skipping all tasks",
@@ -345,6 +389,7 @@ mod tests {
345389
use alloy::primitives::Address;
346390
use shared::models::task::TaskState;
347391
use std::{collections::HashMap, str::FromStr, sync::Arc};
392+
348393
use uuid::Uuid;
349394

350395
fn create_test_node(addr: &str, status: NodeStatus) -> OrchestratorNode {
@@ -468,37 +513,87 @@ mod tests {
468513
created_at: 0,
469514
updated_at: None,
470515
};
516+
plugin.store_context.task_store.add_task(task1.clone());
517+
518+
let mut task2 = task1.clone();
519+
task2.id = Uuid::new_v4();
520+
plugin.store_context.task_store.add_task(task2.clone());
471521

472-
let tasks = vec![task1];
522+
let mut task3 = task1.clone();
523+
task3.id = Uuid::new_v4();
524+
plugin.store_context.task_store.add_task(task3.clone());
525+
526+
let tasks = vec![task1, task2, task3];
473527

474528
let filtered_tasks = plugin.filter_tasks(&tasks, &node1.address);
475529
assert_eq!(filtered_tasks.len(), 0);
476530

477531
let _ = plugin
478532
.handle_status_change(&node1, &NodeStatus::Healthy)
479533
.await;
534+
let mut tasks_clone = tasks.clone();
535+
tasks_clone.reverse();
536+
assert_ne!(tasks_clone[0].id, tasks[0].id);
480537

481-
let filtered_tasks = plugin.filter_tasks(&tasks, &node1.address);
538+
let (filtered_tasks_1, filtered_tasks_2) = tokio::join!(
539+
async { plugin.filter_tasks(&tasks, &node1.address) },
540+
async { plugin.filter_tasks(&tasks_clone, &node2.address) }
541+
);
482542

483543
// Check both nodes get assigned valid and different indexes
484-
assert_eq!(filtered_tasks.len(), 1);
485-
let task = &filtered_tasks[0];
486-
let env_vars = task.env_vars.as_ref().unwrap();
487-
assert_eq!(env_vars.get("GROUP_INDEX").unwrap(), "0");
488-
assert_eq!(env_vars.get("RANK").unwrap(), "0");
489-
assert_eq!(env_vars.get("WORLD_SIZE").unwrap(), "2");
490-
assert_eq!(task.args.as_ref().unwrap()[3], "model/Qwen3-14B-0.2");
491-
assert_ne!(env_vars.get("GROUP_ID").unwrap(), "${GROUP_ID}");
544+
// Also ensure both nodes get the same task
545+
assert_eq!(filtered_tasks_1.len(), 1);
546+
let task_node_1 = &filtered_tasks_1[0];
547+
let env_vars_1 = task_node_1.env_vars.as_ref().unwrap();
548+
assert_eq!(env_vars_1.get("GROUP_INDEX").unwrap(), "0");
549+
assert_eq!(env_vars_1.get("RANK").unwrap(), "0");
550+
assert_eq!(env_vars_1.get("WORLD_SIZE").unwrap(), "2");
551+
assert_eq!(task_node_1.args.as_ref().unwrap()[3], "model/Qwen3-14B-0.2");
552+
assert_ne!(env_vars_1.get("GROUP_ID").unwrap(), "${GROUP_ID}");
553+
554+
assert_eq!(filtered_tasks_2.len(), 1);
555+
let task_node_2 = &filtered_tasks_2[0];
556+
let env_vars_2 = task_node_2.env_vars.as_ref().unwrap();
557+
assert_eq!(env_vars_2.get("GROUP_INDEX").unwrap(), "1");
558+
assert_eq!(env_vars_2.get("RANK").unwrap(), "1");
559+
assert_eq!(env_vars_2.get("WORLD_SIZE").unwrap(), "2");
560+
assert_eq!(task_node_2.args.as_ref().unwrap()[3], "model/Qwen3-14B-1.2");
561+
assert_ne!(env_vars_2.get("GROUP_ID").unwrap(), "${GROUP_ID}");
562+
563+
assert_eq!(task_node_1.id, task_node_2.id);
564+
}
565+
566+
#[tokio::test]
567+
async fn test_group_scheduling_without_tasks() {
568+
let store: Arc<RedisStore> = Arc::new(RedisStore::new_test());
569+
let context_store = store.clone();
570+
let store_context = Arc::new(StoreContext::new(context_store));
571+
572+
let plugin = NodeGroupsPlugin::new(2, 5, store.clone(), store_context);
573+
let node1 = create_test_node(
574+
"0x1234567890123456789012345678901234567890",
575+
NodeStatus::Healthy,
576+
);
577+
plugin.store_context.node_store.add_node(node1.clone());
578+
let node2 = create_test_node(
579+
"0x2234567890123456789012345678901234567890",
580+
NodeStatus::Healthy,
581+
);
582+
plugin.store_context.node_store.add_node(node2.clone());
583+
let tasks = vec![];
584+
585+
let filtered_tasks = plugin.filter_tasks(&tasks, &node1.address);
586+
assert_eq!(filtered_tasks.len(), 0);
587+
588+
let _ = plugin
589+
.handle_status_change(&node1, &NodeStatus::Healthy)
590+
.await;
591+
592+
let filtered_tasks = plugin.filter_tasks(&tasks, &node1.address);
593+
assert_eq!(filtered_tasks.len(), 0);
492594

493595
let filtered_tasks = plugin.filter_tasks(&tasks, &node2.address);
494-
assert_eq!(filtered_tasks.len(), 1);
495-
let task = &filtered_tasks[0];
496-
let env_vars = task.env_vars.as_ref().unwrap();
497-
assert_eq!(env_vars.get("GROUP_INDEX").unwrap(), "1");
498-
assert_eq!(env_vars.get("RANK").unwrap(), "1");
499-
assert_eq!(env_vars.get("WORLD_SIZE").unwrap(), "2");
500-
assert_eq!(task.args.as_ref().unwrap()[3], "model/Qwen3-14B-1.2");
501-
assert_ne!(env_vars.get("GROUP_ID").unwrap(), "${GROUP_ID}");
596+
assert_eq!(filtered_tasks.len(), 0);
502597
}
503598

504599
#[tokio::test]
@@ -556,6 +651,7 @@ mod tests {
556651
created_at: 0,
557652
updated_at: None,
558653
};
654+
plugin.store_context.task_store.add_task(task.clone());
559655

560656
let tasks = vec![task];
561657

@@ -855,8 +951,6 @@ mod tests {
855951
let _ = plugin
856952
.handle_status_change(&node2, &NodeStatus::Healthy)
857953
.await;
858-
let nodes = plugin.store_context.node_store.get_nodes();
859-
println!("nodes {:?}", nodes);
860954

861955
let node_2_group_id: Option<String> = conn
862956
.hget(NODE_GROUP_MAP_KEY, node2.address.to_string())

crates/orchestrator/src/store/domains/task_store.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,11 @@ impl TaskStore {
5656
// Remove task ID from list
5757
let _: () = con.lrem(TASK_LIST_KEY, 0, id).unwrap();
5858
}
59+
60+
pub fn get_task(&self, id: &str) -> Option<Task> {
61+
let mut con = self.redis.client.get_connection().unwrap();
62+
let task_key = format!("{}{}", TASK_KEY_PREFIX, id);
63+
let task: Option<Task> = con.get(&task_key).unwrap();
64+
task
65+
}
5966
}

crates/worker/src/checks/hardware/hardware_check.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ impl HardwareChecker {
4949
IssueType::InsufficientCpu,
5050
"Failed to detect CPU information",
5151
);
52-
return Err(Box::new(std::io::Error::new(
53-
std::io::ErrorKind::Other,
52+
return Err(Box::new(std::io::Error::other(
5453
"Failed to detect CPU information",
5554
)));
5655
}

0 commit comments

Comments
 (0)