Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/realtime/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def get_weather(city: str) -> str:
f"{RECOMMENDED_PROMPT_PREFIX} "
"You are a helpful triaging agent. You can use your tools to delegate questions to other appropriate agents."
),
tools=[get_weather],
handoffs=[faq_agent, realtime_handoff(seat_booking_agent)],
)

Expand Down
148 changes: 142 additions & 6 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,14 @@ def __init__(self, *, transport_config: TransportConfig | None = None) -> None:
self._current_item_id: str | None = None
self._audio_state_tracker: ModelAudioTracker = ModelAudioTracker()
self._ongoing_response: bool = False
# Keep local response control in one place so create/cancel sequencing
# stays readable without multiple overlapping boolean flags.
self._response_control: Literal["free", "create_requested", "cancel_requested"] = "free"
self._response_create_request_version = 0
self._response_create_sent_version = 0
self._response_create_event_counter = 0
self._pending_response_create_event_id: str | None = None
self._response_control_condition = asyncio.Condition()
self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
self._playback_tracker: RealtimePlaybackTracker | None = None
self._created_session: OpenAISessionCreateRequest | None = None
Expand Down Expand Up @@ -427,14 +435,17 @@ async def _listen_for_messages(self):

except websockets.exceptions.ConnectionClosedOK:
Comment thread
seratch marked this conversation as resolved.
# Normal connection closure - no exception event needed
await self._release_response_waiters()
logger.debug("WebSocket connection closed normally")
except websockets.exceptions.ConnectionClosed as e:
await self._release_response_waiters()
await self._emit_event(
RealtimeModelExceptionEvent(
exception=e, context="WebSocket connection closed unexpectedly"
)
)
except Exception as e:
await self._release_response_waiters()
await self._emit_event(
RealtimeModelExceptionEvent(
exception=e, context="WebSocket error in message listener"
Expand Down Expand Up @@ -469,10 +480,120 @@ async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None:
payload = event.model_dump_json(exclude_unset=True)
await self._websocket.send(payload)

async def _set_response_control(
self, control: Literal["free", "create_requested", "cancel_requested"]
) -> None:
async with self._response_control_condition:
self._response_control = control
self._response_control_condition.notify_all()

async def _mark_response_created(self) -> None:
async with self._response_control_condition:
self._ongoing_response = True
self._pending_response_create_event_id = None
self._response_control = "free"
self._response_control_condition.notify_all()

async def _mark_response_done(self) -> None:
async with self._response_control_condition:
self._ongoing_response = False
self._pending_response_create_event_id = None
self._response_control = "free"
self._response_control_condition.notify_all()

async def _release_response_waiters(self) -> None:
# Connection teardown means no response.done will arrive, so local
# response sequencing must be released explicitly.
async with self._response_control_condition:
self._ongoing_response = False
self._pending_response_create_event_id = None
self._response_control = "free"
self._response_control_condition.notify_all()

async def _reserve_response_create_request(self) -> int:
async with self._response_control_condition:
self._response_create_request_version += 1
request_version = self._response_create_request_version
self._response_control_condition.notify_all()
return request_version

async def _clear_pending_response_create(self, event_id: str | None = None) -> bool:
async with self._response_control_condition:
if event_id is not None and self._pending_response_create_event_id != event_id:
return False
self._pending_response_create_event_id = None
if self._response_control == "create_requested":
self._response_control = "free"
self._response_control_condition.notify_all()
return True

async def _send_response_create_when_idle(self, request_version: int) -> None:
while True:
async with self._response_control_condition:
await self._response_control_condition.wait_for(
lambda: self._response_create_sent_version >= request_version
or (not self._ongoing_response and self._response_control == "free")
)
if self._response_create_sent_version >= request_version:
return

target_version = self._response_create_request_version
self._response_control = "create_requested"
self._response_create_event_counter += 1
event_id = f"agents_py_response_create_{self._response_create_event_counter}"
self._pending_response_create_event_id = event_id

try:
await self._send_raw_message(
OpenAIResponseCreateEvent(type="response.create", event_id=event_id)
)
except BaseException:
await self._clear_pending_response_create(event_id)
raise

async with self._response_control_condition:
self._response_create_sent_version = max(
self._response_create_sent_version, target_version
)
self._response_control_condition.notify_all()
return

def _is_running_in_websocket_listener_task(self) -> bool:
current_task = asyncio.current_task()
return current_task is not None and current_task is self._websocket_task

async def _send_response_create_in_background(self, request_version: int) -> None:
try:
await self._send_response_create_when_idle(request_version)
except asyncio.CancelledError:
logger.debug("Deferred response.create task was cancelled")
except AssertionError as exc:
if str(exc) != "Not connected":
await self._emit_event(
RealtimeModelExceptionEvent(
exception=exc, context="Error sending deferred response.create"
)
)
except websockets.exceptions.ConnectionClosed:
logger.debug("Skipping deferred response.create because the websocket is closed")
except Exception as exc:
await self._emit_event(
RealtimeModelExceptionEvent(
exception=exc, context="Error sending deferred response.create"
)
)

async def _start_response_create(self, request_version: int) -> None:
if self._is_running_in_websocket_listener_task():
asyncio.create_task(self._send_response_create_in_background(request_version))
else:
await self._send_response_create_when_idle(request_version)

async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None:
converted = _ConversionHelper.convert_user_input_to_item_create(event)
await self._send_raw_message(converted)
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
request_version = await self._reserve_response_create_request()
await self._start_response_create(request_version)

async def _send_audio(self, event: RealtimeModelSendAudio) -> None:
converted = _ConversionHelper.convert_audio_to_input_audio_buffer_append(event)
Expand All @@ -499,7 +620,8 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_item))

