2828import logging
2929import mimetypes
3030import os
31+
32+ # Enable gRPC fork support so child processes created via os.fork()
33+ # can safely create new gRPC channels. Must be set before grpc's
34+ # C-core is loaded (which happens through the google.api_core
35+ # imports below). setdefault respects any explicit user override.
36+ os .environ .setdefault ("GRPC_ENABLE_FORK_SUPPORT" , "1" )
37+
3138import random
3239import time
3340from types import MappingProxyType
7683_SCHEMA_VERSION = "1"
7784_SCHEMA_VERSION_LABEL_KEY = "adk_schema_version"
7885
79- # Human-in-the-loop (HITL) tool names that receive additional
80- # dedicated event types alongside the normal TOOL_* events.
81- _HITL_TOOL_NAMES = frozenset ({
82- "adk_request_credential" ,
83- "adk_request_confirmation" ,
84- "adk_request_input" ,
85- })
8686_HITL_EVENT_MAP = MappingProxyType ({
8787 "adk_request_credential" : "HITL_CREDENTIAL_REQUEST" ,
8888 "adk_request_confirmation" : "HITL_CONFIRMATION_REQUEST" ,
8989 "adk_request_input" : "HITL_INPUT_REQUEST" ,
9090})
9191
92+ # Track all living plugin instances so the fork handler can reset
93+ # them proactively in the child, before _ensure_started runs.
94+ _LIVE_PLUGINS : weakref .WeakSet = weakref .WeakSet ()
95+
96+
97+ def _after_fork_in_child () -> None :
98+ """Reset every living plugin instance after os.fork()."""
99+ for plugin in list (_LIVE_PLUGINS ):
100+ try :
101+ plugin ._reset_runtime_state ()
102+ except Exception :
103+ pass
104+
105+
106+ if hasattr (os , "register_at_fork" ):
107+ os .register_at_fork (after_in_child = _after_fork_in_child )
108+
92109
93110def _safe_callback (func ):
94111 """Decorator that catches and logs exceptions in plugin callbacks.
@@ -1407,7 +1424,10 @@ def process_text(t: str) -> tuple[str, bool]:
14071424 if content .config and getattr (content .config , "system_instruction" , None ):
14081425 si = content .config .system_instruction
14091426 if isinstance (si , str ):
1410- json_payload ["system_prompt" ] = si
1427+ truncated_si , trunc = process_text (si )
1428+ if trunc :
1429+ is_truncated = True
1430+ json_payload ["system_prompt" ] = truncated_si
14111431 else :
14121432 summary , parts , trunc = await self ._parse_content_object (si )
14131433 if trunc :
@@ -1855,6 +1875,7 @@ def __init__(
18551875 self ._schema = None
18561876 self .arrow_schema = None
18571877 self ._init_pid = os .getpid ()
1878+ _LIVE_PLUGINS .add (self )
18581879
18591880 def _cleanup_stale_loop_states (self ) -> None :
18601881 """Removes entries for event loops that have been closed."""
@@ -2142,9 +2163,73 @@ def _ensure_schema_exists(self) -> None:
21422163 exc_info = True ,
21432164 )
21442165
2166+ @staticmethod
2167+ def _schema_fields_match (
2168+ existing : list [bq_schema .SchemaField ],
2169+ desired : list [bq_schema .SchemaField ],
2170+ ) -> tuple [
2171+ list [bq_schema .SchemaField ],
2172+ list [bq_schema .SchemaField ],
2173+ ]:
2174+ """Compares existing vs desired schema fields recursively.
2175+
2176+ Returns:
2177+ A tuple of (new_top_level_fields, updated_record_fields).
2178+ ``new_top_level_fields`` are fields in *desired* that are
2179+ entirely absent from *existing*.
2180+ ``updated_record_fields`` are RECORD fields that exist in
2181+ both but have new sub-fields in *desired*; each entry is a
2182+ copy of the existing field with the missing sub-fields
2183+ appended.
2184+ """
2185+ existing_by_name = {f .name : f for f in existing }
2186+ new_fields : list [bq_schema .SchemaField ] = []
2187+ updated_records : list [bq_schema .SchemaField ] = []
2188+
2189+ for desired_field in desired :
2190+ existing_field = existing_by_name .get (desired_field .name )
2191+ if existing_field is None :
2192+ new_fields .append (desired_field )
2193+ elif (
2194+ desired_field .field_type == "RECORD"
2195+ and existing_field .field_type == "RECORD"
2196+ and desired_field .fields
2197+ ):
2198+ # Recurse into nested RECORD fields.
2199+ sub_new , sub_updated = (
2200+ BigQueryAgentAnalyticsPlugin ._schema_fields_match (
2201+ list (existing_field .fields ),
2202+ list (desired_field .fields ),
2203+ )
2204+ )
2205+ if sub_new or sub_updated :
2206+ # Build a merged sub-field list.
2207+ merged_sub = list (existing_field .fields )
2208+ # Replace updated nested records in-place.
2209+ updated_names = {f .name for f in sub_updated }
2210+ merged_sub = [
2211+ next (u for u in sub_updated if u .name == f .name )
2212+ if f .name in updated_names
2213+ else f
2214+ for f in merged_sub
2215+ ]
2216+ # Append entirely new sub-fields.
2217+ merged_sub .extend (sub_new )
2218+ # Rebuild via API representation to preserve all
2219+ # existing field attributes (policy_tags, etc.).
2220+ api_repr = existing_field .to_api_repr ()
2221+ api_repr ["fields" ] = [sf .to_api_repr () for sf in merged_sub ]
2222+ updated_records .append (bq_schema .SchemaField .from_api_repr (api_repr ))
2223+
2224+ return new_fields , updated_records
2225+
21452226 def _maybe_upgrade_schema (self , existing_table : bigquery .Table ) -> None :
21462227 """Adds missing columns to an existing table (additive only).
21472228
2229+ Handles nested RECORD fields by recursing into sub-fields.
2230+ The version label is only stamped after a successful update
2231+ so that a failed attempt is retried on the next run.
2232+
21482233 Args:
21492234 existing_table: The current BigQuery table object.
21502235 """
@@ -2154,24 +2239,43 @@ def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None:
21542239 if stored_version == _SCHEMA_VERSION :
21552240 return
21562241
2157- existing_names = {f .name for f in existing_table .schema }
2158- new_fields = [f for f in self ._schema if f .name not in existing_names ]
2242+ new_fields , updated_records = self ._schema_fields_match (
2243+ list (existing_table .schema ), list (self ._schema )
2244+ )
21592245
2160- if new_fields :
2161- merged = list (existing_table .schema ) + new_fields
2246+ if new_fields or updated_records :
2247+ # Build merged top-level schema.
2248+ updated_names = {f .name for f in updated_records }
2249+ merged = [
2250+ next (u for u in updated_records if u .name == f .name )
2251+ if f .name in updated_names
2252+ else f
2253+ for f in existing_table .schema
2254+ ]
2255+ merged .extend (new_fields )
21622256 existing_table .schema = merged
2257+
2258+ change_desc = []
2259+ if new_fields :
2260+ change_desc .append (f"new columns { [f .name for f in new_fields ]} " )
2261+ if updated_records :
2262+ change_desc .append (
2263+ f"updated RECORD fields { [f .name for f in updated_records ]} "
2264+ )
21632265 logger .info (
2164- "Auto-upgrading table %s: adding columns %s" ,
2266+ "Auto-upgrading table %s: %s" ,
21652267 self .full_table_id ,
2166- [ f . name for f in new_fields ] ,
2268+ ", " . join ( change_desc ) ,
21672269 )
21682270
2169- # Always stamp the version label so we skip on next run.
2170- labels = dict (existing_table .labels or {})
2171- labels [_SCHEMA_VERSION_LABEL_KEY ] = _SCHEMA_VERSION
2172- existing_table .labels = labels
2173-
21742271 try :
2272+ # Stamp the version label inside the try block so that
2273+ # on failure the label is NOT persisted and the next run
2274+ # retries the upgrade.
2275+ labels = dict (existing_table .labels or {})
2276+ labels [_SCHEMA_VERSION_LABEL_KEY ] = _SCHEMA_VERSION
2277+ existing_table .labels = labels
2278+
21752279 update_fields = ["schema" , "labels" ]
21762280 self .client .update_table (existing_table , update_fields )
21772281 except Exception as e :
@@ -2243,6 +2347,22 @@ async def shutdown(self, timeout: float | None = None) -> None:
22432347 if loop in self ._loop_state_by_loop :
22442348 await self ._loop_state_by_loop [loop ].batch_processor .shutdown (timeout = t )
22452349
2350+ # 1b. Drain batch processors on other (non-current) loops.
2351+ for other_loop , state in self ._loop_state_by_loop .items ():
2352+ if other_loop is loop or other_loop .is_closed ():
2353+ continue
2354+ try :
2355+ future = asyncio .run_coroutine_threadsafe (
2356+ state .batch_processor .shutdown (timeout = t ),
2357+ other_loop ,
2358+ )
2359+ future .result (timeout = t )
2360+ except Exception :
2361+ logger .warning (
2362+ "Could not drain batch processor on loop %s" ,
2363+ other_loop ,
2364+ )
2365+
22462366 # 2. Close clients for all states
22472367 for state in self ._loop_state_by_loop .values ():
22482368 if state .write_client and getattr (
@@ -2298,6 +2418,38 @@ def _reset_runtime_state(self) -> None:
22982418 process. Pure-data fields like ``_schema`` and
22992419 ``arrow_schema`` are kept because they are safe across fork.
23002420 """
2421+ logger .warning (
2422+ "Fork detected (parent PID %s, child PID %s). Resetting"
2423+ " gRPC state for BigQuery analytics plugin. Note: gRPC"
2424+ " bidirectional streaming (used by the BigQuery Storage"
2425+ " Write API) is not fork-safe. If writes hang or time"
2426+ " out, configure the 'spawn' start method at your program"
2427+ " entry-point before creating child processes:"
2428+ " multiprocessing.set_start_method('spawn')" ,
2429+ self ._init_pid ,
2430+ os .getpid (),
2431+ )
2432+ # Best-effort: close inherited gRPC channels so broken
2433+ # finalizers don't interfere with newly created channels.
2434+ # For grpc.aio channels, close() is a coroutine. We cannot
2435+ # await here (called from sync context / fork handler), so
2436+ # we skip async channels and only close sync ones.
2437+ for loop_state in self ._loop_state_by_loop .values ():
2438+ wc = getattr (loop_state , "write_client" , None )
2439+ transport = getattr (wc , "transport" , None )
2440+ if transport is not None :
2441+ try :
2442+ channel = getattr (transport , "_grpc_channel" , None )
2443+ if channel is not None and hasattr (channel , "close" ):
2444+ result = channel .close ()
2445+ # If close() returned a coroutine (grpc.aio channel),
2446+ # discard it to avoid unawaited-coroutine warnings.
2447+ if asyncio .iscoroutine (result ):
2448+ result .close ()
2449+ except Exception :
2450+ pass
2451+
2452+ # Clear all runtime state.
23012453 self ._setup_lock = None
23022454 self .client = None
23032455 self ._loop_state_by_loop = {}
@@ -2442,7 +2594,11 @@ def _enrich_attributes(
24422594 # Include session state if non-empty (contains user-set metadata
24432595 # like gchat thread-id, customer_id, etc.)
24442596 if session .state :
2445- session_meta ["state" ] = dict (session .state )
2597+ truncated_state , _ = _recursive_smart_truncate (
2598+ dict (session .state ),
2599+ self .config .max_content_length ,
2600+ )
2601+ session_meta ["state" ] = truncated_state
24462602 attrs ["session_metadata" ] = session_meta
24472603 except Exception :
24482604 pass
@@ -2988,6 +3144,7 @@ async def on_model_error_callback(
29883144 "LLM_ERROR" ,
29893145 callback_context ,
29903146 event_data = EventData (
3147+ status = "ERROR" ,
29913148 error_message = str (error ),
29923149 latency_ms = duration ,
29933150 span_id_override = None if has_ambient else span_id ,
@@ -3110,6 +3267,7 @@ async def on_tool_error_callback(
31103267 raw_content = content_dict ,
31113268 is_truncated = is_truncated ,
31123269 event_data = EventData (
3270+ status = "ERROR" ,
31133271 error_message = str (error ),
31143272 latency_ms = duration ,
31153273 span_id_override = None if has_ambient else span_id ,
0 commit comments