Skip to content

Commit 099990e

Browse files
committed
fix review comments
1 parent 80ad0d1 commit 099990e

2 files changed

Lines changed: 200 additions & 76 deletions

File tree

src/agents/realtime/openai_realtime.py

Lines changed: 158 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88
import os
99
from collections.abc import Mapping
10+
from dataclasses import dataclass
1011
from datetime import datetime
1112
from typing import Annotated, Any, Callable, Literal, Union, cast
1213

@@ -192,6 +193,124 @@ async def get_api_key(key: str | Callable[[], MaybeAwaitable[str]] | None) -> st
192193
ServerEventTypeAdapter: TypeAdapter[AllRealtimeServerEvents] | None = None
193194

194195

196+
@dataclass(frozen=True)
197+
class _PendingResponseCreate:
198+
event_id: str
199+
target_version: int
200+
201+
202+
class _ResponseCreateSequencer:
203+
"""Tracks local response sequencing around response.create and response.cancel."""
204+
205+
def __init__(self) -> None:
206+
self._ongoing_response = False
207+
self._response_control: Literal["free", "create_requested", "cancel_requested"] = "free"
208+
self._response_create_request_version = 0
209+
self._response_create_sent_version = 0
210+
self._response_create_event_counter = 0
211+
self._pending_response_create_event_id: str | None = None
212+
self._condition = asyncio.Condition()
213+
214+
@property
215+
def ongoing_response(self) -> bool:
216+
return self._ongoing_response
217+
218+
@property
219+
def response_control(self) -> Literal["free", "create_requested", "cancel_requested"]:
220+
return self._response_control
221+
222+
@property
223+
def pending_response_create_event_id(self) -> str | None:
224+
return self._pending_response_create_event_id
225+
226+
def set_ongoing_response_for_test(self, value: bool) -> None:
227+
self._ongoing_response = value
228+
229+
async def set_response_control(
230+
self, control: Literal["free", "create_requested", "cancel_requested"]
231+
) -> None:
232+
async with self._condition:
233+
self._response_control = control
234+
self._condition.notify_all()
235+
236+
async def mark_response_created(self) -> None:
237+
async with self._condition:
238+
self._ongoing_response = True
239+
self._pending_response_create_event_id = None
240+
self._response_control = "free"
241+
self._condition.notify_all()
242+
243+
async def mark_response_done(self) -> None:
244+
async with self._condition:
245+
self._ongoing_response = False
246+
self._pending_response_create_event_id = None
247+
self._response_control = "free"
248+
self._condition.notify_all()
249+
250+
async def release_waiters(self) -> None:
251+
async with self._condition:
252+
self._ongoing_response = False
253+
self._pending_response_create_event_id = None
254+
self._response_control = "free"
255+
self._condition.notify_all()
256+
257+
async def reserve_response_create_request(self) -> int:
258+
async with self._condition:
259+
self._response_create_request_version += 1
260+
request_version = self._response_create_request_version
261+
self._condition.notify_all()
262+
return request_version
263+
264+
async def clear_pending_response_create(self, event_id: str | None = None) -> bool:
265+
async with self._condition:
266+
if self._response_control != "create_requested":
267+
return False
268+
if event_id is not None and self._pending_response_create_event_id != event_id:
269+
return False
270+
# Some realtime error payloads omit nested error.event_id. When that
271+
# happens, fail open so a rejected response.create does not wedge
272+
# follow-up turn sequencing forever.
273+
self._pending_response_create_event_id = None
274+
self._response_control = "free"
275+
self._condition.notify_all()
276+
return True
277+
278+
async def wait_for_response_create_slot(
279+
self, request_version: int
280+
) -> _PendingResponseCreate | None:
281+
while True:
282+
async with self._condition:
283+
await self._condition.wait_for(
284+
lambda: self._response_create_sent_version >= request_version
285+
or (not self._ongoing_response and self._response_control == "free")
286+
)
287+
if self._response_create_sent_version >= request_version:
288+
return None
289+
290+
self._response_control = "create_requested"
291+
self._response_create_event_counter += 1
292+
event_id = f"agents_py_response_create_{self._response_create_event_counter}"
293+
self._pending_response_create_event_id = event_id
294+
return _PendingResponseCreate(
295+
event_id=event_id,
296+
target_version=self._response_create_request_version,
297+
)
298+
299+
async def mark_response_create_sent(self, pending: _PendingResponseCreate) -> None:
300+
async with self._condition:
301+
self._response_create_sent_version = max(
302+
self._response_create_sent_version, pending.target_version
303+
)
304+
self._condition.notify_all()
305+
306+
async def begin_cancel_response(self) -> bool:
307+
async with self._condition:
308+
if not self._ongoing_response or self._response_control == "cancel_requested":
309+
return False
310+
self._response_control = "cancel_requested"
311+
return True
312+
313+
195314
def get_server_event_type_adapter() -> TypeAdapter[AllRealtimeServerEvents]:
196315
global ServerEventTypeAdapter
197316
if not ServerEventTypeAdapter:
@@ -278,22 +397,30 @@ def __init__(self, *, transport_config: TransportConfig | None = None) -> None:
278397
self._listeners: list[RealtimeModelListener] = []
279398
self._current_item_id: str | None = None
280399
self._audio_state_tracker: ModelAudioTracker = ModelAudioTracker()
281-
self._ongoing_response: bool = False
282-
# Keep local response control in one place so create/cancel sequencing
283-
# stays readable without multiple overlapping boolean flags.
284-
self._response_control: Literal["free", "create_requested", "cancel_requested"] = "free"
285-
self._response_create_request_version = 0
286-
self._response_create_sent_version = 0
287-
self._response_create_event_counter = 0
288-
self._pending_response_create_event_id: str | None = None
289-
self._response_control_condition = asyncio.Condition()
400+
self._response_create_sequencer = _ResponseCreateSequencer()
290401
self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
291402
self._playback_tracker: RealtimePlaybackTracker | None = None
292403
self._created_session: OpenAISessionCreateRequest | None = None
293404
self._server_event_type_adapter = get_server_event_type_adapter()
294405
self._call_id: str | None = None
295406
self._transport_config: TransportConfig | None = transport_config
296407

