Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/realtime/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def get_weather(city: str) -> str:
f"{RECOMMENDED_PROMPT_PREFIX} "
"You are a helpful triaging agent. You can use your tools to delegate questions to other appropriate agents."
),
tools=[get_weather],
handoffs=[faq_agent, realtime_handoff(seat_booking_agent)],
)

Expand Down
60 changes: 54 additions & 6 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def __init__(self, *, transport_config: TransportConfig | None = None) -> None:
self._current_item_id: str | None = None
self._audio_state_tracker: ModelAudioTracker = ModelAudioTracker()
self._ongoing_response: bool = False
# Keep local response control in one place so create/cancel sequencing
# stays readable without multiple overlapping boolean flags.
self._response_control: Literal["free", "create_requested", "cancel_requested"] = "free"
self._response_control_condition = asyncio.Condition()
self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
self._playback_tracker: RealtimePlaybackTracker | None = None
self._created_session: OpenAISessionCreateRequest | None = None
Expand Down Expand Up @@ -427,14 +431,17 @@ async def _listen_for_messages(self):

except websockets.exceptions.ConnectionClosedOK:
Comment thread
seratch marked this conversation as resolved.
# Normal connection closure - no exception event needed
await self._release_response_waiters()
logger.debug("WebSocket connection closed normally")
except websockets.exceptions.ConnectionClosed as e:
await self._release_response_waiters()
await self._emit_event(
RealtimeModelExceptionEvent(
exception=e, context="WebSocket connection closed unexpectedly"
)
)
except Exception as e:
await self._release_response_waiters()
await self._emit_event(
RealtimeModelExceptionEvent(
exception=e, context="WebSocket error in message listener"
Expand Down Expand Up @@ -469,10 +476,41 @@ async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None:
payload = event.model_dump_json(exclude_unset=True)
await self._websocket.send(payload)

async def _set_response_control(
self, control: Literal["free", "create_requested", "cancel_requested"]
) -> None:
async with self._response_control_condition:
self._response_control = control
self._response_control_condition.notify_all()

async def _mark_response_created(self) -> None:
self._ongoing_response = True
await self._set_response_control("free")

async def _mark_response_done(self) -> None:
self._ongoing_response = False
await self._set_response_control("free")

async def _release_response_waiters(self) -> None:
await self._set_response_control("free")
Comment thread
seratch marked this conversation as resolved.
Outdated

async def _send_response_create_when_idle(self) -> None:
async with self._response_control_condition:
await self._response_control_condition.wait_for(
lambda: not self._ongoing_response and self._response_control == "free"
)
self._response_control = "create_requested"

try:
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
except Exception:
await self._set_response_control("free")
raise

async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None:
converted = _ConversionHelper.convert_user_input_to_item_create(event)
await self._send_raw_message(converted)
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
await self._send_response_create_when_idle()

async def _send_audio(self, event: RealtimeModelSendAudio) -> None:
converted = _ConversionHelper.convert_audio_to_input_audio_buffer_append(event)
Expand All @@ -499,7 +537,7 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_item))

if event.start_response:
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
await self._send_response_create_when_idle()
Comment thread
seratch marked this conversation as resolved.
Outdated

