Skip to content

Commit 1c1e706

Browse files
anticomputerCopilot
andcommitted
Address PR review feedback for CLI model config
- Fix docstring arg order: match resume_session_id, cli_model_config parameter order in run_main signature - Persist cli_model_config in TaskflowSession for deterministic resume: restored automatically on resume, can be explicitly overridden with --model-config - Add TestCliModelConfigOverride test suite (6 tests) covering override precedence, resolution, session persistence, and resume behavior - Add session-level tests for cli_model_config persistence and defaults Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent f9e6169 commit 1c1e706

4 files changed

Lines changed: 112 additions & 2 deletions

File tree

src/seclab_taskflow_agent/runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,8 @@ async def run_main(
501501
taskflow_path: Taskflow module path, or None.
502502
cli_globals: Global variables from CLI.
503503
prompt: User prompt text.
504-
cli_model_config: Model configuration module path, or None.
505504
resume_session_id: Session ID to resume from a checkpoint.
505+
cli_model_config: Model configuration module path, or None.
506506
"""
507507
from .session import TaskflowSession
508508

@@ -545,6 +545,9 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo
545545
cli_globals = session.cli_globals
546546
prompt = session.prompt
547547
last_mcp_tool_results = list(session.last_tool_results)
548+
# Restore persisted model config unless explicitly overridden
549+
if not cli_model_config and session.cli_model_config:
550+
cli_model_config = session.cli_model_config
548551
await render_model_output(
549552
f"** 🤖🔄 Resuming session {resume_session_id} from task {session.next_task_index}\n"
550553
)
@@ -578,6 +581,7 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo
578581
cli_globals=cli_globals,
579582
prompt=prompt or "",
580583
total_tasks=len(taskflow_doc.taskflow),
584+
cli_model_config=cli_model_config or "",
581585
)
582586
session.save()
583587
await render_model_output(f"** 🤖📋 Session: {session.session_id}\n")

src/seclab_taskflow_agent/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class TaskflowSession(BaseModel):
6060
finished: bool = False
6161
error: str = ""
6262

63+
# CLI model config override persisted for deterministic resume
64+
cli_model_config: str = ""
65+
6366
# Accumulated tool results carried across tasks (used by repeat_prompt)
6467
last_tool_results: list[str] = Field(default_factory=list)
6568

tests/test_runner.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,95 @@ def test_task_engine_keys_override_config(self):
278278

279279

280280
# ===================================================================
281-
# _build_prompts_to_run
281+
# CLI model config override
282282
# ===================================================================
283283

284+
class TestCliModelConfigOverride:
285+
"""Tests for CLI model config overriding taskflow model_config_ref."""
286+
287+
def test_cli_overrides_taskflow_model_config(self):
288+
"""cli_model_config takes precedence over taskflow_doc.model_config_ref."""
289+
taskflow_ref = "taskflow.models.default"
290+
cli_ref = "cli.models.override"
291+
292+
# Simulate the override logic from run_main
293+
model_config_ref = taskflow_ref
294+
if cli_ref:
295+
model_config_ref = cli_ref
296+
297+
assert model_config_ref == cli_ref
298+
299+
def test_taskflow_model_config_used_when_cli_absent(self):
300+
"""Taskflow model_config_ref is used when cli_model_config is None."""
301+
taskflow_ref = "taskflow.models.default"
302+
cli_ref = None
303+
304+
model_config_ref = taskflow_ref
305+
if cli_ref:
306+
model_config_ref = cli_ref
307+
308+
assert model_config_ref == taskflow_ref
309+
310+
def test_cli_model_config_resolves_via_available_tools(self):
311+
"""CLI-provided model config is resolved through _resolve_model_config."""
312+
at = _mock_available_tools()
313+
at.get_model_config.return_value = _make_model_config(
314+
models={"fast": "gpt-4o-mini"},
315+
)
316+
keys, mdict, params, api_type, backend = _resolve_model_config(at, "cli.override.ref")
317+
at.get_model_config.assert_called_once_with("cli.override.ref")
318+
assert mdict == {"fast": "gpt-4o-mini"}
319+
320+
def test_cli_model_config_persisted_in_session(self):
321+
"""cli_model_config is stored in session for deterministic resume."""
322+
from seclab_taskflow_agent.session import TaskflowSession
323+
324+
session = TaskflowSession(
325+
taskflow_path="test.flow",
326+
cli_model_config="cli.models.fast",
327+
)
328+
assert session.cli_model_config == "cli.models.fast"
329+
330+
def test_session_resume_restores_cli_model_config(self, tmp_path, monkeypatch):
331+
"""Resumed session restores cli_model_config when not overridden."""
332+
monkeypatch.setattr("seclab_taskflow_agent.session.session_dir", lambda: tmp_path)
333+
from seclab_taskflow_agent.session import TaskflowSession
334+
335+
session = TaskflowSession(
336+
taskflow_path="test.flow",
337+
cli_model_config="persisted.models.ref",
338+
)
339+
session.save()
340+
341+
loaded = TaskflowSession.load(session.session_id)
342+
343+
# Simulate the resume logic from run_main
344+
cli_model_config = None # not passed on resume
345+
if not cli_model_config and loaded.cli_model_config:
346+
cli_model_config = loaded.cli_model_config
347+
348+
assert cli_model_config == "persisted.models.ref"
349+
350+
def test_session_resume_cli_override_takes_precedence(self, tmp_path, monkeypatch):
351+
"""Explicit --model-config on resume overrides persisted value."""
352+
monkeypatch.setattr("seclab_taskflow_agent.session.session_dir", lambda: tmp_path)
353+
from seclab_taskflow_agent.session import TaskflowSession
354+
355+
session = TaskflowSession(
356+
taskflow_path="test.flow",
357+
cli_model_config="persisted.models.ref",
358+
)
359+
session.save()
360+
361+
loaded = TaskflowSession.load(session.session_id)
362+
363+
# Simulate the resume logic from run_main with explicit override
364+
cli_model_config = "new.override.ref"
365+
if not cli_model_config and loaded.cli_model_config:
366+
cli_model_config = loaded.cli_model_config
367+
368+
assert cli_model_config == "new.override.ref"
369+
284370
class TestBuildPromptsToRun:
285371
"""Tests for _build_prompts_to_run (async, run via asyncio.run)."""
286372

tests/test_session.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,23 @@ def test_list_sessions(self, tmp_path, monkeypatch):
7878
assert s1.session_id in ids
7979
assert s2.session_id in ids
8080

81+
def test_cli_model_config_persisted(self, tmp_path, monkeypatch):
82+
"""cli_model_config is persisted and restored on load."""
83+
monkeypatch.setattr("seclab_taskflow_agent.session.session_dir", lambda: tmp_path)
84+
s = TaskflowSession(
85+
taskflow_path="examples.taskflows.echo",
86+
cli_model_config="custom.models.fast",
87+
)
88+
s.save()
89+
90+
loaded = TaskflowSession.load(s.session_id)
91+
assert loaded.cli_model_config == "custom.models.fast"
92+
93+
def test_cli_model_config_defaults_empty(self):
94+
"""cli_model_config defaults to empty string."""
95+
s = TaskflowSession(taskflow_path="test.flow")
96+
assert s.cli_model_config == ""
97+
8198

8299
class TestCompletedTask:
83100
"""Tests for CompletedTask model."""

0 commit comments

Comments
 (0)