|
7 | 7 | import math |
8 | 8 | import os |
9 | 9 | from collections.abc import Mapping |
| 10 | +from dataclasses import dataclass |
10 | 11 | from datetime import datetime |
11 | 12 | from typing import Annotated, Any, Callable, Literal, Union, cast |
12 | 13 |
|
@@ -192,6 +193,124 @@ async def get_api_key(key: str | Callable[[], MaybeAwaitable[str]] | None) -> st |
192 | 193 | ServerEventTypeAdapter: TypeAdapter[AllRealtimeServerEvents] | None = None |
193 | 194 |
|
194 | 195 |
|
| 196 | +@dataclass(frozen=True) |
| 197 | +class _PendingResponseCreate: |
| 198 | + event_id: str |
| 199 | + target_version: int |
| 200 | + |
| 201 | + |
| 202 | +class _ResponseCreateSequencer: |
| 203 | + """Tracks local response sequencing around response.create and response.cancel.""" |
| 204 | + |
| 205 | + def __init__(self) -> None: |
| 206 | + self._ongoing_response = False |
| 207 | + self._response_control: Literal["free", "create_requested", "cancel_requested"] = "free" |
| 208 | + self._response_create_request_version = 0 |
| 209 | + self._response_create_sent_version = 0 |
| 210 | + self._response_create_event_counter = 0 |
| 211 | + self._pending_response_create_event_id: str | None = None |
| 212 | + self._condition = asyncio.Condition() |
| 213 | + |
| 214 | + @property |
| 215 | + def ongoing_response(self) -> bool: |
| 216 | + return self._ongoing_response |
| 217 | + |
| 218 | + @property |
| 219 | + def response_control(self) -> Literal["free", "create_requested", "cancel_requested"]: |
| 220 | + return self._response_control |
| 221 | + |
| 222 | + @property |
| 223 | + def pending_response_create_event_id(self) -> str | None: |
| 224 | + return self._pending_response_create_event_id |
| 225 | + |
| 226 | + def set_ongoing_response_for_test(self, value: bool) -> None: |
| 227 | + self._ongoing_response = value |
| 228 | + |
| 229 | + async def set_response_control( |
| 230 | + self, control: Literal["free", "create_requested", "cancel_requested"] |
| 231 | + ) -> None: |
| 232 | + async with self._condition: |
| 233 | + self._response_control = control |
| 234 | + self._condition.notify_all() |
| 235 | + |
| 236 | + async def mark_response_created(self) -> None: |
| 237 | + async with self._condition: |
| 238 | + self._ongoing_response = True |
| 239 | + self._pending_response_create_event_id = None |
| 240 | + self._response_control = "free" |
| 241 | + self._condition.notify_all() |
| 242 | + |
| 243 | + async def mark_response_done(self) -> None: |
| 244 | + async with self._condition: |
| 245 | + self._ongoing_response = False |
| 246 | + self._pending_response_create_event_id = None |
| 247 | + self._response_control = "free" |
| 248 | + self._condition.notify_all() |
| 249 | + |
| 250 | + async def release_waiters(self) -> None: |
| 251 | + async with self._condition: |
| 252 | + self._ongoing_response = False |
| 253 | + self._pending_response_create_event_id = None |
| 254 | + self._response_control = "free" |
| 255 | + self._condition.notify_all() |
| 256 | + |
| 257 | + async def reserve_response_create_request(self) -> int: |
| 258 | + async with self._condition: |
| 259 | + self._response_create_request_version += 1 |
| 260 | + request_version = self._response_create_request_version |
| 261 | + self._condition.notify_all() |
| 262 | + return request_version |
| 263 | + |
| 264 | + async def clear_pending_response_create(self, event_id: str | None = None) -> bool: |
| 265 | + async with self._condition: |
| 266 | + if self._response_control != "create_requested": |
| 267 | + return False |
| 268 | + if event_id is not None and self._pending_response_create_event_id != event_id: |
| 269 | + return False |
| 270 | + # Some realtime error payloads omit nested error.event_id. When that |
| 271 | + # happens, fail open so a rejected response.create does not wedge |
| 272 | + # follow-up turn sequencing forever. |
| 273 | + self._pending_response_create_event_id = None |
| 274 | + self._response_control = "free" |
| 275 | + self._condition.notify_all() |
| 276 | + return True |
| 277 | + |
| 278 | + async def wait_for_response_create_slot( |
| 279 | + self, request_version: int |
| 280 | + ) -> _PendingResponseCreate | None: |
| 281 | + while True: |
| 282 | + async with self._condition: |
| 283 | + await self._condition.wait_for( |
| 284 | + lambda: self._response_create_sent_version >= request_version |
| 285 | + or (not self._ongoing_response and self._response_control == "free") |
| 286 | + ) |
| 287 | + if self._response_create_sent_version >= request_version: |
| 288 | + return None |
| 289 | + |
| 290 | + self._response_control = "create_requested" |
| 291 | + self._response_create_event_counter += 1 |
| 292 | + event_id = f"agents_py_response_create_{self._response_create_event_counter}" |
| 293 | + self._pending_response_create_event_id = event_id |
| 294 | + return _PendingResponseCreate( |
| 295 | + event_id=event_id, |
| 296 | + target_version=self._response_create_request_version, |
| 297 | + ) |
| 298 | + |
| 299 | + async def mark_response_create_sent(self, pending: _PendingResponseCreate) -> None: |
| 300 | + async with self._condition: |
| 301 | + self._response_create_sent_version = max( |
| 302 | + self._response_create_sent_version, pending.target_version |
| 303 | + ) |
| 304 | + self._condition.notify_all() |
| 305 | + |
| 306 | + async def begin_cancel_response(self) -> bool: |
| 307 | + async with self._condition: |
| 308 | + if not self._ongoing_response or self._response_control == "cancel_requested": |
| 309 | + return False |
| 310 | + self._response_control = "cancel_requested" |
| 311 | + return True |
| 312 | + |
| 313 | + |
195 | 314 | def get_server_event_type_adapter() -> TypeAdapter[AllRealtimeServerEvents]: |
196 | 315 | global ServerEventTypeAdapter |
197 | 316 | if not ServerEventTypeAdapter: |
@@ -278,22 +397,30 @@ def __init__(self, *, transport_config: TransportConfig | None = None) -> None: |
278 | 397 | self._listeners: list[RealtimeModelListener] = [] |
279 | 398 | self._current_item_id: str | None = None |
280 | 399 | self._audio_state_tracker: ModelAudioTracker = ModelAudioTracker() |
281 | | - self._ongoing_response: bool = False |
282 | | - # Keep local response control in one place so create/cancel sequencing |
283 | | - # stays readable without multiple overlapping boolean flags. |
284 | | - self._response_control: Literal["free", "create_requested", "cancel_requested"] = "free" |
285 | | - self._response_create_request_version = 0 |
286 | | - self._response_create_sent_version = 0 |
287 | | - self._response_create_event_counter = 0 |
288 | | - self._pending_response_create_event_id: str | None = None |
289 | | - self._response_control_condition = asyncio.Condition() |
| 400 | + self._response_create_sequencer = _ResponseCreateSequencer() |
290 | 401 | self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None |
291 | 402 | self._playback_tracker: RealtimePlaybackTracker | None = None |
292 | 403 | self._created_session: OpenAISessionCreateRequest | None = None |
293 | 404 | self._server_event_type_adapter = get_server_event_type_adapter() |
294 | 405 | self._call_id: str | None = None |
295 | 406 | self._transport_config: TransportConfig | None = transport_config |
296 | 407 |
|
| 408 | + @property |
| 409 | + def _ongoing_response(self) -> bool: |
| 410 | + return self._response_create_sequencer.ongoing_response |
| 411 | + |
| 412 | + @_ongoing_response.setter |
| 413 | + def _ongoing_response(self, value: bool) -> None: |
| 414 | + self._response_create_sequencer.set_ongoing_response_for_test(value) |
| 415 | + |
| 416 | + @property |
| 417 | + def _response_control(self) -> Literal["free", "create_requested", "cancel_requested"]: |
| 418 | + return self._response_create_sequencer.response_control |
| 419 | + |
| 420 | + @property |
| 421 | + def _pending_response_create_event_id(self) -> str | None: |
| 422 | + return self._response_create_sequencer.pending_response_create_event_id |
| 423 | + |
297 | 424 | async def connect(self, options: RealtimeModelConfig) -> None: |
298 | 425 | """Establish a connection to the model and keep it alive.""" |
299 | 426 | assert self._websocket is None, "Already connected" |
@@ -483,80 +610,41 @@ async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None: |
483 | 610 | async def _set_response_control( |
484 | 611 | self, control: Literal["free", "create_requested", "cancel_requested"] |
485 | 612 | ) -> None: |
486 | | - async with self._response_control_condition: |
487 | | - self._response_control = control |
488 | | - self._response_control_condition.notify_all() |
| 613 | + await self._response_create_sequencer.set_response_control(control) |
489 | 614 |
|
490 | 615 | async def _mark_response_created(self) -> None: |
491 | | - async with self._response_control_condition: |
492 | | - self._ongoing_response = True |
493 | | - self._pending_response_create_event_id = None |
494 | | - self._response_control = "free" |
495 | | - self._response_control_condition.notify_all() |
| 616 | + await self._response_create_sequencer.mark_response_created() |
496 | 617 |
|
497 | 618 | async def _mark_response_done(self) -> None: |
498 | | - async with self._response_control_condition: |
499 | | - self._ongoing_response = False |
500 | | - self._pending_response_create_event_id = None |
501 | | - self._response_control = "free" |
502 | | - self._response_control_condition.notify_all() |
| 619 | + await self._response_create_sequencer.mark_response_done() |
503 | 620 |
|
504 | 621 | async def _release_response_waiters(self) -> None: |
505 | 622 | # Connection teardown means no response.done will arrive, so local |
506 | 623 | # response sequencing must be released explicitly. |
507 | | - async with self._response_control_condition: |
508 | | - self._ongoing_response = False |
509 | | - self._pending_response_create_event_id = None |
510 | | - self._response_control = "free" |
511 | | - self._response_control_condition.notify_all() |
| 624 | + await self._response_create_sequencer.release_waiters() |
512 | 625 |
|
513 | 626 | async def _reserve_response_create_request(self) -> int: |
514 | | - async with self._response_control_condition: |
515 | | - self._response_create_request_version += 1 |
516 | | - request_version = self._response_create_request_version |
517 | | - self._response_control_condition.notify_all() |
518 | | - return request_version |
| 627 | + return await self._response_create_sequencer.reserve_response_create_request() |
519 | 628 |
|
520 | 629 | async def _clear_pending_response_create(self, event_id: str | None = None) -> bool: |
521 | | - async with self._response_control_condition: |
522 | | - if event_id is not None and self._pending_response_create_event_id != event_id: |
523 | | - return False |
524 | | - self._pending_response_create_event_id = None |
525 | | - if self._response_control == "create_requested": |
526 | | - self._response_control = "free" |
527 | | - self._response_control_condition.notify_all() |
528 | | - return True |
| 630 | + return await self._response_create_sequencer.clear_pending_response_create(event_id) |
529 | 631 |
|
530 | 632 | async def _send_response_create_when_idle(self, request_version: int) -> None: |
531 | | - while True: |
532 | | - async with self._response_control_condition: |
533 | | - await self._response_control_condition.wait_for( |
534 | | - lambda: self._response_create_sent_version >= request_version |
535 | | - or (not self._ongoing_response and self._response_control == "free") |
536 | | - ) |
537 | | - if self._response_create_sent_version >= request_version: |
538 | | - return |
539 | | - |
540 | | - target_version = self._response_create_request_version |
541 | | - self._response_control = "create_requested" |
542 | | - self._response_create_event_counter += 1 |
543 | | - event_id = f"agents_py_response_create_{self._response_create_event_counter}" |
544 | | - self._pending_response_create_event_id = event_id |
| 633 | + pending = await self._response_create_sequencer.wait_for_response_create_slot( |
| 634 | + request_version |
| 635 | + ) |
| 636 | + if pending is None: |
| 637 | + return |
545 | 638 |
|
546 | | - try: |
547 | | - await self._send_raw_message( |
548 | | - OpenAIResponseCreateEvent(type="response.create", event_id=event_id) |
549 | | - ) |
550 | | - except BaseException: |
551 | | - await self._clear_pending_response_create(event_id) |
552 | | - raise |
| 639 | + try: |
| 640 | + await self._send_raw_message( |
| 641 | + OpenAIResponseCreateEvent(type="response.create", event_id=pending.event_id) |
| 642 | + ) |
| 643 | + except BaseException: |
| 644 | + await self._clear_pending_response_create(pending.event_id) |
| 645 | + raise |
553 | 646 |
|
554 | | - async with self._response_control_condition: |
555 | | - self._response_create_sent_version = max( |
556 | | - self._response_create_sent_version, target_version |
557 | | - ) |
558 | | - self._response_control_condition.notify_all() |
559 | | - return |
| 647 | + await self._response_create_sequencer.mark_response_create_sent(pending) |
560 | 648 |
|
561 | 649 | def _is_running_in_websocket_listener_task(self) -> bool: |
562 | 650 | current_task = asyncio.current_task() |
@@ -798,10 +886,8 @@ async def close(self) -> None: |
798 | 886 | await self._release_response_waiters() |
799 | 887 |
|
800 | 888 | async def _cancel_response(self) -> None: |
801 | | - async with self._response_control_condition: |
802 | | - if not self._ongoing_response or self._response_control == "cancel_requested": |
803 | | - return |
804 | | - self._response_control = "cancel_requested" |
| 889 | + if not await self._response_create_sequencer.begin_cancel_response(): |
| 890 | + return |
805 | 891 |
|
806 | 892 | try: |
807 | 893 | await self._send_raw_message(OpenAIResponseCancelEvent(type="response.cancel")) |
@@ -957,11 +1043,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): |
957 | 1043 | elif parsed.type == "session.updated": |
958 | 1044 | self._update_created_session(parsed.session) |
959 | 1045 | elif parsed.type == "error": |
960 | | - if ( |
961 | | - not self._ongoing_response |
962 | | - and self._response_control == "create_requested" |
963 | | - and parsed.error.event_id is not None |
964 | | - ): |
| 1046 | + if not self._ongoing_response and self._response_control == "create_requested": |
965 | 1047 | await self._clear_pending_response_create(parsed.error.event_id) |
966 | 1048 | await self._emit_event(RealtimeModelErrorEvent(error=parsed.error)) |
967 | 1049 | elif parsed.type == "conversation.item.deleted": |
|
0 commit comments