Skip to content

Commit 8526723

Browse files
authored
feat: #1859 add runtime function tool concurrency config (#3152)
1 parent 0466636 commit 8526723

7 files changed

Lines changed: 219 additions & 5 deletions

File tree

src/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
Runner,
110110
ToolErrorFormatter,
111111
ToolErrorFormatterArgs,
112+
ToolExecutionConfig,
112113
)
113114
from .run_context import AgentHookContext, RunContextWrapper, TContext
114115
from .run_error_handlers import (
@@ -432,6 +433,7 @@ def enable_verbose_stdout_logging():
432433
"ResponsesWebSocketSession",
433434
"RunConfig",
434435
"ReasoningItemIdPolicy",
436+
"ToolExecutionConfig",
435437
"ToolErrorFormatter",
436438
"ToolErrorFormatterArgs",
437439
"RunState",

src/agents/run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
RunOptions,
4141
ToolErrorFormatter,
4242
ToolErrorFormatterArgs,
43+
ToolExecutionConfig,
4344
)
4445
from .run_context import RunContextWrapper, TContext
4546
from .run_error_handlers import RunErrorHandlers
@@ -136,6 +137,7 @@
136137
"CallModelData",
137138
"CallModelInputFilter",
138139
"ReasoningItemIdPolicy",
140+
"ToolExecutionConfig",
139141
"ToolErrorFormatter",
140142
"ToolErrorFormatterArgs",
141143
"DEFAULT_MAX_TURNS",

src/agents/run_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,24 @@ class ToolErrorFormatterArgs(Generic[TContext]):
8888
ToolErrorFormatter = Callable[[ToolErrorFormatterArgs[Any]], MaybeAwaitable[str | None]]
8989

9090

91+
@dataclass
92+
class ToolExecutionConfig:
93+
"""Grouped SDK-side execution settings for local tool calls."""
94+
95+
max_function_tool_concurrency: int | None = None
96+
"""Maximum number of local function tool calls to execute concurrently.
97+
98+
Set to `None` to preserve the default behavior, which starts all function tool calls
99+
emitted in a turn. This does not change provider-side `parallel_tool_calls` behavior.
100+
"""
101+
102+
def __post_init__(self) -> None:
103+
if self.max_function_tool_concurrency is not None and (
104+
self.max_function_tool_concurrency < 1
105+
):
106+
raise ValueError("tool_execution.max_function_tool_concurrency must be at least 1")
107+
108+
91109
@dataclass
92110
class SandboxConcurrencyLimits:
93111
"""Concurrency limits for sandbox materialization work."""
@@ -255,6 +273,9 @@ class RunConfig:
255273
sandbox: SandboxRunConfig | None = None
256274
"""Optional sandbox runtime configuration for `SandboxAgent` execution."""
257275

276+
tool_execution: ToolExecutionConfig | None = None
277+
"""Optional SDK-side execution settings for local tool calls."""
278+
258279

259280
class RunOptions(TypedDict, Generic[TContext]):
260281
"""Arguments for ``AgentRunner`` methods."""
@@ -297,6 +318,7 @@ class RunOptions(TypedDict, Generic[TContext]):
297318
"RunOptions",
298319
"SandboxConcurrencyLimits",
299320
"SandboxRunConfig",
321+
"ToolExecutionConfig",
300322
"ToolErrorFormatter",
301323
"ToolErrorFormatterArgs",
302324
"_default_trace_include_sensitive_data",

