Skip to content

Commit 912343e

Browse files
Preserve DAG order for disabled nodes
Co-authored-by: EvalOpsBot <EvalOpsBot@users.noreply.github.com>
1 parent f6547d3 commit 912343e

File tree

1 file changed

+79
-76
lines changed

1 file changed

+79
-76
lines changed

src/core/dag.rs

Lines changed: 79 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -183,47 +183,11 @@ where
183183
let mut launch_sequence = 0usize;
184184

185185
while completed.len() < specs.len() {
186-
let mut progressed = false;
187-
loop {
188-
let Some(spec) = specs
189-
.iter()
190-
.find(|candidate| {
191-
!candidate.enabled
192-
&& !completed.contains(&candidate.id)
193-
&& !in_flight.contains(&candidate.id)
194-
&& candidate
195-
.dependencies
196-
.iter()
197-
.all(|dependency| completed.contains(dependency))
198-
})
199-
.cloned()
200-
else {
201-
break;
202-
};
203-
204-
recorded.push((
205-
launch_sequence,
206-
DagExecutionRecord {
207-
name: spec.id.name().to_string(),
208-
enabled: false,
209-
duration_ms: 0,
210-
},
211-
));
212-
launch_sequence += 1;
213-
completed.insert(spec.id);
214-
progressed = true;
215-
}
216-
217-
if completed.len() == specs.len() {
218-
break;
219-
}
220-
221186
let ready_indices = specs
222187
.iter()
223188
.enumerate()
224189
.filter(|(_, candidate)| {
225-
candidate.enabled
226-
&& !completed.contains(&candidate.id)
190+
!completed.contains(&candidate.id)
227191
&& !in_flight.contains(&candidate.id)
228192
&& candidate
229193
.dependencies
@@ -234,12 +198,24 @@ where
234198
.collect::<Vec<_>>();
235199

236200
if join_set.is_empty() {
237-
if let Some(index) = ready_indices
238-
.iter()
239-
.copied()
240-
.find(|index| !specs[*index].hints.parallelizable)
241-
{
242-
let spec = specs[index].clone();
201+
let Some(index) = ready_indices.first().copied() else {
202+
anyhow::bail!("DAG has unresolved or cyclic dependencies");
203+
};
204+
let spec = specs[index].clone();
205+
if !spec.enabled {
206+
recorded.push((
207+
launch_sequence,
208+
DagExecutionRecord {
209+
name: spec.id.name().to_string(),
210+
enabled: false,
211+
duration_ms: 0,
212+
},
213+
));
214+
launch_sequence += 1;
215+
completed.insert(spec.id);
216+
continue;
217+
}
218+
if !spec.hints.parallelizable {
243219
let started = Instant::now();
244220
let output = spawn(spec.id.clone())?.await?;
245221
apply(spec.id.clone(), output)?;
@@ -255,38 +231,23 @@ where
255231
completed.insert(spec.id);
256232
continue;
257233
}
234+
}
258235

259-
for index in ready_indices.iter().copied() {
260-
let spec = specs[index].clone();
261-
let sequence = launch_sequence;
262-
launch_sequence += 1;
263-
let id = spec.id.clone();
264-
let future = spawn(id.clone())?;
265-
let started = Instant::now();
266-
in_flight.insert(id.clone());
267-
join_set.spawn(async move {
268-
let output = future.await;
269-
(sequence, id, started.elapsed().as_millis() as u64, output)
270-
});
271-
}
272-
} else {
273-
for index in ready_indices
274-
.iter()
275-
.copied()
276-
.filter(|index| specs[*index].hints.parallelizable)
277-
{
278-
let spec = specs[index].clone();
279-
let sequence = launch_sequence;
280-
launch_sequence += 1;
281-
let id = spec.id.clone();
282-
let future = spawn(id.clone())?;
283-
let started = Instant::now();
284-
in_flight.insert(id.clone());
285-
join_set.spawn(async move {
286-
let output = future.await;
287-
(sequence, id, started.elapsed().as_millis() as u64, output)
288-
});
236+
for index in ready_indices.iter().copied() {
237+
let spec = specs[index].clone();
238+
if !spec.enabled || !spec.hints.parallelizable {
239+
break;
289240
}
241+
let sequence = launch_sequence;
242+
launch_sequence += 1;
243+
let id = spec.id.clone();
244+
let future = spawn(id.clone())?;
245+
let started = Instant::now();
246+
in_flight.insert(id.clone());
247+
join_set.spawn(async move {
248+
let output = future.await;
249+
(sequence, id, started.elapsed().as_millis() as u64, output)
250+
});
290251
}
291252

292253
if !join_set.is_empty() {
@@ -310,9 +271,7 @@ where
310271
continue;
311272
}
312273

313-
if !progressed {
314-
anyhow::bail!("DAG has unresolved or cyclic dependencies");
315-
}
274+
anyhow::bail!("DAG has unresolved or cyclic dependencies");
316275
}
317276

318277
recorded.sort_by_key(|(sequence, _)| *sequence);
@@ -659,4 +618,48 @@ mod tests {
659618
assert!(applied.contains(&"branch".to_string()));
660619
assert!(applied.contains(&"leaf".to_string()));
661620
}
621+
622+
#[tokio::test]
623+
async fn execute_dag_with_parallelism_preserves_ready_spec_order_for_disabled_nodes() {
624+
let specs = vec![
625+
DagNodeSpec {
626+
id: TestNode::Root,
627+
dependencies: vec![],
628+
hints: hints(false),
629+
enabled: true,
630+
},
631+
DagNodeSpec {
632+
id: TestNode::Branch,
633+
dependencies: vec![TestNode::Root],
634+
hints: hints(true),
635+
enabled: true,
636+
},
637+
DagNodeSpec {
638+
id: TestNode::Leaf,
639+
dependencies: vec![TestNode::Root],
640+
hints: hints(true),
641+
enabled: false,
642+
},
643+
];
644+
let mut applied = Vec::new();
645+
646+
let records = execute_dag_with_parallelism(
647+
&specs,
648+
|node| {
649+
Ok(async move { Ok(node.name().to_string()) }.boxed())
650+
},
651+
|_, output| {
652+
applied.push(output);
653+
Ok(())
654+
},
655+
)
656+
.await
657+
.unwrap();
658+
659+
assert_eq!(
660+
records.iter().map(|record| record.name.as_str()).collect::<Vec<_>>(),
661+
vec!["root", "branch", "leaf"]
662+
);
663+
assert_eq!(applied, vec!["root", "branch"]);
664+
}
662665
}

0 commit comments

Comments
 (0)