forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgemma_llm.py
More file actions
407 lines (328 loc) · 13 KB
/
gemma_llm.py
File metadata and controls
407 lines (328 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import cached_property
import json
import logging
import re
from typing import Any
from typing import AsyncGenerator
from google.adk.models.google_llm import Gemini
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types
from google.genai.types import Content
from google.genai.types import FunctionDeclaration
from google.genai.types import Part
from pydantic import AliasChoices
from pydantic import BaseModel
from pydantic import Field
from pydantic import ValidationError
from typing_extensions import override
logger = logging.getLogger('google_adk.' + __name__)
_MARKDOWN_CODE_BLOCK_RE = re.compile(
r'```(?:(json|tool_code))?\s*(.*?)\s*```', re.DOTALL
)
class GemmaFunctionCallingMixin:
"""Mixin providing function calling support for Gemma models.
Gemma models don't have native function calling support, so this mixin
provides the logic to:
1. Convert function declarations to system instruction prompts
2. Convert function call/response parts to text in the conversation
3. Extract function calls from model text responses
"""
def _move_function_calls_into_system_instruction(
self, llm_request: LlmRequest
) -> None:
"""Converts function declarations to system instructions for Gemma."""
# Convert function calls/responses in contents to text
new_contents: list[Content] = []
for content_item in llm_request.contents:
(
new_parts_for_content,
has_function_response_part,
has_function_call_part,
) = _convert_content_parts_for_gemma(content_item)
if has_function_response_part:
if new_parts_for_content:
new_contents.append(Content(role='user', parts=new_parts_for_content))
elif has_function_call_part:
if new_parts_for_content:
new_contents.append(
Content(role='model', parts=new_parts_for_content)
)
else:
new_contents.append(content_item)
llm_request.contents = new_contents
if not llm_request.config.tools:
return
all_function_declarations: list[FunctionDeclaration] = []
for tool_item in llm_request.config.tools:
if isinstance(tool_item, types.Tool) and tool_item.function_declarations:
all_function_declarations.extend(tool_item.function_declarations)
if all_function_declarations:
system_instruction = _build_gemma_function_system_instruction(
all_function_declarations
)
llm_request.append_instructions([system_instruction])
llm_request.config.tools = []
def _extract_function_calls_from_response(
self, llm_response: LlmResponse
) -> None:
"""Extracts function calls from Gemma text responses."""
if llm_response.partial or (llm_response.turn_complete is True):
return
if not llm_response.content:
return
if not llm_response.content.parts:
return
if len(llm_response.content.parts) > 1:
return
response_text = llm_response.content.parts[0].text
if not response_text:
return
try:
json_candidate = None
block_match = _MARKDOWN_CODE_BLOCK_RE.search(response_text)
if block_match:
json_candidate = block_match.group(2).strip()
else:
found, json_text = _get_last_valid_json_substring(response_text)
if found:
json_candidate = json_text
if not json_candidate:
return
function_call_parsed = GemmaFunctionCallModel.model_validate_json(
json_candidate
)
function_call = types.FunctionCall(
name=function_call_parsed.name,
args=function_call_parsed.parameters,
)
function_call_part = Part(function_call=function_call)
llm_response.content.parts = [function_call_part]
except (json.JSONDecodeError, ValidationError) as e:
logger.debug(
'Error attempting to parse JSON into function call. Leaving as text'
' response. %s',
e,
)
except Exception as e:
logger.warning(
'Error processing Gemma function call response: %s',
e,
exc_info=True,
)
class GemmaFunctionCallModel(BaseModel):
"""Flexible Pydantic model for parsing inline Gemma function call responses."""
name: str = Field(validation_alias=AliasChoices('name', 'function'))
parameters: dict[str, Any] = Field(
validation_alias=AliasChoices('parameters', 'args')
)
class Gemma(GemmaFunctionCallingMixin, Gemini):
"""Integration for Gemma models exposed via the Gemini API.
Only Gemma 3 models are supported at this time. For agentic use cases,
use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended.
For full documentation, see: https://ai.google.dev/gemma/docs/core/
NOTE: Gemma does **NOT** support system instructions. Any system instructions
will be replaced with an initial *user* prompt in the LLM request. If system
instructions change over the course of agent execution, the initial content
**SHOULD** be replaced. Special care is warranted here.
See:
https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions
NOTE: Gemma's function calling support is limited. It does not have full
access to the
same built-in tools as Gemini. It also does not have special API support for
tools and
functions. Rather, tools must be passed in via a `user` prompt, and extracted
from model
responses based on approximate shape.
NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY**
supports
usage via the Gemini API.
"""
model: str = (
'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it]
)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(model="{self.model}")'
@classmethod
@override
def supported_models(cls) -> list[str]:
"""Provides the list of supported models.
Returns:
A list of supported models.
"""
return [
r'gemma-3.*',
]
@cached_property
def _api_backend(self) -> GoogleLLMVariant:
return GoogleLLMVariant.GEMINI_API
@override
async def _preprocess_request(self, llm_request: LlmRequest) -> None:
self._move_function_calls_into_system_instruction(llm_request=llm_request)
if system_instruction := llm_request.config.system_instruction:
contents = llm_request.contents
instruction_content = Content(
role='user', parts=[Part.from_text(text=system_instruction)]
)
# NOTE: if history is preserved, we must include the system instructions ONLY once at the beginning
# of any chain of contents.
if contents:
if contents[0] != instruction_content:
# only prepend if it hasn't already been done
llm_request.contents = [instruction_content] + contents
llm_request.config.system_instruction = None
return await super()._preprocess_request(llm_request)
@override
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
"""Sends a request to the Gemma model.
Args:
llm_request: LlmRequest, the request to send to the Gemini model.
stream: bool = False, whether to do streaming call.
Yields:
LlmResponse: The model response.
"""
# print(f'{llm_request=}')
assert llm_request.model.startswith('gemma-'), (
f'Requesting a non-Gemma model ({llm_request.model}) with the Gemma LLM'
' is not supported.'
)
async for response in super().generate_content_async(llm_request, stream):
self._extract_function_calls_from_response(response)
yield response
def _convert_content_parts_for_gemma(
content_item: Content,
) -> tuple[list[Part], bool, bool]:
"""Converts function call/response parts within a content item to text parts.
Args:
content_item: The original Content item.
Returns:
A tuple containing:
- A list of new Part objects with function calls/responses converted to text.
- A boolean indicating if any function response parts were found.
- A boolean indicating if any function call parts were found.
"""
new_parts: list[Part] = []
has_function_response_part = False
has_function_call_part = False
for part in content_item.parts:
if func_response := part.function_response:
has_function_response_part = True
response_text = (
f'Invoking tool `{func_response.name}` produced:'
f' `{json.dumps(func_response.response)}`.'
)
new_parts.append(Part.from_text(text=response_text))
elif func_call := part.function_call:
has_function_call_part = True
new_parts.append(
Part.from_text(text=func_call.model_dump_json(exclude_none=True))
)
else:
new_parts.append(part)
return new_parts, has_function_response_part, has_function_call_part
def _build_gemma_function_system_instruction(
function_declarations: list[FunctionDeclaration],
) -> str:
"""Constructs the system instruction string for Gemma function calling."""
if not function_declarations:
return ''
system_instruction_prefix = 'You have access to the following functions:\n['
instruction_parts = []
for func in function_declarations:
instruction_parts.append(func.model_dump_json(exclude_none=True))
separator = ',\n'
system_instruction = (
f'{system_instruction_prefix}{separator.join(instruction_parts)}\n]\n'
)
system_instruction += (
'When you call a function, you MUST respond in the format of: '
"""{"name": function name, "parameters": dictionary of argument name and its value}\n"""
'When you call a function, you MUST NOT include any other text in the'
' response.\n'
)
return system_instruction
def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]:
"""Attempts to find and return the last valid JSON object in a string.
This function is designed to extract JSON that might be embedded in a larger
text, potentially with introductory or concluding remarks. It will always choose
the last block of valid json found within the supplied text (if it exists).
Args:
text: The input string to search for JSON objects.
Returns:
A tuple:
- bool: True if a valid JSON substring was found, False otherwise.
- str | None: The last valid JSON substring found, or None if none was
found.
"""
decoder = json.JSONDecoder()
last_json_str = None
start_pos = 0
while start_pos < len(text):
try:
first_brace_index = text.index('{', start_pos)
_, end_index = decoder.raw_decode(text[first_brace_index:])
last_json_str = text[first_brace_index : first_brace_index + end_index]
start_pos = first_brace_index + end_index
except json.JSONDecodeError:
start_pos = first_brace_index + 1
except ValueError:
break
if last_json_str:
return True, last_json_str
return False, None
try:
from google.adk.models.lite_llm import LiteLlm # noqa: F401
except ImportError as e:
logger.debug('LiteLlm not available; Gemma3Ollama will not be defined: %s', e)
LiteLlm = None
if LiteLlm is not None:
class Gemma3Ollama(GemmaFunctionCallingMixin, LiteLlm):
"""Integration for Gemma 3 models running locally via Ollama.
This enables fully local agent workflows using Gemma 3 models.
Requires Ollama to be running with a Gemma 3 model pulled.
Example:
ollama pull gemma3:12b
model = Gemma3Ollama(model="ollama/gemma3:12b")
"""
def __init__(self, model: str = 'ollama/gemma3:12b', **kwargs):
super().__init__(model=model, **kwargs)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(model="{self.model}")'
@classmethod
@override
def supported_models(cls) -> list[str]:
return [
r'ollama/gemma3.*',
]
@override
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
"""Sends a request to Gemma via Ollama/LiteLLM.
Args:
llm_request: LlmRequest, the request to send.
stream: bool = False, whether to do streaming call.
Yields:
LlmResponse: The model response.
"""
self._move_function_calls_into_system_instruction(llm_request)
async for response in super().generate_content_async(llm_request, stream):
self._extract_function_calls_from_response(response)
yield response