Skip to content

Commit 24dd908

Browse files
Merge branch 'main' into fix/lite-llm-serialization-error
2 parents f9f9222 + 179380f commit 24dd908

6 files changed

Lines changed: 162 additions & 14 deletions

File tree

.github/workflows/pre-commit.yml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
name: pre-commit
16+
17+
on:
18+
push:
19+
branches: [main, v2]
20+
paths:
21+
- '**.py'
22+
- '.pre-commit-config.yaml'
23+
- 'pyproject.toml'
24+
pull_request:
25+
branches: [main, v2]
26+
paths:
27+
- '**.py'
28+
- '.pre-commit-config.yaml'
29+
- 'pyproject.toml'
30+
31+
jobs:
32+
pre-commit:
33+
runs-on: ubuntu-latest
34+
steps:
35+
- name: Checkout Code
36+
uses: actions/checkout@v6
37+
38+
- name: Run pre-commit checks
39+
uses: pre-commit/action@v3.0.1

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ dev = [
9090
"flit>=3.10.0",
9191
"isort>=6.0.0",
9292
"mypy>=1.15.0",
93+
"pre-commit>=4.0.0",
9394
"pyink>=25.12.0",
9495
"pylint>=2.6.0",
9596
# go/keep-sorted end

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,6 +1949,7 @@ def __init__(
19491949
table_id: Optional[str] = None,
19501950
config: Optional[BigQueryLoggerConfig] = None,
19511951
location: str = "US",
1952+
credentials: Optional[google.auth.credentials.Credentials] = None,
19521953
**kwargs,
19531954
) -> None:
19541955
"""Initializes the instance.
@@ -1959,6 +1960,8 @@ def __init__(
19591960
table_id: BigQuery table ID (optional, overrides config).
19601961
config: BigQueryLoggerConfig (optional).
19611962
location: BigQuery location (default: "US").
1963+
credentials: Google Auth credentials (optional). If None, uses
1964+
Application Default Credentials.
19621965
**kwargs: Additional configuration parameters for BigQueryLoggerConfig.
19631966
"""
19641967
super().__init__(name="bigquery_agent_analytics")
@@ -1985,6 +1988,7 @@ def __init__(
19851988
self._startup_error: Optional[Exception] = None
19861989
self._is_shutting_down = False
19871990
self._setup_lock = None
1991+
self._credentials = credentials
19881992
self.client = None
19891993
self._loop_state_by_loop: dict[asyncio.AbstractEventLoop, _LoopState] = {}
19901994
self._write_stream_name = None # Resolved stream name
@@ -2097,15 +2101,16 @@ async def _get_loop_state(self) -> _LoopState:
20972101
# grpc.aio clients are loop-bound, so we create one per event loop.
20982102

20992103
def get_credentials():
2100-
creds, project_id = google.auth.default(
2104+
creds, _ = google.auth.default(
21012105
scopes=["https://www.googleapis.com/auth/cloud-platform"]
21022106
)
2103-
return creds, project_id
2107+
return creds
21042108

2105-
creds, project_id = await loop.run_in_executor(
2106-
self._executor, get_credentials
2107-
)
2108-
quota_project_id = getattr(creds, "quota_project_id", None)
2109+
if self._credentials is None:
2110+
self._credentials = await loop.run_in_executor(
2111+
self._executor, get_credentials
2112+
)
2113+
quota_project_id = getattr(self._credentials, "quota_project_id", None)
21092114
options = (
21102115
client_options.ClientOptions(quota_project_id=quota_project_id)
21112116
if quota_project_id
@@ -2119,7 +2124,7 @@ def get_credentials():
21192124
client_info = gapic_client_info.ClientInfo(user_agent=" ".join(user_agents))
21202125

21212126
write_client = BigQueryWriteAsyncClient(
2122-
credentials=creds,
2127+
credentials=self._credentials,
21232128
client_info=client_info,
21242129
client_options=options,
21252130
)
@@ -2173,7 +2178,9 @@ async def _lazy_setup(self, **kwargs) -> None:
21732178
self.client = await loop.run_in_executor(
21742179
self._executor,
21752180
lambda: bigquery.Client(
2176-
project=self.project_id, location=self.location
2181+
project=self.project_id,
2182+
location=self.location,
2183+
credentials=self._credentials,
21772184
),
21782185
)
21792186

@@ -2193,7 +2200,9 @@ async def _lazy_setup(self, **kwargs) -> None:
21932200
self.project_id,
21942201
self.config.gcs_bucket_name,
21952202
self._executor,
2196-
storage_client=kwargs.get("storage_client"),
2203+
storage_client=storage.Client(
2204+
project=self.project_id, credentials=self._credentials
2205+
),
21972206
)
21982207

21992208
self.parser = HybridContentParser(

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@
4848
_USAGE_METADATA_CUSTOM_METADATA_KEY = '_usage_metadata'
4949

5050

51+
def _quote_filter_literal(value: str) -> str:
52+
"""Quotes filter values so embedded metacharacters stay inside the literal."""
53+
escaped_value = value.replace('\\', '\\\\').replace('"', '\\"')
54+
return f'"{escaped_value}"'
55+
56+
5157
def _set_internal_custom_metadata(
5258
metadata_dict: dict[str, Any], *, key: str, value: dict[str, Any]
5359
) -> None:
@@ -228,7 +234,7 @@ async def list_sessions(
228234
sessions = []
229235
config = {}
230236
if user_id is not None:
231-
config['filter'] = f'user_id="{user_id}"'
237+
config['filter'] = f'user_id={_quote_filter_literal(user_id)}'
232238
sessions_iterator = await api_client.agent_engines.sessions.list(
233239
name=f'reasoningEngines/{reasoning_engine_id}',
234240
config=config,

tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,18 @@ def tool_context(invocation_context):
104104
return tool_context_lib.ToolContext(invocation_context=invocation_context)
105105

106106

107+
class FakeCredentials(google.auth.credentials.Credentials):
108+
109+
def __init__(self):
110+
pass
111+
112+
def refresh(self, request):
113+
pass
114+
115+
107116
@pytest.fixture
108117
def mock_auth_default():
109-
mock_creds = mock.create_autospec(
110-
google.auth.credentials.Credentials, instance=True, spec_set=True
111-
)
118+
mock_creds = FakeCredentials()
112119
with mock.patch.object(
113120
google.auth,
114121
"default",
@@ -2000,6 +2007,62 @@ async def test_no_quota_project_when_creds_lack_it(
20002007
_, kwargs = mock_bq_write_cls.call_args
20012008
assert kwargs["client_options"] is None
20022009

2010+
@pytest.mark.asyncio
2011+
async def test_custom_credentials_used(
2012+
self,
2013+
mock_to_arrow_schema,
2014+
mock_asyncio_to_thread,
2015+
):
2016+
"""Verify custom credentials are used and default auth is not called."""
2017+
mock_custom_creds = mock.create_autospec(
2018+
google.auth.credentials.Credentials, instance=True, spec_set=True
2019+
)
2020+
mock_custom_creds.quota_project_id = "custom-quota-project"
2021+
2022+
config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig(
2023+
gcs_bucket_name="test-bucket",
2024+
create_views=False,
2025+
)
2026+
2027+
with mock.patch.object(
2028+
google.auth,
2029+
"default",
2030+
autospec=True,
2031+
) as mock_auth_default:
2032+
with mock.patch.object(
2033+
bigquery_agent_analytics_plugin,
2034+
"BigQueryWriteAsyncClient",
2035+
autospec=True,
2036+
) as mock_bq_write_cls:
2037+
with mock.patch(
2038+
"google.cloud.bigquery.Client", autospec=True
2039+
) as mock_bq_cls:
2040+
with mock.patch(
2041+
"google.cloud.storage.Client", autospec=True
2042+
) as mock_storage_cls:
2043+
async with managed_plugin(
2044+
project_id=PROJECT_ID,
2045+
dataset_id=DATASET_ID,
2046+
table_id=TABLE_ID,
2047+
credentials=mock_custom_creds,
2048+
config=config,
2049+
) as plugin:
2050+
await plugin._ensure_started()
2051+
2052+
mock_auth_default.assert_not_called()
2053+
2054+
mock_bq_write_cls.assert_called_once()
2055+
_, kwargs = mock_bq_write_cls.call_args
2056+
assert kwargs["credentials"] == mock_custom_creds
2057+
2058+
mock_bq_cls.assert_called_once()
2059+
_, kwargs = mock_bq_cls.call_args
2060+
assert kwargs["credentials"] == mock_custom_creds
2061+
2062+
mock_storage_cls.assert_called_once()
2063+
_, kwargs = mock_storage_cls.call_args
2064+
assert kwargs["credentials"] == mock_custom_creds
2065+
20032066
@pytest.mark.asyncio
20042067
async def test_pickle_safety(self, mock_auth_default, mock_bq_client):
20052068
"""Test that the plugin can be pickled safely."""

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def __init__(self) -> None:
375375
self.agent_engines.sessions.events.list.side_effect = self._list_events
376376
self.agent_engines.sessions.events.append.side_effect = self._append_event
377377
self.last_create_session_config: dict[str, Any] = {}
378+
self.last_list_sessions_config: dict[str, Any] = {}
378379

379380
async def __aenter__(self):
380381
"""Enters the asynchronous context."""
@@ -391,8 +392,9 @@ async def _get_session(self, name: str):
391392
raise api_core_exceptions.NotFound(f'Session not found: {session_id}')
392393

393394
async def _list_sessions(self, name: str, config: dict[str, Any]):
395+
self.last_list_sessions_config = config
394396
filter_val = config.get('filter', '')
395-
user_id_match = re.search(r'user_id="([^"]+)"', filter_val)
397+
user_id_match = re.search(r'user_id="((?:\\.|[^"])*)"', filter_val)
396398
if user_id_match:
397399
user_id = user_id_match.group(1)
398400
if user_id == 'user_with_pages':
@@ -877,6 +879,34 @@ async def test_list_sessions_all_users():
877879
}
878880

879881

882+
@pytest.mark.asyncio
883+
@pytest.mark.usefixtures('mock_get_api_client')
884+
@pytest.mark.parametrize(
885+
('payload', 'expected_filter'),
886+
[
887+
(
888+
'attacker" OR user_id!=""',
889+
'user_id="attacker\\" OR user_id!=\\"\\""',
890+
),
891+
('\\', 'user_id="\\\\"'),
892+
('', 'user_id=""'),
893+
],
894+
)
895+
async def test_list_sessions_quotes_user_id_filter(
896+
mock_api_client_instance, payload, expected_filter
897+
):
898+
session_service = mock_vertex_ai_session_service()
899+
900+
sessions = await session_service.list_sessions(
901+
app_name='123', user_id=payload
902+
)
903+
904+
assert sessions.sessions == []
905+
assert mock_api_client_instance.last_list_sessions_config == {
906+
'filter': expected_filter
907+
}
908+
909+
880910
@pytest.mark.asyncio
881911
@pytest.mark.usefixtures('mock_get_api_client')
882912
async def test_create_session():

0 commit comments

Comments
 (0)