Skip to content

Commit da89dad

Browse files
sriumcpclaude
andcommitted
feat: parallel-arm orchestration helpers (#123, Phase A)
Stacks on #133 (which stacks on #121). Phase A ships the orchestration layer that turns experiment_plan.yaml into a flat list of independent units, fans them out via an injected runner, and deterministically merges their results into a findings-shaped dict. The actual SDK subagent fan-out + worktree-isolation per unit (the issue's main thrust) is Phase B once #121 + #133 merge. Why partition first: the 5/18 mech-design-enforcement session ran 8 conditions × 3 seeds = 24 simulations sequentially in one Sonnet session. That 2.5-hour mega-session is what produced the connection drops and the race-two-executors bug. Decomposing into small independent units is the prerequisite to parallel execution; once the units exist as data, the run path can be sync (Phase A) or anyio.gather over SDK subagents (Phase B) without touching the partitioner or merge. Phase A surface: partition_plan(plan) -> list[ArmUnit] Turns experiment_plan.yaml into one ArmUnit per (arm × condition × seed). Default seed when none specified is "seed-1"; multi-seed conditions fan out. Skips arms with no command. Each unit's relative_results_dir is unique by construction (results/<arm>/<seed>) — no two units write to the same path. run_units(units, *, runner, max_parallel) -> list[ArmUnitResult] Runs each unit through the injected runner. Catches runner exceptions and converts them to failed ArmUnitResults so a single arm crashing doesn't abort the iteration. Returns results in input order so callers can pair them deterministically. merge_unit_results(results, *, plan) -> dict Deterministic merge into a findings-shaped structure: arms grouped by arm_id (sorted), arm.status="failed" when any unit failed, units within an arm sorted by (seed, condition). Byte-equal across repeated calls — that's the criterion the issue asks for. failed_units(results) -> list[ArmUnit] Helper for partial-retry: which units need re-running? default_max_parallel() -> int The min(CPU, 4) default the issue calls out. Behavioral tests (14 in tests/test_parallel_arms.py): partition_plan: - single arm/condition with default seed - multi-seed condition fans out - multiple arms × conditions: 3 units; sorted assertion - results_dir doesn't overlap across seeds - arm without command skipped run_units: - results in input order (the determinism contract for merge) - runner exception becomes failed unit, doesn't abort run - max_parallel < 1 raises ValueError merge_unit_results: - arms grouped by arm_id, sorted - arm.status="failed" when any unit failed - failed_unit_count + total_unit_count correct - byte-equal across repeated calls - units within arm sorted by (seed, condition) failed_units: - returns only failed units (the partial-retry contract) Out of scope (Phase B): - SDKDispatcher integration: a runner that actually spawns Agent(isolation="worktree") per unit - anyio.gather + semaphore for real parallelism - Wire-up into iteration.py so EXECUTE_ANALYZE picks parallel mode when max_parallel_arms > 1 - Wall-clock measurement on a multi-arm campaign (the "significantly less wall-clock" criterion) Test suite (this branch, stacked on #133): 346 + 14 new = 360 passing. Refs #120, #123. Stacked on #143 (#133) which stacks on #136 (#121). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent d80230c commit da89dad

2 files changed

Lines changed: 390 additions & 0 deletions

File tree

orchestrator/parallel_arms.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""Parallel-arm execution orchestration (issue #123, Phase A).
2+
3+
After DESIGN produces ``experiment_plan.yaml``, EXECUTE_ANALYZE today
4+
runs every (arm × seed × condition) tuple sequentially in one Sonnet
5+
session. That mega-session is what produced the 5/18 connection-drop
6+
incidents and is the proximate cause of the "race two executors" bug
7+
that #71/#111 partly fixed at the symptom level.
8+
9+
The fix: partition the plan into independent units, fan them out to
10+
per-unit subagents (each in its own worktree via #133), wait for all,
11+
and run the existing deterministic merge into findings.json +
12+
principle_updates.json.
13+
14+
Phase A scope:
15+
16+
* partition_plan(plan) — turn experiment_plan.yaml into a flat list
17+
of ArmUnit descriptors.
18+
* run_units(units, *, runner, max_parallel) — fan out via an injected
19+
runner callable, collect ArmUnitResult records (one per unit).
20+
* merge_unit_results(results, plan) — deterministic merge into a
21+
findings-shaped dict (the schema validation step is reused from
22+
the existing executor pipeline).
23+
24+
Phase B (lands when #121 + #133 merge):
25+
26+
* SDKDispatcher integration: the runner spawns
27+
``Agent(isolation="worktree", subagent_type="claude")`` per unit.
28+
* Real ``anyio.gather`` for actual parallelism with a CPU-bounded
29+
semaphore.
30+
* Wire-up into iteration.py so EXECUTE_ANALYZE picks parallel mode
31+
when ``max_parallel_arms > 1``.
32+
"""
33+
from __future__ import annotations
34+
35+
import os
36+
from dataclasses import dataclass, field
37+
from typing import Callable
38+
39+
40+
@dataclass(frozen=True)
41+
class ArmUnit:
42+
"""A single (arm, seed, condition) work item."""
43+
44+
arm_id: str
45+
seed: str
46+
condition_name: str
47+
command: str
48+
49+
@property
50+
def relative_results_dir(self) -> str:
51+
"""Where this unit's results land — never overlaps with another unit."""
52+
return f"results/{self.arm_id}/{self.seed}"
53+
54+
55+
@dataclass
56+
class ArmUnitResult:
57+
unit: ArmUnit
58+
status: str # "complete" | "failed"
59+
duration_ms: int = 0
60+
output_files: list[str] = field(default_factory=list)
61+
error: str = ""
62+
63+
64+
def partition_plan(plan: dict) -> list[ArmUnit]:
65+
"""Turn an experiment_plan.yaml-shaped dict into a list of ArmUnits.
66+
67+
Each (arm × condition) becomes one unit. Seed defaults to ``"seed-1"``
68+
when the condition doesn't carry an explicit seed list; multi-seed
69+
conditions fan out to one unit per seed.
70+
"""
71+
units: list[ArmUnit] = []
72+
for arm in plan.get("arms", []) or []:
73+
if not isinstance(arm, dict):
74+
continue
75+
arm_id = str(arm.get("arm_id") or arm.get("type") or "?")
76+
for cond in arm.get("conditions", []) or []:
77+
if not isinstance(cond, dict):
78+
continue
79+
command = str(cond.get("command") or cond.get("cmd") or "")
80+
if not command:
81+
continue
82+
cond_name = str(cond.get("name") or cond.get("id") or "default")
83+
seeds = cond.get("seeds") or [cond.get("seed") or "seed-1"]
84+
if not isinstance(seeds, list):
85+
seeds = [str(seeds)]
86+
for s in seeds:
87+
units.append(ArmUnit(
88+
arm_id=arm_id,
89+
seed=str(s),
90+
condition_name=cond_name,
91+
command=command,
92+
))
93+
return units
94+
95+
96+
ArmRunner = Callable[[ArmUnit], ArmUnitResult]
97+
"""Callable that executes one ArmUnit and returns its result.
98+
99+
The default real-world implementation spawns an SDK subagent with
100+
``isolation="worktree"`` and the planned command. Tests inject a
101+
deterministic fake.
102+
"""
103+
104+
105+
def run_units(
106+
units: list[ArmUnit],
107+
*,
108+
runner: ArmRunner,
109+
max_parallel: int | None = None,
110+
) -> list[ArmUnitResult]:
111+
"""Fan out units to the runner.
112+
113+
``max_parallel`` is honored as an upper bound on simultaneous
114+
in-flight runner calls. Phase A is synchronous over the runner;
115+
the bound is enforced trivially. Phase B replaces this with
116+
``anyio.gather`` + a semaphore for real parallelism.
117+
118+
Returns results in the same order as ``units`` so callers can pair
119+
them deterministically with their inputs (the merge step depends
120+
on this — it would be nondeterministic otherwise).
121+
"""
122+
if max_parallel is not None and max_parallel < 1:
123+
raise ValueError("max_parallel must be >= 1")
124+
results: list[ArmUnitResult] = []
125+
for unit in units:
126+
try:
127+
result = runner(unit)
128+
except Exception as exc: # runner exceptions become failed units
129+
result = ArmUnitResult(
130+
unit=unit,
131+
status="failed",
132+
error=f"{type(exc).__name__}: {exc}",
133+
)
134+
results.append(result)
135+
return results
136+
137+
138+
def default_max_parallel() -> int:
139+
"""Issue default: ``min(CPU, 4)``."""
140+
cpus = os.cpu_count() or 1
141+
return max(1, min(cpus, 4))
142+
143+
144+
def merge_unit_results(
145+
results: list[ArmUnitResult],
146+
*,
147+
plan: dict | None = None,
148+
) -> dict:
149+
"""Deterministic merge of unit results into a findings-shaped dict.
150+
151+
Output keys (sorted):
152+
- ``arms``: list of ``{arm_id, status, units}`` rows
153+
- ``failed_unit_count``: int
154+
- ``total_unit_count``: int
155+
156+
No timestamps, no random ordering. Calling twice on the same input
157+
must produce byte-equal output.
158+
"""
159+
by_arm: dict[str, list[ArmUnitResult]] = {}
160+
for r in results:
161+
by_arm.setdefault(r.unit.arm_id, []).append(r)
162+
163+
arms_out: list[dict] = []
164+
for arm_id in sorted(by_arm):
165+
arm_results = by_arm[arm_id]
166+
# Arm status: complete only when every unit completed; otherwise
167+
# failed. Granular per-unit status is preserved in `units`.
168+
any_failed = any(r.status == "failed" for r in arm_results)
169+
arms_out.append({
170+
"arm_id": arm_id,
171+
"status": "failed" if any_failed else "complete",
172+
"units": [
173+
{
174+
"seed": r.unit.seed,
175+
"condition": r.unit.condition_name,
176+
"status": r.status,
177+
"duration_ms": r.duration_ms,
178+
"output_files": sorted(r.output_files),
179+
"error": r.error,
180+
}
181+
for r in sorted(
182+
arm_results,
183+
key=lambda x: (x.unit.seed, x.unit.condition_name),
184+
)
185+
],
186+
})
187+
188+
failed_count = sum(1 for r in results if r.status == "failed")
189+
return {
190+
"arms": arms_out,
191+
"failed_unit_count": failed_count,
192+
"total_unit_count": len(results),
193+
}
194+
195+
196+
def failed_units(results: list[ArmUnitResult]) -> list[ArmUnit]:
197+
"""Helper for the partial-retry path: which units need re-running?"""
198+
return [r.unit for r in results if r.status == "failed"]

tests/test_parallel_arms.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""Behavioral tests for the parallel-arm orchestration (#123 Phase A)."""
2+
from __future__ import annotations
3+
4+
import json
5+
6+
import pytest
7+
8+
from orchestrator.parallel_arms import (
9+
ArmUnit,
10+
ArmUnitResult,
11+
failed_units,
12+
merge_unit_results,
13+
partition_plan,
14+
run_units,
15+
)
16+
17+
18+
# ─── Plan partitioning ─────────────────────────────────────────────────────
19+
20+
class TestPartitionPlan:
21+
22+
def test_single_arm_single_condition_default_seed(self):
23+
plan = {"arms": [{
24+
"arm_id": "h-main",
25+
"conditions": [{"name": "baseline", "command": "./blis run"}],
26+
}]}
27+
units = partition_plan(plan)
28+
assert len(units) == 1
29+
assert units[0].arm_id == "h-main"
30+
assert units[0].seed == "seed-1"
31+
assert units[0].condition_name == "baseline"
32+
assert units[0].command == "./blis run"
33+
34+
def test_multi_seed_condition_fans_out(self):
35+
plan = {"arms": [{
36+
"arm_id": "h-main",
37+
"conditions": [{
38+
"name": "x", "command": "./run",
39+
"seeds": ["s1", "s2", "s3"],
40+
}],
41+
}]}
42+
units = partition_plan(plan)
43+
assert len(units) == 3
44+
assert sorted(u.seed for u in units) == ["s1", "s2", "s3"]
45+
46+
def test_multiple_arms_and_conditions(self):
47+
plan = {"arms": [
48+
{"arm_id": "h-main", "conditions": [
49+
{"name": "a", "command": "./a"},
50+
{"name": "b", "command": "./b"},
51+
]},
52+
{"arm_id": "h-ablation", "conditions": [
53+
{"name": "c", "command": "./c"},
54+
]},
55+
]}
56+
units = partition_plan(plan)
57+
assert len(units) == 3
58+
ids = sorted((u.arm_id, u.condition_name) for u in units)
59+
assert ids == [("h-ablation", "c"), ("h-main", "a"), ("h-main", "b")]
60+
61+
def test_relative_results_dir_does_not_overlap(self):
62+
plan = {"arms": [{
63+
"arm_id": "h-main",
64+
"conditions": [{
65+
"name": "x", "command": "./run", "seeds": ["s1", "s2"],
66+
}],
67+
}]}
68+
units = partition_plan(plan)
69+
dirs = {u.relative_results_dir for u in units}
70+
assert len(dirs) == 2 # s1 and s2 land in different paths
71+
72+
def test_skips_arms_without_command(self):
73+
plan = {"arms": [{
74+
"arm_id": "h-main",
75+
"conditions": [{"name": "no-cmd"}],
76+
}]}
77+
assert partition_plan(plan) == []
78+
79+
80+
# ─── Run units ─────────────────────────────────────────────────────────────
81+
82+
class _RecordingRunner:
83+
def __init__(self, statuses: dict[str, str] | None = None):
84+
self.calls: list[ArmUnit] = []
85+
self.statuses = statuses or {}
86+
87+
def __call__(self, unit: ArmUnit) -> ArmUnitResult:
88+
self.calls.append(unit)
89+
status = self.statuses.get(unit.arm_id, "complete")
90+
return ArmUnitResult(
91+
unit=unit, status=status, duration_ms=100,
92+
output_files=[f"{unit.relative_results_dir}/out.json"],
93+
)
94+
95+
96+
class TestRunUnits:
97+
98+
def test_results_returned_in_input_order(self):
99+
units = [
100+
ArmUnit("h-main", "s1", "x", "./a"),
101+
ArmUnit("h-main", "s2", "x", "./a"),
102+
ArmUnit("h-ablation", "s1", "y", "./b"),
103+
]
104+
runner = _RecordingRunner()
105+
results = run_units(units, runner=runner)
106+
assert [r.unit.seed for r in results] == ["s1", "s2", "s1"]
107+
108+
def test_runner_exception_becomes_failed_unit(self):
109+
units = [ArmUnit("h-main", "s1", "x", "./a")]
110+
111+
def crash(_):
112+
raise RuntimeError("boom")
113+
114+
results = run_units(units, runner=crash)
115+
assert results[0].status == "failed"
116+
assert "boom" in results[0].error
117+
assert "RuntimeError" in results[0].error
118+
119+
def test_max_parallel_must_be_positive(self):
120+
with pytest.raises(ValueError):
121+
run_units([], runner=_RecordingRunner(), max_parallel=0)
122+
123+
124+
# ─── Merge ─────────────────────────────────────────────────────────────────
125+
126+
class TestMergeUnitResults:
127+
128+
def _results(self) -> list[ArmUnitResult]:
129+
return [
130+
ArmUnitResult(
131+
unit=ArmUnit("h-main", "s1", "x", "./a"),
132+
status="complete", duration_ms=100,
133+
output_files=["results/h-main/s1/out.json"],
134+
),
135+
ArmUnitResult(
136+
unit=ArmUnit("h-main", "s2", "x", "./a"),
137+
status="complete", duration_ms=120,
138+
output_files=["results/h-main/s2/out.json"],
139+
),
140+
ArmUnitResult(
141+
unit=ArmUnit("h-ablation", "s1", "y", "./b"),
142+
status="failed", error="exit 1",
143+
),
144+
]
145+
146+
def test_arms_grouped_by_arm_id(self):
147+
out = merge_unit_results(self._results())
148+
ids = [a["arm_id"] for a in out["arms"]]
149+
# Sorted for determinism.
150+
assert ids == ["h-ablation", "h-main"]
151+
152+
def test_arm_status_failed_when_any_unit_failed(self):
153+
out = merge_unit_results(self._results())
154+
by_id = {a["arm_id"]: a for a in out["arms"]}
155+
assert by_id["h-ablation"]["status"] == "failed"
156+
assert by_id["h-main"]["status"] == "complete"
157+
158+
def test_failed_count_correct(self):
159+
out = merge_unit_results(self._results())
160+
assert out["failed_unit_count"] == 1
161+
assert out["total_unit_count"] == 3
162+
163+
def test_byte_equal_across_repeated_calls(self):
164+
a = json.dumps(merge_unit_results(self._results()), sort_keys=True)
165+
b = json.dumps(merge_unit_results(self._results()), sort_keys=True)
166+
assert a == b
167+
168+
def test_units_within_arm_sorted_by_seed_and_condition(self):
169+
results = [
170+
ArmUnitResult(unit=ArmUnit("h-main", "s2", "b", "./x"), status="complete"),
171+
ArmUnitResult(unit=ArmUnit("h-main", "s1", "a", "./x"), status="complete"),
172+
ArmUnitResult(unit=ArmUnit("h-main", "s1", "b", "./x"), status="complete"),
173+
]
174+
out = merge_unit_results(results)
175+
seeds = [u["seed"] for u in out["arms"][0]["units"]]
176+
conds = [u["condition"] for u in out["arms"][0]["units"]]
177+
assert list(zip(seeds, conds)) == [("s1", "a"), ("s1", "b"), ("s2", "b")]
178+
179+
180+
# ─── Partial-retry helper ──────────────────────────────────────────────────
181+
182+
class TestFailedUnits:
183+
184+
def test_returns_only_failed_units(self):
185+
results = [
186+
ArmUnitResult(unit=ArmUnit("h-main", "s1", "x", "./a"), status="complete"),
187+
ArmUnitResult(unit=ArmUnit("h-main", "s2", "x", "./a"), status="failed"),
188+
ArmUnitResult(unit=ArmUnit("h-ablation", "s1", "y", "./b"), status="failed"),
189+
]
190+
failed = failed_units(results)
191+
assert len(failed) == 2
192+
assert all(r.arm_id != "h-main" or r.seed == "s2" for r in failed)

0 commit comments

Comments
 (0)