Skip to content

Commit 50b9813

Browse files
authored
Merge branch 'main' into sid/evaluate-full-response
2 parents 7c9d5b0 + 4c0c6db commit 50b9813

8 files changed

Lines changed: 325 additions & 24 deletions

File tree

src/google/adk/environment/_local_environment.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,23 +123,24 @@ async def execute(
123123
)
124124

125125
@override
126-
async def read_file(self, path: str) -> bytes:
126+
async def read_file(self, path: str | Path) -> bytes:
127127
if self._working_dir is None:
128128
raise RuntimeError('`working_dir` is not set. Call initialize() first.')
129129

130-
path = self._resolve_path(path)
131-
return await asyncio.to_thread(self._sync_read, path)
130+
resolved = self._resolve_path(path)
131+
return await asyncio.to_thread(self._sync_read, resolved)
132132

133133
@override
134-
async def write_file(self, path: str, content: str | bytes) -> None:
134+
async def write_file(self, path: str | Path, content: str | bytes) -> None:
135135
if self._working_dir is None:
136136
raise RuntimeError('`working_dir` is not set. Call initialize() first.')
137137

138-
path = self._resolve_path(path)
139-
return await asyncio.to_thread(self._sync_write, path, content)
138+
resolved = self._resolve_path(path)
139+
return await asyncio.to_thread(self._sync_write, resolved, content)
140140

141-
def _resolve_path(self, path: str) -> str:
141+
def _resolve_path(self, path: str | Path) -> str:
142142
"""Resolve a relative path against the working directory."""
143+
path = str(path)
143144
if os.path.isabs(path):
144145
return path
145146
return os.path.join(self._working_dir, path)

src/google/adk/models/apigee_llm.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
'object',
6161
)
6262

63+
_REFUSAL_PREFIX = '[[REFUSAL]]: '
64+
6365

