Skip to content

Commit 55b4d59

Browse files
committed
Merge remote feat/adk-compat-smooth-upgrade into feat/google-adk-v2-upgrade
2 parents fa1ddc9 + 0214038 commit 55b4d59

10 files changed

Lines changed: 330 additions & 41 deletions

File tree

tests/test_adk_compat.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from types import SimpleNamespace
16+
17+
from veadk.memory.short_term_memory_backends.mysql_backend import MysqlSTMBackend
18+
from veadk.memory.short_term_memory_backends.postgresql_backend import PostgreSqlSTMBackend
19+
from veadk.memory.short_term_memory_backends.sqlite_backend import SQLiteSTMBackend
20+
from veadk.tracing.telemetry.attributes.extractors.tool_attributes_extractors import (
21+
tool_gen_ai_tool_output,
22+
)
23+
from veadk.tracing.telemetry.attributes.extractors.types import ToolAttributesParams
24+
import veadk.utils.adk_compat as adk_compat
25+
26+
27+
def test_get_previous_interaction_id_missing_field():
28+
llm_request = SimpleNamespace()
29+
assert adk_compat.get_previous_interaction_id(llm_request) is None
30+
31+
32+
def test_get_previous_interaction_id_with_field():
33+
llm_request = SimpleNamespace(previous_interaction_id="interaction_123")
34+
assert adk_compat.get_previous_interaction_id(llm_request) == "interaction_123"
35+
36+
37+
def test_get_event_function_calls_from_getter():
38+
expected_calls = [SimpleNamespace(name="tool_a")]
39+
40+
class Event:
41+
def get_function_calls(self):
42+
return expected_calls
43+
44+
calls = adk_compat.get_event_function_calls(Event())
45+
assert calls == expected_calls
46+
47+
48+
def test_get_event_function_calls_fallback_to_parts():
49+
part1 = SimpleNamespace(function_call=SimpleNamespace(name="tool_1"))
50+
part2 = SimpleNamespace(function_call=None)
51+
event = SimpleNamespace(content=SimpleNamespace(parts=[part1, part2]))
52+
53+
calls = adk_compat.get_event_function_calls(event)
54+
assert len(calls) == 1
55+
assert calls[0].name == "tool_1"
56+
57+
58+
def test_get_event_function_calls_getter_error_fallback_to_parts():
59+
class Event:
60+
content = SimpleNamespace(parts=[SimpleNamespace(function_call="fallback_call")])
61+
62+
def get_function_calls(self):
63+
raise RuntimeError("broken getter")
64+
65+
calls = adk_compat.get_event_function_calls(Event())
66+
assert calls == ["fallback_call"]
67+
68+
69+
def test_get_event_function_responses_fallback_to_parts():
70+
part = SimpleNamespace(function_response=SimpleNamespace(name="tool_resp"))
71+
event = SimpleNamespace(content=SimpleNamespace(parts=[part]))
72+
73+
responses = adk_compat.get_event_function_responses(event)
74+
assert len(responses) == 1
75+
assert responses[0].name == "tool_resp"
76+
77+
78+
def test_mysql_backend_url_respects_async_driver_flag(monkeypatch):
79+
monkeypatch.setattr(
80+
"veadk.memory.short_term_memory_backends.mysql_backend.should_use_async_db_drivers",
81+
lambda: True,
82+
)
83+
backend = MysqlSTMBackend()
84+
assert backend._db_url.startswith("mysql+aiomysql://")
85+
86+
monkeypatch.setattr(
87+
"veadk.memory.short_term_memory_backends.mysql_backend.should_use_async_db_drivers",
88+
lambda: False,
89+
)
90+
backend = MysqlSTMBackend()
91+
assert backend._db_url.startswith("mysql+pymysql://")
92+
93+
94+
def test_postgresql_backend_url_respects_async_driver_flag(monkeypatch):
95+
monkeypatch.setattr(
96+
"veadk.memory.short_term_memory_backends.postgresql_backend.should_use_async_db_drivers",
97+
lambda: True,
98+
)
99+
backend = PostgreSqlSTMBackend()
100+
assert backend._db_url.startswith("postgresql+asyncpg://")
101+
102+
monkeypatch.setattr(
103+
"veadk.memory.short_term_memory_backends.postgresql_backend.should_use_async_db_drivers",
104+
lambda: False,
105+
)
106+
backend = PostgreSqlSTMBackend()
107+
assert backend._db_url.startswith("postgresql://")
108+
109+
110+
def test_sqlite_backend_url_respects_async_driver_flag(monkeypatch, tmp_path):
111+
db_file = tmp_path / "compat-test.db"
112+
113+
monkeypatch.setattr(
114+
"veadk.memory.short_term_memory_backends.sqlite_backend.should_use_async_db_drivers",
115+
lambda: True,
116+
)
117+
backend = SQLiteSTMBackend(local_path=str(db_file))
118+
assert backend._db_url.startswith("sqlite+aiosqlite:///")
119+
120+
monkeypatch.setattr(
121+
"veadk.memory.short_term_memory_backends.sqlite_backend.should_use_async_db_drivers",
122+
lambda: False,
123+
)
124+
backend = SQLiteSTMBackend(local_path=str(db_file))
125+
assert backend._db_url.startswith("sqlite:///")
126+
127+
128+
def test_tool_output_extractor_accepts_dict_response():
129+
function_response_event = SimpleNamespace(
130+
content=SimpleNamespace(
131+
parts=[
132+
SimpleNamespace(
133+
function_response={
134+
"id": "id_1",
135+
"name": "tool_name",
136+
"response": {"ok": True},
137+
}
138+
)
139+
]
140+
)
141+
)
142+
params = ToolAttributesParams(
143+
tool=SimpleNamespace(name="tool_name"),
144+
args={},
145+
function_response_event=function_response_event,
146+
)
147+
148+
response = tool_gen_ai_tool_output(params)
149+
assert '"name": "tool_name"' in response.content
150+
151+
152+
def test_tool_output_extractor_accepts_object_response():
153+
function_response_event = SimpleNamespace(
154+
content=SimpleNamespace(
155+
parts=[
156+
SimpleNamespace(
157+
function_response=SimpleNamespace(
158+
id="id_2",
159+
name="tool_obj",
160+
response={"status": "done"},
161+
)
162+
)
163+
]
164+
)
165+
)
166+
params = ToolAttributesParams(
167+
tool=SimpleNamespace(name="tool_obj"),
168+
args={},
169+
function_response_event=function_response_event,
170+
)
171+
172+
response = tool_gen_ai_tool_output(params)
173+
assert '"name": "tool_obj"' in response.content

