Skip to content

Commit 6859474

Browse files
committed
fix: address PR review findings
- remove global OpenAI client/api mutations (race condition with concurrent agents) - add empty resolved_agents validation before deploy_task_agents - store toolbox name on MCPServerEntry instead of accessing private _name - update last_tool_results unconditionally in session.record_task - move raises into match default cases to fix mixed return warnings - catch non-SystemExit exceptions in prompt parser
1 parent 3fdd0fa commit 6859474

6 files changed

Lines changed: 25 additions & 18 deletions

File tree

src/seclab_taskflow_agent/agent.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
TContext,
2020
Tool,
2121
result,
22-
set_default_openai_api,
23-
set_default_openai_client,
2422
set_tracing_disabled,
2523
)
2624
from agents.agent import FunctionToolResult, ModelSettings, ToolsToFinalOutputResult
@@ -196,8 +194,6 @@ def __init__(
196194
api_key=resolved_token,
197195
default_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID},
198196
)
199-
set_default_openai_client(client)
200-
set_default_openai_api(api_type)
201197
set_tracing_disabled(True)
202198
self.run_hooks = run_hooks or TaskRunHooks()
203199

src/seclab_taskflow_agent/capi.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def to_url(self) -> str:
3737
return f"https://{self}/inference"
3838
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
3939
return f"https://{self}/v1"
40-
raise ValueError(f"Unsupported endpoint: {self}")
40+
case _:
41+
raise ValueError(f"Unsupported endpoint: {self}")
4142

4243

4344
COPILOT_INTEGRATION_ID = "vscode-chat"
@@ -121,10 +122,11 @@ def supports_tool_calls(model: str, models: dict[str, dict]) -> bool:
121122
return "tool-calling" in models.get(model, {}).get("capabilities", [])
122123
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
123124
return "gpt-" in model.lower()
124-
raise ValueError(
125-
f"Unsupported Model Endpoint: {api_endpoint}\n"
126-
f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}"
127-
)
125+
case _:
126+
raise ValueError(
127+
f"Unsupported Model Endpoint: {api_endpoint}\n"
128+
f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}"
129+
)
128130

129131

130132
def list_tool_call_models(token: str) -> dict[str, dict]:

src/seclab_taskflow_agent/mcp_lifecycle.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@
3333
class MCPServerEntry:
3434
"""A paired MCP server wrapper and optional local process."""
3535

36-
__slots__ = ("server", "process")
36+
__slots__ = ("server", "process", "name")
3737

38-
def __init__(self, server: MCPNamespaceWrap, process: StreamableMCPThread | None = None):
38+
def __init__(self, server: MCPNamespaceWrap, process: StreamableMCPThread | None = None, name: str = ""):
3939
self.server = server
4040
self.process = process
41+
self.name = name
4142

4243

4344
def build_mcp_servers(
@@ -117,7 +118,7 @@ def _print_err(line: str) -> None:
117118
case _:
118119
raise ValueError(f"Unsupported MCP transport: {params['kind']}")
119120

120-
entries.append(MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc))
121+
entries.append(MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc, name=tb))
121122

122123
return entries
123124

@@ -136,7 +137,7 @@ async def mcp_session_task(
136137
"""
137138
try:
138139
for entry in entries:
139-
logging.debug(f"Connecting mcp server: {entry.server._name}")
140+
logging.debug(f"Connecting mcp server: {entry.name}")
140141
if entry.process is not None:
141142
entry.process.start()
142143
await entry.process.async_wait_for_connection(poll_interval=0.1)
@@ -147,17 +148,17 @@ async def mcp_session_task(
147148

148149
for entry in list(reversed(entries)):
149150
try:
150-
logging.debug(f"Starting cleanup for mcp server: {entry.server._name}")
151+
logging.debug(f"Starting cleanup for mcp server: {entry.name}")
151152
await entry.server.cleanup()
152-
logging.debug(f"Cleaned up mcp server: {entry.server._name}")
153+
logging.debug(f"Cleaned up mcp server: {entry.name}")
153154
if entry.process is not None:
154155
entry.process.stop()
155156
try:
156157
await asyncio.to_thread(entry.process.join_and_raise)
157158
except Exception as e:
158159
logging.warning(f"Streamable mcp server process exception: {e}")
159160
except asyncio.CancelledError:
160-
logging.exception(f"Timeout on cleanup for mcp server: {entry.server._name}")
161+
logging.exception(f"Timeout on cleanup for mcp server: {entry.name}")
161162
except RuntimeError:
162163
logging.exception("RuntimeError in mcp session task")
163164
except asyncio.CancelledError:

src/seclab_taskflow_agent/prompt_parser.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def parse_prompt_args(
5454
if e.code == 2:
5555
logging.exception(f"User provided incomplete prompt: {user_prompt}")
5656
return None, None, None, None, "", help_msg
57+
except Exception:
58+
logging.exception(f"Failed to parse prompt: {user_prompt}")
59+
return None, None, None, None, "", help_msg
5760
p = args[0].p.strip() if args[0].p else None
5861
t = args[0].t.strip() if args[0].t else None
5962
list_models = args[0].l

src/seclab_taskflow_agent/runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,12 @@ async def run_prompts(async_task: bool = False, max_concurrent_tasks: int = 5) -
603603
raise ValueError(f"No such personality: {agent_name}")
604604
resolved_agents[agent_name] = personality
605605

606+
if not resolved_agents:
607+
raise ValueError(
608+
"No agents resolved for this task. "
609+
"Specify a personality with -p or provide an agents list."
610+
)
611+
606612
async def _deploy(ra: dict, pp: str) -> bool:
607613
async with semaphore:
608614
return await deploy_task_agents(

src/seclab_taskflow_agent/session.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ def record_task(
9999
tool_results=tool_results or [],
100100
)
101101
)
102-
if tool_results:
103-
self.last_tool_results = list(tool_results)
102+
self.last_tool_results = list(tool_results or [])
104103
self.save()
105104

106105
def mark_finished(self) -> None:

0 commit comments

Comments
 (0)