Skip to content

Commit 60f7591

Browse files
authored
Merge branch 'main' into fix-cli-robustness
2 parents 2b7aafa + 179380f commit 60f7591

13 files changed

Lines changed: 1176 additions & 37 deletions

.github/workflows/pre-commit.yml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
name: pre-commit
16+
17+
on:
18+
push:
19+
branches: [main, v2]
20+
paths:
21+
- '**.py'
22+
- '.pre-commit-config.yaml'
23+
- 'pyproject.toml'
24+
pull_request:
25+
branches: [main, v2]
26+
paths:
27+
- '**.py'
28+
- '.pre-commit-config.yaml'
29+
- 'pyproject.toml'
30+
31+
jobs:
32+
pre-commit:
33+
runs-on: ubuntu-latest
34+
steps:
35+
- name: Checkout Code
36+
uses: actions/checkout@v6
37+
38+
- name: Run pre-commit checks
39+
uses: pre-commit/action@v3.0.1

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ dev = [
9090
"flit>=3.10.0",
9191
"isort>=6.0.0",
9292
"mypy>=1.15.0",
93+
"pre-commit>=4.0.0",
9394
"pyink>=25.12.0",
9495
"pylint>=2.6.0",
9596
# go/keep-sorted end

src/google/adk/features/_feature_registry.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ class FeatureName(str, Enum):
4141
GOOGLE_CREDENTIALS_CONFIG = "GOOGLE_CREDENTIALS_CONFIG"
4242
GOOGLE_TOOL = "GOOGLE_TOOL"
4343
JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL"
44+
# Private (leading underscore): not part of the public API surface.
45+
# GE flips this on by setting the env var
46+
# `ADK_ENABLE_MCP_GRACEFUL_ERROR_HANDLING=1`; nothing should import this
47+
# enum member by name. Keeping it private avoids a backward-compat
48+
# obligation for what is intended as a temporary, internal kill-switch.
49+
_MCP_GRACEFUL_ERROR_HANDLING = "MCP_GRACEFUL_ERROR_HANDLING"
4450
PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING"
4551
PUBSUB_TOOL_CONFIG = "PUBSUB_TOOL_CONFIG"
4652
PUBSUB_TOOLSET = "PUBSUB_TOOLSET"
@@ -131,6 +137,9 @@ class FeatureConfig:
131137
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL: FeatureConfig(
132138
FeatureStage.WIP, default_on=False
133139
),
140+
FeatureName._MCP_GRACEFUL_ERROR_HANDLING: FeatureConfig(
141+
FeatureStage.EXPERIMENTAL, default_on=False
142+
),
134143
FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig(
135144
FeatureStage.EXPERIMENTAL, default_on=True
136145
),

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,6 +1949,7 @@ def __init__(
19491949
table_id: Optional[str] = None,
19501950
config: Optional[BigQueryLoggerConfig] = None,
19511951
location: str = "US",
1952+
credentials: Optional[google.auth.credentials.Credentials] = None,
19521953
**kwargs,
19531954
) -> None:
19541955
"""Initializes the instance.
@@ -1959,6 +1960,8 @@ def __init__(
19591960
table_id: BigQuery table ID (optional, overrides config).
19601961
config: BigQueryLoggerConfig (optional).
19611962
location: BigQuery location (default: "US").
1963+
credentials: Google Auth credentials (optional). If None, uses
1964+
Application Default Credentials.
19621965
**kwargs: Additional configuration parameters for BigQueryLoggerConfig.
19631966
"""
19641967
super().__init__(name="bigquery_agent_analytics")
@@ -1985,6 +1988,7 @@ def __init__(
19851988
self._startup_error: Optional[Exception] = None
19861989
self._is_shutting_down = False
19871990
self._setup_lock = None
1991+
self._credentials = credentials
19881992
self.client = None
19891993
self._loop_state_by_loop: dict[asyncio.AbstractEventLoop, _LoopState] = {}
19901994
self._write_stream_name = None # Resolved stream name
@@ -2097,15 +2101,16 @@ async def _get_loop_state(self) -> _LoopState:
20972101
# grpc.aio clients are loop-bound, so we create one per event loop.
20982102

20992103
def get_credentials():
2100-
creds, project_id = google.auth.default(
2104+
creds, _ = google.auth.default(
21012105
scopes=["https://www.googleapis.com/auth/cloud-platform"]
21022106
)
2103-
return creds, project_id
2107+
return creds
21042108

2105-
creds, project_id = await loop.run_in_executor(
2106-
self._executor, get_credentials
2107-
)
2108-
quota_project_id = getattr(creds, "quota_project_id", None)
2109+
if self._credentials is None:
2110+
self._credentials = await loop.run_in_executor(
2111+
self._executor, get_credentials
2112+
)
2113+
quota_project_id = getattr(self._credentials, "quota_project_id", None)
21092114
options = (
21102115
client_options.ClientOptions(quota_project_id=quota_project_id)
21112116
if quota_project_id
@@ -2119,7 +2124,7 @@ def get_credentials():
21192124
client_info = gapic_client_info.ClientInfo(user_agent=" ".join(user_agents))
21202125

21212126
write_client = BigQueryWriteAsyncClient(
2122-
credentials=creds,
2127+
credentials=self._credentials,
21232128
client_info=client_info,
21242129
client_options=options,
21252130
)
@@ -2173,7 +2178,9 @@ async def _lazy_setup(self, **kwargs) -> None:
21732178
self.client = await loop.run_in_executor(
21742179
self._executor,
21752180
lambda: bigquery.Client(
2176-
project=self.project_id, location=self.location
2181+
project=self.project_id,
2182+
location=self.location,
2183+
credentials=self._credentials,
21772184
),
21782185
)
21792186

@@ -2193,7 +2200,9 @@ async def _lazy_setup(self, **kwargs) -> None:
21932200
self.project_id,
21942201
self.config.gcs_bucket_name,
21952202
self._executor,
2196-
storage_client=kwargs.get("storage_client"),
2203+
storage_client=storage.Client(
2204+
project=self.project_id, credentials=self._credentials
2205+
),
21972206
)
21982207

21992208
self.parser = HybridContentParser(

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@
4848
_USAGE_METADATA_CUSTOM_METADATA_KEY = '_usage_metadata'
4949

5050

51+
def _quote_filter_literal(value: str) -> str:
52+
"""Quotes filter values so embedded metacharacters stay inside the literal."""
53+
escaped_value = value.replace('\\', '\\\\').replace('"', '\\"')
54+
return f'"{escaped_value}"'
55+
56+
5157
def _set_internal_custom_metadata(
5258
metadata_dict: dict[str, Any], *, key: str, value: dict[str, Any]
5359
) -> None:
@@ -228,7 +234,7 @@ async def list_sessions(
228234
sessions = []
229235
config = {}
230236
if user_id is not None:
231-
config['filter'] = f'user_id="{user_id}"'
237+
config['filter'] = f'user_id={_quote_filter_literal(user_id)}'
232238
sessions_iterator = await api_client.agent_engines.sessions.list(
233239
name=f'reasoningEngines/{reasoning_engine_id}',
234240
config=config,

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
from pydantic import BaseModel
4545
from pydantic import ConfigDict
4646

47+
from ...features import FeatureName
48+
from ...features import is_feature_enabled
4749
from .session_context import SessionContext
4850

4951
logger = logging.getLogger('google_adk.' + __name__)
@@ -237,11 +239,18 @@ def __init__(
237239
self._connection_params = connection_params
238240
self._errlog = errlog
239241

240-
# Session pool: maps session keys to (session, exit_stack, loop) tuples
242+
# Session pool: maps session keys to (session, exit_stack, loop) tuples.
243+
# Kept as a tuple for backward-compatibility with downstream tests
244+
# that construct or unpack entries directly.
241245
self._sessions: Dict[
242246
str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop]
243247
] = {}
244248

249+
# Sibling pool: maps session keys to their SessionContext. Stored
250+
# separately from `_sessions` so the tuple shape above stays stable.
251+
# Used by McpTool to access `_run_guarded` for transport-crash detection.
252+
self._session_contexts: Dict[str, SessionContext] = {}
253+
245254
# Map of event loops to their respective locks to prevent race conditions
246255
# across different event loops in session creation.
247256
self._session_lock_map: dict[asyncio.AbstractEventLoop, asyncio.Lock] = {}
@@ -323,6 +332,26 @@ def _is_session_disconnected(self, session: ClientSession) -> bool:
323332
"""
324333
return session._read_stream._closed or session._write_stream._closed
325334

335+
def _get_session_context(
336+
self, headers: Optional[Dict[str, str]] = None
337+
) -> Optional[SessionContext]:
338+
"""Returns the SessionContext for the session matching the given headers.
339+
340+
Note: This method reads from the session-context pool without acquiring
341+
``_session_lock``. This is safe because it is called immediately after
342+
``create_session()`` (which populates the entry under the lock) within
343+
the same task, and dict reads are atomic in CPython.
344+
345+
Args:
346+
headers: Optional headers used to identify the session.
347+
348+
Returns:
349+
The SessionContext if a matching session exists, None otherwise.
350+
"""
351+
merged_headers = self._merge_headers(headers)
352+
session_key = self._generate_session_key(merged_headers)
353+
return self._session_contexts.get(session_key)
354+
326355
async def _cleanup_session(
327356
self,
328357
session_key: str,
@@ -378,6 +407,10 @@ def cleanup_done(f: asyncio.Future):
378407
finally:
379408
if session_key in self._sessions:
380409
del self._sessions[session_key]
410+
# Also drop the SessionContext reference so we don't leak the
411+
# SessionContext after its underlying session is gone.
412+
if session_key in self._session_contexts:
413+
del self._session_contexts[session_key]
381414

382415
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
383416
"""Creates an MCP client based on the connection parameters.
@@ -453,15 +486,30 @@ async def create_session(
453486
if session_key in self._sessions:
454487
session, exit_stack, stored_loop = self._sessions[session_key]
455488

456-
# Check if the existing session is still connected and bound to the current loop
489+
# Check if the existing session is still connected and bound to
490+
# the current loop. When the feature flag is on, we ALSO check the
491+
# SessionContext's background task: a crashed transport can leave
492+
# the session's read/write streams open even though the underlying
493+
# task has already died (e.g. after a 4xx/5xx HTTP response).
494+
# Without that extra check, callers would reuse a dead session and
495+
# hang on the next call. The check is gated because it triggers
496+
# session re-creation in some test mocks where `_task` looks
497+
# "not alive" but the streams are otherwise reusable.
457498
current_loop = asyncio.get_running_loop()
458-
if stored_loop is current_loop and not self._is_session_disconnected(
459-
session
499+
if is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING): # pylint: disable=protected-access
500+
ctx = self._session_contexts.get(session_key)
501+
ctx_alive = ctx is None or ctx._is_task_alive # pylint: disable=protected-access
502+
else:
503+
ctx_alive = True # Pre-fix: do not consult task aliveness
504+
if (
505+
stored_loop is current_loop
506+
and not self._is_session_disconnected(session)
507+
and ctx_alive
460508
):
461509
# Session is still good, return it
462510
return session
463511
else:
464-
# Session is disconnected or from a different loop, clean it up
512+
# Session is disconnected, dead, or from a different loop; clean up.
465513
logger.info(
466514
'Cleaning up session (disconnected or different loop): %s',
467515
session_key,
@@ -485,26 +533,32 @@ async def create_session(
485533
client = self._create_client(merged_headers)
486534
is_stdio = isinstance(self._connection_params, StdioConnectionParams)
487535

536+
session_context = SessionContext(
537+
client=client,
538+
timeout=timeout_in_seconds,
539+
sse_read_timeout=sse_read_timeout_in_seconds,
540+
is_stdio=is_stdio,
541+
sampling_callback=self._sampling_callback,
542+
sampling_capabilities=self._sampling_capabilities,
543+
)
544+
488545
session = await asyncio.wait_for(
489-
exit_stack.enter_async_context(
490-
SessionContext(
491-
client=client,
492-
timeout=timeout_in_seconds,
493-
sse_read_timeout=sse_read_timeout_in_seconds,
494-
is_stdio=is_stdio,
495-
sampling_callback=self._sampling_callback,
496-
sampling_capabilities=self._sampling_capabilities,
497-
)
498-
),
546+
exit_stack.enter_async_context(session_context),
499547
timeout=timeout_in_seconds,
500548
)
501549

502-
# Store session, exit stack, and loop in the pool
550+
# Store session, exit stack, and loop in the pool. The pool storage
551+
# remains a tuple for backward-compatibility with downstream tests
552+
# that construct or unpack entries directly.
503553
self._sessions[session_key] = (
504554
session,
505555
exit_stack,
506556
asyncio.get_running_loop(),
507557
)
558+
# Track the SessionContext in a sibling dict so McpTool can call
559+
# `_run_guarded` on it. Stored separately to avoid changing the
560+
# shape of `_sessions` (which is a public-ish internal surface).
561+
self._session_contexts[session_key] = session_context
508562
logger.debug('Created new session: %s', session_key)
509563
return session
510564

@@ -524,6 +578,7 @@ def __getstate__(self):
524578
state = self.__dict__.copy()
525579
# Remove unpicklable entries or those that shouldn't persist across pickle
526580
state['_sessions'] = {}
581+
state['_session_contexts'] = {}
527582
state['_session_lock_map'] = {}
528583

529584
# Locks and file-like objects cannot be pickled
@@ -537,6 +592,7 @@ def __setstate__(self, state):
537592
self.__dict__.update(state)
538593
# Re-initialize members that were not pickled
539594
self._sessions = {}
595+
self._session_contexts = {}
540596
self._session_lock_map = {}
541597
self._lock_map_lock = threading.Lock()
542598
# If _errlog was removed during pickling, default to sys.stderr

0 commit comments

Comments
 (0)