Skip to content

Commit 7e254e7

Browse files
committed
refactor: rename transport registration function for clarity and update imports
1 parent 3f85525 commit 7e254e7

2 files changed

Lines changed: 7 additions & 12 deletions

File tree

src/seclab_taskflow_agent/mcp_lifecycle.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
from __future__ import annotations
1111

12-
__all__ = ["MCP_CLEANUP_TIMEOUT", "build_mcp_servers", "mcp_session_task"]
12+
__all__ = ["MCP_CLEANUP_TIMEOUT", "build_mcp_servers", "mcp_session_task", "register_transport"]
1313

1414
import asyncio
1515
import logging
16-
from types import MappingProxyType
16+
1717
from typing import TYPE_CHECKING, Any, Callable
1818

1919
from agents.mcp import MCPServerSse, MCPServerStdio, MCPServerStreamableHttp, create_static_tool_filter
@@ -48,15 +48,15 @@ def __init__(self, server: MCPNamespaceWrap, process: StreamableMCPThread | None
4848
MCP_TRANSPORT_REGISTRY: dict[str, MCPServerBuilder] = {}
4949

5050

51-
def _register_transport(kind: str) -> Callable[[MCPServerBuilder], MCPServerBuilder]:
51+
def register_transport(kind: str) -> Callable[[MCPServerBuilder], MCPServerBuilder]:
5252
"""Decorator to register an MCP transport builder."""
5353
def decorator(builder: MCPServerBuilder) -> MCPServerBuilder:
5454
MCP_TRANSPORT_REGISTRY[kind] = builder
5555
return builder
5656
return decorator
5757

5858

59-
@_register_transport("stdio")
59+
@register_transport("stdio")
6060
def _build_stdio(tb: str, params: dict[str, Any], tool_filter: Any, client_session_timeout: int, confirms: list[str]) -> MCPServerEntry:
6161
if params.get("reconnecting", False):
6262
mcp_server = ReconnectingMCPServerStdio(
@@ -77,7 +77,7 @@ def _build_stdio(tb: str, params: dict[str, Any], tool_filter: Any, client_sessi
7777
return MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), None, name=tb)
7878

7979

80-
@_register_transport("sse")
80+
@register_transport("sse")
8181
def _build_sse(tb: str, params: dict[str, Any], tool_filter: Any, client_session_timeout: int, confirms: list[str]) -> MCPServerEntry:
8282
mcp_server = MCPServerSse(
8383
name=tb,
@@ -88,7 +88,7 @@ def _build_sse(tb: str, params: dict[str, Any], tool_filter: Any, client_session
8888
return MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), None, name=tb)
8989

9090

91-
@_register_transport("streamable")
91+
@register_transport("streamable")
9292
def _build_streamable(tb: str, params: dict[str, Any], tool_filter: Any, client_session_timeout: int, confirms: list[str]) -> MCPServerEntry:
9393
server_proc = None
9494
if "command" in params:
@@ -115,11 +115,6 @@ def _print_err(line: str) -> None:
115115
return MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc, name=tb)
116116

117117

118-
# Freeze the registry after all transports are registered to prevent
119-
# accidental mutation in tests or at runtime.
120-
MCP_TRANSPORT_REGISTRY = MappingProxyType(MCP_TRANSPORT_REGISTRY)
121-
122-
123118
def build_mcp_servers(
124119
available_tools: AvailableTools,
125120
toolboxes: list[str],

src/seclab_taskflow_agent/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
3434
from openai import APIConnectionError, APITimeoutError, BadRequestError, RateLimitError
3535
from openai.types.responses import ResponseTextDeltaEvent
36-
from tenacity import AsyncRetrying, retry, retry_if_exception_type, stop_after_attempt, wait_exponential
36+
from tenacity import AsyncRetrying, retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt, wait_exponential
3737

3838
from .agent import TaskAgent, TaskAgentHooks, TaskRunHooks
3939
from .capi import get_default_model

0 commit comments

Comments
 (0)