6466
class ApigeeLlm(Gemini):
6567
"""A BaseLlm implementation for calling Apigee proxy.
@@ -658,11 +660,14 @@ def _content_to_messages(
658660

659661
tool_calls = []
660662
content_parts = []
663+
refusals: list[str] = []
661664

662665
function_responses = []
663666

664667
for part in content.parts or []:
665-
self._process_content_part(content, part, tool_calls, content_parts)
668+
self._process_content_part(
669+
content, part, tool_calls, content_parts, refusals
670+
)
666671
if part.function_response:
667672
function_responses.append({
668673
'role': 'tool',
@@ -673,6 +678,8 @@ def _content_to_messages(
673678
return function_responses
674679

675680
message = {'role': role}
681+
if refusals:
682+
message['refusal'] = '\n'.join(refusals)
676683
if tool_calls:
677684
message['tool_calls'] = tool_calls
678685
if not content_parts:
@@ -691,6 +698,7 @@ def _process_content_part(
691698
part: types.Part,
692699
tool_calls: list[dict[str, Any]],
693700
content_parts: list[dict[str, Any]],
701+
refusals: list[str],
694702
) -> None:
695703
"""Processes a single Part and updates tool_calls or content_parts."""
696704
if content.role != 'user' and (
@@ -731,7 +739,14 @@ def _process_content_part(
731739
# Handled in the loop to return immediately
732740
pass
733741
elif part.text:
734-
content_parts.append({'type': 'text', 'text': part.text})
742+
if part.text.startswith(_REFUSAL_PREFIX):
743+
refusals.append(part.text.removeprefix(_REFUSAL_PREFIX))
744+
else:
745+
before, sep, after = part.text.partition('\n' + _REFUSAL_PREFIX)
746+
if sep:
747+
refusals.append(after)
748+
if before:
749+
content_parts.append({'type': 'text', 'text': before})
735750
elif part.inline_data:
736751
mime_type = part.inline_data.mime_type
737752
data = base64.b64encode(part.inline_data.data).decode('utf-8')
@@ -843,6 +858,7 @@ def __init__(self):
843858
self.usage = {}
844859
self.logprobs = {}
845860
self.custom_metadata = {}
861+
self._refusal_started = False
846862

847863
def process_response(self, response: dict[str, Any]) -> LlmResponse:
848864
"""Processes a complete non-streaming response."""
@@ -989,19 +1005,49 @@ def _accumulate_logprobs(self, logprobs_chunk: dict[str, Any]) -> None:
9891005
self.logprobs['refusal'] = []
9901006
self.logprobs['refusal'].extend(logprobs_chunk['refusal'])
9911007

992-
def _append_content(self, content: str, refusal: str) -> str:
993-
if content and refusal:
994-
content += '\n'
995-
content += refusal
996-
elif refusal:
997-
content = refusal
1008+
def _accumulate_content(self, choice: dict[str, Any]) -> str:
1009+
"""Processes a message or delta chunk to accumulate content and refusals.
1010+
1011+
This method extracts 'content' and 'refusal' from the chunk, updates the
1012+
accumulated state (self.content_parts), and returns the text content for
1013+
this chunk (handling prefixes and newlines if it's a refusal).
1014+
1015+
Args:
1016+
choice: A dictionary representing a message choice or a streaming delta.
1017+
1018+
Returns:
1019+
The text content to be appended or yielded for this chunk.
1020+
"""
1021+
content = choice.get('content', '')
1022+
refusal = choice.get('refusal', '')
1023+
1024+
if content and self._refusal_started:
1025+
logging.warning(
1026+
'Received content after refusal has started. Dropping content.'
1027+
)
1028+
content = ''
1029+
1030+
chunk_text = ''
9981031
if content:
999-
self.content_parts += content
1000-
return content
1032+
chunk_text += content
1033+
1034+
if refusal and not self._refusal_started:
1035+
self._refusal_started = True
1036+
if self.content_parts or chunk_text:
1037+
chunk_text += '\n'
1038+
chunk_text += _REFUSAL_PREFIX
1039+
1040+
if refusal:
1041+
chunk_text += refusal
1042+
1043+
if chunk_text:
1044+
self.content_parts += chunk_text
1045+
1046+
return chunk_text
10011047

10021048
def _add_chat_completion_chunk_delta(
10031049
self, delta: dict[str, Any]
1004-
) -> (list[types.Part], str):
1050+
) -> tuple[list[types.Part], str]:
10051051
"""Adds a chunk delta from a streaming chat completions response.
10061052
10071053
This method processes a single delta chunk from a streaming chat completions
@@ -1021,9 +1067,7 @@ def _add_chat_completion_chunk_delta(
10211067
for tool_call in delta.get('tool_calls', []):
10221068
chunk_part = self._upsert_tool_call(tool_call)
10231069
parts.append(chunk_part)
1024-
content = delta.get('content')
1025-
refusal = delta.get('refusal')
1026-
merged_content = self._append_content(content, refusal)
1070+
merged_content = self._accumulate_content(delta)
10271071
if merged_content:
10281072
parts.append(types.Part.from_text(text=merged_content))
10291073

@@ -1057,9 +1101,7 @@ def _add_chat_completion_message(
10571101
'type': 'function',
10581102
'function': function_call,
10591103
})
1060-
content = message.get('content')
1061-
refusal = message.get('refusal')
1062-
self._append_content(content, refusal)
1104+
self._accumulate_content(message)
10631105

10641106
self._get_or_create_role(message.get('role', 'model'))
10651107
return self._get_content_parts(), self.role

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import base64
1919
import inspect
2020
import logging
21+
import os
2122
from typing import Any
2223
from typing import Callable
2324
from typing import Dict
@@ -169,6 +170,16 @@ def __init__(
169170
Raises:
170171
ValueError: If mcp_tool or mcp_session_manager is None.
171172
"""
173+
174+
# --- BEGIN BOUND TOKEN PATCH ---
175+
# Set GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES to false
176+
# to disable bound token sharing. Tracking on
177+
# https://github.com/google/adk-python/issues/5361
178+
os.environ["GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES"] = (
179+
"false"
180+
)
181+
# --- END BOUND TOKEN PATCH ---
182+
172183
super().__init__(
173184
name=mcp_tool.name,
174185
description=mcp_tool.description if mcp_tool.description else "",

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import asyncio
1818
import base64
1919
import logging
20+
import os
2021
import sys
2122
from typing import Any
2223
from typing import Awaitable
@@ -158,6 +159,15 @@ def __init__(
158159
sampling_capabilities: Optional capabilities for sampling.
159160
"""
160161

162+
# --- BEGIN BOUND TOKEN PATCH ---
163+
# Set GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES to false
164+
# to disable bound token sharing. Tracking on
165+
# https://github.com/google/adk-python/issues/5361
166+
os.environ["GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES"] = (
167+
"false"
168+
)
169+
# --- END BOUND TOKEN PATCH ---
170+
161171
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
162172

163173
self._sampling_callback = sampling_callback

tests/unittests/models/test_apigee_llm.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,3 +649,86 @@ def test_parse_response_usage_metadata():
649649
assert llm_response.usage_metadata.candidates_token_count == 5
650650
assert llm_response.usage_metadata.total_token_count == 15
651651
assert llm_response.usage_metadata.thoughts_token_count == 4
652+
653+
654+
def test_parse_response_with_refusal():
655+
"""Tests that CompletionsHTTPClient parses refusal correctly."""
656+
client = CompletionsHTTPClient(base_url='http://test')
657+
658+
response_dict = {
659+
'choices': [{
660+
'message': {
661+
'role': 'assistant',
662+
'refusal': 'I refuse to answer',
663+
},
664+
'finish_reason': 'stop',
665+
}],
666+
}
667+
llm_response = client._parse_response(response_dict)
668+
assert len(llm_response.content.parts) == 1
669+
assert llm_response.content.parts[0].text == '[[REFUSAL]]: I refuse to answer'
670+
671+
response_dict_mixed = {
672+
'choices': [{
673+
'message': {
674+
'role': 'assistant',
675+
'content': 'Here is some content',
676+
'refusal': 'But I refuse to answer the rest',
677+
},
678+
'finish_reason': 'stop',
679+
}],
680+
}
681+
llm_response_mixed = client._parse_response(response_dict_mixed)
682+
assert len(llm_response_mixed.content.parts) == 1
683+
assert (
684+
llm_response_mixed.content.parts[0].text
685+
== 'Here is some content\n[[REFUSAL]]: But I refuse to answer the rest'
686+
)
687+
688+
689+
@pytest.mark.parametrize(
690+
('parts', 'expected_message'),
691+
[
692+
(
693+
[
694+
types.Part.from_text(text='[[REFUSAL]]: I refuse to answer'),
695+
types.Part.from_text(text='normal content'),
696+
],
697+
{
698+
'role': 'assistant',
699+
'refusal': 'I refuse to answer',
700+
'content': 'normal content',
701+
},
702+
),
703+
(
704+
[
705+
types.Part.from_text(
706+
text=(
707+
'Here is some content\n[[REFUSAL]]: But I refuse to'
708+
' answer the rest'
709+
)
710+
),
711+
],
712+
{
713+
'role': 'assistant',
714+
'refusal': 'But I refuse to answer the rest',
715+
'content': 'Here is some content',
716+
},
717+
),
718+
],
719+
)
720+
def test_construct_payload_with_refusal(parts, expected_message):
721+
"""Tests that CompletionsHTTPClient constructs payload with refusal correctly."""
722+
client = CompletionsHTTPClient(base_url='http://test')
723+
req = LlmRequest(
724+
model='apigee/openai/gpt-4o',
725+
contents=[
726+
types.Content(
727+
role='model',
728+
parts=parts,
729+
)
730+
],
731+
)
732+
payload = client._construct_payload(req, stream=False)
733+
messages = payload['messages']
734+
assert messages == [expected_message]

tests/unittests/models/test_completions_http_client.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717
from unittest.mock import AsyncMock
1818

19+
from google.adk.models.apigee_llm import ChatCompletionsResponseHandler
1920
from google.adk.models.apigee_llm import CompletionsHTTPClient
2021
from google.adk.models.llm_request import LlmRequest
2122
from google.genai import types
@@ -771,3 +772,60 @@ async def mock_aiter_lines():
771772
]
772773
assert len(responses) == expected_response_count
773774
assert responses[0].content.parts[0].text == 'Hello'
775+
776+
777+
def test_process_chunk_with_refusal_streaming():
778+
handler = ChatCompletionsResponseHandler()
779+
780+
chunk1 = {
781+
'choices': [{
782+
'delta': {
783+
'role': 'assistant',
784+
'content': 'Hello',
785+
},
786+
'index': 0,
787+
}]
788+
}
789+
responses1 = list(handler.process_chunk(chunk1))
790+
assert len(responses1) == 1
791+
assert responses1[0].content.parts[0].text == 'Hello'
792+
793+
chunk2 = {
794+
'choices': [{
795+
'delta': {
796+
'refusal': 'I refuse',
797+
},
798+
'index': 0,
799+
}]
800+
}
801+
responses2 = list(handler.process_chunk(chunk2))
802+
assert len(responses2) == 1
803+
assert responses2[0].content.parts[0].text == '\n[[REFUSAL]]: I refuse'
804+
805+
chunk3 = {
806+
'choices': [{
807+
'delta': {
808+
'refusal': ' to answer',
809+
},
810+
'index': 0,
811+
}]
812+
}
813+
responses3 = list(handler.process_chunk(chunk3))
814+
assert len(responses3) == 1
815+
assert responses3[0].content.parts[0].text == ' to answer'
816+
817+
chunk4 = {
818+
'choices': [{
819+
'delta': {},
820+
'finish_reason': 'stop',
821+
'index': 0,
822+
}]
823+
}
824+
responses4 = list(handler.process_chunk(chunk4))
825+
assert len(responses4) == 2
826+
final_response = responses4[1]
827+
assert final_response.finish_reason == types.FinishReason.STOP
828+
assert (
829+
final_response.content.parts[0].text
830+
== 'Hello\n[[REFUSAL]]: I refuse to answer'
831+
)

0 commit comments

Comments
 (0)