|
28 | 28 | from livekit.plugins.google.realtime.api_proto import ClientEvents, LiveAPIModels, Voice |
29 | 29 |
|
30 | 30 | from ..log import logger |
31 | | -from ..utils import create_tools_config, get_tool_results_for_realtime |
| 31 | +from ..utils import create_function_response, create_tools_config, get_tool_results_for_realtime |
32 | 32 | from ..version import __version__ |
33 | 33 |
|
34 | 34 | INPUT_AUDIO_SAMPLE_RATE = 16000 |
|
44 | 44 |
|
45 | 45 | lk_google_debug = int(os.getenv("LK_GOOGLE_DEBUG", 0)) |
46 | 46 |
|
| 47 | +# stop rejecting tool calls after this many in a row to avoid a loop (tool_choice="none") |
| 48 | +MAX_TOOL_CALL_REJECTIONS = 3 |
| 49 | + |
47 | 50 | # Known VertexAI models for the Live API |
48 | 51 | # See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/live-api |
49 | 52 | KNOWN_VERTEXAI_MODELS: frozenset[str] = frozenset( |
@@ -148,6 +151,7 @@ class _RealtimeOptions: |
148 | 151 | api_version: NotGivenOr[str] = NOT_GIVEN |
149 | 152 | tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN |
150 | 153 | tool_response_scheduling: NotGivenOr[types.FunctionResponseScheduling] = NOT_GIVEN |
| 154 | + tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN |
151 | 155 | thinking_config: NotGivenOr[types.ThinkingConfig] = NOT_GIVEN |
152 | 156 | session_resumption: NotGivenOr[types.SessionResumptionConfig] = NOT_GIVEN |
153 | 157 | credentials: google.auth.credentials.Credentials | None = None |
@@ -488,6 +492,10 @@ def __init__(self, realtime_model: RealtimeModel) -> None: |
488 | 492 | self._session_should_close = asyncio.Event() |
489 | 493 | self._response_created_futures: dict[str, asyncio.Future[llm.GenerationCreatedEvent]] = {} |
490 | 494 | self._pending_generation_fut: asyncio.Future[llm.GenerationCreatedEvent] | None = None |
| 495 | + # number of tool calls rejected in the current tool_choice="none" turn; non-zero also |
| 496 | + # means we're draining that turn's trailing events (which have no generation to attach |
| 497 | + # to). reset when the next generation starts. |
| 498 | + self._rejected_tool_calls = 0 |
491 | 499 |
|
492 | 500 | self._session_resumption_handle: str | None = ( |
493 | 501 | self._opts.session_resumption.handle |
@@ -557,7 +565,19 @@ def update_options( |
557 | 565 | # no need to restart |
558 | 566 |
|
559 | 567 | if is_given(tool_choice): |
560 | | - logger.warning("tool_choice is not supported by the Google Realtime API.") |
| 568 | + # no per-response tool_choice on Gemini; "none" is emulated by rejecting any tool |
| 569 | + # call emitted during the turn (see _reject_tool_calls). |
| 570 | + self._opts.tool_choice = tool_choice |
| 571 | + if tool_choice == "none": |
| 572 | + logger.warning( |
| 573 | + "the Google Realtime API has no tool_choice='none'; tool calls emitted " |
| 574 | + "this turn will be rejected so the model replies directly." |
| 575 | + ) |
| 576 | + elif tool_choice not in (None, "auto"): |
| 577 | + logger.warning( |
| 578 | + f"tool_choice='{tool_choice}' is not supported by the Google Realtime API, " |
| 579 | + "falling back to 'auto'." |
| 580 | + ) |
561 | 581 |
|
562 | 582 | if should_restart: |
563 | 583 | self._mark_restart_needed() |
@@ -1045,6 +1065,13 @@ async def _recv_task(self, session: AsyncSession) -> None: |
1045 | 1065 | part["inline_data"] = "<audio>" |
1046 | 1066 | logger.debug("<<< received response", extra={"response": resp_copy}) |
1047 | 1067 |
|
| 1068 | + if response.tool_call and self._opts.tool_choice == "none": |
| 1069 | + # reject without opening a generation, so the pending generate_reply |
| 1070 | + # stays bound to the model's eventual reply and tools stay suppressed |
| 1071 | + # for the whole turn. |
| 1072 | + self._reject_tool_calls(response.tool_call.function_calls or []) |
| 1073 | + continue |
| 1074 | + |
1048 | 1075 | if not self._current_generation or self._current_generation._done: |
1049 | 1076 | if (sc := response.server_content) and sc.interrupted: |
1050 | 1077 | # two cases an interrupted event is sent without an active generation |
@@ -1163,6 +1190,7 @@ def _build_connect_config(self) -> types.LiveConnectConfig: |
1163 | 1190 | return conf |
1164 | 1191 |
|
1165 | 1192 | def _start_new_generation(self) -> None: |
| 1193 | + self._rejected_tool_calls = 0 |
1166 | 1194 | if self._current_generation and not self._current_generation._done: |
1167 | 1195 | logger.warning("starting new generation while another is active. Finalizing previous.") |
1168 | 1196 | self._mark_current_generation_done() |
@@ -1214,7 +1242,13 @@ def _start_new_generation(self) -> None: |
1214 | 1242 | def _handle_server_content(self, server_content: types.LiveServerContent) -> None: |
1215 | 1243 | current_gen = self._current_generation |
1216 | 1244 | if not current_gen: |
1217 | | - logger.warning("received server content but no active generation.") |
| 1245 | + if self._rejected_tool_calls: |
| 1246 | + logger.debug( |
| 1247 | + "ignoring server content from a rejected tool call turn", |
| 1248 | + extra={"server_content": server_content.model_dump_json(exclude_none=True)}, |
| 1249 | + ) |
| 1250 | + else: |
| 1251 | + logger.warning("received server content but no active generation.") |
1218 | 1252 | return |
1219 | 1253 |
|
1220 | 1254 | if model_turn := server_content.model_turn: |
@@ -1332,6 +1366,38 @@ def _handle_input_speech_stopped(self) -> None: |
1332 | 1366 | llm.InputSpeechStoppedEvent(user_transcription_enabled=False), |
1333 | 1367 | ) |
1334 | 1368 |
|
| 1369 | + def _reject_tool_calls(self, function_calls: list[types.FunctionCall]) -> None: |
| 1370 | + if not function_calls: |
| 1371 | + return |
| 1372 | + |
| 1373 | + self._rejected_tool_calls += 1 |
| 1374 | + extra = {"functions": [fnc_call.name for fnc_call in function_calls]} |
| 1375 | + if self._rejected_tool_calls > MAX_TOOL_CALL_REJECTIONS: |
| 1376 | + # stop responding to break the loop; the user can still interrupt by voice |
| 1377 | + if self._rejected_tool_calls == MAX_TOOL_CALL_REJECTIONS + 1: |
| 1378 | + logger.error( |
| 1379 | + "model keeps calling tools despite tool_choice='none'; " |
| 1380 | + f"stopping after {MAX_TOOL_CALL_REJECTIONS} rejections to avoid a loop", |
| 1381 | + extra=extra, |
| 1382 | + ) |
| 1383 | + return |
| 1384 | + |
| 1385 | + logger.warning("rejecting tool call requested while tool_choice='none'", extra=extra) |
| 1386 | + responses = [ |
| 1387 | + create_function_response( |
| 1388 | + llm.FunctionCallOutput( |
| 1389 | + name=fnc_call.name or "", |
| 1390 | + call_id=fnc_call.id or "", |
| 1391 | + output="Tool calls are disabled for this turn, respond to the user directly.", |
| 1392 | + is_error=True, |
| 1393 | + ), |
| 1394 | + vertexai=self._opts.vertexai, |
| 1395 | + tool_response_scheduling=self._opts.tool_response_scheduling, |
| 1396 | + ) |
| 1397 | + for fnc_call in function_calls |
| 1398 | + ] |
| 1399 | + self._send_client_event(types.LiveClientToolResponse(function_responses=responses)) |
| 1400 | + |
1335 | 1401 | def _handle_tool_calls(self, tool_call: types.LiveServerToolCall) -> None: |
1336 | 1402 | if not self._current_generation: |
1337 | 1403 | logger.warning("received tool call but no active generation.") |
@@ -1361,7 +1427,10 @@ def _handle_tool_call_cancellation( |
1361 | 1427 | def _handle_usage_metadata(self, usage_metadata: types.UsageMetadata) -> None: |
1362 | 1428 | current_gen = self._current_generation |
1363 | 1429 | if not current_gen: |
1364 | | - logger.warning("no active generation to report metrics for") |
| 1430 | + if self._rejected_tool_calls: |
| 1431 | + logger.debug("ignoring usage metadata from a rejected tool call turn") |
| 1432 | + else: |
| 1433 | + logger.warning("no active generation to report metrics for") |
1365 | 1434 | return |
1366 | 1435 |
|
1367 | 1436 | ttft = ( |
|
0 commit comments