veadk/evaluation/base_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from google.genai import types
2626
from pydantic import BaseModel
2727

28+
from veadk.utils.adk_compat import get_event_function_calls
2829
from veadk.utils.misc import formatted_timestamp
2930

3031

@@ -556,8 +557,8 @@ async def generate_actual_outputs(self):
556557
and event.content.parts
557558
):
558559
final_response = event.content
559-
elif event.get_function_calls():
560-
for call in event.get_function_calls():
560+
else:
561+
for call in get_event_function_calls(event):
561562
tool_uses.append(call)
562563
tok = time.time()
563564
_latency = str((tok - tik) * 1000)

veadk/integrations/ve_identity/auth_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ async def event_generator():
294294
new_message=message,
295295
run_config=RunConfig(streaming_mode=stream_mode),
296296
):
297-
if event.get_function_calls():
298-
for function_call in event.get_function_calls():
297+
if get_event_function_calls(event):
298+
for function_call in get_event_function_calls(event):
299299
logger.debug(f"Function call: {function_call}")
300300
elif event.content is not None:
301301
yield event.content.parts[0].text

veadk/memory/short_term_memory_backends/mysql_backend.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
from functools import cached_property
1616
from typing import Any
1717

18-
from google.adk import version as adk_version
1918
from google.adk.sessions import (
2019
BaseSessionService,
2120
DatabaseSessionService,
2221
)
23-
from packaging.version import parse as parse_version
2422
from pydantic import Field
2523
from typing_extensions import override
2624
from urllib.parse import quote_plus
@@ -30,6 +28,7 @@
3028
from veadk.memory.short_term_memory_backends.base_backend import (
3129
BaseShortTermMemoryBackend,
3230
)
31+
from veadk.utils.adk_compat import should_use_async_db_drivers
3332

3433

3534
class MysqlSTMBackend(BaseShortTermMemoryBackend):
@@ -39,10 +38,10 @@ class MysqlSTMBackend(BaseShortTermMemoryBackend):
3938
def model_post_init(self, context: Any) -> None:
4039
encoded_username = quote_plus(self.mysql_config.user)
4140
encoded_password = quote_plus(self.mysql_config.password)
42-
if parse_version(adk_version.__version__) < parse_version("1.19.0"):
43-
self._db_url = f"mysql+pymysql://{encoded_username}:{encoded_password}@{self.mysql_config.host}/{self.mysql_config.database}"
44-
else:
41+
if should_use_async_db_drivers():
4542
self._db_url = f"mysql+aiomysql://{encoded_username}:{encoded_password}@{self.mysql_config.host}/{self.mysql_config.database}"
43+
else:
44+
self._db_url = f"mysql+pymysql://{encoded_username}:{encoded_password}@{self.mysql_config.host}/{self.mysql_config.database}"
4645

4746
@cached_property
4847
@override

veadk/memory/short_term_memory_backends/postgresql_backend.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616
from typing import Any
1717
from urllib.parse import quote_plus
1818

19-
from google.adk import version as adk_version
2019
from google.adk.sessions import (
2120
BaseSessionService,
2221
DatabaseSessionService,
2322
)
24-
from packaging.version import parse as parse_version
2523
from pydantic import Field
2624
from typing_extensions import override
2725

