Skip to content

Commit ae19308

Browse files
authored
feat: add stateful model support for server-side conversation management (#2004)
1 parent 424224d commit ae19308

23 files changed

+521
-27
lines changed

src/strands/agent/agent.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from ..hooks.registry import TEvent
4646
from ..interrupt import _InterruptState
4747
from ..models.bedrock import BedrockModel
48-
from ..models.model import Model
48+
from ..models.model import Model, _ModelPlugin
4949
from ..plugins import Plugin
5050
from ..plugins.registry import _PluginRegistry
5151
from ..session.session_manager import SessionManager
@@ -68,6 +68,7 @@
6868
from .base import AgentBase
6969
from .conversation_manager import (
7070
ConversationManager,
71+
NullConversationManager,
7172
SlidingWindowConversationManager,
7273
)
7374
from .state import AgentState
@@ -229,7 +230,19 @@ def __init__(
229230
else:
230231
self.callback_handler = callback_handler
231232

232-
self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager()
233+
if self.model.stateful and conversation_manager is not None:
234+
raise ValueError(
235+
"conversation_manager cannot be used with a stateful model. "
236+
"The model manages conversation state server-side."
237+
)
238+
239+
self.conversation_manager: ConversationManager
240+
if self.model.stateful:
241+
self.conversation_manager = NullConversationManager()
242+
elif conversation_manager:
243+
self.conversation_manager = conversation_manager
244+
else:
245+
self.conversation_manager = SlidingWindowConversationManager()
233246

234247
# Process trace attributes to ensure they're of compatible types
235248
self.trace_attributes: dict[str, AttributeValue] = {}
@@ -282,6 +295,9 @@ def __init__(
282295

283296
self._interrupt_state = _InterruptState()
284297

298+
# Runtime state for model providers (e.g., server-side response ids)
299+
self._model_state: dict[str, Any] = {}
300+
285301
# Initialize lock for guarding concurrent invocations
286302
# Using threading.Lock instead of asyncio.Lock because run_async() creates
287303
# separate event loops in different threads, so asyncio.Lock wouldn't work
@@ -327,6 +343,9 @@ def __init__(
327343
for hook in hooks:
328344
self.hooks.add_hook(hook)
329345

346+
# Register built-in plugins
347+
self._plugin_registry.add_and_init(_ModelPlugin())
348+
330349
if plugins:
331350
for plugin in plugins:
332351
self._plugin_registry.add_and_init(plugin)

src/strands/event_loop/event_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ async def _handle_model_execution(
338338
system_prompt_content=agent._system_prompt_content,
339339
tool_choice=structured_output_context.tool_choice,
340340
invocation_state=invocation_state,
341+
model_state=agent._model_state,
341342
cancel_signal=agent._cancel_signal,
342343
):
343344
yield event

src/strands/event_loop/streaming.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ async def stream_messages(
463463
tool_choice: Any | None = None,
464464
system_prompt_content: list[SystemContentBlock] | None = None,
465465
invocation_state: dict[str, Any] | None = None,
466+
model_state: dict[str, Any] | None = None,
466467
cancel_signal: threading.Event | None = None,
467468
**kwargs: Any,
468469
) -> AsyncGenerator[TypedEvent, None]:
@@ -477,6 +478,7 @@ async def stream_messages(
477478
system_prompt_content: The authoritative system prompt content blocks that always contains the
478479
system prompt data.
479480
invocation_state: Caller-provided state/context that was passed to the agent when it was invoked.
481+
model_state: Runtime state for model providers (e.g., server-side response ids).
480482
cancel_signal: Optional threading.Event to check for cancellation during streaming.
481483
**kwargs: Additional keyword arguments for future extensibility.
482484
@@ -495,6 +497,7 @@ async def stream_messages(
495497
tool_choice=tool_choice,
496498
system_prompt_content=system_prompt_content,
497499
invocation_state=invocation_state,
500+
model_state=model_state,
498501
)
499502

500503
async for event in process_stream(chunks, start_time, cancel_signal):

src/strands/models/model.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,19 @@
44
import logging
55
from collections.abc import AsyncGenerator, AsyncIterable
66
from dataclasses import dataclass
7-
from typing import Any, Literal, TypeVar
7+
from typing import TYPE_CHECKING, Any, Literal, TypeVar
88

99
from pydantic import BaseModel
1010

11+
from ..hooks.events import AfterInvocationEvent
12+
from ..plugins.plugin import Plugin
1113
from ..types.content import Messages, SystemContentBlock
1214
from ..types.streaming import StreamEvent
1315
from ..types.tools import ToolChoice, ToolSpec
1416

17+
if TYPE_CHECKING:
18+
from ..agent.agent import Agent
19+
1520
logger = logging.getLogger(__name__)
1621

1722
T = TypeVar("T", bound=BaseModel)
@@ -37,6 +42,15 @@ class Model(abc.ABC):
3742
standardized way to configure and process requests for different AI model providers.
3843
"""
3944

45+
@property
46+
def stateful(self) -> bool:
47+
"""Whether the model manages conversation state server-side.
48+
49+
Returns:
50+
False by default. Model providers that support server-side state should override this.
51+
"""
52+
return False
53+
4054
@abc.abstractmethod
4155
# pragma: no cover
4256
def update_config(self, **model_config: Any) -> None:
@@ -115,3 +129,34 @@ def stream(
115129
ModelThrottledException: When the model service is throttling requests from the client.
116130
"""
117131
pass
132+
133+
134+
class _ModelPlugin(Plugin):
135+
"""Plugin that manages model-related lifecycle hooks."""
136+
137+
@property
138+
def name(self) -> str:
139+
"""A stable string identifier for this plugin."""
140+
return "strands:model"
141+
142+
@staticmethod
143+
def _on_after_invocation(event: AfterInvocationEvent) -> None:
144+
"""Handle post-invocation model management tasks.
145+
146+
Performs the following:
147+
- Clears messages when the model is managing conversation state server-side.
148+
"""
149+
if event.agent.model.stateful:
150+
event.agent.messages.clear()
151+
logger.debug(
152+
"response_id=<%s> | cleared messages for server-managed conversation",
153+
event.agent._model_state.get("response_id"),
154+
)
155+
156+
def init_agent(self, agent: "Agent") -> None:
157+
"""Register model lifecycle hooks with the agent.
158+
159+
Args:
160+
agent: The agent instance to register hooks with.
161+
"""
162+
agent.add_hook(self._on_after_invocation, AfterInvocationEvent)

src/strands/models/openai_responses.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,8 @@
11
"""OpenAI model provider using the Responses API.
22
3-
The Responses API is OpenAI's newer API that differs from the Chat Completions API in several key ways:
3+
Note: Built-in tools (web search, code interpreter, file search) are not yet supported.
44
5-
1. The Responses API can maintain conversation state server-side through "previous_response_id",
6-
while Chat Completions is stateless and requires sending full conversation history each time.
7-
Note: This implementation currently only implements the stateless approach.
8-
9-
2. Responses API uses "input" (list of items) instead of "messages", and system
10-
prompts are passed as "instructions" rather than a system role message.
11-
12-
3. Responses API supports built-in tools (web search, code interpreter, file search)
13-
Note: These are not yet implemented in this provider.
14-
15-
- Docs: https://platform.openai.com/docs/api-reference/responses
5+
Docs: https://platform.openai.com/docs/api-reference/responses
166
"""
177

188
import base64
@@ -132,10 +122,14 @@ class OpenAIResponsesConfig(TypedDict, total=False):
132122
params: Model parameters (e.g., max_output_tokens, temperature, etc.).
133123
For a complete list of supported parameters, see
134124
https://platform.openai.com/docs/api-reference/responses/create.
125+
stateful: Whether to enable server-side conversation state management.
126+
When True, the server stores conversation history and the client does not need to
127+
send the full message history with each request. Defaults to False.
135128
"""
136129

137130
model_id: str
138131
params: dict[str, Any] | None
132+
stateful: bool
139133

140134
def __init__(
141135
self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig]
@@ -153,6 +147,15 @@ def __init__(
153147

154148
logger.debug("config=<%s> | initializing", self.config)
155149

150+
@property
151+
@override
152+
def stateful(self) -> bool:
153+
"""Whether server-side conversation storage is enabled.
154+
155+
Derived from the ``stateful`` configuration option.
156+
"""
157+
return bool(self.config.get("stateful"))
158+
156159
@override
157160
def update_config(self, **model_config: Unpack[OpenAIResponsesConfig]) -> None: # type: ignore[override]
158161
"""Update the OpenAI Responses API model configuration with the provided arguments.
@@ -180,6 +183,7 @@ async def stream(
180183
system_prompt: str | None = None,
181184
*,
182185
tool_choice: ToolChoice | None = None,
186+
model_state: dict[str, Any] | None = None,
183187
**kwargs: Any,
184188
) -> AsyncGenerator[StreamEvent, None]:
185189
"""Stream conversation with the OpenAI Responses API model.
@@ -189,6 +193,7 @@ async def stream(
189193
tool_specs: List of tool specifications to make available to the model.
190194
system_prompt: System prompt to provide context to the model.
191195
tool_choice: Selection strategy for tool invocation.
196+
model_state: Runtime state for model providers (e.g., server-side response ids).
192197
**kwargs: Additional keyword arguments for future extensibility.
193198
194199
Yields:
@@ -199,7 +204,7 @@ async def stream(
199204
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
200205
"""
201206
logger.debug("formatting request for OpenAI Responses API")
202-
request = self._format_request(messages, tool_specs, system_prompt, tool_choice)
207+
request = self._format_request(messages, tool_specs, system_prompt, tool_choice, model_state)
203208
logger.debug("formatted request=<%s>", request)
204209

205210
logger.debug("invoking OpenAI Responses API model")
@@ -219,7 +224,14 @@ async def stream(
219224

220225
async for event in response:
221226
if hasattr(event, "type"):
222-
if event.type == "response.reasoning_text.delta":
227+
if event.type == "response.created":
228+
# Capture response id for server-side conversation chaining
229+
if hasattr(event, "response"):
230+
response_id = getattr(event.response, "id", None)
231+
if model_state is not None and response_id:
232+
model_state["response_id"] = response_id
233+
234+
elif event.type == "response.reasoning_text.delta":
223235
# Reasoning content streaming (for o1/o3 reasoning models)
224236
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
225237
for chunk in chunks:
@@ -383,6 +395,7 @@ def _format_request(
383395
tool_specs: list[ToolSpec] | None = None,
384396
system_prompt: str | None = None,
385397
tool_choice: ToolChoice | None = None,
398+
model_state: dict[str, Any] | None = None,
386399
) -> dict[str, Any]:
387400
"""Format an OpenAI Responses API compatible response streaming request.
388401
@@ -391,6 +404,7 @@ def _format_request(
391404
tool_specs: List of tool specifications to make available to the model.
392405
system_prompt: System prompt to provide context to the model.
393406
tool_choice: Selection strategy for tool invocation.
407+
model_state: Runtime state for model providers (e.g., server-side response ids).
394408
395409
Returns:
396410
An OpenAI Responses API compatible response streaming request.
@@ -400,13 +414,18 @@ def _format_request(
400414
format.
401415
"""
402416
input_items = self._format_request_messages(messages)
403-
request = {
417+
request: dict[str, Any] = {
404418
"model": self.config["model_id"],
405419
"input": input_items,
406420
"stream": True,
407421
**cast(dict[str, Any], self.config.get("params", {})),
422+
"store": self.stateful,
408423
}
409424

425+
response_id = model_state.get("response_id") if model_state else None
426+
if response_id and self.stateful:
427+
request["previous_response_id"] = response_id
428+
410429
if system_prompt:
411430
request["instructions"] = system_prompt
412431

src/strands/multiagent/graph.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ class GraphNode:
170170
execution_time: int = 0
171171
_initial_messages: Messages = field(default_factory=list, init=False)
172172
_initial_state: AgentState = field(default_factory=AgentState, init=False)
173+
_initial_model_state: dict[str, Any] = field(default_factory=dict, init=False)
173174

174175
def __post_init__(self) -> None:
175176
"""Capture initial executor state after initialization."""
@@ -180,6 +181,9 @@ def __post_init__(self) -> None:
180181
if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"):
181182
self._initial_state = AgentState(self.executor.state.get())
182183

184+
if hasattr(self.executor, "_model_state"):
185+
self._initial_model_state = copy.deepcopy(self.executor._model_state)
186+
183187
def reset_executor_state(self) -> None:
184188
"""Reset GraphNode executor state to initial state when graph was created.
185189
@@ -192,6 +196,9 @@ def reset_executor_state(self) -> None:
192196
if hasattr(self.executor, "state"):
193197
self.executor.state = AgentState(self._initial_state.get())
194198

199+
if hasattr(self.executor, "_model_state"):
200+
self.executor._model_state = copy.deepcopy(self._initial_model_state)
201+
195202
# Reset execution status
196203
self.execution_status = Status.PENDING
197204
self.result = None
@@ -639,6 +646,7 @@ def _activate_interrupt(
639646
"interrupt_state": node.executor._interrupt_state.to_dict(),
640647
"state": node.executor.state.get(),
641648
"messages": node.executor.messages,
649+
"model_state": node.executor._model_state,
642650
}
643651
)
644652

@@ -1074,6 +1082,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
10741082
node.executor.messages = node_context["messages"]
10751083
node.executor.state = AgentState(node_context["state"])
10761084
node.executor._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"])
1085+
node.executor._model_state = node_context.get("model_state", {})
10771086

10781087
return node_responses
10791088

src/strands/multiagent/swarm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,14 @@ class SwarmNode:
6969
swarm: Optional["Swarm"] = None
7070
_initial_messages: Messages = field(default_factory=list, init=False)
7171
_initial_state: AgentState = field(default_factory=AgentState, init=False)
72+
_initial_model_state: dict[str, Any] = field(default_factory=dict, init=False)
7273

7374
def __post_init__(self) -> None:
7475
"""Capture initial executor state after initialization."""
7576
# Deep copy the initial messages and state to preserve them
7677
self._initial_messages = copy.deepcopy(self.executor.messages)
7778
self._initial_state = AgentState(self.executor.state.get())
79+
self._initial_model_state = copy.deepcopy(self.executor._model_state)
7880

7981
def __hash__(self) -> int:
8082
"""Return hash for SwarmNode based on node_id."""
@@ -104,10 +106,12 @@ def reset_executor_state(self) -> None:
104106
self.executor.messages = context["messages"]
105107
self.executor.state = AgentState(context["state"])
106108
self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"])
109+
self.executor._model_state = context.get("model_state", {})
107110
return
108111

109112
self.executor.messages = copy.deepcopy(self._initial_messages)
110113
self.executor.state = AgentState(self._initial_state.get())
114+
self.executor._model_state = copy.deepcopy(self._initial_model_state)
111115

112116

113117
@dataclass
@@ -697,6 +701,7 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M
697701
"interrupt_state": node.executor._interrupt_state.to_dict(),
698702
"state": node.executor.state.get(),
699703
"messages": node.executor.messages,
704+
"model_state": node.executor._model_state,
700705
}
701706

702707
self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})

0 commit comments

Comments
 (0)