Skip to content

Commit e8dfcc9

Browse files
alliscodeCopilot
andcommitted
Pivot: preserve workflow state across run() calls
Replace the prior 'combined message + checkpoint_id in one run()' approach with a cleaner default: Workflow.run no longer wipes shared state or runner- context messages between calls. Iteration counting and per-run kwargs still reset on a fresh-message run; checkpoint and responses runs are continuations that preserve everything. This lets a WorkflowAgent be invoked repeatedly on the same instance and maintain multi-turn context (e.g. accumulated Conversation.messages) without asking developers to opt in. Hosted-agent multi-turn pattern becomes two explicit calls: restore-from-checkpoint (drive to idle), then run-with-message. Key changes: - _workflow.py: drop _state.clear() and reset_for_new_run() from run(). Reset iteration count and run kwargs on fresh-message runs only. Restore 'Cannot provide both message and checkpoint_id' validation. Add async guard: fresh-message run with un-drained pending executor messages from a prior run is invalid. - _runner.py: clear _state before import_state in restore_from_checkpoint so restore is authoritative (import_state merges, not replaces). - _agent.py: revert checkpoint branch to restore-only (no message forward). - _responses.py (foundry_hosting): two-call host pattern - restore checkpoint silently, then run with new user input. - tests: state-preservation is the new default; rebuild Workflow for clean slate. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent baff7e3 commit e8dfcc9

5 files changed

Lines changed: 137 additions & 72 deletions

File tree