@@ -30,6 +28,7 @@
3028
from veadk.memory.short_term_memory_backends.base_backend import (
3129
BaseShortTermMemoryBackend,
3230
)
31+
from veadk.utils.adk_compat import should_use_async_db_drivers
3332

3433

3534
class PostgreSqlSTMBackend(BaseShortTermMemoryBackend):
@@ -39,10 +38,10 @@ class PostgreSqlSTMBackend(BaseShortTermMemoryBackend):
3938
def model_post_init(self, context: Any) -> None:
4039
encoded_username = quote_plus(self.postgresql_config.user)
4140
encoded_password = quote_plus(self.postgresql_config.password)
42-
if parse_version(adk_version.__version__) < parse_version("1.19.0"):
43-
self._db_url = f"postgresql://{encoded_username}:{encoded_password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}"
44-
else:
41+
if should_use_async_db_drivers():
4542
self._db_url = f"postgresql+asyncpg://{encoded_username}:{encoded_password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}"
43+
else:
44+
self._db_url = f"postgresql://{encoded_username}:{encoded_password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}"
4645

4746
@cached_property
4847
@override

veadk/memory/short_term_memory_backends/sqlite_backend.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,16 @@
1717
from functools import cached_property
1818
from typing import Any
1919

20-
from google.adk import version as adk_version
2120
from google.adk.sessions import (
2221
BaseSessionService,
2322
DatabaseSessionService,
2423
)
25-
from packaging.version import parse as parse_version
2624
from typing_extensions import override
2725

2826
from veadk.memory.short_term_memory_backends.base_backend import (
2927
BaseShortTermMemoryBackend,
3028
)
29+
from veadk.utils.adk_compat import should_use_async_db_drivers
3130

3231

3332
class SQLiteSTMBackend(BaseShortTermMemoryBackend):
@@ -41,10 +40,10 @@ def model_post_init(self, context: Any) -> None:
4140
conn = sqlite3.connect(self.local_path)
4241
conn.close()
4342

44-
if parse_version(adk_version.__version__) < parse_version("1.19.0"):
45-
self._db_url = f"sqlite:///{self.local_path}"
46-
else:
43+
if should_use_async_db_drivers():
4744
self._db_url = f"sqlite+aiosqlite:///{self.local_path}"
45+
else:
46+
self._db_url = f"sqlite:///{self.local_path}"
4847

4948
@cached_property
5049
@override

veadk/models/ark_llm.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@
6060

6161
from veadk.config import settings
6262
from veadk.consts import DEFAULT_VIDEO_MODEL_API_BASE
63+
from veadk.utils.adk_compat import (
64+
get_previous_interaction_id,
65+
llm_request_has_field,
66+
)
6367
from veadk.utils.logger import get_logger
6468

6569
logger = get_logger(__name__)
@@ -580,8 +584,8 @@ def record_logs(raw_response: ArkTypeResponse):
580584
f"Status: `{raw_response.status}`. "
581585
f"{error_message}"
582586
)
583-
except Exception as e:
584-
logger.error(f"Failed to record ark logs: {e}")
587+
except Exception:
588+
logger.exception("Failed to record Ark response logs")
585589

586590

587591
def event_to_generate_content_response(
@@ -703,7 +707,7 @@ class ArkLlm(Gemini):
703707

704708
def __init__(self, **kwargs):
705709
# adk version check
706-
if "previous_interaction_id" not in LlmRequest.model_fields:
710+
if not llm_request_has_field("previous_interaction_id"):
707711
raise ImportError(
708712
"If using the ResponsesAPI, "
709713
"please upgrade the version of google-adk to `1.21.0` or higher with the command: "
@@ -746,8 +750,8 @@ async def generate_content_async(
746750
# ------------------------------------------------------ #
747751
# get previous_response_id
748752
previous_response_id = None
749-
if self.enable_responses_cache and llm_request.previous_interaction_id:
750-
previous_response_id = llm_request.previous_interaction_id
753+
if self.enable_responses_cache:
754+
previous_response_id = get_previous_interaction_id(llm_request)
751755
responses_args = {
752756
"model": self.model,
753757
"instructions": instructions,
@@ -786,15 +790,17 @@ async def generate_content_async(
786790
responses_args.copy(), stream=stream
787791
):
788792
yield llm_response
789-
except Exception as retry_e:
790-
logger.error(f"Retry failed in generate_content_async: {retry_e}")
791-
raise retry_e
793+
except Exception:
794+
logger.exception(
795+
"Retry without previous_response_id failed in Ark Responses API"
796+
)
797+
raise
792798
else:
793-
logger.error(f"Error in generate_content_async: {e}")
794-
raise e
795-
except Exception as e:
796-
logger.error(f"Error in generate_content_async: {e}")
797-
raise e
799+
logger.exception("Ark Responses API request failed")
800+
raise
801+
except Exception:
802+
logger.exception("Unexpected error in Ark Responses API generation")
803+
raise
798804

799805
async def generate_content_via_responses(
800806
self, responses_args: dict, stream: bool = False

0 commit comments

Comments
 (0)