-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Expand file tree
/
Copy pathgemini_llm_connection.py
More file actions
233 lines (201 loc) · 8 KB
/
gemini_llm_connection.py
File metadata and controls
233 lines (201 loc) · 8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import AsyncGenerator
from typing import Union
from google.genai import live
from google.genai import types
from .base_llm_connection import BaseLlmConnection
from .llm_response import LlmResponse
logger = logging.getLogger('google_adk.' + __name__)
RealtimeInput = Union[types.Blob, types.ActivityStart, types.ActivityEnd]
class GeminiLlmConnection(BaseLlmConnection):
"""The Gemini model connection."""
def __init__(self, gemini_session: live.AsyncSession):
self._gemini_session = gemini_session
async def send_history(self, history: list[types.Content]):
"""Sends the conversation history to the gemini model.
You call this method right after setting up the model connection.
The model will respond if the last content is from user, otherwise it will
wait for new user input before responding.
Args:
history: The conversation history to send to the model.
"""
# TODO: Remove this filter and translate unary contents to streaming
# contents properly.
# We ignore any audio from user during the agent transfer phase
contents = [
content
for content in history
if content.parts and content.parts[0].text
]
if contents:
await self._gemini_session.send(
input=types.LiveClientContent(
turns=contents,
turn_complete=contents[-1].role == 'user',
),
)
else:
logger.info('no content is sent')
async def send_content(self, content: types.Content):
"""Sends a user content to the gemini model.
The model will respond immediately upon receiving the content.
If you send function responses, all parts in the content should be function
responses.
Args:
content: The content to send to the model.
"""
assert content.parts
if content.parts[0].function_response:
# All parts have to be function responses.
function_responses = [part.function_response for part in content.parts]
logger.debug('Sending LLM function response: %s', function_responses)
await self._gemini_session.send(
input=types.LiveClientToolResponse(
function_responses=function_responses
),
)
else:
logger.debug('Sending LLM new content %s', content)
await self._gemini_session.send(
input=types.LiveClientContent(
turns=[content],
turn_complete=True,
)
)
async def send_realtime(self, input: RealtimeInput):
"""Sends a chunk of audio or a frame of video to the model in realtime.
Args:
input: The input to send to the model.
"""
if isinstance(input, types.Blob):
input_blob = input.model_dump()
logger.debug('Sending LLM Blob: %s', input_blob)
await self._gemini_session.send(input=input_blob)
elif isinstance(input, types.ActivityStart):
logger.debug('Sending LLM activity start signal')
await self._gemini_session.send_realtime_input(activity_start=input)
elif isinstance(input, types.ActivityEnd):
logger.debug('Sending LLM activity end signal')
await self._gemini_session.send_realtime_input(activity_end=input)
else:
raise ValueError('Unsupported input type: %s' % type(input))
def __build_full_text_response(self, text: str):
"""Builds a full text response.
The text should not partial and the returned LlmResponse is not be
partial.
Args:
text: The text to be included in the response.
Returns:
An LlmResponse containing the full text.
"""
return LlmResponse(
content=types.Content(
role='model',
parts=[types.Part.from_text(text=text)],
),
)
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
"""Receives the model response using the llm server connection.
Yields:
LlmResponse: The model response.
"""
text = ''
async for message in self._gemini_session.receive():
logger.debug('Got LLM Live message: %s', message)
if message.server_content:
content = message.server_content.model_turn
if content and content.parts:
llm_response = LlmResponse(
content=content, interrupted=message.server_content.interrupted
)
if content.parts[0].text:
text += content.parts[0].text
llm_response.partial = True
# don't yield the merged text event when receiving audio data
elif text and not content.parts[0].inline_data:
yield self.__build_full_text_response(text)
text = ''
yield llm_response
if (
message.server_content.input_transcription
and message.server_content.input_transcription.text
):
user_text = message.server_content.input_transcription.text
parts = [
types.Part.from_text(
text=user_text,
)
]
llm_response = LlmResponse(
content=types.Content(role='user', parts=parts)
)
yield llm_response
if (
message.server_content.output_transcription
and message.server_content.output_transcription.text
):
# TODO: Right now, we just support output_transcription without
# changing interface and data protocol. Later, we can consider to
# support output_transcription as a separate field in LlmResponse.
# Transcription is always considered as partial event
# We rely on other control signals to determine when to yield the
# full text response(turn_complete, interrupted, or tool_call).
text += message.server_content.output_transcription.text
parts = [
types.Part.from_text(
text=message.server_content.output_transcription.text
)
]
llm_response = LlmResponse(
content=types.Content(role='model', parts=parts), partial=True
)
yield llm_response
if message.server_content.turn_complete:
if text:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(
turn_complete=True, interrupted=message.server_content.interrupted
)
break
# in case of empty content or parts, we sill surface it
# in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model
# safety threshold is triggered
if message.server_content.interrupted and text:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(interrupted=message.server_content.interrupted)
if message.tool_call:
if text:
yield self.__build_full_text_response(text)
text = ''
parts = [
types.Part(function_call=function_call)
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
if message.session_resumption_update:
logger.info('Redeived session reassumption message: %s', message)
yield (
LlmResponse(
live_session_resumption_update=message.session_resumption_update
)
)
async def close(self):
"""Closes the llm server connection."""
await self._gemini_session.close()