Skip to content

Commit 69fc1e3

Browse files
committed
fix: address remaining PR feedback, expand test coverage
- cli: TASK_AGENT_DEBUG="0"/"false" no longer enables debug mode - capi: allow arbitrary API endpoints with graceful fallback - runner: defer tool result pop until after template rendering - test: 72 new unit tests for runner, cli, session, prompt parser, capi - examples: add edge_case_test.yaml for nested JSON repeat_prompt
1 parent 468f97b commit 69fc1e3

10 files changed

Lines changed: 947 additions & 19 deletions

File tree

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# SPDX-FileCopyrightText: GitHub, Inc.
2+
# SPDX-License-Identifier: MIT
3+
4+
# Edge-case test taskflow targeting less-exercised code paths:
5+
# - shell task producing nested JSON for repeat_prompt
6+
# - repeat_prompt over dictionary items (not just arrays)
7+
# - env variable scoping (task-level env)
8+
# - globals CLI override combined with file defaults
9+
# - max_steps constraint
10+
# - must_complete on a non-tool task
11+
# - empty taskflow section handling
12+
13+
seclab-taskflow-agent:
14+
version: "1.0"
15+
filetype: taskflow
16+
17+
model_config: examples.model_configs.model_config
18+
19+
globals:
20+
category: edge-cases
21+
default_value: from-file
22+
23+
taskflow:
24+
# ---------------------------------------------------------------
25+
# Task 1: Shell task with nested JSON structure
26+
# Tests: run, must_complete, complex JSON output
27+
# ---------------------------------------------------------------
28+
- task:
29+
name: nested-json-shell
30+
must_complete: true
31+
run: |
32+
echo '[{"id": 1, "data": {"label": "alpha", "score": 0.95}}, {"id": 2, "data": {"label": "beta", "score": 0.87}}]'
33+
34+
# ---------------------------------------------------------------
35+
# Task 2: Repeat over nested structure, sequential (not async)
36+
# Tests: repeat_prompt (sequential), nested result access,
37+
# globals reference, inputs, env scoping, max_steps
38+
# ---------------------------------------------------------------
39+
- task:
40+
name: sequential-repeat
41+
repeat_prompt: true
42+
must_complete: true
43+
model: gpt_default
44+
max_steps: 5
45+
agents:
46+
- examples.personalities.fruit_expert
47+
inputs:
48+
output_format: json
49+
env:
50+
EDGE_TEST_MODE: "sequential"
51+
user_prompt: |
52+
Category: {{ globals.category }}, default: {{ globals.default_value }}.
53+
Item ID {{ result.id }}: label={{ result.data.label }}, score={{ result.data.score }}.
54+
Respond with exactly one sentence summarizing this item in {{ inputs.output_format }} awareness.
55+
56+
# ---------------------------------------------------------------
57+
# Task 3: Simple prompt with no tools (headless, no toolboxes)
58+
# Tests: pure LLM task, exclude_from_context, model alias
59+
# ---------------------------------------------------------------
60+
- task:
61+
name: pure-llm-task
62+
model: gpt_default
63+
exclude_from_context: true
64+
agents:
65+
- examples.personalities.fruit_expert
66+
max_steps: 3
67+
user_prompt: |
68+
The category is {{ globals.category }}.
69+
Say "edge case test passed" and nothing else.

