Skip to content

Commit bcf38fa

Browse files
haiyuan-eng-googlecopybara-github
authored andcommitted
feat: Enhance BigQuery plugin schema upgrades and error reporting
This change introduces several improvements to the BigQuery Agent Analytics Plugin: * **Fix 1 (High):** Error callbacks (`on_model_error_callback`, `on_tool_error_callback`) now emit `status="ERROR"` instead of defaulting to `"OK"`. * **Fix 2 (Medium):** Schema upgrade now detects missing sub-fields in nested RECORD columns via a new recursive helper. The version label is now stamped only after the `update_table` call succeeds, ensuring failures can be retried. * **Fix 3 (Medium):** Multi-loop `shutdown()` now drains batch processors on non-current event loops using `run_coroutine_threadsafe` before closing transports. * **Fix 4 (Medium):** Session state is truncated before logging to prevent oversized payloads. * **Fix 5 (Low):** String system prompts are now truncated during content parsing. * **Fix 6 (Low):** Removed the unused `_HITL_TOOL_NAMES` frozenset. Co-authored-by: Haiyuan Cao <haiyuan@google.com> PiperOrigin-RevId: 879147684
1 parent feefadf commit bcf38fa

File tree

2 files changed

+576
-23
lines changed

2 files changed

+576
-23
lines changed

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 178 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828
import logging
2929
import mimetypes
3030
import os
31+
32+
# Enable gRPC fork support so child processes created via os.fork()
33+
# can safely create new gRPC channels. Must be set before grpc's
34+
# C-core is loaded (which happens through the google.api_core
35+
# imports below). setdefault respects any explicit user override.
36+
os.environ.setdefault("GRPC_ENABLE_FORK_SUPPORT", "1")
37+
3138
import random
3239
import time
3340
from types import MappingProxyType
@@ -76,19 +83,29 @@
7683
_SCHEMA_VERSION = "1"
7784
_SCHEMA_VERSION_LABEL_KEY = "adk_schema_version"
7885

79-
# Human-in-the-loop (HITL) tool names that receive additional
80-
# dedicated event types alongside the normal TOOL_* events.
81-
_HITL_TOOL_NAMES = frozenset({
82-
"adk_request_credential",
83-
"adk_request_confirmation",
84-
"adk_request_input",
85-
})
8686
_HITL_EVENT_MAP = MappingProxyType({
8787
"adk_request_credential": "HITL_CREDENTIAL_REQUEST",
8888
"adk_request_confirmation": "HITL_CONFIRMATION_REQUEST",
8989
"adk_request_input": "HITL_INPUT_REQUEST",
9090
})
9191

92+
# Track all living plugin instances so the fork handler can reset
93+
# them proactively in the child, before _ensure_started runs.
94+
_LIVE_PLUGINS: weakref.WeakSet = weakref.WeakSet()
95+
96+
97+
def _after_fork_in_child() -> None:
98+
"""Reset every living plugin instance after os.fork()."""
99+
for plugin in list(_LIVE_PLUGINS):
100+
try:
101+
plugin._reset_runtime_state()
102+
except Exception:
103+
pass
104+
105+
106+
if hasattr(os, "register_at_fork"):
107+
os.register_at_fork(after_in_child=_after_fork_in_child)
108+
92109

