Skip to content

Commit 4b2881c

Browse files
authored
feat: expose Responses WebSocket keepalive options (#3080)
1 parent a47b7ea commit 4b2881c

7 files changed

Lines changed: 138 additions & 13 deletions

File tree

src/agents/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@
8383
from .models.openai_agent_registration import OpenAIAgentRegistrationConfig
8484
from .models.openai_chatcompletions import OpenAIChatCompletionsModel
8585
from .models.openai_provider import OpenAIProvider
86-
from .models.openai_responses import OpenAIResponsesModel, OpenAIResponsesWSModel
86+
from .models.openai_responses import (
87+
OpenAIResponsesModel,
88+
OpenAIResponsesWebSocketOptions,
89+
OpenAIResponsesWSModel,
90+
)
8791
from .prompts import DynamicPromptFunction, GenerateDynamicPromptData, Prompt
8892
from .repl import run_demo_loop
8993
from .responses_websocket_session import ResponsesWebSocketSession, responses_websocket_session
@@ -527,6 +531,7 @@ def enable_verbose_stdout_logging():
527531
"set_default_openai_client",
528532
"set_default_openai_api",
529533
"set_default_openai_responses_transport",
534+
"OpenAIResponsesWebSocketOptions",
530535
"set_default_openai_harness",
531536
"set_default_openai_agent_registration",
532537
"responses_websocket_session",

src/agents/models/multi_provider.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .interface import Model, ModelProvider
99
from .openai_agent_registration import OpenAIAgentRegistrationConfig
1010
from .openai_provider import OpenAIProvider
11+
from .openai_responses import OpenAIResponsesWebSocketOptions
1112

1213
MultiProviderOpenAIPrefixMode = Literal["alias", "model_id"]
1314
MultiProviderUnknownPrefixMode = Literal["error", "model_id"]
@@ -86,6 +87,7 @@ def __init__(
8687
openai_prefix_mode: MultiProviderOpenAIPrefixMode = "alias",
8788
unknown_prefix_mode: MultiProviderUnknownPrefixMode = "error",
8889
openai_agent_registration: OpenAIAgentRegistrationConfig | None = None,
90+
openai_responses_websocket_options: OpenAIResponsesWebSocketOptions | None = None,
8991
) -> None:
9092
"""Create a new OpenAI provider.
9193
@@ -117,6 +119,8 @@ def __init__(
117119
such as ``openrouter/openai/gpt-4o``.
118120
openai_agent_registration: Optional agent registration configuration for the OpenAI
119121
provider.
122+
openai_responses_websocket_options: Optional low-level websocket keepalive options for
123+
the OpenAI Responses websocket transport.
120124
"""
121125
self.provider_map = provider_map
122126
self.openai_provider = OpenAIProvider(
@@ -129,6 +133,7 @@ def __init__(
129133
use_responses=openai_use_responses,
130134
use_responses_websocket=openai_use_responses_websocket,
131135
agent_registration=openai_agent_registration,
136+
responses_websocket_options=openai_responses_websocket_options,
132137
)
133138
self._openai_prefix_mode = self._validate_openai_prefix_mode(openai_prefix_mode)
134139
self._unknown_prefix_mode = self._validate_unknown_prefix_mode(unknown_prefix_mode)

src/agents/models/openai_provider.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
resolve_openai_agent_registration_config,
1717
)
1818
from .openai_chatcompletions import OpenAIChatCompletionsModel
19-
from .openai_responses import OpenAIResponsesModel, OpenAIResponsesWSModel
19+
from .openai_responses import (
20+
OpenAIResponsesModel,
21+
OpenAIResponsesWebSocketOptions,
22+
OpenAIResponsesWSModel,
23+
)
2024

2125
# This is kept for backward compatibility but using get_default_model() method is recommended.
2226
DEFAULT_MODEL: str = "gpt-4o"
@@ -49,6 +53,7 @@ def __init__(
4953
use_responses: bool | None = None,
5054
use_responses_websocket: bool | None = None,
5155
agent_registration: OpenAIAgentRegistrationConfig | None = None,
56+
responses_websocket_options: OpenAIResponsesWebSocketOptions | None = None,
5257
) -> None:
5358
"""Create a new OpenAI provider.
5459
@@ -67,6 +72,8 @@ def __init__(
6772
use_responses_websocket: Whether to use websocket transport for the OpenAI responses
6873
API.
6974
agent_registration: Optional agent registration configuration.
75+
responses_websocket_options: Optional low-level websocket keepalive options for the
76+
OpenAI Responses websocket transport.
7077
"""
7178
if openai_client is not None:
7279
assert api_key is None and base_url is None and websocket_base_url is None, (
@@ -95,6 +102,7 @@ def __init__(
95102
self._responses_transport = _openai_shared.get_default_openai_responses_transport()
96103
# Backward-compatibility shim for internal tests/diagnostics that inspect the legacy flag.
97104
self._use_responses_websocket = self._responses_transport == "websocket"
105+
self._responses_websocket_options = responses_websocket_options
98106

99107
# Reuse websocket model wrappers so websocket transport can keep a persistent connection
100108
# when callers pass model names as strings through a shared provider.
@@ -214,17 +222,22 @@ def get_model(self, model_name: str | None) -> Model:
214222
if not self._use_responses:
215223
return OpenAIChatCompletionsModel(model=resolved_model_name, openai_client=client)
216224

217-
responses_model_type = (
218-
OpenAIResponsesWSModel if use_websocket_transport else OpenAIResponsesModel
219-
)
220-
model = responses_model_type(
225+
if use_websocket_transport:
226+
model = OpenAIResponsesWSModel(
227+
model=resolved_model_name,
228+
openai_client=client,
229+
model_is_explicit=model_is_explicit,
230+
websocket_options=self._responses_websocket_options,
231+
)
232+
if loop_cache is not None:
233+
loop_cache[cache_key] = model
234+
return model
235+
236+
model = OpenAIResponsesModel(
221237
model=resolved_model_name,
222238
openai_client=client,
223239
model_is_explicit=model_is_explicit,
224240
)
225-
if use_websocket_transport:
226-
if loop_cache is not None:
227-
loop_cache[cache_key] = model
228241
return model
229242

230243
async def aclose(self) -> None:

src/agents/models/openai_responses.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from openai.types.responses.response_prompt_param import ResponsePromptParam
3232
from openai.types.responses.tool_param import LocalShell
33+
from typing_extensions import NotRequired
3334

3435
from .. import _debug
3536
from .._tool_identity import (
@@ -191,6 +192,24 @@ class _WebsocketRequestTimeouts:
191192
recv: float | None
192193

193194

195+
class OpenAIResponsesWebSocketOptions(TypedDict):
196+
"""Low-level OpenAI Responses websocket connection options."""
197+
198+
ping_interval: NotRequired[float | None]
199+
"""Time in seconds between keepalive pings sent by the client.
200+
201+
The underlying ``websockets`` library usually defaults to 20.0. Set to ``None`` to
202+
disable keepalive pings.
203+
"""
204+
205+
ping_timeout: NotRequired[float | None]
206+
"""Time in seconds to wait for a pong response before disconnecting.
207+
208+
Set to ``None`` to keep pings enabled but disable heartbeat timeouts during large latency
209+
spikes.
210+
"""
211+
212+
194213
class _ResponseStreamWithRequestId:
195214
"""Wrap an SDK event stream and retain the originating request ID."""
196215

@@ -911,10 +930,14 @@ def __init__(
911930
openai_client: AsyncOpenAI,
912931
*,
913932
model_is_explicit: bool = True,
933+
websocket_options: OpenAIResponsesWebSocketOptions | None = None,
914934
) -> None:
915935
super().__init__(
916936
model=model, openai_client=openai_client, model_is_explicit=model_is_explicit
917937
)
938+
self._websocket_options = cast(
939+
OpenAIResponsesWebSocketOptions, dict(websocket_options or {})
940+
)
918941
self._ws_connection: Any | None = None
919942
self._ws_connection_identity: tuple[str, tuple[tuple[str, str], ...]] | None = None
920943
self._ws_connection_loop_ref: weakref.ReferenceType[asyncio.AbstractEventLoop] | None = None
@@ -1531,12 +1554,20 @@ async def _open_websocket_connection(
15311554
"Install `websockets` or `openai[realtime]`."
15321555
) from exc
15331556

1557+
connect_kwargs: dict[str, Any] = {
1558+
"user_agent_header": None,
1559+
"additional_headers": dict(headers),
1560+
"max_size": None,
1561+
"open_timeout": connect_timeout,
1562+
}
1563+
if "ping_interval" in self._websocket_options:
1564+
connect_kwargs["ping_interval"] = self._websocket_options["ping_interval"]
1565+
if "ping_timeout" in self._websocket_options:
1566+
connect_kwargs["ping_timeout"] = self._websocket_options["ping_timeout"]
1567+
15341568
return await connect(
15351569
ws_url,
1536-
user_agent_header=None,
1537-
additional_headers=dict(headers),
1538-
max_size=None,
1539-
open_timeout=connect_timeout,
1570+
**connect_kwargs,
15401571
)
15411572

15421573

src/agents/responses_websocket_session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
MultiProviderUnknownPrefixMode,
1414
)
1515
from .models.openai_provider import OpenAIProvider
16+
from .models.openai_responses import OpenAIResponsesWebSocketOptions
1617
from .result import RunResult, RunResultStreaming
1718
from .run import Runner
1819
from .run_config import RunConfig
@@ -86,6 +87,7 @@ async def responses_websocket_session(
8687
project: str | None = None,
8788
openai_prefix_mode: MultiProviderOpenAIPrefixMode = "alias",
8889
unknown_prefix_mode: MultiProviderUnknownPrefixMode = "error",
90+
responses_websocket_options: OpenAIResponsesWebSocketOptions | None = None,
8991
) -> AsyncIterator[ResponsesWebSocketSession]:
9092
"""Create a shared OpenAI Responses websocket session for multiple Runner calls.
9193
@@ -99,6 +101,9 @@ async def responses_websocket_session(
99101
configured OpenAI-compatible endpoint expects literal namespaced model IDs instead of the SDK's
100102
historical routing-prefix behavior.
101103
104+
Pass ``responses_websocket_options`` to customize low-level websocket keepalive behavior such
105+
as ``ping_interval`` and ``ping_timeout``.
106+
102107
Drain or close streamed iterators before the context exits. Exiting the context while a
103108
websocket request is still in flight may force-close the shared connection.
104109
"""
@@ -112,6 +117,7 @@ async def responses_websocket_session(
112117
openai_use_responses_websocket=True,
113118
openai_prefix_mode=openai_prefix_mode,
114119
unknown_prefix_mode=unknown_prefix_mode,
120+
openai_responses_websocket_options=responses_websocket_options,
115121
)
116122
provider = model_provider.openai_provider
117123
session = ResponsesWebSocketSession(

tests/test_config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from agents import (
10+
responses_websocket_session,
1011
set_default_openai_api,
1112
set_default_openai_client,
1213
set_default_openai_key,
@@ -170,6 +171,36 @@ async def test_openai_provider_reuses_websocket_model_instance_for_same_model_na
170171
assert model1 is model2
171172

172173

174+
@pytest.mark.asyncio
175+
async def test_openai_provider_passes_responses_websocket_options_to_model():
176+
class DummyAsyncOpenAI:
177+
pass
178+
179+
provider = OpenAIProvider(
180+
use_responses=True,
181+
use_responses_websocket=True,
182+
openai_client=DummyAsyncOpenAI(), # type: ignore[arg-type]
183+
responses_websocket_options={"ping_interval": 30.0, "ping_timeout": None},
184+
)
185+
186+
model = provider.get_model("gpt-4")
187+
188+
assert isinstance(model, OpenAIResponsesWSModel)
189+
assert model._websocket_options == {"ping_interval": 30.0, "ping_timeout": None}
190+
191+
192+
@pytest.mark.asyncio
193+
async def test_responses_websocket_session_passes_keepalive_options_to_provider():
194+
async with responses_websocket_session(
195+
api_key="test-key",
196+
responses_websocket_options={"ping_interval": None, "ping_timeout": None},
197+
) as session:
198+
assert session.provider._responses_websocket_options == {
199+
"ping_interval": None,
200+
"ping_timeout": None,
201+
}
202+
203+
173204
def test_openai_provider_does_not_reuse_non_websocket_model_instances():
174205
provider = OpenAIProvider(use_responses=True, use_responses_websocket=False)
175206

tests/test_openai_responses.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,40 @@ async def fake_open(
15801580
assert ws.sent_messages[1]["previous_response_id"] == "resp-1"
15811581

15821582

1583+
@pytest.mark.asyncio
1584+
async def test_websocket_model_passes_keepalive_options_to_connect(monkeypatch):
1585+
import websockets.asyncio.client as websockets_client
1586+
1587+
client = DummyWSClient()
1588+
model = OpenAIResponsesWSModel(
1589+
model="gpt-4",
1590+
openai_client=client, # type: ignore[arg-type]
1591+
websocket_options={"ping_interval": 45.0, "ping_timeout": None},
1592+
)
1593+
ws = DummyWSConnection([])
1594+
captured_kwargs: dict[str, Any] = {}
1595+
1596+
async def fake_connect(ws_url: str, **kwargs: Any) -> DummyWSConnection:
1597+
captured_kwargs["ws_url"] = ws_url
1598+
captured_kwargs.update(kwargs)
1599+
return ws
1600+
1601+
monkeypatch.setattr(websockets_client, "connect", fake_connect)
1602+
1603+
opened = await model._open_websocket_connection(
1604+
"wss://example.test/v1/responses",
1605+
{"Authorization": "Bearer test-key"},
1606+
connect_timeout=10.0,
1607+
)
1608+
1609+
assert opened is ws
1610+
assert captured_kwargs["ws_url"] == "wss://example.test/v1/responses"
1611+
assert captured_kwargs["additional_headers"] == {"Authorization": "Bearer test-key"}
1612+
assert captured_kwargs["open_timeout"] == 10.0
1613+
assert captured_kwargs["ping_interval"] == 45.0
1614+
assert captured_kwargs["ping_timeout"] is None
1615+
1616+
15831617
@pytest.mark.allow_call_model_methods
15841618
def test_websocket_model_reconnects_when_reused_from_different_event_loop(monkeypatch):
15851619
client = DummyWSClient()

0 commit comments

Comments
 (0)