src/agents/run_internal/tool_execution.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,9 @@ def __init__(
13781378
self.pending_tasks: set[asyncio.Task[Any]] = set()
13791379
self.propagating_failure: BaseException | None = None
13801380
self.available_function_tools: list[FunctionTool] = []
1381+
self.max_function_tool_concurrency = (
1382+
config.tool_execution.max_function_tool_concurrency if config.tool_execution else None
1383+
)
13811384

13821385
async def execute(
13831386
self,
@@ -1406,11 +1409,11 @@ async def execute(
14061409
if function_tool_id not in enabled_function_tool_ids:
14071410
self.available_function_tools.append(tool_run.function_tool)
14081411
enabled_function_tool_ids.add(function_tool_id)
1409-
for order, tool_run in enumerate(self.tool_runs):
1410-
self._create_tool_task(tool_run, order)
1412+
pending_tool_runs = list(enumerate(self.tool_runs))
1413+
self._fill_tool_task_slots(pending_tool_runs)
14111414

14121415
try:
1413-
await self._drain_pending_tasks()
1416+
await self._drain_pending_tasks(pending_tool_runs)
14141417
except asyncio.CancelledError as exc:
14151418
if self.propagating_failure is exc:
14161419
raise
@@ -1423,6 +1426,18 @@ async def execute(
14231426
self.tool_output_guardrail_results,
14241427
)
14251428

1429+
def _fill_tool_task_slots(self, pending_tool_runs: list[tuple[int, ToolRunFunction]]) -> None:
1430+
max_concurrency = self.max_function_tool_concurrency
1431+
available_slots = (
1432+
len(pending_tool_runs)
1433+
if max_concurrency is None
1434+
else max_concurrency - len(self.pending_tasks)
1435+
)
1436+
while available_slots > 0 and pending_tool_runs:
1437+
order, tool_run = pending_tool_runs.pop(0)
1438+
self._create_tool_task(tool_run, order)
1439+
available_slots -= 1
1440+
14261441
def _create_tool_task(self, tool_run: ToolRunFunction, order: int) -> None:
14271442
task_state = _FunctionToolTaskState(tool_run=tool_run, order=order)
14281443
task = asyncio.create_task(
@@ -1435,7 +1450,10 @@ def _create_tool_task(self, tool_run: ToolRunFunction, order: int) -> None:
14351450
self.task_states[task] = task_state
14361451
self.pending_tasks.add(task)
14371452

1438-
async def _drain_pending_tasks(self) -> None:
1453+
async def _drain_pending_tasks(
1454+
self,
1455+
pending_tool_runs: list[tuple[int, ToolRunFunction]],
1456+
) -> None:
14391457
while self.pending_tasks:
14401458
done_tasks, self.pending_tasks = await asyncio.wait(
14411459
self.pending_tasks,
@@ -1448,6 +1466,7 @@ async def _drain_pending_tasks(self) -> None:
14481466
)
14491467
if failure is not None:
14501468
await self._raise_failure_after_draining_siblings(failure)
1469+
self._fill_tool_task_slots(pending_tool_runs)
14511470

14521471
async def _raise_failure_after_draining_siblings(
14531472
self,

tests/test_run_config.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from agents import Agent, RunConfig, Runner
5+
from agents import Agent, RunConfig, Runner, ToolExecutionConfig
66
from agents.model_settings import ModelSettings
77
from agents.models.interface import Model, ModelProvider
88

@@ -185,3 +185,18 @@ def test_trace_include_sensitive_data_explicit_override_takes_precedence(monkeyp
185185
monkeypatch.setenv("OPENAI_AGENTS_TRACE_INCLUDE_SENSITIVE_DATA", "true")
186186
config = RunConfig(trace_include_sensitive_data=False)
187187
assert config.trace_include_sensitive_data is False
188+
189+
190+
def test_tool_execution_config_rejects_invalid_function_tool_concurrency() -> None:
191+
with pytest.raises(
192+
ValueError,
193+
match="tool_execution.max_function_tool_concurrency must be at least 1",
194+
):
195+
ToolExecutionConfig(max_function_tool_concurrency=0)
196+
197+
198+
def test_tool_execution_config_is_public_from_agents_package() -> None:
199+
config = RunConfig(tool_execution=ToolExecutionConfig(max_function_tool_concurrency=2))
200+
201+
assert config.tool_execution is not None
202+
assert config.tool_execution.max_function_tool_concurrency == 2

tests/test_run_step_execution.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
ToolApprovalItem,
3838
ToolCallItem,
3939
ToolCallOutputItem,
40+
ToolExecutionConfig,
4041
ToolGuardrailFunctionOutput,
4142
ToolInputGuardrail,
4243
ToolOutputGuardrailData,
@@ -232,6 +233,122 @@ async def test_plaintext_agent_with_tool_call_is_run_again():
232233
assert isinstance(result.next_step, NextStepRunAgain)
233234

234235

236+
@pytest.mark.asyncio
237+
async def test_function_tool_concurrency_default_starts_all_calls():
238+
active_count = 0
239+
max_seen_count = 0
240+
241+
async def tracked_tool(value: int) -> str:
242+
nonlocal active_count, max_seen_count
243+
active_count += 1
244+
max_seen_count = max(max_seen_count, active_count)
245+
try:
246+
await asyncio.sleep(0.01)
247+
return f"ok-{value}"
248+
finally:
249+
active_count -= 1
250+
251+
tool = function_tool(tracked_tool, name_override="tracked_tool")
252+
agent = Agent(name="test", tools=[tool])
253+
response = ModelResponse(
254+
output=[
255+
get_function_tool_call("tracked_tool", json.dumps({"value": 1}), call_id="call_1"),
256+
get_function_tool_call("tracked_tool", json.dumps({"value": 2}), call_id="call_2"),
257+
get_function_tool_call("tracked_tool", json.dumps({"value": 3}), call_id="call_3"),
258+
],
259+
usage=Usage(),
260+
response_id="resp",
261+
)
262+
263+
result = await get_execute_result(agent, response)
264+
265+
assert active_count == 0
266+
assert max_seen_count == 3
267+
assert_item_is_function_tool_call_output(result.generated_items[3], "ok-1")
268+
assert_item_is_function_tool_call_output(result.generated_items[4], "ok-2")
269+
assert_item_is_function_tool_call_output(result.generated_items[5], "ok-3")
270+
271+
272+
@pytest.mark.asyncio
273+
async def test_function_tool_concurrency_cap_limits_calls_and_preserves_output_order():
274+
active_count = 0
275+
max_seen_count = 0
276+
277+
async def tracked_tool(value: int) -> str:
278+
nonlocal active_count, max_seen_count
279+
active_count += 1
280+
max_seen_count = max(max_seen_count, active_count)
281+
try:
282+
await asyncio.sleep(0.03 if value == 1 else 0.001)
283+
return f"ok-{value}"
284+
finally:
285+
active_count -= 1
286+
287+
tool = function_tool(tracked_tool, name_override="tracked_tool")
288+
agent = Agent(name="test", tools=[tool])
289+
response = ModelResponse(
290+
output=[
291+
get_function_tool_call("tracked_tool", json.dumps({"value": 1}), call_id="call_1"),
292+
get_function_tool_call("tracked_tool", json.dumps({"value": 2}), call_id="call_2"),
293+
get_function_tool_call("tracked_tool", json.dumps({"value": 3}), call_id="call_3"),
294+
],
295+
usage=Usage(),
296+
response_id="resp",
297+
)
298+
299+
result = await get_execute_result(
300+
agent,
301+
response,
302+
run_config=RunConfig(tool_execution=ToolExecutionConfig(max_function_tool_concurrency=2)),
303+
)
304+
305+
assert active_count == 0
306+
assert max_seen_count == 2
307+
assert_item_is_function_tool_call_output(result.generated_items[3], "ok-1")
308+
assert_item_is_function_tool_call_output(result.generated_items[4], "ok-2")
309+
assert_item_is_function_tool_call_output(result.generated_items[5], "ok-3")
310+
311+
312+
@pytest.mark.asyncio
313+
async def test_function_tool_concurrency_cap_leaves_queued_calls_unstarted_after_failure():
314+
started_tools: list[str] = []
315+
316+
async def failing_tool() -> str:
317+
started_tools.append("failing_tool")
318+
raise RuntimeError("boom")
319+
320+
async def queued_tool() -> str:
321+
started_tools.append("queued_tool")
322+
return "should-not-run"
323+
324+
failing = function_tool(
325+
failing_tool,
326+
name_override="failing_tool",
327+
failure_error_function=None,
328+
)
329+
queued = function_tool(queued_tool, name_override="queued_tool")
330+
agent = Agent(name="test", tools=[failing, queued])
331+
response = ModelResponse(
332+
output=[
333+
get_function_tool_call("failing_tool", "{}", call_id="call_1"),
334+
get_function_tool_call("queued_tool", "{}", call_id="call_2"),
335+
],
336+
usage=Usage(),
337+
response_id="resp",
338+
)
339+
340+
with pytest.raises(UserError, match="Error running tool failing_tool: boom"):
341+
await get_execute_result(
342+
agent,
343+
response,
344+
run_config=RunConfig(
345+
tool_execution=ToolExecutionConfig(max_function_tool_concurrency=1)
346+
),
347+
)
348+
349+
assert started_tools == ["failing_tool"]
350+
351+
235352
@pytest.mark.asyncio
236353
async def test_plaintext_agent_hosted_shell_items_without_message_runs_again():
237354
shell_tool = ShellTool(environment={"type": "container_auto"})

tests/test_source_compat_constructors.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
RunResult,
1818
RunResultStreaming,
1919
SessionSettings,
20+
ToolExecutionConfig,
2021
ToolGuardrailFunctionOutput,
2122
ToolInputGuardrailData,
2223
ToolOutputGuardrailData,
@@ -92,6 +93,42 @@ def test_run_config_reasoning_item_id_policy_positional_binding() -> None:
9293

9394
assert config.session_settings == session_settings
9495
assert config.reasoning_item_id_policy == "omit"
96+
assert config.sandbox is None
97+
assert config.tool_execution is None
98+
99+
100+
def test_run_config_tool_execution_append_preserves_sandbox_position() -> None:
101+
session_settings = SessionSettings(limit=123)
102+
tool_execution = ToolExecutionConfig(max_function_tool_concurrency=2)
103+
config = RunConfig(
104+
None,
105+
MultiProvider(),
106+
None,
107+
None,
108+
False,
109+
None,
110+
None,
111+
None,
112+
False,
113+
None,
114+
True,
115+
"Agent workflow",
116+
None,
117+
None,
118+
None,
119+
None,
120+
None,
121+
None,
122+
session_settings,
123+
"omit",
124+
None,
125+
tool_execution,
126+
)
127+
128+
assert config.session_settings == session_settings
129+
assert config.reasoning_item_id_policy == "omit"
130+
assert config.sandbox is None
131+
assert config.tool_execution is tool_execution
95132

96133

97134
def test_model_settings_context_management_append_preserves_retry_position() -> None:

0 commit comments

Comments
 (0)