Skip to content

Commit b0fed87

Browse files
authored
Merge branch 'main' into add-dapr-session
2 parents 42c802c + 648d14d commit b0fed87

9 files changed

Lines changed: 402 additions & 30 deletions

File tree

src/agents/extensions/memory/sqlalchemy_session.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,16 @@ async def clear_session(self) -> None:
319319
await sess.execute(
320320
delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
321321
)
322+
323+
@property
324+
def engine(self) -> AsyncEngine:
325+
"""Access the underlying SQLAlchemy AsyncEngine.
326+
327+
This property provides direct access to the engine for advanced use cases,
328+
such as checking connection pool status, configuring engine settings,
329+
or manually disposing the engine when needed.
330+
331+
Returns:
332+
AsyncEngine: The SQLAlchemy async engine instance.
333+
"""
334+
return self._engine

src/agents/extensions/models/litellm_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
4545
from ...models.fake_id import FAKE_RESPONSES_ID
4646
from ...models.interface import Model, ModelTracing
47+
from ...models.openai_responses import Converter as OpenAIResponsesConverter
4748
from ...tool import Tool
4849
from ...tracing import generation_span
4950
from ...tracing.span_data import GenerationSpanData
@@ -367,15 +368,19 @@ async def _fetch_response(
367368
if isinstance(ret, litellm.types.utils.ModelResponse):
368369
return ret
369370

371+
responses_tool_choice = OpenAIResponsesConverter.convert_tool_choice(
372+
model_settings.tool_choice
373+
)
374+
if responses_tool_choice is None or responses_tool_choice is omit:
375+
responses_tool_choice = "auto"
376+
370377
response = Response(
371378
id=FAKE_RESPONSES_ID,
372379
created_at=time.time(),
373380
model=self.model,
374381
object="response",
375382
output=[],
376-
tool_choice=cast(Literal["auto", "required", "none"], tool_choice)
377-
if tool_choice is not omit
378-
else "auto",
383+
tool_choice=responses_tool_choice, # type: ignore[arg-type]
379384
top_p=model_settings.top_p,
380385
temperature=model_settings.temperature,
381386
tools=[],

src/agents/realtime/model_inputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class RealtimeModelSendToolOutput:
9595
class RealtimeModelSendInterrupt:
9696
"""Send an interrupt to the model."""
9797

98+
force_response_cancel: bool = False
99+
"""Force sending a response.cancel event even if automatic cancellation is enabled."""
100+
98101

99102
@dataclass
100103
class RealtimeModelSendSessionUpdate:

src/agents/realtime/openai_realtime.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -395,36 +395,36 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
395395
current_item_id = playback_state.get("current_item_id")
396396
current_item_content_index = playback_state.get("current_item_content_index")
397397
elapsed_ms = playback_state.get("elapsed_ms")
398+
398399
if current_item_id is None or elapsed_ms is None:
399400
logger.debug(
400401
"Skipping interrupt. "
401402
f"Item id: {current_item_id}, "
402403
f"elapsed ms: {elapsed_ms}, "
403404
f"content index: {current_item_content_index}"
404405
)
405-
return
406-
407-
current_item_content_index = current_item_content_index or 0
408-
if elapsed_ms > 0:
409-
await self._emit_event(
410-
RealtimeModelAudioInterruptedEvent(
411-
item_id=current_item_id,
412-
content_index=current_item_content_index,
413-
)
414-
)
415-
converted = _ConversionHelper.convert_interrupt(
416-
current_item_id,
417-
current_item_content_index,
418-
int(elapsed_ms),
419-
)
420-
await self._send_raw_message(converted)
421406
else:
422-
logger.debug(
423-
"Didn't interrupt bc elapsed ms is < 0. "
424-
f"Item id: {current_item_id}, "
425-
f"elapsed ms: {elapsed_ms}, "
426-
f"content index: {current_item_content_index}"
427-
)
407+
current_item_content_index = current_item_content_index or 0
408+
if elapsed_ms > 0:
409+
await self._emit_event(
410+
RealtimeModelAudioInterruptedEvent(
411+
item_id=current_item_id,
412+
content_index=current_item_content_index,
413+
)
414+
)
415+
converted = _ConversionHelper.convert_interrupt(
416+
current_item_id,
417+
current_item_content_index,
418+
int(elapsed_ms),
419+
)
420+
await self._send_raw_message(converted)
421+
else:
422+
logger.debug(
423+
"Didn't interrupt bc elapsed ms is < 0. "
424+
f"Item id: {current_item_id}, "
425+
f"elapsed ms: {elapsed_ms}, "
426+
f"content index: {current_item_content_index}"
427+
)
428428

429429
session = self._created_session
430430
automatic_response_cancellation_enabled = (
@@ -434,12 +434,16 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
434434
and session.audio.input.turn_detection is not None
435435
and session.audio.input.turn_detection.interrupt_response is True
436436
)
437-
if not automatic_response_cancellation_enabled:
437+
should_cancel_response = event.force_response_cancel or (
438+
not automatic_response_cancellation_enabled
439+
)
440+
if should_cancel_response:
438441
await self._cancel_response()
439442

440-
self._audio_state_tracker.on_interrupted()
441-
if self._playback_tracker:
442-
self._playback_tracker.on_interrupted()
443+
if current_item_id is not None and elapsed_ms is not None:
444+
self._audio_state_tracker.on_interrupted()
445+
if self._playback_tracker:
446+
self._playback_tracker.on_interrupted()
443447

444448
async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None:
445449
"""Send a session update to the model."""

src/agents/realtime/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool:
704704
)
705705

706706
# Interrupt the model
707-
await self._model.send_event(RealtimeModelSendInterrupt())
707+
await self._model.send_event(RealtimeModelSendInterrupt(force_response_cancel=True))
708708

709709
# Send guardrail triggered message
710710
guardrail_names = [result.guardrail.get_name() for result in triggered_results]

src/agents/run.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,15 @@ async def _start_streaming(
11381138

11391139
streamed_result.is_complete = True
11401140
finally:
1141+
if streamed_result._input_guardrails_task:
1142+
try:
1143+
await AgentRunner._input_guardrail_tripwire_triggered_for_stream(
1144+
streamed_result
1145+
)
1146+
except Exception as e:
1147+
logger.debug(
1148+
f"Error in streamed_result finalize for agent {current_agent.name} - {e}"
1149+
)
11411150
if current_span:
11421151
current_span.finish(reset_current=True)
11431152
if streamed_result.trace:

tests/extensions/memory/test_sqlalchemy_session.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Summary,
1515
)
1616
from sqlalchemy import select, text, update
17+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
1718
from sqlalchemy.sql import Select
1819

1920
pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed
@@ -390,3 +391,56 @@ async def recording_execute(statement: Any, *args: Any, **kwargs: Any) -> Any:
390391

391392
assert _item_ids(retrieved_full) == ["rs_first", "msg_second"]
392393
assert _item_ids(retrieved_limited) == ["rs_first", "msg_second"]
394+
395+
396+
async def test_engine_property_from_url():
397+
"""Test that the engine property returns the AsyncEngine from from_url."""
398+
session_id = "engine_property_test"
399+
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)
400+
401+
# Verify engine property returns an AsyncEngine instance
402+
assert isinstance(session.engine, AsyncEngine)
403+
404+
# Verify we can use the engine for advanced operations
405+
# For example, check pool status
406+
assert session.engine.pool is not None
407+
408+
# Verify we can manually dispose the engine
409+
await session.engine.dispose()
410+
411+
412+
async def test_engine_property_from_external_engine():
413+
"""Test that the engine property returns the external engine."""
414+
session_id = "external_engine_test"
415+
416+
# Create engine externally
417+
external_engine = create_async_engine(DB_URL)
418+
419+
# Create session with external engine
420+
session = SQLAlchemySession(session_id, engine=external_engine, create_tables=True)
421+
422+
# Verify engine property returns the same engine instance
423+
assert session.engine is external_engine
424+
425+
# Verify we can use the engine
426+
assert isinstance(session.engine, AsyncEngine)
427+
428+
# Clean up - user is responsible for disposing external engine
429+
await external_engine.dispose()
430+
431+
432+
async def test_engine_property_is_read_only():
433+
"""Test that the engine property cannot be modified."""
434+
session_id = "readonly_engine_test"
435+
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)
436+
437+
# Verify engine property exists
438+
assert hasattr(session, "engine")
439+
440+
# Verify it's a property (read-only, cannot be set)
441+
# Type ignore needed because mypy correctly detects this is read-only
442+
with pytest.raises(AttributeError):
443+
session.engine = create_async_engine(DB_URL) # type: ignore[misc]
444+
445+
# Clean up
446+
await session.engine.dispose()