python/packages/core/agent_framework/_workflows/_agent.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -437,17 +437,15 @@ async def _run_core(
437437
yield event
438438

439439
elif checkpoint_id is not None:
440-
# Restore the prior workflow state from the checkpoint and, if
441-
# there's a new user message in this run, deliver it to the
442-
# start executor in the same call. This is the multi-turn
443-
# continuation path: shared state (e.g. accumulated conversation
444-
# history maintained by the workflow's executors) survives across
445-
# turns because Workflow.run sets reset_context=False whenever
446-
# checkpoint_id is provided.
447-
message_arg: Any | None = list(input_messages) if input_messages else None
440+
# Restore the prior workflow state from the checkpoint. Shared
441+
# state (e.g. accumulated conversation history maintained by the
442+
# workflow's executors) survives across turns because Workflow.run
443+
# no longer wipes state per call. Callers who want to deliver a
444+
# new user message after restore should make a second
445+
# `workflow.run(message=...)` call - they are NOT mutually
446+
# exclusive on the same instance, but each must be its own call.
448447
if streaming:
449448
async for event in self.workflow.run(
450-
message=message_arg,
451449
stream=True,
452450
checkpoint_id=checkpoint_id,
453451
checkpoint_storage=checkpoint_storage,
@@ -457,7 +455,6 @@ async def _run_core(
457455
yield event
458456
else:
459457
for event in await self.workflow.run(
460-
message=message_arg,
461458
checkpoint_id=checkpoint_id,
462459
checkpoint_storage=checkpoint_storage,
463460
function_invocation_kwargs=function_invocation_kwargs,

python/packages/core/agent_framework/_workflows/_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,12 @@ async def restore_from_checkpoint(
278278
"Please rebuild the original workflow before resuming."
279279
)
280280

281-
# Restore state
281+
# Restore state. Clear first so import_state (which merges) does
282+
# not leak stale keys from a prior run on this Workflow instance.
283+
# This matters more now that Workflow.run() no longer wipes state
284+
# per call - the only reset point for shared state on a reused
285+
# instance is at restore time.
286+
self._state.clear()
282287
self._state.import_state(checkpoint.state)
283288
# Restore executor states using the restored state
284289
await self._restore_executor_states()

python/packages/core/agent_framework/_workflows/_workflow.py

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def get_executors_list(self) -> list[Executor]:
299299
async def _run_workflow_with_tracing(
300300
self,
301301
initial_executor_fn: Callable[[], Awaitable[None]] | None = None,
302-
reset_context: bool = True,
302+
is_fresh_message_run: bool = True,
303303
streaming: bool = False,
304304
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
305305
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
@@ -310,13 +310,18 @@ async def _run_workflow_with_tracing(
310310
of external callers to maintain context across different workflow runs.
311311
312312
Args:
313-
initial_executor_fn: Optional function to execute initial executor
314-
reset_context: Whether to reset the context for a new run
315-
streaming: Whether to enable streaming mode for agents
313+
initial_executor_fn: Optional function to execute initial executor.
314+
is_fresh_message_run: True when this run is a fresh new turn delivered
315+
via the start executor (i.e. ``message`` is provided without a
316+
``checkpoint_id`` or ``responses``). Resets per-run accounting
317+
(iteration counter and run kwargs) without touching the shared
318+
workflow state. False for checkpoint restores and responses-only
319+
runs, which are continuations of prior work.
320+
streaming: Whether to enable streaming mode for agents.
316321
function_invocation_kwargs: Optional kwargs to store in State for function
317-
invocations in subagents
322+
invocations in subagents.
318323
client_kwargs: Optional kwargs to store in State for chat client
319-
invocations in subagents
324+
invocations in subagents.
320325
321326
Yields:
322327
WorkflowEvent: The events generated during the workflow execution.
@@ -345,16 +350,26 @@ async def _run_workflow_with_tracing(
345350
in_progress = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS)
346351
yield in_progress # noqa: RUF070
347352

348-
# Reset context for a new run if supported
349-
if reset_context:
353+
# Per-run reset for fresh-message runs only. We deliberately
354+
# do NOT clear shared workflow state (`_state.clear()`) or the
355+
# runner context's in-flight messages (`reset_for_new_run()`)
356+
# here - state and pending work persist across `run()` calls
357+
# so that a `WorkflowAgent` can deliver multi-turn input on
358+
# the same instance and have prior turns' context survive.
359+
# Iteration counting and per-run kwargs ARE per-run though,
360+
# so they're reset here.
361+
if is_fresh_message_run:
350362
self._runner.reset_iteration_count()
351-
self._runner.context.reset_for_new_run()
352-
self._state.clear()
353363

354364
# Store run kwargs in State so executors can access them.
355-
# Only overwrite when new kwargs are explicitly provided or state was
356-
# just cleared (fresh run). On continuation (reset_context=False) with
357-
# no new kwargs, preserve the kwargs from the original run.
365+
# Per-run kwargs semantics:
366+
# - On a fresh message run, prior kwargs go away (set to {}
367+
# by default, or to the new kwargs if provided). This
368+
# prevents stale kwargs from a prior turn leaking into the
369+
# current turn.
370+
# - On a continuation (checkpoint restore or responses), the
371+
# prior run's kwargs are preserved unless the caller
372+
# explicitly provides new kwargs.
358373
if function_invocation_kwargs is not None or client_kwargs is not None:
359374
combined_kwargs: dict[str, Any] = {}
360375
if function_invocation_kwargs is not None:
@@ -366,11 +381,12 @@ async def _run_workflow_with_tracing(
366381
client_kwargs, "client_kwargs"
367382
)
368383
self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs)
369-
elif reset_context:
384+
elif is_fresh_message_run:
370385
self._state.set(WORKFLOW_RUN_KWARGS_KEY, {})
371386
self._state.commit() # Commit immediately so kwargs are available
372387

373-
# Set streaming mode after reset
388+
# Set streaming mode (always set explicitly per run since
389+
# reset_for_new_run() no longer runs to clear it).
374390
self._runner_context.set_streaming(streaming)
375391

376392
# Execute initial setup if provided
@@ -443,7 +459,7 @@ async def _execute_with_message_or_checkpoint(
443459
if message is None and checkpoint_id is None:
444460
raise ValueError("Must provide either 'message' or 'checkpoint_id'")
445461

446-
# Handle checkpoint restoration (may be combined with message below)
462+
# Handle checkpoint restoration
447463
if checkpoint_id is not None:
448464
has_checkpointing = self._runner.context.has_checkpointing()
449465

@@ -455,10 +471,8 @@ async def _execute_with_message_or_checkpoint(
455471

456472
await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage)
457473

458-
# Handle initial message - if combined with a checkpoint_id, this
459-
# delivers a continuation message to the workflow's start executor
460-
# without clearing prior shared state (reset_context=False).
461-
if message is not None:
474+
# Handle initial message
475+
elif message is not None:
462476
executor = self.get_start_executor()
463477
await executor.execute(
464478
message,
@@ -587,13 +601,29 @@ async def _run_core(
587601
if checkpoint_storage is not None:
588602
self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage)
589603

590-
initial_executor_fn, reset_context = self._resolve_execution_mode(
604+
# Async validation: a fresh-message run (no checkpoint, no responses)
605+
# is only allowed when the runner context has fully drained from any
606+
# prior run. If it still has in-flight executor messages, the prior
607+
# run didn't complete - the caller must either resume from a
608+
# checkpoint or wait for the prior run to drain. (Pending request_info
609+
# events are intentionally NOT blocked here: a follow-up run with
610+
# message=... is the normal way to deliver a response to those
611+
# pending requests, e.g. via WorkflowAgent._process_pending_requests.)
612+
if message is not None and checkpoint_id is None and responses is None:
613+
if await self._runner.context.has_messages():
614+
raise RuntimeError(
615+
"Cannot start a new run with 'message' while in-flight executor "
616+
"messages remain from a prior run. Either resume from a checkpoint "
617+
"(checkpoint_id=...) or wait for the prior run to complete."
618+
)
619+
620+
initial_executor_fn = self._resolve_execution_mode(
591621
message, responses, checkpoint_id, checkpoint_storage
592622
)
593623

594624
async for event in self._run_workflow_with_tracing(
595625
initial_executor_fn=initial_executor_fn,
596-
reset_context=reset_context,
626+
is_fresh_message_run=(message is not None and checkpoint_id is None and responses is None),
597627
streaming=streaming,
598628
function_invocation_kwargs=function_invocation_kwargs,
599629
client_kwargs=client_kwargs,
@@ -662,13 +692,7 @@ def _validate_run_params(
662692
raise ValueError("Cannot provide both 'message' and 'responses'. Use one or the other.")
663693

664694
if message is not None and checkpoint_id is not None:
665-
# Combined message + checkpoint_id is supported: restore prior
666-
# workflow state from the checkpoint, then execute the start
667-
# executor with the new message. The workflow's shared state
668-
# (e.g. accumulated conversation history kept in custom shared
669-
# state) is preserved across the boundary because reset_context
670-
# is set to False for this combination (see _resolve_execution_mode).
671-
pass
695+
raise ValueError("Cannot provide both 'message' and 'checkpoint_id'. Use one or the other.")
672696

673697
if message is None and responses is None and checkpoint_id is None:
674698
raise ValueError(
@@ -682,12 +706,8 @@ def _resolve_execution_mode(
682706
responses: Mapping[str, Any] | None,
683707
checkpoint_id: str | None,
684708
checkpoint_storage: CheckpointStorage | None,
685-
) -> tuple[Callable[[], Awaitable[None]], bool]:
686-
"""Determine the initial executor function and reset_context flag based on parameters.
687-
688-
Returns:
689-
A tuple of (initial_executor_fn, reset_context).
690-
"""
709+
) -> Callable[[], Awaitable[None]]:
710+
"""Determine the initial executor function based on parameters."""
691711
if responses is not None:
692712
if checkpoint_id is not None:
693713
# Combined: restore checkpoint then send responses
@@ -697,13 +717,12 @@ def _resolve_execution_mode(
697717
else:
698718
# Send responses only (requires pending requests in workflow state)
699719
initial_executor_fn = functools.partial(self._send_responses_internal, responses)
700-
return initial_executor_fn, False
720+
return initial_executor_fn
701721
# Regular run or checkpoint restoration
702722
initial_executor_fn = functools.partial(
703723
self._execute_with_message_or_checkpoint, message, checkpoint_id, checkpoint_storage
704724
)
705-
reset_context = message is not None and checkpoint_id is None
706-
return initial_executor_fn, reset_context
725+
return initial_executor_fn
707726

708727
async def _restore_and_send_responses(
709728
self,

python/packages/core/tests/workflow/test_workflow.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,13 @@ async def handle_message(
488488
await ctx.yield_output(existing_messages.copy()) # type: ignore
489489

490490

491-
async def test_workflow_multiple_runs_no_state_collision():
492-
"""Test that running the same workflow instance multiple times doesn't have state collision."""
491+
async def test_workflow_multiple_runs_preserve_state():
492+
"""Test that running the same workflow instance multiple times preserves shared state.
493+
494+
State preservation is the new default - calling ``Workflow.run`` repeatedly
495+
on the same instance behaves like a chat agent maintaining memory across
496+
turns. Callers that want fresh state should rebuild the Workflow.
497+
"""
493498
with tempfile.TemporaryDirectory() as temp_dir:
494499
storage = FileCheckpointStorage(temp_dir)
495500

@@ -503,29 +508,45 @@ async def test_workflow_multiple_runs_no_state_collision():
503508
.build()
504509
)
505510

506-
# Run 1: Should only see messages from run 1
511+
# Run 1: Single record from run 1
507512
result1 = await workflow.run(StateTrackingMessage(data="message1", run_id="run1"))
508513
assert result1.get_final_state() == WorkflowRunState.IDLE
509514
outputs1 = result1.get_outputs()
510515
assert outputs1[0] == ["run1:message1"]
511516

512-
# Run 2: Should only see messages from run 2, not run 1
517+
# Run 2: State from run 1 persists; run 2's record appends.
513518
result2 = await workflow.run(StateTrackingMessage(data="message2", run_id="run2"))
514519
assert result2.get_final_state() == WorkflowRunState.IDLE
515520
outputs2 = result2.get_outputs()
516-
assert outputs2[0] == ["run2:message2"] # Should NOT contain run1 data
521+
assert outputs2[0] == ["run1:message1", "run2:message2"]
517522

518-
# Run 3: Should only see messages from run 3
523+
# Run 3: Same - all three accumulate.
519524
result3 = await workflow.run(StateTrackingMessage(data="message3", run_id="run3"))
520525
assert result3.get_final_state() == WorkflowRunState.IDLE
521526
outputs3 = result3.get_outputs()
522-
assert outputs3[0] == ["run3:message3"] # Should NOT contain run1 or run2 data
527+
assert outputs3[0] == ["run1:message1", "run2:message2", "run3:message3"]
528+
529+
530+
async def test_workflow_multiple_runs_no_state_collision_after_rebuild():
531+
"""Rebuilding the Workflow gives a fresh shared-state slate."""
532+
with tempfile.TemporaryDirectory() as temp_dir:
533+
storage = FileCheckpointStorage(temp_dir)
523534

524-
# Verify that each run only processed its own message
525-
# This confirms that the checkpointable context properly resets between runs
526-
assert outputs1[0] != outputs2[0]
527-
assert outputs2[0] != outputs3[0]
528-
assert outputs1[0] != outputs3[0]
535+
def _build():
536+
executor = StateTrackingExecutor(id="state_executor")
537+
return (
538+
WorkflowBuilder(start_executor=executor, checkpoint_storage=storage)
539+
.add_edge(executor, executor)
540+
.build()
541+
)
542+
543+
wf1 = _build()
544+
result1 = await wf1.run(StateTrackingMessage(data="message1", run_id="run1"))
545+
assert result1.get_outputs()[0] == ["run1:message1"]
546+
547+
wf2 = _build()
548+
result2 = await wf2.run(StateTrackingMessage(data="message2", run_id="run2"))
549+
assert result2.get_outputs()[0] == ["run2:message2"]
529550

530551

531552
async def test_workflow_checkpoint_runtime_only_configuration(
@@ -942,13 +963,16 @@ async def test_workflow_run_parameter_validation(simple_executor: Executor) -> N
942963
result = await workflow.run(test_message)
943964
assert result.get_final_state() == WorkflowRunState.IDLE
944965

945-
# Valid: message + checkpoint_id (combined restore + new input)
946-
# is supported as of the multi-turn checkpoint continuation work
947-
# (restore prior state, then deliver message to start executor with
948-
# reset_context=False). Use a fake id - we just need to confirm the
949-
# call no longer raises at the validation layer.
950-
# Note: passing a non-existent checkpoint_id will fail at restore time,
951-
# which is a different code path than the validation we're checking.
966+
# Invalid: message + checkpoint_id (mutually exclusive). Multi-turn
967+
# state preservation is handled by Workflow.run preserving state across
968+
# calls, so the host pattern is two separate calls (restore-then-run),
969+
# not a single combined call.
970+
with pytest.raises(ValueError, match="Cannot provide both 'message' and 'checkpoint_id'"):
971+
await workflow.run(test_message, checkpoint_id="some-checkpoint")
972+
973+
with pytest.raises(ValueError, match="Cannot provide both 'message' and 'checkpoint_id'"):
974+
async for _ in workflow.run(test_message, checkpoint_id="some-checkpoint", stream=True):
975+
pass
952976

953977
# Invalid: none of message or checkpoint_id
954978
with pytest.raises(ValueError, match="Must provide at least one of"):

python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,33 @@ async def _handle_inner_workflow(
298298
yield response_event_stream.emit_created()
299299
yield response_event_stream.emit_in_progress()
300300

301+
# Multi-turn pattern: when we have a prior checkpoint, restore it
302+
# first (drive the workflow back to idle with prior state intact),
303+
# then make a separate call that delivers the new user input. This
304+
# depends on Workflow.run preserving shared state across calls. The
305+
# restore-only call may yield events from any pending in-flight
306+
# work in the checkpoint; we consume those internally here so they
307+
# don't surface to the response stream as duplicates.
308+
if latest_checkpoint_id is not None:
309+
if is_streaming_request:
310+
async for _ in self._agent.run(
311+
stream=True,
312+
checkpoint_id=latest_checkpoint_id,
313+
checkpoint_storage=checkpoint_storage,
314+
):
315+
pass
316+
else:
317+
await self._agent.run(
318+
stream=False,
319+
checkpoint_id=latest_checkpoint_id,
320+
checkpoint_storage=checkpoint_storage,
321+
)
322+
301323
if not is_streaming_request:
302-
# Run the agent in non-streaming mode
324+
# Run the agent in non-streaming mode with the new user input.
303325
response = await self._agent.run(
304326
input_messages,
305327
stream=False,
306-
checkpoint_id=latest_checkpoint_id,
307328
checkpoint_storage=checkpoint_storage,
308329
)
309330

@@ -320,11 +341,10 @@ async def _handle_inner_workflow(
320341
# lazily created on matching content, closed when a different type arrives.
321342
tracker = _OutputItemTracker(response_event_stream)
322343

323-
# Run the workflow agent in streaming mode
344+
# Run the workflow agent in streaming mode with the new user input.
324345
async for update in self._agent.run(
325346
input_messages,
326347
stream=True,
327-
checkpoint_id=latest_checkpoint_id,
328348
checkpoint_storage=checkpoint_storage,
329349
):
330350
for content in update.contents:

0 commit comments

Comments
 (0)