def _get_playback_state(self) -> RealtimePlaybackState:
if self._playback_tracker:
Expand Down Expand Up @@ -663,6 +701,7 @@ async def _handle_conversation_item(

async def close(self) -> None:
"""Close the session."""
await self._release_response_waiters()
if self._websocket:
await self._websocket.close()
self._websocket = None
Expand All @@ -675,9 +714,16 @@ async def close(self) -> None:
self._websocket_task = None

async def _cancel_response(self) -> None:
if self._ongoing_response:
async with self._response_control_condition:
if not self._ongoing_response or self._response_control == "cancel_requested":
return
self._response_control = "cancel_requested"

try:
await self._send_raw_message(OpenAIResponseCancelEvent(type="response.cancel"))
self._ongoing_response = False
except Exception:
await self._set_response_control("free")
raise

async def _handle_ws_event(self, event: dict[str, Any]):
await self._emit_event(RealtimeModelRawServerEvent(data=event))
Expand Down Expand Up @@ -816,17 +862,19 @@ async def _handle_ws_event(self, event: dict[str, Any]):
if not automatic_response_cancellation_enabled:
await self._cancel_response()
elif parsed.type == "response.created":
self._ongoing_response = True
await self._mark_response_created()
await self._emit_event(RealtimeModelTurnStartedEvent())
elif parsed.type == "response.done":
self._ongoing_response = False
await self._mark_response_done()
await self._emit_event(RealtimeModelTurnEndedEvent())
elif parsed.type == "session.created":
await self._send_tracing_config(self._tracing_config)
self._update_created_session(parsed.session)
elif parsed.type == "session.updated":
self._update_created_session(parsed.session)
elif parsed.type == "error":
if not self._ongoing_response and self._response_control == "create_requested":
await self._set_response_control("free")
await self._emit_event(RealtimeModelErrorEvent(error=parsed.error))
elif parsed.type == "conversation.item.deleted":
await self._emit_event(RealtimeModelItemDeletedEvent(item_id=parsed.item_id))
Expand Down
67 changes: 66 additions & 1 deletion tests/realtime/test_openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ async def test_send_event_dispatch(self, model, monkeypatch):
monkeypatch.setattr(model, "_send_raw_message", send_raw)

await model.send_event(RealtimeModelSendUserInput(user_input="hi"))
await model._mark_response_done()
await model.send_event(RealtimeModelSendAudio(audio=b"a", commit=False))
await model.send_event(RealtimeModelSendAudio(audio=b"a", commit=True))
await model.send_event(
Expand All @@ -736,7 +737,7 @@ async def test_interrupt_force_cancel_overrides_auto_cancellation(self, model, m
"""Interrupt should send response.cancel even when auto cancel is enabled."""
model._audio_state_tracker.set_audio_format("pcm16")
model._audio_state_tracker.on_audio_delta("item_1", 0, b"\x00" * 4800)
model._ongoing_response = True
await model._mark_response_created()
model._created_session = SimpleNamespace(
audio=SimpleNamespace(
input=SimpleNamespace(turn_detection=SimpleNamespace(interrupt_response=True))
Expand All @@ -753,7 +754,12 @@ async def test_interrupt_force_cancel_overrides_auto_cancellation(self, model, m
assert send_raw.await_count == 2
payload_types = {call.args[0].type for call in send_raw.call_args_list}
assert payload_types == {"conversation.item.truncate", "response.cancel"}
assert model._ongoing_response is True
assert model._response_control == "cancel_requested"

await model._mark_response_done()
assert model._ongoing_response is False
assert model._response_control == "free"
assert model._audio_state_tracker.get_last_audio_item() is None

@pytest.mark.asyncio
Expand All @@ -780,6 +786,65 @@ async def test_interrupt_respects_auto_cancellation_when_not_forced(self, model,
assert all(call.args[0].type != "response.cancel" for call in send_raw.call_args_list)
assert model._ongoing_response is True

@pytest.mark.asyncio
async def test_send_user_input_waits_for_response_done_before_response_create(
self, model, monkeypatch
):
"""Active turns should delay the next response.create until response.done arrives."""
payload_types: list[str] = []

async def fake_send_raw(event):
payload_types.append(event.type)

monkeypatch.setattr(model, "_send_raw_message", fake_send_raw)
await model._mark_response_created()

task = asyncio.create_task(
model._send_user_input(RealtimeModelSendUserInput(user_input="hi"))
)
await asyncio.sleep(0)

assert payload_types == ["conversation.item.create"]
assert task.done() is False

await model._mark_response_done()
await asyncio.wait_for(task, timeout=1)

assert payload_types == ["conversation.item.create", "response.create"]

@pytest.mark.asyncio
async def test_tool_output_start_response_waits_for_response_done_before_response_create(
self, model, monkeypatch
):
"""Tool outputs that restart the model should also wait for the prior turn to finish."""
payload_types: list[str] = []

async def fake_send_raw(event):
payload_types.append(event.type)

monkeypatch.setattr(model, "_send_raw_message", fake_send_raw)
monkeypatch.setattr(model, "_emit_event", AsyncMock())
await model._mark_response_created()

task = asyncio.create_task(
model._send_tool_output(
RealtimeModelSendToolOutput(
tool_call=RealtimeModelToolCallEvent(name="t", call_id="c", arguments="{}"),
output="ok",
start_response=True,
)
)
)
await asyncio.sleep(0)

assert "response.create" not in payload_types
assert task.done() is False

await model._mark_response_done()
await asyncio.wait_for(task, timeout=1)

assert payload_types[-1] == "response.create"

def test_add_remove_listener_and_tools_conversion(self, model):
listener = AsyncMock()
model.add_listener(listener)
Expand Down
Loading