Skip to content

Commit e75afed

Browse files
authored
feat(core): per-state concurrency limits for batch execution (#407)
## Summary - ConcurrencyConfig with per-status parallelism caps and effective_workers() calculation - BatchConfig in EnvironmentConfig for YAML persistence - parse_concurrency_by_status() for CLI string parsing - Wired into _execute_parallel for actual group dispatch ## Validation - Demo: All criteria verified (per-status limits, global bound, config roundtrip, fallback, parsing) - Tests: 21 unit tests, 1755 total core tests passing - CI: All checks green - Review: CodeRabbit findings addressed (effective_workers wiring, error messages) Closes #407
1 parent 249e2c4 commit e75afed

3 files changed

Lines changed: 289 additions & 2 deletions

File tree

codeframe/core/conductor.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,76 @@ def _utc_now() -> datetime:
414414
return datetime.now(timezone.utc)
415415

416416

417+
@dataclass
418+
class ConcurrencyConfig:
419+
"""Per-state concurrency limits for batch execution.
420+
421+
When ``by_status`` is non-empty, each status key caps the number of
422+
concurrent workers for tasks in that state. Unspecified statuses
423+
fall back to ``max_parallel``.
424+
"""
425+
426+
max_parallel: int = 4
427+
by_status: dict[str, int] = field(default_factory=dict)
428+
429+
def get_limit_for_status(self, status: str) -> int:
430+
"""Return the concurrency limit for a given task status."""
431+
return self.by_status.get(status, self.max_parallel)
432+
433+
def effective_workers(
434+
self,
435+
*,
436+
statuses: list[str],
437+
group_size: int,
438+
global_running: int,
439+
) -> int:
440+
"""Compute the effective worker count for a group of tasks.
441+
442+
Takes the minimum of:
443+
- Global slots remaining (max_parallel - global_running)
444+
- Per-status limit for the most constrained status in the group
445+
- Group size
446+
"""
447+
global_slots = max(1, self.max_parallel - global_running)
448+
if statuses and self.by_status:
449+
per_status = min(self.get_limit_for_status(s) for s in statuses)
450+
else:
451+
per_status = self.max_parallel
452+
return max(1, min(global_slots, per_status, group_size))
453+
454+
455+
def parse_concurrency_by_status(value: str | None) -> dict[str, int]:
456+
"""Parse a --max-parallel-by-status string into a dict.
457+
458+
Format: "READY=3,IN_PROGRESS=2"
459+
460+
Raises:
461+
ValueError: On invalid status names or format.
462+
"""
463+
from codeframe.core.state_machine import TaskStatus
464+
465+
if not value:
466+
return {}
467+
468+
valid_statuses = {s.value for s in TaskStatus}
469+
result: dict[str, int] = {}
470+
471+
for pair in value.split(","):
472+
pair = pair.strip()
473+
if "=" not in pair:
474+
raise ValueError(f"Invalid format '{pair}'. Expected STATUS=N")
475+
key, val = pair.split("=", 1)
476+
key = key.strip().upper()
477+
if key not in valid_statuses:
478+
raise ValueError(f"Invalid status '{key}'. Valid: {', '.join(sorted(valid_statuses))}")
479+
try:
480+
result[key] = int(val.strip())
481+
except ValueError:
482+
raise ValueError(f"Invalid value '{val.strip()}' for status '{key}'. Must be an integer.")
483+
484+
return result
485+
486+
417487
class BatchStatus(str, Enum):
418488
"""Status of a batch execution."""
419489

@@ -462,6 +532,7 @@ class BatchRun:
462532
engine: str = "react"
463533
stall_timeout_s: int = 300
464534
stall_action: str = "blocker"
535+
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
465536

466537

467538
def start_batch(
@@ -476,6 +547,7 @@ def start_batch(
476547
engine: str = "react",
477548
stall_timeout_s: int = 300,
478549
stall_action: str = "blocker",
550+
concurrency_by_status: Optional[dict[str, int]] = None,
479551
) -> BatchRun:
480552
"""Start a batch execution of multiple tasks.
481553
@@ -510,6 +582,11 @@ def start_batch(
510582
now = _utc_now()
511583
on_failure_enum = OnFailure(on_failure)
512584

585+
concurrency = ConcurrencyConfig(
586+
max_parallel=max_parallel,
587+
by_status=concurrency_by_status or {},
588+
)
589+
513590
batch = BatchRun(
514591
id=batch_id,
515592
workspace_id=workspace.id,
@@ -524,6 +601,7 @@ def start_batch(
524601
engine=engine,
525602
stall_timeout_s=stall_timeout_s,
526603
stall_action=stall_action,
604+
concurrency=concurrency,
527605
)
528606

529607
# Save to database
@@ -1446,8 +1524,18 @@ def _execute_parallel(
14461524
else:
14471525
failed_count += 1
14481526
else:
1449-
# Multiple tasks - run in parallel
1450-
effective_workers = min(group_size, batch.max_parallel)
1527+
# Multiple tasks - run in parallel (use per-status limits if configured)
1528+
if batch.concurrency.by_status:
1529+
group_statuses = []
1530+
for tid in group:
1531+
t = tasks.get(workspace, tid)
1532+
if t:
1533+
group_statuses.append(t.status.value)
1534+
effective_workers = batch.concurrency.effective_workers(
1535+
statuses=group_statuses, group_size=group_size, global_running=0,
1536+
)
1537+
else:
1538+
effective_workers = min(group_size, batch.max_parallel)
14511539
print(f"Running {group_size} tasks with {effective_workers} workers")
14521540

14531541
# Execute group in parallel

codeframe/core/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ class AgentBudgetConfig:
8181
stall_timeout_s: int = 300
8282

8383

84+
@dataclass
85+
class BatchConfig:
86+
"""Batch execution configuration."""
87+
88+
max_parallel: int = 4
89+
max_parallel_by_status: dict[str, int] = dataclass_field(default_factory=dict)
90+
91+
8492
@dataclass
8593
class HooksConfig:
8694
"""Workspace lifecycle hooks configuration.
@@ -124,6 +132,9 @@ class EnvironmentConfig:
124132
# Agent budget
125133
agent_budget: AgentBudgetConfig = dataclass_field(default_factory=AgentBudgetConfig)
126134

135+
# Batch execution
136+
batch: BatchConfig = dataclass_field(default_factory=BatchConfig)
137+
127138
# Workspace lifecycle hooks
128139
hooks: HooksConfig = dataclass_field(default_factory=HooksConfig)
129140

@@ -281,6 +292,8 @@ def from_dict(cls, data: dict[str, Any]) -> "EnvironmentConfig":
281292
data["context"] = ContextConfig(**data["context"])
282293
if "agent_budget" in data and isinstance(data["agent_budget"], dict):
283294
data["agent_budget"] = AgentBudgetConfig(**data["agent_budget"])
295+
if "batch" in data and isinstance(data["batch"], dict):
296+
data["batch"] = BatchConfig(**data["batch"])
284297
if "hooks" in data and isinstance(data["hooks"], dict):
285298
data["hooks"] = HooksConfig(**data["hooks"])
286299
return cls(**data)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""Tests for per-state concurrency limits in batch execution."""
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
import pytest
6+
7+
pytestmark = pytest.mark.v2
8+
9+
10+
class TestConcurrencyConfig:
11+
"""Test ConcurrencyConfig dataclass."""
12+
13+
def test_defaults(self) -> None:
14+
from codeframe.core.conductor import ConcurrencyConfig
15+
16+
cfg = ConcurrencyConfig()
17+
assert cfg.max_parallel == 4
18+
assert cfg.by_status == {}
19+
20+
def test_custom_values(self) -> None:
21+
from codeframe.core.conductor import ConcurrencyConfig
22+
23+
cfg = ConcurrencyConfig(max_parallel=8, by_status={"READY": 3, "IN_PROGRESS": 2})
24+
assert cfg.max_parallel == 8
25+
assert cfg.by_status["READY"] == 3
26+
27+
def test_get_limit_for_status_configured(self) -> None:
28+
from codeframe.core.conductor import ConcurrencyConfig
29+
30+
cfg = ConcurrencyConfig(max_parallel=4, by_status={"READY": 2})
31+
assert cfg.get_limit_for_status("READY") == 2
32+
33+
def test_get_limit_for_status_fallback_to_global(self) -> None:
34+
from codeframe.core.conductor import ConcurrencyConfig
35+
36+
cfg = ConcurrencyConfig(max_parallel=4, by_status={"READY": 2})
37+
assert cfg.get_limit_for_status("IN_PROGRESS") == 4
38+
39+
def test_get_limit_for_status_empty_by_status(self) -> None:
40+
from codeframe.core.conductor import ConcurrencyConfig
41+
42+
cfg = ConcurrencyConfig(max_parallel=6)
43+
assert cfg.get_limit_for_status("READY") == 6
44+
45+
def test_effective_workers_global_limit(self) -> None:
46+
from codeframe.core.conductor import ConcurrencyConfig
47+
48+
cfg = ConcurrencyConfig(max_parallel=2, by_status={"READY": 5})
49+
# Global limit (2) is less than per-status limit (5)
50+
workers = cfg.effective_workers(statuses=["READY"], group_size=10, global_running=0)
51+
assert workers == 2
52+
53+
def test_effective_workers_per_status_limit(self) -> None:
54+
from codeframe.core.conductor import ConcurrencyConfig
55+
56+
cfg = ConcurrencyConfig(max_parallel=10, by_status={"READY": 3})
57+
# Per-status limit (3) is less than global (10)
58+
workers = cfg.effective_workers(statuses=["READY"], group_size=10, global_running=0)
59+
assert workers == 3
60+
61+
def test_effective_workers_group_size_limit(self) -> None:
62+
from codeframe.core.conductor import ConcurrencyConfig
63+
64+
cfg = ConcurrencyConfig(max_parallel=10)
65+
workers = cfg.effective_workers(statuses=["READY"], group_size=2, global_running=0)
66+
assert workers == 2
67+
68+
def test_effective_workers_accounts_for_running(self) -> None:
69+
from codeframe.core.conductor import ConcurrencyConfig
70+
71+
cfg = ConcurrencyConfig(max_parallel=4)
72+
workers = cfg.effective_workers(statuses=["READY"], group_size=10, global_running=3)
73+
assert workers == 1 # Only 1 global slot left
74+
75+
def test_effective_workers_mixed_statuses(self) -> None:
76+
from codeframe.core.conductor import ConcurrencyConfig
77+
78+
cfg = ConcurrencyConfig(max_parallel=10, by_status={"READY": 3, "IN_PROGRESS": 1})
79+
# Mixed group: bottleneck is IN_PROGRESS (1)
80+
workers = cfg.effective_workers(statuses=["READY", "IN_PROGRESS"], group_size=5, global_running=0)
81+
assert workers == 1
82+
83+
def test_effective_workers_never_negative(self) -> None:
84+
from codeframe.core.conductor import ConcurrencyConfig
85+
86+
cfg = ConcurrencyConfig(max_parallel=2)
87+
workers = cfg.effective_workers(statuses=["READY"], group_size=5, global_running=10)
88+
assert workers >= 1 # At least 1 worker
89+
90+
91+
class TestBatchConfig:
92+
"""Test BatchConfig in EnvironmentConfig."""
93+
94+
def test_defaults(self) -> None:
95+
from codeframe.core.config import EnvironmentConfig
96+
97+
cfg = EnvironmentConfig()
98+
assert cfg.batch.max_parallel == 4
99+
assert cfg.batch.max_parallel_by_status == {}
100+
101+
def test_from_dict(self) -> None:
102+
from codeframe.core.config import EnvironmentConfig
103+
104+
cfg = EnvironmentConfig.from_dict({
105+
"batch": {
106+
"max_parallel": 8,
107+
"max_parallel_by_status": {"READY": 3, "IN_PROGRESS": 2},
108+
}
109+
})
110+
assert cfg.batch.max_parallel == 8
111+
assert cfg.batch.max_parallel_by_status["READY"] == 3
112+
113+
def test_roundtrip(self) -> None:
114+
from codeframe.core.config import BatchConfig, EnvironmentConfig
115+
116+
orig = EnvironmentConfig(batch=BatchConfig(max_parallel=6, max_parallel_by_status={"READY": 2}))
117+
d = orig.to_dict()
118+
restored = EnvironmentConfig.from_dict(d)
119+
assert restored.batch.max_parallel == 6
120+
assert restored.batch.max_parallel_by_status["READY"] == 2
121+
122+
123+
class TestStartBatchConcurrency:
124+
"""Test start_batch with concurrency_by_status."""
125+
126+
def test_start_batch_accepts_concurrency_by_status(self) -> None:
127+
from codeframe.core.conductor import start_batch
128+
129+
workspace = MagicMock()
130+
workspace.id = "w1"
131+
132+
mock_task = MagicMock()
133+
mock_task.id = "t1"
134+
mock_task.title = "Test"
135+
136+
with patch("codeframe.core.conductor.tasks.get", return_value=mock_task):
137+
with patch("codeframe.core.conductor._save_batch"):
138+
with patch("codeframe.core.conductor.events.emit_for_workspace"):
139+
with patch("codeframe.core.conductor._execute_serial"):
140+
batch = start_batch(
141+
workspace, ["t1"],
142+
concurrency_by_status={"READY": 2},
143+
)
144+
145+
assert batch.concurrency.by_status == {"READY": 2}
146+
assert batch.concurrency.max_parallel == 4 # default
147+
148+
149+
class TestParseConcurrencyString:
150+
"""Test parsing of --max-parallel-by-status CLI flag."""
151+
152+
def test_parse_valid_string(self) -> None:
153+
from codeframe.core.conductor import parse_concurrency_by_status
154+
155+
result = parse_concurrency_by_status("READY=3,IN_PROGRESS=2")
156+
assert result == {"READY": 3, "IN_PROGRESS": 2}
157+
158+
def test_parse_single_value(self) -> None:
159+
from codeframe.core.conductor import parse_concurrency_by_status
160+
161+
result = parse_concurrency_by_status("READY=5")
162+
assert result == {"READY": 5}
163+
164+
def test_parse_none_returns_empty(self) -> None:
165+
from codeframe.core.conductor import parse_concurrency_by_status
166+
167+
result = parse_concurrency_by_status(None)
168+
assert result == {}
169+
170+
def test_parse_empty_returns_empty(self) -> None:
171+
from codeframe.core.conductor import parse_concurrency_by_status
172+
173+
result = parse_concurrency_by_status("")
174+
assert result == {}
175+
176+
def test_parse_invalid_status_raises(self) -> None:
177+
from codeframe.core.conductor import parse_concurrency_by_status
178+
179+
with pytest.raises(ValueError, match="Invalid status"):
180+
parse_concurrency_by_status("INVALID=3")
181+
182+
def test_parse_invalid_format_raises(self) -> None:
183+
from codeframe.core.conductor import parse_concurrency_by_status
184+
185+
with pytest.raises(ValueError):
186+
parse_concurrency_by_status("READY:3")

0 commit comments

Comments
 (0)