src/seclab_taskflow_agent/capi.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,8 @@ def list_capi_models(token: str) -> dict[str, dict]:
7979
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
8080
models_catalog = "models"
8181
case _:
82-
raise ValueError(
83-
f"Unsupported Model Endpoint: {api_endpoint}\n"
84-
f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}"
85-
)
82+
# Unknown endpoint — try the OpenAI-style models catalog
83+
models_catalog = "models"
8684
r = httpx.get(
8785
httpx.URL(api_endpoint).join(models_catalog),
8886
headers={
@@ -100,6 +98,10 @@ def list_capi_models(token: str) -> dict[str, dict]:
10098
models_list = r.json()
10199
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
102100
models_list = r.json().get("data", [])
101+
case _:
102+
# Unknown endpoint — try OpenAI-style {"data": [...]}
103+
body = r.json()
104+
models_list = body.get("data", body) if isinstance(body, dict) else body
103105
for model in models_list:
104106
models[model.get("id")] = dict(model)
105107
except httpx.RequestError:
@@ -123,10 +125,9 @@ def supports_tool_calls(model: str, models: dict[str, dict]) -> bool:
123125
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
124126
return "gpt-" in model.lower()
125127
case _:
126-
raise ValueError(
127-
f"Unsupported Model Endpoint: {api_endpoint}\n"
128-
f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}"
129-
)
128+
# Unknown endpoint — optimistically assume tool-call support
129+
# if the model is present in the catalog.
130+
return model in models
130131

131132

132133
def list_tool_call_models(token: str) -> dict[str, dict]:

src/seclab_taskflow_agent/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def main(
113113
) -> None:
114114
"""Run a taskflow or personality-based agent session."""
115115
# Debug mode from flag or env var
116-
debug = debug or bool(os.getenv("TASK_AGENT_DEBUG"))
116+
debug = debug or os.getenv("TASK_AGENT_DEBUG", "").strip().lower() in ("1", "true", "yes")
117117

118118
# Validate mutual exclusivity (resume is standalone)
119119
if resume and (personality or taskflow or list_models):

src/seclab_taskflow_agent/runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,6 @@ async def _build_prompts_to_run(
214214
logging.critical("No last MCP tool result available")
215215
raise
216216

217-
# Consume only after successful parse
218-
last_mcp_tool_results.pop()
219-
220217
if not iterable_result:
221218
await render_model_output("** 🤖❗MCP tool result iterable is empty!\n")
222219
else:
@@ -234,6 +231,10 @@ async def _build_prompts_to_run(
234231
except jinja2.TemplateError as e:
235232
logging.error(f"Error rendering template for result {value}: {e}")
236233
raise ValueError(f"Template rendering failed: {e}")
234+
235+
# Consume only after all prompts rendered successfully so that
236+
# the result remains available for retry/resume on failure.
237+
last_mcp_tool_results.pop()
237238
else:
238239
prompts_to_run.append(task_prompt)
239240
return prompts_to_run

tests/test_api_endpoint_config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,14 @@ def test_to_url_openai(self):
6262
assert endpoint.to_url() == "https://api.openai.com/v1"
6363

6464
def test_unsupported_endpoint(self, monkeypatch):
65-
"""Test that unsupported API endpoint raises ValueError."""
65+
"""Test that unsupported API endpoint falls back gracefully."""
6666
api_endpoint = "https://unsupported.example.com"
6767
monkeypatch.setenv("AI_API_ENDPOINT", api_endpoint)
68-
with pytest.raises(ValueError) as excinfo:
69-
list_capi_models("abc")
70-
msg = str(excinfo.value)
71-
assert "Unsupported Model Endpoint" in msg
72-
assert "https://models.github.ai/inference" in msg
73-
assert "https://api.githubcopilot.com" in msg
68+
# Unknown endpoints should not raise; they try OpenAI-style catalog
69+
# and return an empty dict on connection failure.
70+
result = list_capi_models("abc")
71+
assert isinstance(result, dict)
72+
assert result == {}
7473

7574

7675
if __name__ == "__main__":

tests/test_capi_extended.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-FileCopyrightText: GitHub, Inc.
2+
# SPDX-License-Identifier: MIT
3+
4+
"""Extended tests for capi module."""
5+
6+
from __future__ import annotations
7+
8+
from seclab_taskflow_agent.capi import AI_API_ENDPOINT_ENUM, supports_tool_calls
9+
10+
11+
class TestSupportsToolCalls:
12+
"""Tests for supports_tool_calls with unknown endpoints."""
13+
14+
def test_unknown_endpoint_known_model(self, monkeypatch):
15+
"""Unknown endpoint returns True when model is in the catalog."""
16+
monkeypatch.setenv("AI_API_ENDPOINT", "https://custom.api.example.com/v1")
17+
models = {"my-model": {"id": "my-model"}}
18+
assert supports_tool_calls("my-model", models) is True
19+
20+
def test_unknown_endpoint_unknown_model(self, monkeypatch):
21+
"""Unknown endpoint returns False when model is NOT in the catalog."""
22+
monkeypatch.setenv("AI_API_ENDPOINT", "https://custom.api.example.com/v1")
23+
models = {"other-model": {"id": "other-model"}}
24+
assert supports_tool_calls("missing-model", models) is False
25+
26+
def test_copilot_endpoint_with_capabilities(self, monkeypatch):
27+
"""Copilot endpoint checks capabilities.supports.tool_calls."""
28+
monkeypatch.setenv("AI_API_ENDPOINT", "https://api.githubcopilot.com")
29+
models = {
30+
"gpt-4o": {
31+
"id": "gpt-4o",
32+
"capabilities": {"supports": {"tool_calls": True}},
33+
}
34+
}
35+
assert supports_tool_calls("gpt-4o", models) is True
36+
37+
def test_copilot_endpoint_without_capabilities(self, monkeypatch):
38+
"""Copilot endpoint returns False when tool_calls not in capabilities."""
39+
monkeypatch.setenv("AI_API_ENDPOINT", "https://api.githubcopilot.com")
40+
models = {
41+
"text-only": {
42+
"id": "text-only",
43+
"capabilities": {"supports": {}},
44+
}
45+
}
46+
assert supports_tool_calls("text-only", models) is False
47+
48+
def test_models_github_endpoint(self, monkeypatch):
49+
"""models.github.ai checks for 'tool-calling' in capabilities list."""
50+
monkeypatch.setenv("AI_API_ENDPOINT", "https://models.github.ai/inference")
51+
models = {
52+
"openai/gpt-4o": {
53+
"id": "openai/gpt-4o",
54+
"capabilities": ["tool-calling", "chat"],
55+
}
56+
}
57+
assert supports_tool_calls("openai/gpt-4o", models) is True
58+
59+
def test_models_github_endpoint_no_tool_calling(self, monkeypatch):
60+
"""models.github.ai returns False when 'tool-calling' not in list."""
61+
monkeypatch.setenv("AI_API_ENDPOINT", "https://models.github.ai/inference")
62+
models = {
63+
"some-model": {
64+
"id": "some-model",
65+
"capabilities": ["chat"],
66+
}
67+
}
68+
assert supports_tool_calls("some-model", models) is False
69+
70+
def test_openai_endpoint_gpt_model(self, monkeypatch):
71+
"""OpenAI endpoint returns True for models containing 'gpt-'."""
72+
monkeypatch.setenv("AI_API_ENDPOINT", "https://api.openai.com/v1")
73+
assert supports_tool_calls("gpt-4o", {}) is True
74+
75+
def test_openai_endpoint_non_gpt_model(self, monkeypatch):
76+
"""OpenAI endpoint returns False for non-GPT models."""
77+
monkeypatch.setenv("AI_API_ENDPOINT", "https://api.openai.com/v1")
78+
assert supports_tool_calls("claude-3-opus", {}) is False
79+
80+
81+
class TestAIAPIEndpointEnum:
82+
"""Tests for the AI_API_ENDPOINT_ENUM StrEnum."""
83+
84+
def test_enum_values(self):
85+
"""All expected endpoint values exist."""
86+
assert AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB == "models.github.ai"
87+
assert AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT == "api.githubcopilot.com"
88+
assert AI_API_ENDPOINT_ENUM.AI_API_OPENAI == "api.openai.com"
89+
90+
def test_to_url_models_github(self):
91+
assert AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB.to_url() == "https://models.github.ai/inference"
92+
93+
def test_to_url_copilot(self):
94+
assert AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT.to_url() == "https://api.githubcopilot.com"
95+
96+
def test_to_url_openai(self):
97+
assert AI_API_ENDPOINT_ENUM.AI_API_OPENAI.to_url() == "https://api.openai.com/v1"

tests/test_cli.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-FileCopyrightText: GitHub, Inc.
2+
# SPDX-License-Identifier: MIT
3+
4+
"""Unit tests for the Typer CLI module."""
5+
6+
from __future__ import annotations
7+
8+
import pytest
9+
import typer
10+
11+
from seclab_taskflow_agent.cli import _parse_global
12+
13+
14+
class TestParseGlobal:
15+
"""Tests for _parse_global KEY=VALUE parsing."""
16+
17+
def test_valid_key_value(self):
18+
"""Standard KEY=VALUE is parsed correctly."""
19+
assert _parse_global("fruit=apple") == ("fruit", "apple")
20+
21+
def test_missing_equals_raises(self):
22+
"""A string without '=' raises BadParameter."""
23+
with pytest.raises(typer.BadParameter, match="Expected KEY=VALUE"):
24+
_parse_global("no_equals_here")
25+
26+
def test_value_with_equals_sign(self):
27+
"""Only the first '=' is used as the delimiter."""
28+
key, val = _parse_global("url=https://example.com?foo=bar")
29+
assert key == "url"
30+
assert val == "https://example.com?foo=bar"
31+
32+
def test_whitespace_stripped(self):
33+
"""Leading/trailing whitespace in key and value is stripped."""
34+
key, val = _parse_global(" key = value ")
35+
assert key == "key"
36+
assert val == "value"
37+
38+
def test_empty_value(self):
39+
"""An empty value after '=' is allowed."""
40+
key, val = _parse_global("key=")
41+
assert key == "key"
42+
assert val == ""
43+
44+
def test_empty_key(self):
45+
"""An empty key before '=' is technically allowed by the parser."""
46+
key, val = _parse_global("=value")
47+
assert key == ""
48+
assert val == "value"
49+
50+
51+
class TestDebugEnvParsing:
52+
"""Tests for the TASK_AGENT_DEBUG environment variable expression."""
53+
54+
@staticmethod
55+
def _is_debug(env_value: str) -> bool:
56+
"""Reproduce the debug expression from cli.py."""
57+
return env_value.strip().lower() in ("1", "true", "yes")
58+
59+
def test_zero_is_false(self):
60+
assert self._is_debug("0") is False
61+
62+
def test_one_is_true(self):
63+
assert self._is_debug("1") is True
64+
65+
def test_true_string_is_true(self):
66+
assert self._is_debug("true") is True
67+
68+
def test_TRUE_string_is_true(self):
69+
assert self._is_debug("TRUE") is True
70+
71+
def test_yes_string_is_true(self):
72+
assert self._is_debug("yes") is True
73+
74+
def test_empty_string_is_false(self):
75+
assert self._is_debug("") is False
76+
77+
def test_false_string_is_false(self):
78+
assert self._is_debug("false") is False
79+
80+
def test_whitespace_trimmed(self):
81+
assert self._is_debug(" 1 ") is True
82+
83+
def test_random_text_is_false(self):
84+
assert self._is_debug("enabled") is False

0 commit comments

Comments
 (0)