tests/realtime/test_openai_realtime.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from types import SimpleNamespace
23
from typing import Any, cast
34
from unittest.mock import AsyncMock, Mock, patch
45

@@ -509,6 +510,59 @@ async def test_send_event_dispatch(self, model, monkeypatch):
509510
# session update -> 1
510511
assert send_raw.await_count == 8
511512

513+
@pytest.mark.asyncio
514+
async def test_interrupt_force_cancel_overrides_auto_cancellation(self, model, monkeypatch):
515+
"""Interrupt should send response.cancel even when auto cancel is enabled."""
516+
model._audio_state_tracker.set_audio_format("pcm16")
517+
model._audio_state_tracker.on_audio_delta("item_1", 0, b"\x00" * 4800)
518+
model._ongoing_response = True
519+
model._created_session = SimpleNamespace(
520+
audio=SimpleNamespace(
521+
input=SimpleNamespace(
522+
turn_detection=SimpleNamespace(interrupt_response=True)
523+
)
524+
)
525+
)
526+
527+
send_raw = AsyncMock()
528+
emit_event = AsyncMock()
529+
monkeypatch.setattr(model, "_send_raw_message", send_raw)
530+
monkeypatch.setattr(model, "_emit_event", emit_event)
531+
532+
await model._send_interrupt(RealtimeModelSendInterrupt(force_response_cancel=True))
533+
534+
assert send_raw.await_count == 2
535+
payload_types = {call.args[0].type for call in send_raw.call_args_list}
536+
assert payload_types == {"conversation.item.truncate", "response.cancel"}
537+
assert model._ongoing_response is False
538+
assert model._audio_state_tracker.get_last_audio_item() is None
539+
540+
@pytest.mark.asyncio
541+
async def test_interrupt_respects_auto_cancellation_when_not_forced(self, model, monkeypatch):
542+
"""Interrupt should avoid sending response.cancel when relying on automatic cancellation."""
543+
model._audio_state_tracker.set_audio_format("pcm16")
544+
model._audio_state_tracker.on_audio_delta("item_1", 0, b"\x00" * 4800)
545+
model._ongoing_response = True
546+
model._created_session = SimpleNamespace(
547+
audio=SimpleNamespace(
548+
input=SimpleNamespace(
549+
turn_detection=SimpleNamespace(interrupt_response=True)
550+
)
551+
)
552+
)
553+
554+
send_raw = AsyncMock()
555+
emit_event = AsyncMock()
556+
monkeypatch.setattr(model, "_send_raw_message", send_raw)
557+
monkeypatch.setattr(model, "_emit_event", emit_event)
558+
559+
await model._send_interrupt(RealtimeModelSendInterrupt())
560+
561+
assert send_raw.await_count == 1
562+
assert send_raw.call_args_list[0].args[0].type == "conversation.item.truncate"
563+
assert all(call.args[0].type != "response.cancel" for call in send_raw.call_args_list)
564+
assert model._ongoing_response is True
565+
512566
def test_add_remove_listener_and_tools_conversion(self, model):
513567
listener = AsyncMock()
514568
model.add_listener(listener)

0 commit comments

Comments
 (0)