Skip to content

Commit 4674ce7

Browse files
committed
Enable eval DAG batching
1 parent 1a3884e commit 4674ce7

File tree

2 files changed

+97
-2
lines changed

2 files changed

+97
-2
lines changed

src/commands/eval/runner/execute/dag.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ fn stage_hints(stage: EvalFixtureStage) -> DagNodeExecutionHints {
197197
subgraph: None,
198198
},
199199
EvalFixtureStage::ReproductionValidation => DagNodeExecutionHints {
200-
parallelizable: false,
200+
parallelizable: true,
201201
retryable: true,
202202
side_effects: false,
203203
subgraph: None,
@@ -691,6 +691,17 @@ mod tests {
691691
assert_eq!(artifact.dependencies, vec!["benchmark_metrics"]);
692692
}
693693

694+
#[test]
695+
fn build_stage_specs_marks_reproduction_parallelizable() {
696+
let specs = build_stage_specs(true);
697+
let reproduction = specs
698+
.iter()
699+
.find(|spec| spec.id == EvalFixtureStage::ReproductionValidation)
700+
.unwrap();
701+
702+
assert!(reproduction.hints.parallelizable);
703+
}
704+
694705
#[test]
695706
fn eval_fixture_graph_contract_exposes_reproduction_outputs() {
696707
let graph = describe_eval_fixture_graph(true);

src/core/dag.rs

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ where
236236
for index in ready_indices.iter().copied() {
237237
let spec = specs[index].clone();
238238
if !spec.enabled || !spec.hints.parallelizable {
239-
break;
239+
continue;
240240
}
241241
let sequence = launch_sequence;
242242
launch_sequence += 1;
@@ -663,4 +663,88 @@ mod tests {
663663
);
664664
assert_eq!(applied, vec!["root", "branch"]);
665665
}
666+
667+
#[tokio::test]
668+
async fn execute_dag_with_parallelism_skips_non_parallel_ready_nodes_when_batching() {
669+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
670+
enum MixedNode {
671+
Root,
672+
FastA,
673+
Slow,
674+
FastB,
675+
}
676+
677+
impl DagNode for MixedNode {
678+
fn name(&self) -> &'static str {
679+
match self {
680+
Self::Root => "root",
681+
Self::FastA => "fast_a",
682+
Self::Slow => "slow",
683+
Self::FastB => "fast_b",
684+
}
685+
}
686+
}
687+
688+
let specs = vec![
689+
DagNodeSpec {
690+
id: MixedNode::Root,
691+
dependencies: vec![],
692+
hints: hints(false),
693+
enabled: true,
694+
},
695+
DagNodeSpec {
696+
id: MixedNode::FastA,
697+
dependencies: vec![MixedNode::Root],
698+
hints: hints(true),
699+
enabled: true,
700+
},
701+
DagNodeSpec {
702+
id: MixedNode::Slow,
703+
dependencies: vec![MixedNode::Root],
704+
hints: hints(false),
705+
enabled: true,
706+
},
707+
DagNodeSpec {
708+
id: MixedNode::FastB,
709+
dependencies: vec![MixedNode::Root],
710+
hints: hints(true),
711+
enabled: true,
712+
},
713+
];
714+
let active = Arc::new(AtomicUsize::new(0));
715+
let max_active = Arc::new(AtomicUsize::new(0));
716+
717+
let records = execute_dag_with_parallelism(
718+
&specs,
719+
|node| {
720+
let active = Arc::clone(&active);
721+
let max_active = Arc::clone(&max_active);
722+
Ok(async move {
723+
if matches!(node, MixedNode::FastA | MixedNode::FastB) {
724+
let current = active.fetch_add(1, Ordering::SeqCst) + 1;
725+
let observed_max = max_active.load(Ordering::SeqCst);
726+
if current > observed_max {
727+
max_active.store(current, Ordering::SeqCst);
728+
}
729+
tokio::time::sleep(Duration::from_millis(25)).await;
730+
active.fetch_sub(1, Ordering::SeqCst);
731+
}
732+
Ok(node.name().to_string())
733+
}
734+
.boxed())
735+
},
736+
|_, _| Ok(()),
737+
)
738+
.await
739+
.unwrap();
740+
741+
assert_eq!(
742+
records
743+
.iter()
744+
.map(|record| record.name.as_str())
745+
.collect::<Vec<_>>(),
746+
vec!["root", "fast_a", "fast_b", "slow"]
747+
);
748+
assert!(max_active.load(Ordering::SeqCst) >= 2);
749+
}
666750
}

0 commit comments

Comments
 (0)