Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.

Commit 76c2ee3

Browse files
authored
improve task creation api (#459)
* Deprecate old "cmd" + "args" task API and introduce simply task creation
1 parent cc4818f commit 76c2ee3

6 files changed

Lines changed: 57 additions & 42 deletions

File tree

crates/orchestrator/src/plugins/node_groups/scheduler_impl.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ impl SchedulerPlugin for NodeGroupsPlugin {
152152
*value = new_value;
153153
}
154154
task_clone.env_vars = Some(env_vars);
155-
task_clone.args = task_clone.args.map(|args| {
156-
args.into_iter()
155+
task_clone.cmd = task_clone.cmd.map(|cmd| {
156+
cmd.into_iter()
157157
.map(|arg| {
158158
arg.replace("${GROUP_INDEX}", &idx.to_string())
159159
.replace("${GROUP_SIZE}", &group.nodes.len().to_string())

crates/orchestrator/src/plugins/node_groups/tests.rs

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -533,8 +533,8 @@ async fn test_group_scheduling() {
533533
image: "prime-vllm".to_string(),
534534
name: "test-task".to_string(),
535535
env_vars: Some(env_vars),
536-
command: Some("uv".to_string()),
537-
args: Some(vec![
536+
cmd: Some(vec![
537+
"uv".to_string(),
538538
"run".to_string(),
539539
"generate.py".to_string(),
540540
"--model".to_string(),
@@ -548,6 +548,7 @@ async fn test_group_scheduling() {
548548
"--file-number".to_string(),
549549
"${LAST_FILE_IDX}".to_string(),
550550
]),
551+
entrypoint: None,
551552
state: TaskState::PENDING,
552553
created_at: 0,
553554
..Default::default()
@@ -609,23 +610,23 @@ async fn test_group_scheduling() {
609610
assert_eq!(env_vars_1.get("GROUP_INDEX").unwrap(), "0");
610611
assert_eq!(env_vars_1.get("RANK").unwrap(), "0");
611612
assert_eq!(env_vars_1.get("WORLD_SIZE").unwrap(), "2");
612-
assert_eq!(task_node_1.args.as_ref().unwrap()[3], "model/Qwen3-14B-0.2");
613+
assert_eq!(task_node_1.cmd.as_ref().unwrap()[4], "model/Qwen3-14B-0.2");
613614
assert_ne!(env_vars_1.get("GROUP_ID").unwrap(), "${GROUP_ID}");
614615
assert_eq!(env_vars_1.get("TOTAL_UPLOAD_COUNT").unwrap(), "1");
615616
assert_eq!(env_vars_1.get("LAST_FILE_IDX").unwrap(), "0");
616-
assert_eq!(task_node_1.args.as_ref().unwrap()[9], "1"); // Check upload count in args
617+
assert_eq!(task_node_1.cmd.as_ref().unwrap()[10], "1"); // Check upload count in cmd
617618

618619
assert_eq!(filtered_tasks_2.len(), 1);
619620
let task_node_2 = &filtered_tasks_2[0];
620621
let env_vars_2 = task_node_2.env_vars.as_ref().unwrap();
621622
assert_eq!(env_vars_2.get("GROUP_INDEX").unwrap(), "1");
622623
assert_eq!(env_vars_2.get("RANK").unwrap(), "1");
623624
assert_eq!(env_vars_2.get("WORLD_SIZE").unwrap(), "2");
624-
assert_eq!(task_node_2.args.as_ref().unwrap()[3], "model/Qwen3-14B-1.2");
625+
assert_eq!(task_node_2.cmd.as_ref().unwrap()[4], "model/Qwen3-14B-1.2");
625626
assert_ne!(env_vars_2.get("GROUP_ID").unwrap(), "${GROUP_ID}");
626627
assert_eq!(env_vars_2.get("TOTAL_UPLOAD_COUNT").unwrap(), "0");
627628
assert_eq!(env_vars_2.get("LAST_FILE_IDX").unwrap(), "0");
628-
assert_eq!(task_node_2.args.as_ref().unwrap()[9], "0"); // Check upload count in args
629+
assert_eq!(task_node_2.cmd.as_ref().unwrap()[10], "0"); // Check upload count in cmd
629630

630631
assert_eq!(task_node_1.id, task_node_2.id);
631632
}
@@ -752,8 +753,12 @@ async fn test_group_formation_with_max_size() {
752753
image: "test-image".to_string(),
753754
name: "test-task".to_string(),
754755
env_vars: Some(env_vars),
755-
command: Some("run".to_string()),
756-
args: Some(vec!["--index".to_string(), "${GROUP_INDEX}".to_string()]),
756+
cmd: Some(vec![
757+
"run".to_string(),
758+
"--index".to_string(),
759+
"${GROUP_INDEX}".to_string(),
760+
]),
761+
entrypoint: None,
757762
state: TaskState::PENDING,
758763
created_at: 0,
759764
..Default::default()
@@ -851,8 +856,12 @@ async fn test_node_groups_with_allowed_topologies() {
851856
image: "test-image".to_string(),
852857
name: "test-task".to_string(),
853858
env_vars: None,
854-
command: Some("run".to_string()),
855-
args: Some(vec!["--index".to_string(), "${GROUP_INDEX}".to_string()]),
859+
cmd: Some(vec![
860+
"run".to_string(),
861+
"--index".to_string(),
862+
"${GROUP_INDEX}".to_string(),
863+
]),
864+
entrypoint: None,
856865
scheduling_config: Some(SchedulingConfig {
857866
plugins: Some(HashMap::from([(
858867
"node_groups".to_string(),
@@ -884,8 +893,12 @@ async fn test_node_groups_with_allowed_topologies() {
884893
image: "test-image".to_string(),
885894
name: "test-task".to_string(),
886895
env_vars: None,
887-
command: Some("run".to_string()),
888-
args: Some(vec!["--index".to_string(), "${GROUP_INDEX}".to_string()]),
896+
cmd: Some(vec![
897+
"run".to_string(),
898+
"--index".to_string(),
899+
"${GROUP_INDEX}".to_string(),
900+
]),
901+
entrypoint: None,
889902
scheduling_config: Some(SchedulingConfig {
890903
plugins: Some(HashMap::from([(
891904
"node_groups".to_string(),

crates/orchestrator/src/scheduler/mod.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ impl Scheduler {
4444
}
4545
}
4646

47-
// Replace variables in args
48-
if let Some(args) = &mut task.args {
49-
for arg in args.iter_mut() {
47+
// Replace variables in cmd
48+
if let Some(cmd) = &mut task.cmd {
49+
for arg in cmd.iter_mut() {
5050
*arg = arg
5151
.replace("${TASK_ID}", &task.id.to_string())
5252
.replace("${NODE_ADDRESS}", &node_address.to_string());
@@ -107,10 +107,11 @@ mod tests {
107107
state: TaskState::PENDING,
108108
created_at: 1,
109109
env_vars: Some(env_vars),
110-
args: Some(vec![
110+
cmd: Some(vec![
111111
"--task=${TASK_ID}".to_string(),
112112
"--node=${NODE_ADDRESS}".to_string(),
113113
]),
114+
entrypoint: None,
114115
..Default::default()
115116
};
116117

@@ -132,9 +133,9 @@ mod tests {
132133
&format!("node-{}", node_address)
133134
);
134135

135-
// Check args replacement
136-
let args = returned_task.args.unwrap();
137-
assert_eq!(args[0], format!("--task={}", task.id));
138-
assert_eq!(args[1], format!("--node={}", node_address));
136+
// Check cmd replacement
137+
let cmd = returned_task.cmd.unwrap();
138+
assert_eq!(cmd[0], format!("--task={}", task.id));
139+
assert_eq!(cmd[1], format!("--node={}", node_address));
139140
}
140141
}

crates/shared/src/models/task.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ pub struct TaskRequest {
6363
pub image: String,
6464
pub name: String,
6565
pub env_vars: Option<std::collections::HashMap<String, String>>,
66-
pub command: Option<String>,
67-
pub args: Option<Vec<String>>,
66+
pub cmd: Option<Vec<String>>,
67+
pub entrypoint: Option<Vec<String>>,
6868
pub scheduling_config: Option<SchedulingConfig>,
6969
pub storage_config: Option<StorageConfig>,
7070
pub metadata: Option<TaskMetadata>,
@@ -82,8 +82,8 @@ pub struct Task {
8282
pub image: String,
8383
pub name: String,
8484
pub env_vars: Option<std::collections::HashMap<String, String>>,
85-
pub command: Option<String>,
86-
pub args: Option<Vec<String>>,
85+
pub cmd: Option<Vec<String>>,
86+
pub entrypoint: Option<Vec<String>>,
8787
pub state: TaskState,
8888
#[serde(default)]
8989
pub created_at: i64,
@@ -104,8 +104,8 @@ impl Default for Task {
104104
image: String::new(),
105105
name: String::new(),
106106
env_vars: None,
107-
command: None,
108-
args: None,
107+
cmd: None,
108+
entrypoint: None,
109109
state: TaskState::default(),
110110
created_at: 0,
111111
updated_at: None,
@@ -147,6 +147,7 @@ impl StorageConfig {
147147
Ok(())
148148
}
149149
}
150+
150151
impl TryFrom<TaskRequest> for Task {
151152
type Error = String;
152153

@@ -159,8 +160,8 @@ impl TryFrom<TaskRequest> for Task {
159160
id: Uuid::new_v4(),
160161
image: request.image,
161162
name: request.name,
162-
command: request.command,
163-
args: request.args,
163+
cmd: request.cmd,
164+
entrypoint: request.entrypoint,
164165
env_vars: request.env_vars,
165166
state: TaskState::PENDING,
166167
created_at: Utc::now().timestamp_millis(),

crates/worker/src/docker/docker_manager.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ impl DockerManager {
123123
// Simple Vec of (host_path, container_path, read_only)
124124
volumes: Option<Vec<(String, String, bool)>>,
125125
shm_size: Option<u64>,
126+
entrypoint: Option<Vec<String>>,
126127
) -> Result<String, DockerError> {
127128
info!("Starting to pull image: {}", image);
128129

@@ -261,6 +262,9 @@ impl DockerManager {
261262
cmd: command
262263
.as_ref()
263264
.map(|c| c.iter().map(String::as_str).collect()),
265+
entrypoint: entrypoint
266+
.as_ref()
267+
.map(|e| e.iter().map(String::as_str).collect()),
264268
host_config,
265269
..Default::default()
266270
};

crates/worker/src/docker/service.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -189,21 +189,17 @@ impl DockerService {
189189
return;
190190
}
191191
};
192-
let cmd_full = (payload.command, payload.args);
193-
let cmd = match cmd_full {
194-
(Some(c), Some(a)) => {
195-
let mut cmd = vec![c];
196-
cmd.extend(a.into_iter().map(|arg| {
192+
let cmd = match payload.cmd {
193+
Some(cmd_vec) => {
194+
cmd_vec.into_iter().map(|arg| {
197195
if let Some(seed) = p2p_seed {
198196
arg.replace("${WORKER_P2P_SEED}", &seed.to_string())
199197
} else {
200198
arg
201199
}
202-
}));
203-
cmd
200+
}).collect()
204201
}
205-
(Some(c), None) => vec![c],
206-
_ => vec!["sleep".to_string(), "infinity".to_string()],
202+
None => vec!["sleep".to_string(), "infinity".to_string()],
207203
};
208204

209205
let mut env_vars: HashMap<String, String> = HashMap::new();
@@ -231,7 +227,7 @@ impl DockerService {
231227
67108864 // Default to 64MB in bytes
232228
}
233229
};
234-
match manager_clone.start_container(&payload.image, &container_task_id, Some(env_vars), Some(cmd), gpu, Some(volumes), Some(shm_size)).await {
230+
match manager_clone.start_container(&payload.image, &container_task_id, Some(env_vars), Some(cmd), gpu, Some(volumes), Some(shm_size), payload.entrypoint).await {
235231
Ok(container_id) => {
236232
Console::info("DockerService", &format!("Container started with id: {}", container_id));
237233
},
@@ -375,8 +371,8 @@ mod tests {
375371
name: "test".to_string(),
376372
id: Uuid::new_v4(),
377373
env_vars: None,
378-
command: Some("sleep".to_string()),
379-
args: Some(vec!["5".to_string()]), // Reduced sleep time
374+
cmd: Some(vec!["sleep".to_string(), "5".to_string()]), // Reduced sleep time
375+
entrypoint: None,
380376
state: TaskState::PENDING,
381377
created_at: Utc::now().timestamp_millis(),
382378
..Default::default()

0 commit comments

Comments
 (0)