Skip to content

Commit 62595b2

Browse files
authored
[BREAKING] Python: Refactor workflows kwargs (microsoft#5010)
* Refactor workflows kwargs usage * Update sample * Add tests * Update samples * Fix formatting * Comments * Comments 2 * Comments 3 * Fix test and typing
1 parent fd253c0 commit 62595b2

10 files changed

Lines changed: 1097 additions & 831 deletions

File tree

python/packages/core/agent_framework/_workflows/_agent.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import sys
88
import uuid
9-
from collections.abc import AsyncIterable, Awaitable, Sequence
9+
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
1010
from dataclasses import dataclass
1111
from datetime import datetime, timezone
1212
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
@@ -152,7 +152,8 @@ def run(
152152
session: AgentSession | None = None,
153153
checkpoint_id: str | None = None,
154154
checkpoint_storage: CheckpointStorage | None = None,
155-
**kwargs: Any,
155+
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
156+
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
156157
) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ...
157158

158159
@overload
@@ -164,7 +165,8 @@ async def run(
164165
session: AgentSession | None = None,
165166
checkpoint_id: str | None = None,
166167
checkpoint_storage: CheckpointStorage | None = None,
167-
**kwargs: Any,
168+
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
169+
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
168170
) -> AgentResponse: ...
169171

170172
def run(
@@ -175,7 +177,8 @@ def run(
175177
session: AgentSession | None = None,
176178
checkpoint_id: str | None = None,
177179
checkpoint_storage: CheckpointStorage | None = None,
178-
**kwargs: Any,
180+
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
181+
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
179182
) -> ResponseStream[AgentResponseUpdate, AgentResponse] | Awaitable[AgentResponse]:
180183
"""Get a response from the workflow agent.
181184
@@ -192,8 +195,12 @@ def run(
192195
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
193196
used to load and restore the checkpoint. When provided without checkpoint_id,
194197
enables checkpointing for this run.
195-
**kwargs: Additional keyword arguments passed through to underlying workflow
196-
and tool functions.
198+
function_invocation_kwargs: Keyword arguments forwarded to tool invocations in
199+
subagents. Either a mapping of agent name/executor id to kwargs, or a flat
200+
mapping of kwargs for all tool invocations.
201+
client_kwargs: Keyword arguments forwarded to chat client calls in
202+
subagents. Either a mapping of agent name/executor id to kwargs, or a flat
203+
mapping of kwargs for all chat client calls.
197204
198205
Returns:
199206
When stream=True: An AsyncIterable[AgentResponseUpdate] for streaming updates.
@@ -208,10 +215,26 @@ def run(
208215
response_id = str(uuid.uuid4())
209216
if stream:
210217
return ResponseStream(
211-
self._run_stream_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs),
218+
self._run_stream_impl(
219+
messages,
220+
response_id,
221+
session,
222+
checkpoint_id,
223+
checkpoint_storage,
224+
function_invocation_kwargs=function_invocation_kwargs,
225+
client_kwargs=client_kwargs,
226+
),
212227
finalizer=AgentResponse.from_updates,
213228
)
214-
return self._run_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs)
229+
return self._run_impl(
230+
messages,
231+
response_id,
232+
session,
233+
checkpoint_id,
234+
checkpoint_storage,
235+
function_invocation_kwargs=function_invocation_kwargs,
236+
client_kwargs=client_kwargs,
237+
)
215238

216239
async def _run_impl(
217240
self,
@@ -220,7 +243,8 @@ async def _run_impl(
220243
session: AgentSession | None,
221244
checkpoint_id: str | None = None,
222245
checkpoint_storage: CheckpointStorage | None = None,
223-
**kwargs: Any,
246+
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
247+
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
224248
) -> AgentResponse:
225249
"""Internal implementation of non-streaming execution.
226250
@@ -230,8 +254,8 @@ async def _run_impl(
230254
session: The agent session for conversation context.
231255
checkpoint_id: ID of checkpoint to restore from.
232256
checkpoint_storage: Runtime checkpoint storage.
233-
**kwargs: Additional keyword arguments passed through to the underlying
234-
workflow and tool functions.
257+
function_invocation_kwargs: Optional kwargs for tool invocations.
258+
client_kwargs: Optional kwargs for chat client calls.
235259
236260
Returns:
237261
An AgentResponse representing the workflow execution results.
@@ -264,7 +288,12 @@ async def _run_impl(
264288

265289
output_events: list[WorkflowEvent[Any]] = []
266290
async for event in self._run_core(
267-
session_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs
291+
session_messages,
292+
checkpoint_id,
293+
checkpoint_storage,
294+
streaming=False,
295+
function_invocation_kwargs=function_invocation_kwargs,
296+
client_kwargs=client_kwargs,
268297
):
269298
if event.type == "output" or event.type == "request_info":
270299
output_events.append(event)
@@ -285,7 +314,8 @@ async def _run_stream_impl(
285314
session: AgentSession | None,
286315
checkpoint_id: str | None = None,
287316
checkpoint_storage: CheckpointStorage | None = None,
288-
**kwargs: Any,
317+
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
318+
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
289319
) -> AsyncIterable[AgentResponseUpdate]:
290320
"""Internal implementation of streaming execution.
291321
@@ -295,8 +325,8 @@ async def _run_stream_impl(
295325
session: The agent session for conversation context.
296326
checkpoint_id: ID of checkpoint to restore from.
297327
checkpoint_storage: Runtime checkpoint storage.
298-
**kwargs: Additional keyword arguments passed through to the underlying
299-
workflow and tool functions.
328+
function_invocation_kwargs: Optional kwargs for tool invocations.
329+
client_kwargs: Optional kwargs for chat client calls.
300330
301331
Yields:
302332
AgentResponseUpdate objects representing the workflow execution progress.
@@ -329,7 +359,12 @@ async def _run_stream_impl(
329359
session_messages: list[Message] = session_context.get_messages(include_input=True)
330360
all_updates: list[AgentResponseUpdate] = []
331361
async for event in self._run_core(
332-
session_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs
362+
session_messages,
363+
checkpoint_id,
364+
checkpoint_storage,
365+
streaming=True,
366+
function_invocation_kwargs=function_invocation_kwargs,
367+
client_kwargs=client_kwargs,
333368
):
334369
updates = self._convert_workflow_event_to_agent_response_updates(response_id, event)
335370
for update in updates:
@@ -349,7 +384,8 @@ async def _run_core(
349384
checkpoint_id: str | None,
350385
checkpoint_storage: CheckpointStorage | None,
351386
streaming: bool,
352-
**kwargs: Any,
387+
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
388+
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
353389
) -> AsyncIterable[WorkflowEvent]:
354390
"""Core implementation that yields workflow events for both streaming and non-streaming modes.
355391
@@ -358,8 +394,8 @@ async def _run_core(
358394
checkpoint_id: ID of checkpoint to restore from.
359395
checkpoint_storage: Runtime checkpoint storage.
360396
streaming: Whether to use streaming workflow methods.
361-
**kwargs: Additional keyword arguments passed through to the underlying
362-
workflow and tool functions.
397+
function_invocation_kwargs: Optional kwargs for tool invocations.
398+
client_kwargs: Optional kwargs for chat client calls.
363399
364400
Yields:
365401
WorkflowEvent objects from the workflow execution.
@@ -371,10 +407,19 @@ async def _run_core(
371407
if bool(self.pending_requests):
372408
function_responses = self._process_pending_requests(input_messages)
373409
if streaming:
374-
async for event in self.workflow.run(responses=function_responses, stream=True, **kwargs):
410+
async for event in self.workflow.run(
411+
responses=function_responses,
412+
stream=True,
413+
function_invocation_kwargs=function_invocation_kwargs,
414+
client_kwargs=client_kwargs,
415+
):
375416
yield event
376417
else:
377-
for event in await self.workflow.run(responses=function_responses, **kwargs):
418+
for event in await self.workflow.run(
419+
responses=function_responses,
420+
function_invocation_kwargs=function_invocation_kwargs,
421+
client_kwargs=client_kwargs,
422+
):
378423
yield event
379424

380425
elif checkpoint_id is not None:
@@ -383,14 +428,16 @@ async def _run_core(
383428
stream=True,
384429
checkpoint_id=checkpoint_id,
385430
checkpoint_storage=checkpoint_storage,
386-
**kwargs,
431+
function_invocation_kwargs=function_invocation_kwargs,
432+
client_kwargs=client_kwargs,
387433
):
388434
yield event
389435
else:
390436
for event in await self.workflow.run(
391437
checkpoint_id=checkpoint_id,
392438
checkpoint_storage=checkpoint_storage,
393-
**kwargs,
439+
function_invocation_kwargs=function_invocation_kwargs,
440+
client_kwargs=client_kwargs,
394441
):
395442
yield event
396443

@@ -400,14 +447,16 @@ async def _run_core(
400447
message=input_messages,
401448
stream=True,
402449
checkpoint_storage=checkpoint_storage,
403-
**kwargs,
450+
function_invocation_kwargs=function_invocation_kwargs,
451+
client_kwargs=client_kwargs,
404452
):
405453
yield event
406454
else:
407455
for event in await self.workflow.run(
408456
message=input_messages,
409457
checkpoint_storage=checkpoint_storage,
410-
**kwargs,
458+
function_invocation_kwargs=function_invocation_kwargs,
459+
client_kwargs=client_kwargs,
411460
):
412461
yield event
413462

0 commit comments

Comments
 (0)