|
| 1 | +# Copyright 2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""AutoTracingPlugin helpers: arg capture, span attrs, tracing wrapper.""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import asyncio |
| 20 | +import dataclasses |
| 21 | +import functools |
| 22 | +import inspect |
| 23 | +import logging |
| 24 | +import re |
| 25 | +from typing import Any |
| 26 | +from typing import Callable |
| 27 | +from typing import Sequence |
| 28 | + |
| 29 | +from opentelemetry import trace as trace_api |
| 30 | + |
| 31 | +logger = logging.getLogger("google_adk." + __name__) |
| 32 | + |
| 33 | +DEFAULT_MAX_REPR_LEN = 4096 |
| 34 | +DEFAULT_MAX_RECORDED_YIELDS = 16 |
| 35 | + |
| 36 | +NamedArg = tuple[str, str] |
| 37 | +WRAPPED_ATTR = "_adk_auto_tracing_wrapped" |
| 38 | +_SELF_OR_CLS = frozenset({"self", "cls"}) |
| 39 | +_SCALAR_TYPES = frozenset({int, float, bool, str, bytes, type(None)}) |
| 40 | +_DEFAULT_REPR_RE = re.compile(r"^<.+ object at 0x[0-9a-fA-F]+>$") |
| 41 | + |
| 42 | + |
| 43 | +@dataclasses.dataclass(frozen=True) |
| 44 | +class Caps: |
| 45 | + """Bounds for captured repr strings and recorded generator yields.""" |
| 46 | + |
| 47 | + max_repr_len: int = DEFAULT_MAX_REPR_LEN |
| 48 | + max_recorded_yields: int = DEFAULT_MAX_RECORDED_YIELDS |
| 49 | + |
| 50 | + |
| 51 | +class StreamResult: |
| 52 | + """Capped sample (``items``) + true yield count (``total``) for a wrapped generator.""" |
| 53 | + |
| 54 | + def __init__(self, items: Sequence[Any], caps: Caps, total: int): |
| 55 | + self._items = items |
| 56 | + self._caps = caps |
| 57 | + self._total = total |
| 58 | + |
| 59 | + def __repr__(self) -> str: |
| 60 | + if self._total == 0: |
| 61 | + return "<generator: 0 items yielded>" |
| 62 | + sample = [safe_repr(it, self._caps) for it in self._items] |
| 63 | + suffix = ( |
| 64 | + f" ... + {self._total - len(sample)} more" |
| 65 | + if self._total > len(sample) |
| 66 | + else "" |
| 67 | + ) |
| 68 | + return ( |
| 69 | + f"<generator: {self._total} items yielded; first {len(sample)}:" |
| 70 | + f" [{', '.join(sample)}]{suffix}>" |
| 71 | + ) |
| 72 | + |
| 73 | + |
| 74 | +def safe_repr(value: Any, caps: Caps) -> str: |
| 75 | + """``repr(value)`` capped, resilient, with default-form objects summarized.""" |
| 76 | + max_len = caps.max_repr_len |
| 77 | + # Fast path: scalars never hit the default-repr regex or summary. |
| 78 | + if type(value) in _SCALAR_TYPES: |
| 79 | + r = repr(value) |
| 80 | + return ( |
| 81 | + r |
| 82 | + if len(r) <= max_len |
| 83 | + else r[:max_len] + f"...[{len(r) - max_len} more chars]" |
| 84 | + ) |
| 85 | + try: |
| 86 | + r = repr(value) |
| 87 | + except Exception as exc: # pylint: disable=broad-exception-caught |
| 88 | + logger.warning( |
| 89 | + "AutoTracingPlugin: repr() failed for %s: %s", |
| 90 | + type(value).__name__, |
| 91 | + exc, |
| 92 | + ) |
| 93 | + r = f"<unrepr-able {type(value).__name__}: {exc!r}>" |
| 94 | + if _DEFAULT_REPR_RE.match(r): |
| 95 | + r = _summarize_default(value) |
| 96 | + if len(r) > max_len: |
| 97 | + r = r[:max_len] + f"...[{len(r) - max_len} more chars]" |
| 98 | + return r |
| 99 | + |
| 100 | + |
| 101 | +def public_slot_names(cls: type) -> set[str]: |
| 102 | + """Public attr names declared in ``__slots__`` across ``cls.__mro__``. |
| 103 | +
|
| 104 | + Handles the ``__slots__ = "x"`` shorthand (must be treated as a single |
| 105 | + name, not iterated as characters). |
| 106 | + """ |
| 107 | + names: set[str] = set() |
| 108 | + for klass in cls.__mro__: |
| 109 | + slots = getattr(klass, "__slots__", None) |
| 110 | + if slots is None: |
| 111 | + continue |
| 112 | + if isinstance(slots, str): |
| 113 | + slots = (slots,) |
| 114 | + for slot in slots: |
| 115 | + if slot and not slot.startswith("_"): |
| 116 | + names.add(slot) |
| 117 | + return names |
| 118 | + |
| 119 | + |
| 120 | +def _summarize_default(value: Any) -> str: |
| 121 | + """Replaces ``<X object at 0x..>`` with a public-field summary (handles ``__slots__``).""" |
| 122 | + cls = type(value).__name__ |
| 123 | + public: list[tuple[str, Any]] = [] |
| 124 | + instance_dict = getattr(value, "__dict__", None) |
| 125 | + if isinstance(instance_dict, dict): |
| 126 | + public.extend( |
| 127 | + (k, v) for k, v in instance_dict.items() if not k.startswith("_") |
| 128 | + ) |
| 129 | + for slot_name in public_slot_names(type(value)): |
| 130 | + try: |
| 131 | + public.append((slot_name, getattr(value, slot_name))) |
| 132 | + except AttributeError: |
| 133 | + continue |
| 134 | + if not public: |
| 135 | + return f"<{cls}>" |
| 136 | + fields = [] |
| 137 | + for k, v in public: |
| 138 | + try: |
| 139 | + vr = repr(v) |
| 140 | + except Exception as exc: # pylint: disable=broad-exception-caught |
| 141 | + logger.warning( |
| 142 | + "AutoTracingPlugin: repr() failed for %s.%s (%s): %s", |
| 143 | + cls, |
| 144 | + k, |
| 145 | + type(v).__name__, |
| 146 | + exc, |
| 147 | + ) |
| 148 | + vr = f"<unrepr-able {type(v).__name__}>" |
| 149 | + fields.append(f"{k}={vr}") |
| 150 | + return f"<{cls} fields={{{', '.join(fields)}}}>" |
| 151 | + |
| 152 | + |
| 153 | +def positional_param_names(fn: Callable[..., Any]) -> tuple[str, ...]: |
| 154 | + """Returns ``fn``'s positional parameter names; ``()`` if introspection fails.""" |
| 155 | + try: |
| 156 | + return tuple( |
| 157 | + n |
| 158 | + for n, p in inspect.signature(fn).parameters.items() |
| 159 | + if p.kind |
| 160 | + in ( |
| 161 | + inspect.Parameter.POSITIONAL_ONLY, |
| 162 | + inspect.Parameter.POSITIONAL_OR_KEYWORD, |
| 163 | + ) |
| 164 | + ) |
| 165 | + except (TypeError, ValueError): |
| 166 | + return () |
| 167 | + |
| 168 | + |
| 169 | +def name_value_pairs( |
| 170 | + param_names: Sequence[str], |
| 171 | + args: tuple[Any, ...], |
| 172 | + kwargs: dict[str, Any], |
| 173 | + caps: Caps, |
| 174 | +) -> list[NamedArg]: |
| 175 | + """Returns ``[(name, repr)]`` for args + kwargs (no self/cls).""" |
| 176 | + pairs: list[NamedArg] = [] |
| 177 | + for i, v in enumerate(args): |
| 178 | + name = param_names[i] if i < len(param_names) else f"arg{i}" |
| 179 | + if name in _SELF_OR_CLS: |
| 180 | + continue |
| 181 | + pairs.append((name, safe_repr(v, caps))) |
| 182 | + for k, v in kwargs.items(): |
| 183 | + pairs.append((k, safe_repr(v, caps))) |
| 184 | + return pairs |
| 185 | + |
| 186 | + |
| 187 | +def record_io_on_span( |
| 188 | + span: trace_api.Span, |
| 189 | + pairs: Sequence[NamedArg], |
| 190 | + result: Any, |
| 191 | + exc: BaseException | None, |
| 192 | + caps: Caps, |
| 193 | +) -> None: |
| 194 | + """Writes ``adk.fn.*`` attributes onto ``span`` for the call's IO.""" |
| 195 | + s = span.set_attribute |
| 196 | + for k, v in pairs: |
| 197 | + s(f"adk.fn.arg.{k}", v) |
| 198 | + if exc is not None: |
| 199 | + s("adk.fn.exc_type", type(exc).__qualname__) |
| 200 | + s("adk.fn.exc_repr", safe_repr(exc, caps)) |
| 201 | + return |
| 202 | + s("adk.fn.return", safe_repr(result, caps)) |
| 203 | + |
| 204 | + |
| 205 | +def display_name_for(fn: Callable[..., Any]) -> str: |
| 206 | + """Returns the short (Class.method or function) name for ``fn``.""" |
| 207 | + qn = fn.__qualname__ |
| 208 | + return ".".join(qn.split(".")[-2:]) if "." in qn else qn |
| 209 | + |
| 210 | + |
| 211 | +def tracer_will_record(tracer: trace_api.Tracer) -> bool: |
| 212 | + """True iff ``tracer`` will record (not a NoOpTracer).""" |
| 213 | + return not isinstance(tracer, trace_api.NoOpTracer) |
| 214 | + |
| 215 | + |
| 216 | +def build_tracing_wrapper( |
| 217 | + fn: Callable[..., Any], |
| 218 | + tracer: trace_api.Tracer, |
| 219 | + caps: Caps, |
| 220 | +) -> Callable[..., Any]: |
| 221 | + """Returns a tracing wrapper for ``fn`` matching its sync/async/gen shape.""" |
| 222 | + # A non-recording tracer never produces IO; don't pay span/context cost. |
| 223 | + if not tracer_will_record(tracer): |
| 224 | + return fn |
| 225 | + |
| 226 | + display_name = display_name_for(fn) |
| 227 | + # inspect.signature is expensive; resolve once at wrap time. |
| 228 | + param_names = positional_param_names(fn) |
| 229 | + yield_cap = caps.max_recorded_yields |
| 230 | + |
| 231 | + def _finish(span, args, kwargs, result, exc): |
| 232 | + if not span.is_recording(): |
| 233 | + return |
| 234 | + pairs = name_value_pairs(param_names, args, kwargs, caps) |
| 235 | + record_io_on_span(span, pairs, result, exc, caps) |
| 236 | + |
| 237 | + @functools.wraps(fn) |
| 238 | + async def async_wrapper(*args, **kwargs): |
| 239 | + with tracer.start_as_current_span(display_name) as span: |
| 240 | + try: |
| 241 | + r = await fn(*args, **kwargs) |
| 242 | + except BaseException as exc: |
| 243 | + _finish(span, args, kwargs, None, exc) |
| 244 | + raise |
| 245 | + _finish(span, args, kwargs, r, None) |
| 246 | + return r |
| 247 | + |
| 248 | + @functools.wraps(fn) |
| 249 | + async def async_gen_wrapper(*args, **kwargs): |
| 250 | + with tracer.start_as_current_span(display_name) as span: |
| 251 | + items: list[Any] = [] |
| 252 | + total = 0 |
| 253 | + try: |
| 254 | + async for item in fn(*args, **kwargs): |
| 255 | + total += 1 |
| 256 | + if len(items) < yield_cap: |
| 257 | + items.append(item) |
| 258 | + yield item |
| 259 | + except BaseException as exc: |
| 260 | + _finish(span, args, kwargs, StreamResult(items, caps, total), exc) |
| 261 | + raise |
| 262 | + _finish(span, args, kwargs, StreamResult(items, caps, total), None) |
| 263 | + |
| 264 | + @functools.wraps(fn) |
| 265 | + def gen_wrapper(*args, **kwargs): |
| 266 | + with tracer.start_as_current_span(display_name) as span: |
| 267 | + items: list[Any] = [] |
| 268 | + total = 0 |
| 269 | + try: |
| 270 | + for item in fn(*args, **kwargs): |
| 271 | + total += 1 |
| 272 | + if len(items) < yield_cap: |
| 273 | + items.append(item) |
| 274 | + yield item |
| 275 | + except BaseException as exc: |
| 276 | + _finish(span, args, kwargs, StreamResult(items, caps, total), exc) |
| 277 | + raise |
| 278 | + _finish(span, args, kwargs, StreamResult(items, caps, total), None) |
| 279 | + |
| 280 | + @functools.wraps(fn) |
| 281 | + def sync_wrapper(*args, **kwargs): |
| 282 | + with tracer.start_as_current_span(display_name) as span: |
| 283 | + try: |
| 284 | + r = fn(*args, **kwargs) |
| 285 | + except BaseException as exc: |
| 286 | + _finish(span, args, kwargs, None, exc) |
| 287 | + raise |
| 288 | + _finish(span, args, kwargs, r, None) |
| 289 | + return r |
| 290 | + |
| 291 | + if inspect.isasyncgenfunction(fn): |
| 292 | + wrapper = async_gen_wrapper |
| 293 | + elif asyncio.iscoroutinefunction(fn): |
| 294 | + wrapper = async_wrapper |
| 295 | + elif inspect.isgeneratorfunction(fn): |
| 296 | + wrapper = gen_wrapper |
| 297 | + else: |
| 298 | + wrapper = sync_wrapper |
| 299 | + setattr(wrapper, WRAPPED_ATTR, True) |
| 300 | + return wrapper |
0 commit comments