93110
def _safe_callback(func):
94111
"""Decorator that catches and logs exceptions in plugin callbacks.
@@ -1407,7 +1424,10 @@ def process_text(t: str) -> tuple[str, bool]:
14071424
if content.config and getattr(content.config, "system_instruction", None):
14081425
si = content.config.system_instruction
14091426
if isinstance(si, str):
1410-
json_payload["system_prompt"] = si
1427+
truncated_si, trunc = process_text(si)
1428+
if trunc:
1429+
is_truncated = True
1430+
json_payload["system_prompt"] = truncated_si
14111431
else:
14121432
summary, parts, trunc = await self._parse_content_object(si)
14131433
if trunc:
@@ -1855,6 +1875,7 @@ def __init__(
18551875
self._schema = None
18561876
self.arrow_schema = None
18571877
self._init_pid = os.getpid()
1878+
_LIVE_PLUGINS.add(self)
18581879

18591880
def _cleanup_stale_loop_states(self) -> None:
18601881
"""Removes entries for event loops that have been closed."""
@@ -2142,9 +2163,73 @@ def _ensure_schema_exists(self) -> None:
21422163
exc_info=True,
21432164
)
21442165

2166+
@staticmethod
2167+
def _schema_fields_match(
2168+
existing: list[bq_schema.SchemaField],
2169+
desired: list[bq_schema.SchemaField],
2170+
) -> tuple[
2171+
list[bq_schema.SchemaField],
2172+
list[bq_schema.SchemaField],
2173+
]:
2174+
"""Compares existing vs desired schema fields recursively.
2175+
2176+
Returns:
2177+
A tuple of (new_top_level_fields, updated_record_fields).
2178+
``new_top_level_fields`` are fields in *desired* that are
2179+
entirely absent from *existing*.
2180+
``updated_record_fields`` are RECORD fields that exist in
2181+
both but have new sub-fields in *desired*; each entry is a
2182+
copy of the existing field with the missing sub-fields
2183+
appended.
2184+
"""
2185+
existing_by_name = {f.name: f for f in existing}
2186+
new_fields: list[bq_schema.SchemaField] = []
2187+
updated_records: list[bq_schema.SchemaField] = []
2188+
2189+
for desired_field in desired:
2190+
existing_field = existing_by_name.get(desired_field.name)
2191+
if existing_field is None:
2192+
new_fields.append(desired_field)
2193+
elif (
2194+
desired_field.field_type == "RECORD"
2195+
and existing_field.field_type == "RECORD"
2196+
and desired_field.fields
2197+
):
2198+
# Recurse into nested RECORD fields.
2199+
sub_new, sub_updated = (
2200+
BigQueryAgentAnalyticsPlugin._schema_fields_match(
2201+
list(existing_field.fields),
2202+
list(desired_field.fields),
2203+
)
2204+
)
2205+
if sub_new or sub_updated:
2206+
# Build a merged sub-field list.
2207+
merged_sub = list(existing_field.fields)
2208+
# Replace updated nested records in-place.
2209+
updated_names = {f.name for f in sub_updated}
2210+
merged_sub = [
2211+
next(u for u in sub_updated if u.name == f.name)
2212+
if f.name in updated_names
2213+
else f
2214+
for f in merged_sub
2215+
]
2216+
# Append entirely new sub-fields.
2217+
merged_sub.extend(sub_new)
2218+
# Rebuild via API representation to preserve all
2219+
# existing field attributes (policy_tags, etc.).
2220+
api_repr = existing_field.to_api_repr()
2221+
api_repr["fields"] = [sf.to_api_repr() for sf in merged_sub]
2222+
updated_records.append(bq_schema.SchemaField.from_api_repr(api_repr))
2223+
2224+
return new_fields, updated_records
2225+
21452226
def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None:
21462227
"""Adds missing columns to an existing table (additive only).
21472228
2229+
Handles nested RECORD fields by recursing into sub-fields.
2230+
The version label is only stamped after a successful update
2231+
so that a failed attempt is retried on the next run.
2232+
21482233
Args:
21492234
existing_table: The current BigQuery table object.
21502235
"""
@@ -2154,24 +2239,43 @@ def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None:
21542239
if stored_version == _SCHEMA_VERSION:
21552240
return
21562241

2157-
existing_names = {f.name for f in existing_table.schema}
2158-
new_fields = [f for f in self._schema if f.name not in existing_names]
2242+
new_fields, updated_records = self._schema_fields_match(
2243+
list(existing_table.schema), list(self._schema)
2244+
)
21592245

2160-
if new_fields:
2161-
merged = list(existing_table.schema) + new_fields
2246+
if new_fields or updated_records:
2247+
# Build merged top-level schema.
2248+
updated_names = {f.name for f in updated_records}
2249+
merged = [
2250+
next(u for u in updated_records if u.name == f.name)
2251+
if f.name in updated_names
2252+
else f
2253+
for f in existing_table.schema
2254+
]
2255+
merged.extend(new_fields)
21622256
existing_table.schema = merged
2257+
2258+
change_desc = []
2259+
if new_fields:
2260+
change_desc.append(f"new columns {[f.name for f in new_fields]}")
2261+
if updated_records:
2262+
change_desc.append(
2263+
f"updated RECORD fields {[f.name for f in updated_records]}"
2264+
)
21632265
logger.info(
2164-
"Auto-upgrading table %s: adding columns %s",
2266+
"Auto-upgrading table %s: %s",
21652267
self.full_table_id,
2166-
[f.name for f in new_fields],
2268+
", ".join(change_desc),
21672269
)
21682270

2169-
# Always stamp the version label so we skip on next run.
2170-
labels = dict(existing_table.labels or {})
2171-
labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION
2172-
existing_table.labels = labels
2173-
21742271
try:
2272+
# Stamp the version label inside the try block so that
2273+
# on failure the label is NOT persisted and the next run
2274+
# retries the upgrade.
2275+
labels = dict(existing_table.labels or {})
2276+
labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION
2277+
existing_table.labels = labels
2278+
21752279
update_fields = ["schema", "labels"]
21762280
self.client.update_table(existing_table, update_fields)
21772281
except Exception as e:
@@ -2243,6 +2347,22 @@ async def shutdown(self, timeout: float | None = None) -> None:
22432347
if loop in self._loop_state_by_loop:
22442348
await self._loop_state_by_loop[loop].batch_processor.shutdown(timeout=t)
22452349

2350+
# 1b. Drain batch processors on other (non-current) loops.
2351+
for other_loop, state in self._loop_state_by_loop.items():
2352+
if other_loop is loop or other_loop.is_closed():
2353+
continue
2354+
try:
2355+
future = asyncio.run_coroutine_threadsafe(
2356+
state.batch_processor.shutdown(timeout=t),
2357+
other_loop,
2358+
)
2359+
future.result(timeout=t)
2360+
except Exception:
2361+
logger.warning(
2362+
"Could not drain batch processor on loop %s",
2363+
other_loop,
2364+
)
2365+
22462366
# 2. Close clients for all states
22472367
for state in self._loop_state_by_loop.values():
22482368
if state.write_client and getattr(
@@ -2298,6 +2418,38 @@ def _reset_runtime_state(self) -> None:
22982418
process. Pure-data fields like ``_schema`` and
22992419
``arrow_schema`` are kept because they are safe across fork.
23002420
"""
2421+
logger.warning(
2422+
"Fork detected (parent PID %s, child PID %s). Resetting"
2423+
" gRPC state for BigQuery analytics plugin. Note: gRPC"
2424+
" bidirectional streaming (used by the BigQuery Storage"
2425+
" Write API) is not fork-safe. If writes hang or time"
2426+
" out, configure the 'spawn' start method at your program"
2427+
" entry-point before creating child processes:"
2428+
" multiprocessing.set_start_method('spawn')",
2429+
self._init_pid,
2430+
os.getpid(),
2431+
)
2432+
# Best-effort: close inherited gRPC channels so broken
2433+
# finalizers don't interfere with newly created channels.
2434+
# For grpc.aio channels, close() is a coroutine. We cannot
2435+
# await here (called from sync context / fork handler), so
2436+
# we skip async channels and only close sync ones.
2437+
for loop_state in self._loop_state_by_loop.values():
2438+
wc = getattr(loop_state, "write_client", None)
2439+
transport = getattr(wc, "transport", None)
2440+
if transport is not None:
2441+
try:
2442+
channel = getattr(transport, "_grpc_channel", None)
2443+
if channel is not None and hasattr(channel, "close"):
2444+
result = channel.close()
2445+
# If close() returned a coroutine (grpc.aio channel),
2446+
# discard it to avoid unawaited-coroutine warnings.
2447+
if asyncio.iscoroutine(result):
2448+
result.close()
2449+
except Exception:
2450+
pass
2451+
2452+
# Clear all runtime state.
23012453
self._setup_lock = None
23022454
self.client = None
23032455
self._loop_state_by_loop = {}
@@ -2442,7 +2594,11 @@ def _enrich_attributes(
24422594
# Include session state if non-empty (contains user-set metadata
24432595
# like gchat thread-id, customer_id, etc.)
24442596
if session.state:
2445-
session_meta["state"] = dict(session.state)
2597+
truncated_state, _ = _recursive_smart_truncate(
2598+
dict(session.state),
2599+
self.config.max_content_length,
2600+
)
2601+
session_meta["state"] = truncated_state
24462602
attrs["session_metadata"] = session_meta
24472603
except Exception:
24482604
pass
@@ -2988,6 +3144,7 @@ async def on_model_error_callback(
29883144
"LLM_ERROR",
29893145
callback_context,
29903146
event_data=EventData(
3147+
status="ERROR",
29913148
error_message=str(error),
29923149
latency_ms=duration,
29933150
span_id_override=None if has_ambient else span_id,
@@ -3110,6 +3267,7 @@ async def on_tool_error_callback(
31103267
raw_content=content_dict,
31113268
is_truncated=is_truncated,
31123269
event_data=EventData(
3270+
status="ERROR",
31133271
error_message=str(error),
31143272
latency_ms=duration,
31153273
span_id_override=None if has_ambient else span_id,

0 commit comments

Comments
 (0)