if event.start_response:
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
request_version = await self._reserve_response_create_request()
await self._start_response_create(request_version)

def _get_playback_state(self) -> RealtimePlaybackState:
if self._playback_tracker:
Expand Down Expand Up @@ -673,11 +795,19 @@ async def close(self) -> None:
except asyncio.CancelledError:
pass
self._websocket_task = None
await self._release_response_waiters()

async def _cancel_response(self) -> None:
if self._ongoing_response:
async with self._response_control_condition:
if not self._ongoing_response or self._response_control == "cancel_requested":
return
self._response_control = "cancel_requested"

try:
await self._send_raw_message(OpenAIResponseCancelEvent(type="response.cancel"))
self._ongoing_response = False
except Exception:
await self._set_response_control("free")
raise

async def _handle_ws_event(self, event: dict[str, Any]):
await self._emit_event(RealtimeModelRawServerEvent(data=event))
Expand Down Expand Up @@ -816,17 +946,23 @@ async def _handle_ws_event(self, event: dict[str, Any]):
if not automatic_response_cancellation_enabled:
await self._cancel_response()
elif parsed.type == "response.created":
self._ongoing_response = True
await self._mark_response_created()
await self._emit_event(RealtimeModelTurnStartedEvent())
elif parsed.type == "response.done":
self._ongoing_response = False
await self._mark_response_done()
await self._emit_event(RealtimeModelTurnEndedEvent())
elif parsed.type == "session.created":
await self._send_tracing_config(self._tracing_config)
self._update_created_session(parsed.session)
elif parsed.type == "session.updated":
self._update_created_session(parsed.session)
elif parsed.type == "error":
if (
not self._ongoing_response
and self._response_control == "create_requested"
and parsed.error.event_id is not None
):
await self._clear_pending_response_create(parsed.error.event_id)
Comment thread
seratch marked this conversation as resolved.
Outdated
await self._emit_event(RealtimeModelErrorEvent(error=parsed.error))
elif parsed.type == "conversation.item.deleted":
await self._emit_event(RealtimeModelItemDeletedEvent(item_id=parsed.item_id))
Expand Down
Loading
Loading