408+
@property
409+
def _ongoing_response(self) -> bool:
410+
return self._response_create_sequencer.ongoing_response
411+
412+
@_ongoing_response.setter
413+
def _ongoing_response(self, value: bool) -> None:
414+
self._response_create_sequencer.set_ongoing_response_for_test(value)
415+
416+
@property
417+
def _response_control(self) -> Literal["free", "create_requested", "cancel_requested"]:
418+
return self._response_create_sequencer.response_control
419+
420+
@property
421+
def _pending_response_create_event_id(self) -> str | None:
422+
return self._response_create_sequencer.pending_response_create_event_id
423+
297424
async def connect(self, options: RealtimeModelConfig) -> None:
298425
"""Establish a connection to the model and keep it alive."""
299426
assert self._websocket is None, "Already connected"
@@ -483,80 +610,41 @@ async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None:
483610
async def _set_response_control(
484611
self, control: Literal["free", "create_requested", "cancel_requested"]
485612
) -> None:
486-
async with self._response_control_condition:
487-
self._response_control = control
488-
self._response_control_condition.notify_all()
613+
await self._response_create_sequencer.set_response_control(control)
489614

490615
async def _mark_response_created(self) -> None:
491-
async with self._response_control_condition:
492-
self._ongoing_response = True
493-
self._pending_response_create_event_id = None
494-
self._response_control = "free"
495-
self._response_control_condition.notify_all()
616+
await self._response_create_sequencer.mark_response_created()
496617

497618
async def _mark_response_done(self) -> None:
498-
async with self._response_control_condition:
499-
self._ongoing_response = False
500-
self._pending_response_create_event_id = None
501-
self._response_control = "free"
502-
self._response_control_condition.notify_all()
619+
await self._response_create_sequencer.mark_response_done()
503620

504621
async def _release_response_waiters(self) -> None:
505622
# Connection teardown means no response.done will arrive, so local
506623
# response sequencing must be released explicitly.
507-
async with self._response_control_condition:
508-
self._ongoing_response = False
509-
self._pending_response_create_event_id = None
510-
self._response_control = "free"
511-
self._response_control_condition.notify_all()
624+
await self._response_create_sequencer.release_waiters()
512625

513626
async def _reserve_response_create_request(self) -> int:
514-
async with self._response_control_condition:
515-
self._response_create_request_version += 1
516-
request_version = self._response_create_request_version
517-
self._response_control_condition.notify_all()
518-
return request_version
627+
return await self._response_create_sequencer.reserve_response_create_request()
519628

520629
async def _clear_pending_response_create(self, event_id: str | None = None) -> bool:
521-
async with self._response_control_condition:
522-
if event_id is not None and self._pending_response_create_event_id != event_id:
523-
return False
524-
self._pending_response_create_event_id = None
525-
if self._response_control == "create_requested":
526-
self._response_control = "free"
527-
self._response_control_condition.notify_all()
528-
return True
630+
return await self._response_create_sequencer.clear_pending_response_create(event_id)
529631

