Skip to content

Commit 23f68bc

Browse files
authored
Merge branch 'main' into main
2 parents 6ea267c + 5cfef01 commit 23f68bc

12 files changed

Lines changed: 491 additions & 335 deletions

src/google/adk/auth/auth_handler.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@
3636
AUTHLIB_AVAILABLE = False
3737

3838

39+
def _normalize_oauth_scopes(
40+
scopes: dict[str, str] | list[str] | None,
41+
) -> list[str]:
42+
"""Normalize OAuth scopes into the list shape expected by authlib."""
43+
if not scopes:
44+
return []
45+
if isinstance(scopes, dict):
46+
return list(scopes.keys())
47+
return list(scopes)
48+
49+
3950
class AuthHandler:
4051
"""A handler that handles the auth flow in Agent Development Kit to help
4152
orchestrate the credential request and response flow (e.g. OAuth flow)
@@ -164,7 +175,7 @@ def generate_auth_uri(
164175

165176
if isinstance(auth_scheme, OpenIdConnectWithConfig):
166177
authorization_endpoint = auth_scheme.authorization_endpoint
167-
scopes = auth_scheme.scopes
178+
scopes = _normalize_oauth_scopes(auth_scheme.scopes)
168179
else:
169180
authorization_endpoint = (
170181
auth_scheme.flows.implicit
@@ -176,17 +187,20 @@ def generate_auth_uri(
176187
or auth_scheme.flows.password
177188
and auth_scheme.flows.password.tokenUrl
178189
)
179-
scopes = (
180-
auth_scheme.flows.implicit
181-
and auth_scheme.flows.implicit.scopes
182-
or auth_scheme.flows.authorizationCode
183-
and auth_scheme.flows.authorizationCode.scopes
184-
or auth_scheme.flows.clientCredentials
185-
and auth_scheme.flows.clientCredentials.scopes
186-
or auth_scheme.flows.password
187-
and auth_scheme.flows.password.scopes
188-
)
189-
scopes = list(scopes.keys())
190+
if auth_scheme.flows.implicit:
191+
scopes = _normalize_oauth_scopes(auth_scheme.flows.implicit.scopes)
192+
elif auth_scheme.flows.authorizationCode:
193+
scopes = _normalize_oauth_scopes(
194+
auth_scheme.flows.authorizationCode.scopes
195+
)
196+
elif auth_scheme.flows.clientCredentials:
197+
scopes = _normalize_oauth_scopes(
198+
auth_scheme.flows.clientCredentials.scopes
199+
)
200+
elif auth_scheme.flows.password:
201+
scopes = _normalize_oauth_scopes(auth_scheme.flows.password.scopes)
202+
else:
203+
scopes = []
190204

191205
client = OAuth2Session(
192206
auth_credential.oauth2.client_id,

src/google/adk/evaluation/final_response_match_v2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ def aggregate_invocation_results(
237237
continue
238238
num_evaluated += 1
239239
num_valid += result.score
240+
241+
if num_evaluated == 0:
242+
return EvaluationResult(
243+
overall_score=None,
244+
overall_eval_status=EvalStatus.NOT_EVALUATED,
245+
per_invocation_results=per_invocation_results,
246+
)
247+
240248
overall_score = num_valid / num_evaluated
241249
return EvaluationResult(
242250
overall_score=overall_score,

src/google/adk/planners/plan_re_act_planner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def process_planning_response(
7171
# Split the response into reasoning and final answer parts.
7272
self._handle_non_function_call_parts(response_parts[i], preserved_parts)
7373

74-
if first_fc_part_index > 0:
74+
if first_fc_part_index >= 0:
7575
j = first_fc_part_index + 1
7676
while j < len(response_parts):
7777
if response_parts[j].function_call:

src/google/adk/runners.py

Lines changed: 8 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,6 @@ def _find_active_task_isolation_scope(session) -> Optional[str]:
106106
return None
107107

108108

109-
def _is_tool_call_or_response(event: Event) -> bool:
110-
return bool(event.get_function_calls() or event.get_function_responses())
111-
112-
113109
def _get_function_responses_from_content(
114110
content: types.Content,
115111
) -> list[types.FunctionResponse]:
@@ -120,21 +116,6 @@ def _get_function_responses_from_content(
120116
]
121117

122118

123-
def _is_transcription(event: Event) -> bool:
124-
return (
125-
event.input_transcription is not None
126-
or event.output_transcription is not None
127-
)
128-
129-
130-
def _has_non_empty_transcription_text(
131-
transcription: types.Transcription,
132-
) -> bool:
133-
return bool(
134-
transcription and transcription.text and transcription.text.strip()
135-
)
136-
137-
138119
def _apply_run_config_custom_metadata(
139120
event: Event, run_config: RunConfig | None
140121
) -> None:
@@ -1404,22 +1385,6 @@ async def _exec_with_plugin(
14041385
yield early_exit_event
14051386
else:
14061387
# Step 2: Otherwise continue with normal execution
1407-
# Note for live/bidi:
1408-
# the transcription may arrive later than the action(function call
1409-
# event and thus function response event). In this case, the order of
1410-
# transcription and function call event will be wrong if we just
1411-
# append as it arrives. To address this, we should check if there is
1412-
# transcription going on. If there is transcription going on, we
1413-
# should hold on appending the function call event until the
1414-
# transcription is finished. The transcription in progress can be
1415-
# identified by checking if the transcription event is partial. When
1416-
# the next transcription event is not partial, it means the previous
1417-
# transcription is finished. Then if there is any buffered function
1418-
# call event, we should append them after this finished(non-partial)
1419-
# transcription event.
1420-
buffered_events: list[Event] = []
1421-
is_transcribing: bool = False
1422-
14231388
async with aclosing(execute_fn(invocation_context)) as agen:
14241389
async for event in agen:
14251390
_apply_run_config_custom_metadata(
@@ -1437,50 +1402,14 @@ async def _exec_with_plugin(
14371402
)
14381403

14391404
if is_live_call:
1440-
if event.partial and _is_transcription(event):
1441-
is_transcribing = True
1442-
if is_transcribing and _is_tool_call_or_response(event):
1443-
# only buffer function call and function response event which is
1444-
# non-partial
1445-
buffered_events.append(output_event)
1446-
continue
1447-
# Note for live/bidi: for audio response, it's considered as
1448-
# non-partial event(event.partial=None)
1449-
# event.partial=False and event.partial=None are considered as
1450-
# non-partial event; event.partial=True is considered as partial
1451-
# event.
1452-
if event.partial is not True:
1453-
if _is_transcription(event) and (
1454-
_has_non_empty_transcription_text(event.input_transcription)
1455-
or _has_non_empty_transcription_text(
1456-
event.output_transcription
1457-
)
1458-
):
1459-
# transcription end signal, append buffered events
1460-
is_transcribing = False
1461-
logger.debug(
1462-
'Appending transcription finished event: %s', event
1463-
)
1464-
if self._should_append_event(event, is_live_call):
1465-
await self.session_service.append_event(
1466-
session=invocation_context.session, event=output_event
1467-
)
1468-
1469-
for buffered_event in buffered_events:
1470-
logger.debug('Appending buffered event: %s', buffered_event)
1471-
await self.session_service.append_event(
1472-
session=invocation_context.session, event=buffered_event
1473-
)
1474-
yield buffered_event # yield buffered events to caller
1475-
buffered_events = []
1476-
else:
1477-
# non-transcription event or empty transcription event, for
1478-
# example, event that stores blob reference, should be appended.
1479-
if self._should_append_event(event, is_live_call):
1480-
logger.debug('Appending non-buffered event: %s', event)
1481-
await self.session_service.append_event(
1482-
session=invocation_context.session, event=output_event
1483-
)
1405+
# Skip partial transcriptions for Live
1406+
if event.partial is not True and self._should_append_event(
1407+
event, is_live_call
1408+
):
1409+
logger.debug('Appending live event: %s', output_event)
1410+
await self.session_service.append_event(
1411+
session=invocation_context.session, event=output_event
1412+
)
14841413
else:
14851414
if event.partial is not True:
14861415
await self.session_service.append_event(

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from ..features import FeatureName
3838
from ..features import is_feature_enabled
3939
from ..utils.variant_utils import GoogleLLMVariant
40+
from ._gemini_schema_util import _sanitize_schema_formats_for_gemini
4041

4142
_py_type_2_schema_type = {
4243
'str': types.Type.STRING,
@@ -365,8 +366,13 @@ def from_function_with_options(
365366
param
366367
)
367368

369+
sanitized_schema = json_schema_dict
370+
if variant == GoogleLLMVariant.GEMINI_API:
371+
sanitized_schema = _sanitize_schema_formats_for_gemini(
372+
json_schema_dict
373+
)
368374
parameters_json_schema[name] = types.Schema.model_validate(
369-
json_schema_dict
375+
sanitized_schema
370376
)
371377
if param.default is not inspect.Parameter.empty:
372378
if param.default is not None:

src/google/adk/tools/_function_parameter_parse_util.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,15 @@ def _generate_json_schema_for_parameter(
123123
) -> dict[str, Any]:
124124
"""Generates a JSON schema for a parameter using pydantic.TypeAdapter."""
125125

126-
param_schema_adapter = pydantic.TypeAdapter(
127-
param.annotation,
128-
config=pydantic.ConfigDict(arbitrary_types_allowed=True),
129-
)
126+
if inspect.isclass(param.annotation) and issubclass(
127+
param.annotation, pydantic.BaseModel
128+
):
129+
param_schema_adapter = pydantic.TypeAdapter(param.annotation)
130+
else:
131+
param_schema_adapter = pydantic.TypeAdapter(
132+
param.annotation,
133+
config=pydantic.ConfigDict(arbitrary_types_allowed=True),
134+
)
130135
json_schema_dict = param_schema_adapter.json_schema()
131136
json_schema_dict = _add_unevaluated_items_to_fixed_len_tuple_schema(
132137
json_schema_dict

tests/unittests/auth/test_auth_handler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from fastapi.openapi.models import APIKeyIn
2323
from fastapi.openapi.models import OAuth2
2424
from fastapi.openapi.models import OAuthFlowAuthorizationCode
25+
from fastapi.openapi.models import OAuthFlowClientCredentials
2526
from fastapi.openapi.models import OAuthFlows
2627
from google.adk.auth.auth_credential import AuthCredential
2728
from google.adk.auth.auth_credential import AuthCredentialTypes
@@ -273,6 +274,35 @@ def test_generate_auth_uri_openid(
273274
assert "client_id=mock_client_id" in result.oauth2.auth_uri
274275
assert result.oauth2.state == "mock_state"
275276

277+
@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
278+
def test_generate_auth_uri_client_credentials_with_missing_scopes(
279+
self, oauth2_credentials
280+
):
281+
"""Test client credentials flow tolerates missing scopes."""
282+
auth_scheme = OAuth2(
283+
flows=OAuthFlows(
284+
clientCredentials=OAuthFlowClientCredentials(
285+
tokenUrl="https://example.com/oauth2/token"
286+
)
287+
)
288+
)
289+
auth_scheme.flows.clientCredentials.scopes = None
290+
291+
config = AuthConfig(
292+
auth_scheme=auth_scheme,
293+
raw_auth_credential=oauth2_credentials,
294+
exchanged_auth_credential=oauth2_credentials.model_copy(deep=True),
295+
)
296+
297+
handler = AuthHandler(config)
298+
result = handler.generate_auth_uri()
299+
300+
assert (
301+
result.oauth2.auth_uri
302+
== "https://example.com/oauth2/token?client_id=mock_client_id&scope="
303+
)
304+
assert result.oauth2.state == "mock_state"
305+
276306
@patch("google.adk.auth.auth_handler.OAuth2Session")
277307
def test_generate_auth_uri_pkce(
278308
self, mock_oauth2_session, oauth2_auth_scheme, oauth2_credentials

tests/unittests/evaluation/test_final_response_match_v2.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,34 @@ def test_aggregate_invocation_results():
561561
# Only 4 / 8 invocations are evaluated, and 2 / 4 are valid.
562562
assert aggregated_result.overall_score == 0.5
563563
assert aggregated_result.overall_eval_status == EvalStatus.PASSED
564+
565+
566+
def test_aggregate_invocation_results_none_evaluated():
567+
evaluator = _create_test_evaluator_gemini(threshold=0.5)
568+
569+
actual_invocation, expected_invocation = _create_test_invocations(
570+
"candidate text", "reference text"
571+
)
572+
573+
per_invocation_results = [
574+
PerInvocationResult(
575+
actual_invocation=actual_invocation,
576+
expected_invocation=expected_invocation,
577+
score=None,
578+
eval_status=EvalStatus.NOT_EVALUATED,
579+
),
580+
PerInvocationResult(
581+
actual_invocation=actual_invocation,
582+
expected_invocation=expected_invocation,
583+
score=1.0,
584+
eval_status=EvalStatus.NOT_EVALUATED,
585+
),
586+
]
587+
588+
aggregated_result = evaluator.aggregate_invocation_results(
589+
per_invocation_results
590+
)
591+
592+
assert aggregated_result.overall_score is None
593+
assert aggregated_result.overall_eval_status == EvalStatus.NOT_EVALUATED
594+
assert aggregated_result.per_invocation_results == per_invocation_results
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for PlanReActPlanner.process_planning_response."""
16+
17+
from google.adk.planners.plan_re_act_planner import PlanReActPlanner
18+
from google.genai import types
19+
20+
21+
def _function_call_names(parts):
22+
return [p.function_call.name for p in parts if p.function_call]
23+
24+
25+
def test_preserves_all_leading_parallel_function_calls():
26+
"""Parallel function calls at the start of the response must all survive.
27+
28+
Regression test: the trailing-group guard used ``> 0``, so when the first
29+
part was a function call (index 0) the loop that collects the rest of the
30+
parallel call group never ran and every call after the first was dropped.
31+
"""
32+
planner = PlanReActPlanner()
33+
response_parts = [
34+
types.Part.from_function_call(name="get_weather", args={"city": "SF"}),
35+
types.Part.from_function_call(name="get_time", args={"city": "SF"}),
36+
]
37+
38+
result = planner.process_planning_response(
39+
callback_context=None, response_parts=response_parts
40+
)
41+
42+
assert _function_call_names(result) == ["get_weather", "get_time"]
43+
44+
45+
def test_preserves_parallel_function_calls_after_leading_text():
46+
"""The same parallel group is preserved when text comes first."""
47+
planner = PlanReActPlanner()
48+
response_parts = [
49+
types.Part(text="Let me look that up."),
50+
types.Part.from_function_call(name="get_weather", args={"city": "SF"}),
51+
types.Part.from_function_call(name="get_time", args={"city": "SF"}),
52+
]
53+
54+
result = planner.process_planning_response(
55+
callback_context=None, response_parts=response_parts
56+
)
57+
58+
assert _function_call_names(result) == ["get_weather", "get_time"]

0 commit comments

Comments
 (0)