530632
async def _send_response_create_when_idle(self, request_version: int) -> None:
531-
while True:
532-
async with self._response_control_condition:
533-
await self._response_control_condition.wait_for(
534-
lambda: self._response_create_sent_version >= request_version
535-
or (not self._ongoing_response and self._response_control == "free")
536-
)
537-
if self._response_create_sent_version >= request_version:
538-
return
539-
540-
target_version = self._response_create_request_version
541-
self._response_control = "create_requested"
542-
self._response_create_event_counter += 1
543-
event_id = f"agents_py_response_create_{self._response_create_event_counter}"
544-
self._pending_response_create_event_id = event_id
633+
pending = await self._response_create_sequencer.wait_for_response_create_slot(
634+
request_version
635+
)
636+
if pending is None:
637+
return
545638

546-
try:
547-
await self._send_raw_message(
548-
OpenAIResponseCreateEvent(type="response.create", event_id=event_id)
549-
)
550-
except BaseException:
551-
await self._clear_pending_response_create(event_id)
552-
raise
639+
try:
640+
await self._send_raw_message(
641+
OpenAIResponseCreateEvent(type="response.create", event_id=pending.event_id)
642+
)
643+
except BaseException:
644+
await self._clear_pending_response_create(pending.event_id)
645+
raise
553646

554-
async with self._response_control_condition:
555-
self._response_create_sent_version = max(
556-
self._response_create_sent_version, target_version
557-
)
558-
self._response_control_condition.notify_all()
559-
return
647+
await self._response_create_sequencer.mark_response_create_sent(pending)
560648

561649
def _is_running_in_websocket_listener_task(self) -> bool:
562650
current_task = asyncio.current_task()
@@ -798,10 +886,8 @@ async def close(self) -> None:
798886
await self._release_response_waiters()
799887

800888
async def _cancel_response(self) -> None:
801-
async with self._response_control_condition:
802-
if not self._ongoing_response or self._response_control == "cancel_requested":
803-
return
804-
self._response_control = "cancel_requested"
889+
if not await self._response_create_sequencer.begin_cancel_response():
890+
return
805891

806892
try:
807893
await self._send_raw_message(OpenAIResponseCancelEvent(type="response.cancel"))
@@ -957,11 +1043,7 @@ async def _handle_ws_event(self, event: dict[str, Any]):
9571043
elif parsed.type == "session.updated":
9581044
self._update_created_session(parsed.session)
9591045
elif parsed.type == "error":
960-
if (
961-
not self._ongoing_response
962-
and self._response_control == "create_requested"
963-
and parsed.error.event_id is not None
964-
):
1046+
if not self._ongoing_response and self._response_control == "create_requested":
9651047
await self._clear_pending_response_create(parsed.error.event_id)
9661048
await self._emit_event(RealtimeModelErrorEvent(error=parsed.error))
9671049
elif parsed.type == "conversation.item.deleted":

tests/realtime/test_openai_realtime.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,48 @@ async def fake_send_raw(event):
10531053
"response.create",
10541054
]
10551055

1056+
@pytest.mark.asyncio
1057+
async def test_missing_error_event_id_releases_in_flight_response_create(
1058+
self, model, monkeypatch
1059+
):
1060+
"""Missing nested error.event_id should not wedge a rejected response.create."""
1061+
payload_types: list[str] = []
1062+
1063+
async def fake_send_raw(event):
1064+
payload_types.append(event.type)
1065+
1066+
monkeypatch.setattr(model, "_send_raw_message", fake_send_raw)
1067+
monkeypatch.setattr(model, "_emit_event", AsyncMock())
1068+
1069+
await model._send_user_input(RealtimeModelSendUserInput(user_input="first"))
1070+
1071+
assert model._pending_response_create_event_id is not None
1072+
assert model._response_control == "create_requested"
1073+
1074+
await model._handle_ws_event(
1075+
{
1076+
"type": "error",
1077+
"event_id": "event_err_missing_nested",
1078+
"error": {
1079+
"type": "invalid_request_error",
1080+
"code": "bad_response_create",
1081+
"message": "bad response.create",
1082+
},
1083+
}
1084+
)
1085+
1086+
assert model._pending_response_create_event_id is None
1087+
assert model._response_control == "free"
1088+
1089+
await model._send_user_input(RealtimeModelSendUserInput(user_input="second"))
1090+
1091+
assert payload_types == [
1092+
"conversation.item.create",
1093+
"response.create",
1094+
"conversation.item.create",
1095+
"response.create",
1096+
]
1097+
10561098
@pytest.mark.asyncio
10571099
async def test_release_response_waiters_clears_active_response_state(self, model):
10581100
"""Releasing waiters should also clear local active-response bookkeeping."""

0 commit comments

Comments
 (0)