diff --git a/gigaevo/llm/bandit.py b/gigaevo/llm/bandit.py index b63fe5c0..d8a4084c 100644 --- a/gigaevo/llm/bandit.py +++ b/gigaevo/llm/bandit.py @@ -9,16 +9,25 @@ from __future__ import annotations from collections import deque +from collections.abc import AsyncIterator, Iterator from dataclasses import dataclass, field from enum import Enum import math from typing import TYPE_CHECKING, Any +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import BaseMessage +from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI from loguru import logger import numpy as np -from gigaevo.llm.models import MultiModelRouter, _StructuredOutputRouter +from gigaevo.llm.call_outcome import BanditAction, classify_call_result +from gigaevo.llm.models import ( + MultiModelRouter, + _remember_selected_model, + _StructuredOutputRouter, +) from gigaevo.utils.trackers.base import LogWriter if TYPE_CHECKING: @@ -263,15 +272,181 @@ def __init__( # -- selection ---------------------------------------------------------- def _select(self) -> tuple[ChatOpenAI, str]: - """Select a model via UCB1 and record the pull.""" + """Select a model via UCB1 and record the pull. + + Also writes the selected name into the ``_selected_model`` ContextVar + so downstream consumers (``BaseAgent.acall_llm``, ``MutationAgent``, + ``InsightsAgent``, ...) can read it via ``get_selected_model()``. + ``MultiModelRouter._select`` does this; the bandit override must + preserve the contract or ``state['metadata']['model_used']`` would + carry whatever value the previous non-bandit selection happened to + leave behind in the ContextVar. + """ name = self._bandit.select() self._bandit.record_pull(name) + _remember_selected_model(name) tid = self._current_task_id() if tid is not None: self._task_model_map[tid] = name idx = self.model_names.index(name) return self.models[idx], name + # -- dispatch ----------------------------------------------------------- + + def _inject_failure_reward(self, exc: BaseException, arm_name: str) -> None: + """Classify a failed LLM call and dispatch on its bandit action. + + ``_select`` records the pull before the LLM call, so without this + hook a failure would inflate ``total_pulls`` for the arm with no + matching window entry — the UCB1 confidence term shrinks for that + arm and the bandit underexplores flaky models. The classifier maps + the exception to an outcome whose ``OUTCOME_ACTION`` is currently + ``INJECT_ZERO_REWARD`` for every failure variant; we normalize a + zero reward and append it to the arm's window. + """ + result = classify_call_result(exc, model_name=arm_name) + if result.action is BanditAction.DEFER_TO_OUTCOME: + # SUCCESS — never reachable here (we are inside the except), + # but the action lookup keeps the contract honest. + return + if arm_name not in self._bandit.arms: + # Mirrors the on_mutation_outcome guard: an unknown arm name + # would otherwise raise KeyError inside update_reward. + logger.debug( + "[BanditModelRouter] Skipping zero-reward injection for " + "unknown arm {!r} (outcome={})", + arm_name, + result.outcome.value, + ) + return + normalized = self._reward_normalizer.normalize(0.0) + self._bandit.update_reward(arm_name, normalized) + logger.debug( + "[BanditModelRouter] Zero reward injected for {} | outcome={} exception={}", + arm_name, + result.outcome.value, + result.exception_class, + ) + + def _safe_inject_failure_reward(self, exc: BaseException, arm_name: str) -> None: + """Best-effort wrapper around ``_inject_failure_reward``. + + Mirrors ``_StructuredOutputRouter._maybe_fire_failure_hook``: the + ledger-symmetry update is observability-only and must never mask the + original LLM exception (or its traceback). Any error inside the hook + is swallowed with a debug log so ``raise`` at the call site still + re-raises the real failure. + """ + try: + self._inject_failure_reward(exc, arm_name) + except Exception as hook_exc: # noqa: BLE001 — observability-only + logger.debug( + "[BanditModelRouter] Failure-reward injection itself raised " + "for arm {!r}: {!r}. Original exception preserved.", + arm_name, + hook_exc, + ) + + def _safe_track(self, response: Any, name: str) -> None: + """Best-effort token-tracker call. + + ``self._tracker.track`` reads ``response.usage_metadata`` and writes + to a ``LogWriter``. A telemetry-side bug (malformed token_usage from + a hostile provider, broken writer, etc.) must not propagate to the + caller — the LLM call already succeeded, and on the bandit success + path the reward is *deferred* to ``on_mutation_outcome``, so an + exception here would both lose the response AND leave the bandit + unable to associate the deferred reward (the caller never returns). + """ + try: + self._tracker.track(response, name) + except Exception as track_exc: # noqa: BLE001 — telemetry only + logger.debug( + "[BanditModelRouter] Token tracking failed for arm {!r}: {!r}. " + "LLM response preserved; reward still deferred to " + "on_mutation_outcome.", + name, + track_exc, + ) + + def invoke( + self, + input: LanguageModelInput, + config: RunnableConfig | None = None, + **kwargs: Any, + ) -> BaseMessage: + model, name = self._select() + try: + response = model.invoke(input, self._config(config, name), **kwargs) + except BaseException as exc: + self._safe_inject_failure_reward(exc, name) + raise + self._safe_track(response, name) + return response + + async def ainvoke( + self, + input: LanguageModelInput, + config: RunnableConfig | None = None, + **kwargs: Any, + ) -> BaseMessage: + model, name = self._select() + try: + response = await model.ainvoke(input, self._config(config, name), **kwargs) + except BaseException as exc: + self._safe_inject_failure_reward(exc, name) + raise + self._safe_track(response, name) + return response + + def stream( + self, + input: LanguageModelInput, + config: RunnableConfig | None = None, + **kwargs: Any, + ) -> Iterator[BaseMessage]: + """Streaming counterpart to :meth:`invoke` with ledger-symmetry guard. + + ``MultiModelRouter.stream`` records the pull (via ``_select``) and + then yields from ``model.stream`` without try/except. A mid-stream + failure would leave ``total_pulls`` and the reward window out of + step exactly as the unwrapped ``invoke`` path used to. We mirror + the ``invoke`` contract: classify any exception via the bandit + hook, then re-raise. + """ + model, name = self._select() + last = None + try: + for chunk in model.stream(input, self._config(config, name), **kwargs): + last = chunk + yield chunk + except BaseException as exc: + self._safe_inject_failure_reward(exc, name) + raise + if last is not None: + self._safe_track(last, name) + + async def astream( + self, + input: LanguageModelInput, + config: RunnableConfig | None = None, + **kwargs: Any, + ) -> AsyncIterator[BaseMessage]: + """Async streaming counterpart to :meth:`ainvoke`. See :meth:`stream`.""" + model, name = self._select() + last = None + try: + async for chunk in model.astream( + input, self._config(config, name), **kwargs + ): + last = chunk + yield chunk + except BaseException as exc: + self._safe_inject_failure_reward(exc, name) + raise + if last is not None: + self._safe_track(last, name) + # -- mutation outcome --------------------------------------------------- def on_mutation_outcome( @@ -350,6 +525,10 @@ def with_structured_output(self, schema: Any, **kwargs) -> _StructuredOutputRout def _bandit_select() -> tuple[Any, str]: name = self._bandit.select() self._bandit.record_pull(name) + # Mirror MultiModelRouter._select: publish the selection to the + # ContextVar so ``get_selected_model()`` reads the actual arm + # name during structured-output dispatch. + _remember_selected_model(name) tid = self._current_task_id() if tid is not None: self._task_model_map[tid] = name @@ -364,4 +543,5 @@ def _bandit_select() -> tuple[Any, str]: self._tracker, task_model_map=self._task_model_map, select_override=_bandit_select, + failure_hook=self._inject_failure_reward, ) diff --git a/gigaevo/llm/call_outcome.py b/gigaevo/llm/call_outcome.py new file mode 100644 index 00000000..1b5c59bb --- /dev/null +++ b/gigaevo/llm/call_outcome.py @@ -0,0 +1,847 @@ +"""LLM call outcome classification — pure, exception-driven, exhaustive. + +================================================================================ +CONTEXT +================================================================================ + +Every LLM call dispatched through the router stack either returns a usable +response or raises. Callers (``BanditModelRouter``, retry shims, telemetry) +need to react to the *kind* of failure, not the specific exception class: +a 429 rate-limit deserves a zero reward but no breaker trip; a 401 auth +error is operator-side and should not be confused with arm failure; a +length-limit truncation is the arm's fault; a langchain +``OutputParserException`` wrapping an ``openai.RateLimitError`` should be +classified by the *root* cause, not by the wrapper. + +This module is the one classifier. Every other piece of router/bandit +plumbing consumes ``LLMCallResult`` rather than rolling its own try/except +ladder. + +================================================================================ +STATE MODEL — 12 outcomes, exhaustive +================================================================================ + + SUCCESS call returned a usable response + RATE_LIMITED HTTP 429 / provider rate-limit signal + SERVER_5XX HTTP 5xx that is not a more specific signal + TIMEOUT asyncio / httpx / openai timeout (HTTP 408/504 too) + NETWORK_ERROR connection refused, DNS, TLS, broken pipe, protocol + AUTH_FAILED HTTP 401 / 403 / invalid API key + BAD_REQUEST HTTP 4xx not classified as a more specific outcome + CONTEXT_OVERFLOW prompt exceeded model context window (HTTP 413, or + explicit ``ContextOverflowError`` family) + OUTPUT_TRUNCATED successful call but the response was cut off at the + token budget (``LengthFinishReasonError`` family) + CONTENT_FILTERED model refused / moderation blocked the response + (``OpenAIRefusalError`` / ``OpenAIModerationError`` / + ``ContentFilterFinishReasonError`` family) + PARSE_FAILED response received but parsing raised + (pydantic ``ValidationError``, ``OutputParserException``, + ``json.JSONDecodeError``, ``APIResponseValidationError``, + ``httpx.DecodingError``) + OTHER_EXCEPTION anything ``classify_call_result`` cannot name + +The set is EXHAUSTIVE: ``classify_call_result`` is a total function from +``BaseException | None`` to ``LLMCallResult``. New providers add MRO names +to the existing rule tables — they do not add outcome classes unless the +bandit contract changes. + +================================================================================ +CLASSIFICATION (total function) +================================================================================ + +The classifier walks the ``__cause__`` / ``__context__`` chain (terminated +by cycle protection via an ``id()`` set, since Python chains are finite in +practice) and classifies each frame. Outcomes from the chain are then +collapsed via priority ranking (see ``_OUTCOME_PRIORITY``) so a wrapper +exception cannot mask a more informative root cause — a langchain +``OutputParserException`` whose ``__cause__`` is an ``openai.RateLimitError`` +lands on ``RATE_LIMITED``, not ``PARSE_FAILED``. + +Per-frame classification (``_classify_one``), in order: + + 1. MRO-based specific overrides (context overflow, content filter, + output-truncated) — these outrank HTTP status because the surface + exception is more specific than the status code (e.g. a 400 that + is really ``OpenAIContextOverflowError``). + 2. HTTP status lookup (``_STATUS_TO_OUTCOME``) — authoritative for + common codes (401, 403, 408, 413, 429, 504); range fallthroughs + cover the rest of 4xx (→ BAD_REQUEST) and 5xx (→ SERVER_5XX). + 3. MRO-based fingerprint match for TIMEOUT / NETWORK_ERROR / + PARSE_FAILED. + 4. ``isinstance`` against built-ins (``TimeoutError``, + ``ConnectionError``) — catches ``asyncio.TimeoutError`` (same + class as ``TimeoutError`` since 3.11) and langchain-openai's + ``StreamChunkTimeoutError`` (extends builtin ``TimeoutError``). + 5. Context-overflow text marker as a last-resort signal. + 6. Otherwise ``OTHER_EXCEPTION``. + +The classifier never imports a provider SDK (openai, httpx, +langchain-openai, langchain-core, pydantic). It inspects ``type(exc).__mro__`` +class names, exception attributes (``status_code``, ``response.status_code``, +``response.headers``) and ``str(exc)`` only. + +================================================================================ +BANDIT ACTION TABLE +================================================================================ + + SUCCESS → DEFER_TO_OUTCOME (reward arrives via on_mutation_outcome) + every other outcome → INJECT_ZERO_REWARD + +The fine-grained distinction between "arm's fault" (PARSE_FAILED, +CONTENT_FILTERED, OUTPUT_TRUNCATED) and "operator's / infra's fault" +(AUTH_FAILED, RATE_LIMITED, SERVER_5XX, NETWORK_ERROR, TIMEOUT, +CONTEXT_OVERFLOW, BAD_REQUEST) is preserved for telemetry and future +circuit-breakers. The bandit ledger currently treats them the same +because mean-reward signal is self-correcting: an AUTH_FAILED storm +penalizes every arm equally, so relative ranking is preserved. + +================================================================================ +INVARIANTS (enforced at module load + tested in tests/llm/test_call_outcome.py) +================================================================================ + + I1. classify_call_result(None).outcome is SUCCESS; no other input + produces SUCCESS. + I2. ``OUTCOME_ACTION`` is closed under ``LLMCallOutcome`` — every + enum member is a key (asserted at module load). + I3. SUCCESS is the only outcome with action DEFER_TO_OUTCOME; every + other outcome has action INJECT_ZERO_REWARD (asserted at module + load). + I4. ``_OUTCOME_PRIORITY`` is a permutation of ``LLMCallOutcome`` + members (asserted at module load). + I5. HTTP status precedence: a 4xx/5xx with class-name fingerprint + for CONTEXT_OVERFLOW / CONTENT_FILTERED / OUTPUT_TRUNCATED wins + over the bare status classification; otherwise status wins over + generic class-name fingerprints. + I6. Cause-chain priority: a wrapper exception with a more + informative cause classifies by the cause, not the wrapper. +""" + +from __future__ import annotations + +from collections.abc import Iterator, Mapping +from enum import StrEnum +import math +import re +from types import MappingProxyType +from typing import Final + +from pydantic import BaseModel, ConfigDict, Field + +# ============================================================================ +# Outcome enums +# ============================================================================ + + +class LLMCallOutcome(StrEnum): + """Terminal outcome of one LLM call attempt. Exhaustive over the failure + modes the router stack currently observes against openai 2.x, + httpx 0.28.x, langchain-openai 1.2.x, langchain-core 1.4.x, and + litellm 1.x (rule tables include the litellm class names that do not + inherit a known base — see ``_CONTEXT_OVERFLOW_MRO_NAMES`` and + ``_CONTENT_FILTER_MRO_NAMES`` comments).""" + + SUCCESS = "success" + RATE_LIMITED = "rate_limited" + SERVER_5XX = "server_5xx" + TIMEOUT = "timeout" + NETWORK_ERROR = "network_error" + AUTH_FAILED = "auth_failed" + BAD_REQUEST = "bad_request" + CONTEXT_OVERFLOW = "context_overflow" + OUTPUT_TRUNCATED = "output_truncated" + CONTENT_FILTERED = "content_filtered" + PARSE_FAILED = "parse_failed" + OTHER_EXCEPTION = "other_exception" + + +class BanditAction(StrEnum): + """Per-outcome bandit ledger reaction.""" + + DEFER_TO_OUTCOME = "defer_to_outcome" + INJECT_ZERO_REWARD = "inject_zero_reward" + + +# ============================================================================ +# Outcome → action mapping (closed under LLMCallOutcome) +# ============================================================================ + + +_OUTCOME_ACTION_DRAFT: dict[LLMCallOutcome, BanditAction] = { + LLMCallOutcome.SUCCESS: BanditAction.DEFER_TO_OUTCOME, + LLMCallOutcome.RATE_LIMITED: BanditAction.INJECT_ZERO_REWARD, + LLMCallOutcome.SERVER_5XX: BanditAction.INJECT_ZERO_REWARD, + LLMCallOutcome.TIMEOUT: BanditAction.INJECT_ZERO_REWARD, + LLMCallOutcome.NETWORK_ERROR: BanditAction.INJECT_ZERO_REWARD, + LLMCallOutcome.AUTH_FAILED: BanditAction.INJECT_ZERO_REWARD, + LLMCallOutcome.BAD_REQUEST: BanditAction.INJECT_ZERO_REWARD, + LLMCallOutcome.CONTEXT_OVERFLOW: BanditAction.INJECT_ZERO_REWARD, + LLMCallOutcome.OUTPUT_TRUNCATED: BanditAction.INJECT_ZERO_REWARD, + LLMCallOutcome.CONTENT_FILTERED: BanditAction.INJECT_ZERO_REWARD, + LLMCallOutcome.PARSE_FAILED: BanditAction.INJECT_ZERO_REWARD, + LLMCallOutcome.OTHER_EXCEPTION: BanditAction.INJECT_ZERO_REWARD, +} + +OUTCOME_ACTION: Final[Mapping[LLMCallOutcome, BanditAction]] = MappingProxyType( + _OUTCOME_ACTION_DRAFT +) + +ZERO_REWARD_OUTCOMES: Final[frozenset[LLMCallOutcome]] = frozenset( + outcome + for outcome, action in OUTCOME_ACTION.items() + if action is BanditAction.INJECT_ZERO_REWARD +) + + +# ============================================================================ +# Cause-chain priority order — lower index wins +# ============================================================================ +# Rationale per row: +# RATE_LIMITED — most actionable infra signal (back off) +# AUTH_FAILED — kill-switch: every arm will fail until the key is fixed +# CONTEXT_OVERFLOW — prompt-level, immediate; do not retry blindly +# TIMEOUT — infra-level transient +# SERVER_5XX — infra-level, possibly transient +# NETWORK_ERROR — infra-level, possibly transient +# BAD_REQUEST — request-level, retries unlikely to help +# OUTPUT_TRUNCATED — arm-level, prompt or budget tweak needed +# CONTENT_FILTERED — arm-level +# PARSE_FAILED — arm-level (wrapper of other failures often hits here) +# OTHER_EXCEPTION — unclassified +# SUCCESS — never appears in a cause walk; placed last for closure + + +_OUTCOME_PRIORITY: Final[tuple[LLMCallOutcome, ...]] = ( + LLMCallOutcome.RATE_LIMITED, + LLMCallOutcome.AUTH_FAILED, + LLMCallOutcome.CONTEXT_OVERFLOW, + LLMCallOutcome.TIMEOUT, + LLMCallOutcome.SERVER_5XX, + LLMCallOutcome.NETWORK_ERROR, + LLMCallOutcome.BAD_REQUEST, + LLMCallOutcome.OUTPUT_TRUNCATED, + LLMCallOutcome.CONTENT_FILTERED, + LLMCallOutcome.PARSE_FAILED, + LLMCallOutcome.OTHER_EXCEPTION, + LLMCallOutcome.SUCCESS, +) + +_OUTCOME_PRIORITY_INDEX: Final[Mapping[LLMCallOutcome, int]] = MappingProxyType( + {outcome: i for i, outcome in enumerate(_OUTCOME_PRIORITY)} +) + + +# ============================================================================ +# Module-load invariants +# ============================================================================ +# These trip at import time if the tables fall out of sync with the enum — +# a regression that would otherwise leak into telemetry silently. + + +assert set(OUTCOME_ACTION) == set(LLMCallOutcome), ( + "OUTCOME_ACTION must cover every LLMCallOutcome member exactly once" +) +assert OUTCOME_ACTION[LLMCallOutcome.SUCCESS] is BanditAction.DEFER_TO_OUTCOME, ( + "SUCCESS must map to DEFER_TO_OUTCOME" +) +assert all( + OUTCOME_ACTION[o] is BanditAction.INJECT_ZERO_REWARD + for o in LLMCallOutcome + if o is not LLMCallOutcome.SUCCESS +), "Every non-SUCCESS outcome must map to INJECT_ZERO_REWARD" +assert set(_OUTCOME_PRIORITY) == set(LLMCallOutcome), ( + "_OUTCOME_PRIORITY must be a permutation of LLMCallOutcome members" +) +assert len(_OUTCOME_PRIORITY) == len(LLMCallOutcome), ( + "_OUTCOME_PRIORITY must contain each outcome exactly once" +) + + +def _assert_disjoint( + *named_sets: tuple[str, frozenset[str]], +) -> None: + """Module-load guard: every class-name set must be pairwise disjoint so a + single class never classifies as two outcomes.""" + seen: dict[str, str] = {} + for set_name, names in named_sets: + for name in names: + if name in seen: + raise AssertionError( + f"Class name {name!r} appears in both {seen[name]!r} and " + f"{set_name!r} — rule tables must be disjoint" + ) + seen[name] = set_name + + +# ============================================================================ +# Classification rule tables (private) +# ============================================================================ +# Class-name sets are matched against the FULL MRO chain of the exception +# (``type(exc).__mro__``), so a subclass of a known base is still recognised +# without naming the leaf. All class names below have been verified against +# the installed package sources for openai 2.36.0, httpx 0.28.1, +# langchain-openai 1.2.1, langchain-core 1.4.0. + + +# HTTP status → outcome (explicit codes only; range fallthroughs handled in code) +_STATUS_TO_OUTCOME: Final[Mapping[int, LLMCallOutcome]] = MappingProxyType( + { + 401: LLMCallOutcome.AUTH_FAILED, + 403: LLMCallOutcome.AUTH_FAILED, + 408: LLMCallOutcome.TIMEOUT, + 413: LLMCallOutcome.CONTEXT_OVERFLOW, + 429: LLMCallOutcome.RATE_LIMITED, + 504: LLMCallOutcome.TIMEOUT, + } +) + +# Context-overflow-specific exception class names (openai + langchain-core + litellm). +# Verified MROs: +# OpenAIContextOverflowError → BadRequestError → APIStatusError → ContextOverflowError +# OpenAIAPIContextOverflowError → APIError → ContextOverflowError +# ContextOverflowError → LangChainException +# litellm.ContextWindowExceededError → BadRequestError → APIStatusError (status_code=400) +# — litellm does NOT inherit from langchain's ContextOverflowError, so the +# class name must appear here explicitly; otherwise the 400 status would +# misclassify as BAD_REQUEST. +_CONTEXT_OVERFLOW_MRO_NAMES: Final[frozenset[str]] = frozenset( + { + "ContextOverflowError", # langchain_core base + "OpenAIContextOverflowError", # langchain-openai + "OpenAIAPIContextOverflowError", # langchain-openai + "ContextWindowExceededError", # litellm + } +) + +# Successful call but output cut off at the token budget. +# LengthFinishReasonError → OpenAIError (openai 2.x) +_OUTPUT_TRUNCATED_MRO_NAMES: Final[frozenset[str]] = frozenset( + { + "LengthFinishReasonError", + } +) + +# Refusal / content filter / moderation. +# OpenAIRefusalError → Exception +# OpenAIModerationError → RuntimeError +# ContentFilterFinishReasonError → OpenAIError +# litellm.ContentPolicyViolationError → BadRequestError (status_code=400) +# — provider-side refusal surfaced as a 400; without this entry it would +# fall through to BAD_REQUEST and lose the "arm-level refusal" signal. +_CONTENT_FILTER_MRO_NAMES: Final[frozenset[str]] = frozenset( + { + "OpenAIRefusalError", + "OpenAIModerationError", + "ContentFilterFinishReasonError", + "ContentPolicyViolationError", # litellm + } +) + +# Timeout exception class names. ``TimeoutError`` (builtin / 3.11+ alias for +# ``asyncio.TimeoutError``) is also covered via isinstance below. +# httpx.TimeoutException, ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout +# openai.APITimeoutError +# langchain_openai.StreamChunkTimeoutError (extends builtin TimeoutError) +_TIMEOUT_MRO_NAMES: Final[frozenset[str]] = frozenset( + { + "TimeoutError", + "TimeoutException", + "ConnectTimeout", + "ReadTimeout", + "WriteTimeout", + "PoolTimeout", + "APITimeoutError", + "StreamChunkTimeoutError", + } +) + +# Network exception class names. ``ConnectionError`` builtin caught via isinstance. +# httpx.NetworkError, ConnectError, ReadError, WriteError, CloseError, +# ProtocolError, LocalProtocolError, RemoteProtocolError, ProxyError, +# UnsupportedProtocol +# openai.APIConnectionError +_NETWORK_MRO_NAMES: Final[frozenset[str]] = frozenset( + { + "ConnectionError", + "NetworkError", + "ConnectError", + "ReadError", + "WriteError", + "CloseError", + "ProtocolError", + "LocalProtocolError", + "RemoteProtocolError", + "ProxyError", + "UnsupportedProtocol", + "APIConnectionError", + "ConnectionRefusedError", + "ConnectionResetError", + } +) + +# Parser / decoder failures. +# pydantic.ValidationError → ValueError +# json.JSONDecodeError → ValueError +# langchain_core.exceptions.OutputParserException → ValueError, LangChainException +# openai.APIResponseValidationError → APIError (carries status_code) +# httpx.DecodingError → RequestError +_PARSE_MRO_NAMES: Final[frozenset[str]] = frozenset( + { + "ValidationError", + "JSONDecodeError", + "OutputParserException", + "APIResponseValidationError", + "DecodingError", + } +) + +# Free-text markers for context overflow when no explicit class fired. These +# are common error-message fragments emitted by providers for input that +# exceeds the model context window. +_CONTEXT_OVERFLOW_MARKERS: Final[tuple[str, ...]] = ( + "context length", + "context window", + "context_length_exceeded", + "maximum context length", + "prompt is too long", + "string too long", + "input is too long", + "reduce the length", +) + +_assert_disjoint( + ("_CONTEXT_OVERFLOW_MRO_NAMES", _CONTEXT_OVERFLOW_MRO_NAMES), + ("_OUTPUT_TRUNCATED_MRO_NAMES", _OUTPUT_TRUNCATED_MRO_NAMES), + ("_CONTENT_FILTER_MRO_NAMES", _CONTENT_FILTER_MRO_NAMES), + ("_TIMEOUT_MRO_NAMES", _TIMEOUT_MRO_NAMES), + ("_NETWORK_MRO_NAMES", _NETWORK_MRO_NAMES), + ("_PARSE_MRO_NAMES", _PARSE_MRO_NAMES), +) + + +# ============================================================================ +# Result schema +# ============================================================================ + + +class LLMCallResult(BaseModel): + """Pydantic record of a single LLM call result. Pure data — no behavior. + + Frozen + extra-forbidden + strict so the schema is a closed contract + that telemetry and the bandit can rely on without defensive copying. + + Field semantics: + ``outcome`` classified outcome (exhaustive over + ``LLMCallOutcome``) + ``exception_class`` ``type(exc).__name__`` of the surface + exception. ``None`` only when ``outcome + is SUCCESS``. + ``exception_module`` full module path of the surface exception. + ``None`` only when ``outcome is SUCCESS`` or + ``__module__`` could not be coerced to ``str``. + ``http_status`` status code recovered from the exception, if + any. ``None`` means "no status was present on + the exception" — NOT "status was zero". + ``retry_after_seconds`` Retry-After header value when the response + carries one and the parse succeeds. Three + distinct values: + * ``None`` — header absent / malformed / + non-finite / above the 24-hour cap; PR-B + callers should treat this as "no hint, + use default backoff". + * ``0.0`` — header present and equal to + zero; provider explicitly said "retry + immediately". PR-B callers MUST NOT + conflate this with ``None``. + * positive float — recommended sleep in + seconds, bounded by ``[0.0, 86400.0]``. + ``message`` ``str(exc)`` (or ``None`` when ``__str__`` + raises or returns empty). UTF-16 surrogates + are scrubbed to ``U+FFFD`` so the field is + always safe for ``model_dump_json`` and + downstream JSON consumers (langfuse, tracker + writes). + ``cause_chain`` Tuple of class names walked through + ``__cause__`` / ``__context__`` with cycle + protection. Index ``0`` is the SURFACE + exception's class name; index ``1``+ are + (cause OR context, cause preferred) frames + walked from the surface outward. The tuple + is non-empty for every failure outcome and + empty only for ``SUCCESS``. + ``model_name`` Bandit-arm name attached by the caller for + telemetry. Never inspected by the classifier + — PR-B passes the routed model identifier + here so the result can be correlated with + the arm that produced it. + + PR-B integration contract: + Callers wrap an LLM call in ``try``/``except BaseException``, pass + the caught exception (or ``None`` on success) to + ``classify_call_result``, then branch on ``result.action``: + ``DEFER_TO_OUTCOME`` (success — reward arrives via + ``on_mutation_outcome``) versus ``INJECT_ZERO_REWARD`` (any + failure). For retry / backoff decisions, callers branch on + ``result.outcome`` (rate-limited vs auth-failed vs parse-failed + all differ) and use ``result.retry_after_seconds`` when set. + """ + + model_config = ConfigDict(frozen=True, extra="forbid", strict=True) + + outcome: LLMCallOutcome + exception_class: str | None = None + exception_module: str | None = None + http_status: int | None = Field(default=None, ge=100, le=599) + retry_after_seconds: float | None = Field(default=None, ge=0.0) + message: str | None = None + cause_chain: tuple[str, ...] = () + model_name: str | None = None + + @property + def action(self) -> BanditAction: + """Bandit ledger reaction for this outcome — pure lookup.""" + return OUTCOME_ACTION[self.outcome] + + @property + def is_failure(self) -> bool: + return self.outcome is not LLMCallOutcome.SUCCESS + + +# ============================================================================ +# Public API +# ============================================================================ + + +def classify_call_result( + exc: BaseException | None, + *, + model_name: str | None = None, +) -> LLMCallResult: + """Total classifier from ``(exc, model_name)`` to ``LLMCallResult``. + + ``exc is None`` ⇔ outcome is ``SUCCESS``. For any non-None exception + the result's outcome is one of the eleven failure variants. + + Args: + exc: the exception raised by the LLM call, or ``None`` for success. + model_name: optional bandit-arm name attached to the result for + telemetry — never inspected by the classifier itself. + + Returns: + Frozen ``LLMCallResult`` whose ``outcome`` is exhaustive over + ``LLMCallOutcome``. + """ + if exc is None: + return LLMCallResult(outcome=LLMCallOutcome.SUCCESS, model_name=model_name) + + outcome = _walk_and_classify(exc) + exc_name, exc_module = _safe_class_metadata(exc) + return LLMCallResult( + outcome=outcome, + exception_class=exc_name, + exception_module=exc_module, + http_status=_extract_status(exc), + retry_after_seconds=_extract_retry_after(exc), + message=_safe_message(exc), + cause_chain=_collect_cause_chain(exc), + model_name=model_name, + ) + + +# ============================================================================ +# Internals +# ============================================================================ + + +def _iter_cause_chain(exc: BaseException) -> Iterator[BaseException]: + """Yield ``exc`` and walk ``__cause__`` / ``__context__`` with cycle + protection. ``__cause__`` wins over ``__context__`` (explicit + ``raise X from Y`` is more informative than the implicit context). + + ``__cause__`` / ``__context__`` are normally C-level slots on + ``BaseException``, but a subclass may override them with a descriptor + that raises (legal Python). Access is funnelled through + ``_safe_getattr`` so a hostile property cannot break classifier + totality. A non-BaseException value (reachable via property overrides + that return arbitrary objects) terminates the walk rather than + poisoning the cycle-detection set.""" + seen: set[int] = set() + cur: BaseException | None = exc + while cur is not None and id(cur) not in seen: + seen.add(id(cur)) + yield cur + cause = _safe_getattr(cur, "__cause__") + nxt = cause if cause is not None else _safe_getattr(cur, "__context__") + cur = nxt if isinstance(nxt, BaseException) else None + + +def _walk_and_classify(exc: BaseException) -> LLMCallOutcome: + """Classify each exception in the cause chain and return the + highest-priority outcome — ensures a wrapper exception cannot mask + a more informative root cause.""" + outcomes = [_classify_one(frame) for frame in _iter_cause_chain(exc)] + return min(outcomes, key=_OUTCOME_PRIORITY_INDEX.__getitem__) + + +def _classify_one(exc: BaseException) -> LLMCallOutcome: + """Classify a single exception in isolation. Returns OTHER_EXCEPTION + when no rule matches.""" + mro_names = _mro_class_names(exc) + message = _safe_message(exc) or "" + + # 1. MRO-based specific overrides (more informative than HTTP status). + if mro_names & _CONTEXT_OVERFLOW_MRO_NAMES: + return LLMCallOutcome.CONTEXT_OVERFLOW + if mro_names & _CONTENT_FILTER_MRO_NAMES: + return LLMCallOutcome.CONTENT_FILTERED + if mro_names & _OUTPUT_TRUNCATED_MRO_NAMES: + return LLMCallOutcome.OUTPUT_TRUNCATED + + # 2. HTTP status (authoritative for well-defined codes). + status = _extract_status(exc) + if status is not None: + explicit = _STATUS_TO_OUTCOME.get(status) + if explicit is not None: + return explicit + if 400 <= status < 500: + # Catch-all for 400 / 404 / 409 / 422 / etc. that did not get + # a more specific override above. Check the overflow text marker + # before defaulting to BAD_REQUEST — some providers return 400 + # with no dedicated exception class for context overflow. + if _matches_overflow_marker(message): + return LLMCallOutcome.CONTEXT_OVERFLOW + return LLMCallOutcome.BAD_REQUEST + if 500 <= status < 600: + return LLMCallOutcome.SERVER_5XX + if 100 <= status < 400: + # Defensive: exception that surfaced despite a success-ish status + # almost always comes from a downstream parser stage. + return LLMCallOutcome.PARSE_FAILED + + # 3. MRO-based class fingerprints (subclass-aware via full MRO walk). + if mro_names & _TIMEOUT_MRO_NAMES: + return LLMCallOutcome.TIMEOUT + if mro_names & _NETWORK_MRO_NAMES: + return LLMCallOutcome.NETWORK_ERROR + if mro_names & _PARSE_MRO_NAMES: + return LLMCallOutcome.PARSE_FAILED + + # 4. Built-in hierarchy fallback. ``TimeoutError`` aliases + # ``asyncio.TimeoutError`` since 3.11 and is also the base for + # ``StreamChunkTimeoutError``. ``ConnectionError`` covers + # ``ConnectionRefusedError``, ``ConnectionResetError``, + # ``ConnectionAbortedError``. + if isinstance(exc, TimeoutError): + return LLMCallOutcome.TIMEOUT + if isinstance(exc, ConnectionError): + return LLMCallOutcome.NETWORK_ERROR + + # 5. Free-text context-overflow marker (last resort). + if _matches_overflow_marker(message): + return LLMCallOutcome.CONTEXT_OVERFLOW + + return LLMCallOutcome.OTHER_EXCEPTION + + +def _mro_class_names(exc: BaseException) -> frozenset[str]: + """Names of every class in the exception's MRO, frozen for set ops. + + A metaclass may override ``__mro__`` with a property that raises + (legal Python). The classifier's totality contract forbids letting + that propagate, so the access is wrapped and a hostile MRO collapses + to an empty fingerprint set — every MRO-based branch in + ``_classify_one`` then naturally skips, and the exception falls + through to the status-code / isinstance / OTHER paths.""" + mro: object = _safe_getattr(type(exc), "__mro__") + if mro is None: + return frozenset() + try: + return frozenset(cls.__name__ for cls in mro) # type: ignore[union-attr] + except Exception: + return frozenset() + + +def _safe_str_attr(obj: object, attr: str) -> str | None: + """Return ``getattr(obj, attr)`` coerced to ``str`` with paranoid + fallback. A metaclass that sets ``__module__`` to an int (legal in + Python) would otherwise inject a non-str into the result schema and + raise ``ValidationError`` from inside the classifier, violating the + documented totality contract.""" + raw = _safe_getattr(obj, attr) + if raw is None: + return None + if isinstance(raw, str): + return raw + try: + return str(raw) + except Exception: + return None + + +def _safe_class_metadata(exc: BaseException) -> tuple[str | None, str | None]: + """Return ``(name, module)`` for ``type(exc)`` with defensive coercion.""" + cls = type(exc) + return _safe_str_attr(cls, "__name__"), _safe_str_attr(cls, "__module__") + + +def _safe_getattr(obj: object, name: str) -> object | None: + """``getattr(obj, name, None)`` that also suppresses non-AttributeError + exceptions raised by hostile property/descriptor implementations. + + The stdlib ``getattr(..., default)`` only catches AttributeError; if a + property's getter raises (e.g.) ``RuntimeError``, the call propagates. + Provider SDK exception classes are well-behaved, but we touch attributes + on whatever object the caller hands us, so the classifier must not + crash on a misbehaving third-party descriptor.""" + try: + return getattr(obj, name, None) + except Exception: + return None + + +def _extract_status(exc: BaseException) -> int | None: + """Recover an HTTP status code from common SDK attributes. + + ``type(x) is int`` rather than ``isinstance(x, int)`` so a ``bool`` + masquerading as int (Python quirk) is rejected. Restricts to the + 100-599 HTTP range so a stray ``status_code = -1`` sentinel is also + rejected.""" + candidate = _safe_getattr(exc, "status_code") + if type(candidate) is int and 100 <= candidate <= 599: + return candidate + response = _safe_getattr(exc, "response") + if response is not None: + rs = _safe_getattr(response, "status_code") + if type(rs) is int and 100 <= rs <= 599: + return rs + return None + + +def _extract_retry_after(exc: BaseException) -> float | None: + """Extract a Retry-After header value (seconds) from the exception's + response, if present. Returns ``None`` when the header is missing, + malformed, or the response object lacks a ``headers`` attribute. + + Only handles the integer-seconds form. HTTP-date form is not + recognised — callers should treat ``None`` as "no hint available".""" + response = _safe_getattr(exc, "response") + if response is None: + return None + headers = _safe_getattr(response, "headers") + if headers is None: + return None + get = _safe_getattr(headers, "get") + if not callable(get): + return None + # httpx Headers is case-insensitive; a plain dict isn't, so try both. + try: + raw: object = get("retry-after") + if raw is None: + raw = get("Retry-After") + except Exception: + return None + if raw is None or isinstance(raw, bool): + return None + if type(raw) is int or type(raw) is float: + value = float(raw) + elif isinstance(raw, str): + try: + value = float(raw) + except ValueError: + return None + else: + return None + # Reject non-finite values: ``"inf"`` / ``"Infinity"`` / ``"nan"`` / + # numbers that overflow ``float`` would otherwise propagate as + # ``math.inf`` or ``math.nan`` into a downstream sleep budget and + # diverge from telemetry (JSON serializers emit ``null`` for + # non-finite floats, so the in-memory value and the wire value + # would disagree silently). Negative values are rejected because + # the field semantics are "seconds to wait". Any finite non-negative + # value passes through; the consumer is the right arbiter of a + # maximum acceptable sleep. + if not math.isfinite(value) or value < 0.0: + return None + return value + + +def _safe_message(exc: BaseException) -> str | None: + """``str(exc)`` with paranoid handling of pathological ``__str__``. + + UTF-16 surrogate code points (``U+D800``-``U+DFFF``) are replaced with + ``U+FFFD``. Python ``str`` permits surrogates but UTF-8 encoders and + JSON serializers refuse them; without this scrub the classifier could + return a result whose ``message`` field crashes ``model_dump_json`` + downstream. Provider response messages occasionally include broken + surrogates from misbehaving tokenizers or sliced multibyte echoes. + """ + try: + text = str(exc) + except Exception: + return None + if not text: + return None + return _strip_surrogates(text) + + +def _strip_surrogates(text: str) -> str: + """Replace every UTF-16 surrogate code point with ``U+FFFD``. + + Replaces unconditionally rather than discriminating "lone" from "valid + pair": Python ``str`` is sequence-of-code-points, not UTF-16, so a + high+low pair is two independent code points that UTF-8 still rejects, + not one astral character. Real astral characters live in a single code + point above ``U+FFFF`` and never appear as a surrogate pair in Python + ``str``.""" + return _SURROGATE_RE.sub("�", text) + + +_SURROGATE_RE = re.compile(r"[\ud800-\udfff]") + + +def _matches_overflow_marker(message: str) -> bool: + """``True`` when *message* contains a context-overflow marker substring.""" + if not message: + return False + lowered = message.lower() + return any(marker in lowered for marker in _CONTEXT_OVERFLOW_MARKERS) + + +def _collect_cause_chain(exc: BaseException) -> tuple[str, ...]: + """Class names of every exception in the cause chain.""" + return tuple(type(frame).__name__ for frame in _iter_cause_chain(exc)) + + +# ============================================================================ +# Public invariant documentation (programmatic access for tests) +# ============================================================================ + + +INVARIANTS: Final[tuple[str, ...]] = ( + "classify_call_result(None).outcome is SUCCESS; no other input produces SUCCESS.", + "OUTCOME_ACTION is closed under LLMCallOutcome — every member is a key.", + "SUCCESS is the only outcome with action DEFER_TO_OUTCOME; every other " + "outcome has action INJECT_ZERO_REWARD.", + "_OUTCOME_PRIORITY is a permutation of LLMCallOutcome members.", + "HTTP status precedence: a 4xx/5xx with class-name fingerprint for " + "CONTEXT_OVERFLOW / CONTENT_FILTERED / OUTPUT_TRUNCATED wins over the " + "bare status classification; otherwise status wins over generic " + "class-name fingerprints.", + "Cause-chain priority: a wrapper exception with a more informative " + "cause classifies by the cause, not the wrapper.", + "The classifier never imports a provider SDK (openai, httpx, langchain-*, " + "pydantic). It inspects MRO class names and exception attributes only.", + "LLMCallResult is frozen, extra-forbidden, and strict — the result is a " + "closed contract.", +) + + +__all__ = ( + "BanditAction", + "INVARIANTS", + "LLMCallOutcome", + "LLMCallResult", + "OUTCOME_ACTION", + "ZERO_REWARD_OUTCOMES", + "classify_call_result", +) diff --git a/gigaevo/llm/models.py b/gigaevo/llm/models.py index 7a810372..dc1a4bf6 100644 --- a/gigaevo/llm/models.py +++ b/gigaevo/llm/models.py @@ -278,6 +278,7 @@ def __init__( tracker: TokenTracker, task_model_map: dict[int, str] | None = None, select_override: Callable[[], tuple[Any, str]] | None = None, + failure_hook: Callable[[BaseException, str], None] | None = None, ): self._models = models self._names = model_names @@ -286,6 +287,12 @@ def __init__( self._tracker = tracker self._task_model_map = task_model_map self._select_override = select_override + # Called when ``model.{,a}invoke`` raises. ``BanditModelRouter`` uses + # it to inject a zero reward into the ledger so a failed pull does + # not silently inflate ``total_pulls`` without a matching window + # entry. The hook receives the exception and the selected arm name; + # it must not re-raise (the original exception still propagates). + self._failure_hook = failure_hook def _select(self) -> tuple[Any, str]: if self._select_override is not None: @@ -307,20 +314,60 @@ def _config( def _process(self, response: dict, name: str) -> Any: if raw := response.get("raw"): self._tracker.track(raw, name) - return response.get("parsed") + parsing_error = response.get("parsing_error") + parsed = response.get("parsed") + if parsing_error is not None and parsed is None: + # ``include_raw=True`` makes langchain surface schema-validation + # failures as ``response['parsing_error']`` with ``parsed=None`` + # instead of raising. Returning ``None`` here would silently + # bypass the caller's ``try / except`` and the bandit's + # failure_hook would never fire — the pull was recorded by + # ``_select`` but the reward window would never get a matching + # entry. Raise the parsing_error so the call site routes it + # through the existing failure path. + raise parsing_error + return parsed def invoke( self, input: LanguageModelInput, config: RunnableConfig | None = None, **kwargs ) -> Any: model, name = self._select() - return self._process( - model.invoke(input, self._config(config, name), **kwargs), name - ) + try: + response = model.invoke(input, self._config(config, name), **kwargs) + # ``_process`` runs the token tracker and unwraps the parsed + # Pydantic object. Either step can raise (telemetry-side bug, + # malformed structured response, missing parsed field). Treat + # those failures as call failures for ledger-symmetry purposes + # so the failure_hook fires. + return self._process(response, name) + except BaseException as exc: + self._maybe_fire_failure_hook(exc, name) + raise async def ainvoke( self, input: LanguageModelInput, config: RunnableConfig | None = None, **kwargs ) -> Any: model, name = self._select() - return self._process( - await model.ainvoke(input, self._config(config, name), **kwargs), name - ) + try: + response = await model.ainvoke(input, self._config(config, name), **kwargs) + return self._process(response, name) + except BaseException as exc: + self._maybe_fire_failure_hook(exc, name) + raise + + def _maybe_fire_failure_hook(self, exc: BaseException, name: str) -> None: + if self._failure_hook is None: + return + try: + self._failure_hook(exc, name) + except Exception as hook_exc: # noqa: BLE001 — observability-only + # The hook is observability-only; it must never swallow or + # mutate the original exception. Suppress any hook-side error + # so the caller still sees the real LLM failure — but emit a + # warning so a buggy hook does not silently lose telemetry. + logger.warning( + "[_StructuredOutputRouter] failure_hook for arm {!r} raised " + "{!r}; original LLM exception preserved.", + name, + hook_exc, + ) diff --git a/gigaevo/llm/token_tracking.py b/gigaevo/llm/token_tracking.py index 2e912852..39fea0e6 100644 --- a/gigaevo/llm/token_tracking.py +++ b/gigaevo/llm/token_tracking.py @@ -7,6 +7,36 @@ from gigaevo.utils.trackers.base import LogWriter +def _coerce_int(value: Any) -> int: + """Coerce a token-count value to ``int``, defaulting to ``0``. + + Hostile providers occasionally return ``None``, strings, floats with + non-integer fractions, or arbitrary objects in their ``usage`` payload. + Token counts are summed downstream into a Pydantic ``int`` field + (``TokenUsage.cumulative``), so an uncoerced non-int would raise on the + first ``cum.context += usage.context`` add. Coerce defensively and + swallow conversion errors as ``0`` — token telemetry is observability- + only and must never crash the LLM call site. + """ + if isinstance(value, bool): + # ``bool`` is a subclass of ``int`` — accept silently as 0/1 for + # the rare provider that returns a flag instead of a count. + return int(value) + if isinstance(value, int): + return value + if isinstance(value, float): + try: + return int(value) + except (ValueError, OverflowError): + return 0 + if isinstance(value, str): + try: + return int(value) + except ValueError: + return 0 + return 0 + + class TokenUsage(BaseModel): """Token counts for a single LLM call.""" @@ -17,33 +47,43 @@ class TokenUsage(BaseModel): @classmethod def from_response(cls, response: Any) -> "TokenUsage | None": - """Extract token usage from LLM response metadata.""" - if not hasattr(response, "response_metadata") or not response.response_metadata: + """Extract token usage from LLM response metadata. + + Defensive against hostile / malformed payloads: a provider that + returns a string (or any non-dict) in ``completion_tokens_details``, + ``token_usage``, or ``usage`` must not propagate an ``AttributeError`` + into the call site. The bandit's ``_safe_track`` already swallows + such failures, but direct callers (``MultiModelRouter`` in + ``gigaevo/llm/models.py``) have no second-level guard, so the + hardening lives here. + """ + metadata = getattr(response, "response_metadata", None) + if not metadata or not isinstance(metadata, dict): return None - usage = response.response_metadata.get( - "token_usage" - ) or response.response_metadata.get("usage") - if not usage: + usage = metadata.get("token_usage") or metadata.get("usage") + if not usage or not isinstance(usage, dict): return None - # Extract reasoning tokens - try multiple possible field names/structures + # Extract reasoning tokens - try multiple possible field names/structures. + # Each branch tolerates non-dict / non-int values from hostile providers. reasoning = 0 # OpenAI o1/o3 style: completion_tokens_details.reasoning_tokens - if details := usage.get("completion_tokens_details"): - reasoning = details.get("reasoning_tokens", 0) or 0 + details = usage.get("completion_tokens_details") + if isinstance(details, dict): + reasoning = _coerce_int(details.get("reasoning_tokens")) # Direct field (some providers) if not reasoning: - reasoning = usage.get("reasoning_tokens", 0) or 0 + reasoning = _coerce_int(usage.get("reasoning_tokens")) # Qwen/thinking models might use different names if not reasoning: - reasoning = usage.get("thinking_tokens", 0) or 0 + reasoning = _coerce_int(usage.get("thinking_tokens")) return cls( - context=usage.get("prompt_tokens", 0), - generated=usage.get("completion_tokens", 0), + context=_coerce_int(usage.get("prompt_tokens")), + generated=_coerce_int(usage.get("completion_tokens")), reasoning=reasoning, - total=usage.get("total_tokens", 0), + total=_coerce_int(usage.get("total_tokens")), ) diff --git a/tests/evolution/test_bandit.py b/tests/evolution/test_bandit.py index 85273556..da249f7a 100644 --- a/tests/evolution/test_bandit.py +++ b/tests/evolution/test_bandit.py @@ -1307,3 +1307,835 @@ async def test_mget_called_with_exact_parent_ids(self) -> None: await op.on_program_ingested(child, mock_storage) mock_storage.mget.assert_called_once_with(parent_ids) + + +# --------------------------------------------------------------------------- +# Classifier-driven failure dispatch through invoke / ainvoke +# --------------------------------------------------------------------------- + + +class TestBanditFailureDispatchViaClassifier: + """``_select`` records the pull before the LLM call. A failure between + those two points used to inflate ``total_pulls`` with no matching reward + entry, shrinking the UCB1 confidence term for the failing arm and + underexploring flaky models. The new dispatch wraps the LLM call, + classifies the exception via ``classify_call_result``, and injects a + zero reward via ``_inject_failure_reward`` so pulls and the reward + window stay in step on every failure path. The exception still + propagates.""" + + def _router_with_flaky_arm( + self, exc: BaseException + ) -> tuple[BanditModelRouter, MagicMock]: + flaky = MagicMock() + flaky.model_name = "flaky" + flaky.invoke = MagicMock(side_effect=exc) + flaky.ainvoke = AsyncMock(side_effect=exc) + flaky.with_structured_output = MagicMock(return_value=MagicMock()) + + healthy = MagicMock() + healthy.model_name = "healthy" + healthy.with_structured_output = MagicMock(return_value=MagicMock()) + + router = BanditModelRouter( + [flaky, healthy], + [0.5, 0.5], + fitness_key="score", + higher_is_better=True, + ) + router._langfuse = None + router._bandit.select = lambda: "flaky" # type: ignore[assignment] + return router, flaky + + def test_sync_invoke_failure_records_zero_reward_and_propagates( + self, + ) -> None: + router, _flaky = self._router_with_flaky_arm(RuntimeError("rate limited")) + + with pytest.raises(RuntimeError, match="rate limited"): + router.invoke("hello") + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + async def test_async_ainvoke_failure_records_zero_reward_and_propagates( + self, + ) -> None: + router, _flaky = self._router_with_flaky_arm(RuntimeError("rate limited")) + + with pytest.raises(RuntimeError, match="rate limited"): + await router.ainvoke("hello") + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + async def test_repeated_ainvoke_failures_keep_ledgers_in_step(self) -> None: + router, _flaky = self._router_with_flaky_arm(RuntimeError("boom")) + + for _ in range(7): + with pytest.raises(RuntimeError): + await router.ainvoke("hello") + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 7 + assert stats["flaky"]["window_size"] == 7 + + async def test_successful_call_does_not_inject_immediate_reward(self) -> None: + # The success path defers the reward to on_mutation_outcome, which + # runs later with the fitness result. + model = MagicMock() + model.model_name = "ok" + model.ainvoke = AsyncMock(return_value=MagicMock()) + model.with_structured_output = MagicMock(return_value=MagicMock()) + + router = BanditModelRouter( + [model], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + + await router.ainvoke("hello") + + stats = router.get_bandit_stats() + assert stats["ok"]["total_pulls"] == 1 + # No reward entry yet — on_mutation_outcome drives the real reward. + assert stats["ok"]["window_size"] == 0 + + def test_structured_output_failure_also_injects_zero_reward(self) -> None: + # The bandit's with_structured_output wires the failure_hook through + # to _StructuredOutputRouter so the structured-output dispatch path + # gets the same ledger-symmetry guarantee. + flaky = MagicMock() + flaky.model_name = "flaky" + flaky.with_structured_output = MagicMock(return_value=flaky) + flaky.invoke = MagicMock(side_effect=RuntimeError("structured failure")) + + router = BanditModelRouter( + [flaky], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + + structured = router.with_structured_output(dict) + with pytest.raises(RuntimeError, match="structured failure"): + structured.invoke("hello") + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + def test_inject_failure_reward_skips_unknown_arm_silently(self) -> None: + # Defense-in-depth symmetry with on_mutation_outcome's unknown-arm + # guard. _select cannot normally return a name outside + # self._bandit.arms, but if a future caller invokes + # _inject_failure_reward directly with a stale name (loaded from a + # snapshot, hand-built in a test, etc.) the helper must not raise + # KeyError on top of the original exception. + router, _ = self._router_with_flaky_arm(RuntimeError("boom")) + + # Should not raise. + router._inject_failure_reward(RuntimeError("orig"), "not_an_arm") + + stats = router.get_bandit_stats() + assert stats["flaky"]["window_size"] == 0 + assert stats["healthy"]["window_size"] == 0 + + +# --------------------------------------------------------------------------- +# Re-raise integrity: cause chain, context, traceback survive failure hook +# --------------------------------------------------------------------------- + + +class TestBanditFailureReRaiseIntegrity: + """The classifier dispatch wraps the call in ``except BaseException`` and + bare-``raise``s after injecting the zero reward. The original exception + object, its traceback, ``__cause__`` (explicit ``raise X from Y``), and + ``__context__`` (implicit chain) must all survive untouched — otherwise + higher-level retry / logging layers cannot tell the real failure apart + from a bandit-side error.""" + + def _make_router_raising(self, exc: BaseException) -> BanditModelRouter: + flaky = MagicMock() + flaky.model_name = "flaky" + flaky.invoke = MagicMock(side_effect=exc) + flaky.ainvoke = AsyncMock(side_effect=exc) + flaky.with_structured_output = MagicMock(return_value=MagicMock()) + router = BanditModelRouter( + [flaky], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + router._bandit.select = lambda: "flaky" # type: ignore[assignment] + return router + + def test_sync_invoke_preserves_explicit_cause_chain(self) -> None: + root = ValueError("root cause") + try: + raise RuntimeError("surface") from root + except RuntimeError as e: + surface = e + router = self._make_router_raising(surface) + + with pytest.raises(RuntimeError, match="surface") as exc_info: + router.invoke("hello") + + # Same exception object identity is the strictest possible check. + assert exc_info.value is surface + assert exc_info.value.__cause__ is root + assert exc_info.value.__traceback__ is not None + + async def test_async_ainvoke_preserves_explicit_cause_chain(self) -> None: + root = ValueError("root cause async") + try: + raise RuntimeError("surface async") from root + except RuntimeError as e: + surface = e + router = self._make_router_raising(surface) + + with pytest.raises(RuntimeError, match="surface async") as exc_info: + await router.ainvoke("hello") + assert exc_info.value is surface + assert exc_info.value.__cause__ is root + + +# --------------------------------------------------------------------------- +# Hardened failure-hook: classifier-internal errors must not mask LLM exc +# --------------------------------------------------------------------------- + + +class TestBanditFailureHookErrorContainment: + """If the classifier itself raises (e.g. a future schema change that + refuses some attribute), ``_inject_failure_reward`` must not let the new + exception replace the original LLM failure. Same for downstream + normalizer / bandit calls. The original exception is what the caller + asked to handle; the bandit hook is observability only.""" + + def _make_router(self) -> BanditModelRouter: + flaky = MagicMock() + flaky.model_name = "flaky" + flaky.invoke = MagicMock(side_effect=RuntimeError("real failure")) + flaky.with_structured_output = MagicMock(return_value=MagicMock()) + router = BanditModelRouter( + [flaky], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + router._bandit.select = lambda: "flaky" # type: ignore[assignment] + return router + + def test_classifier_raising_does_not_mask_real_exception(self) -> None: + router = self._make_router() + with patch( + "gigaevo.llm.bandit.classify_call_result", + side_effect=ValueError("classifier exploded"), + ): + with pytest.raises(RuntimeError, match="real failure"): + router.invoke("hello") + + +# --------------------------------------------------------------------------- +# Structured-output: _process raising after a successful invoke +# --------------------------------------------------------------------------- + + +class TestStructuredOutputProcessFailureFiresHook: + """``_StructuredOutputRouter.invoke`` wraps ``model.invoke`` in + try/except so the failure_hook fires on transport errors. But the + response goes through ``_process`` afterward to extract the parsed + Pydantic object, and the response dict may contain a parser exception + that the langchain wrapper surfaces as a ``response['parsing_error']`` + or by raising directly. If ``_process`` raises, the bandit ledger + must still be told (otherwise we have an inflated pull count with no + matching reward — exactly the desync the wiring was supposed to + prevent).""" + + def test_structured_process_failure_fires_failure_hook(self) -> None: + flaky = MagicMock() + flaky.model_name = "flaky" + + # Return a malformed response that crashes _process. The simplest + # repro: response.get("raw") evaluates fine, but then we patch the + # tracker.track to raise — same effect, no need to mock pydantic. + flaky.with_structured_output = MagicMock(return_value=flaky) + flaky.invoke = MagicMock(return_value={"raw": MagicMock(), "parsed": None}) + + router = BanditModelRouter( + [flaky], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + broken_tracker = MagicMock() + broken_tracker.track = MagicMock(side_effect=RuntimeError("track exploded")) + router._tracker = broken_tracker # type: ignore[assignment] + + structured = router.with_structured_output(dict) + with pytest.raises(RuntimeError, match="track exploded"): + structured.invoke("hello") + + # The bandit ledger must be in step: one pull, one reward injection. + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + +# --------------------------------------------------------------------------- +# Success-path tracker exception must not leave ledger out of step +# --------------------------------------------------------------------------- + + +class TestBanditSuccessPathTrackerFailure: + """The success path defers reward to ``on_mutation_outcome``. But if + ``self._tracker.track`` raises (malformed token_usage from a hostile + provider, telemetry-side bug, etc.) the exception leaks back to the + caller without the bandit having recorded a reward — same desync as + the original failure case. The fix: tracker errors should not be + treated as bandit failures (the LLM call succeeded), but the caller + deserves a usable response. We swallow tracker errors and continue.""" + + def test_sync_success_with_tracker_exception_returns_response(self) -> None: + response = MagicMock() + model = MagicMock() + model.model_name = "ok" + model.invoke = MagicMock(return_value=response) + model.with_structured_output = MagicMock(return_value=MagicMock()) + + router = BanditModelRouter( + [model], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + broken_tracker = MagicMock() + broken_tracker.track = MagicMock(side_effect=RuntimeError("telemetry exploded")) + router._tracker = broken_tracker # type: ignore[assignment] + + # Tracker errors are telemetry; the caller must still get the LLM + # response and the bandit must record the pull. The deferred reward + # arrives via on_mutation_outcome. + result = router.invoke("hello") + assert result is response + + stats = router.get_bandit_stats() + assert stats["ok"]["total_pulls"] == 1 + # No reward yet — deferred to on_mutation_outcome. + assert stats["ok"]["window_size"] == 0 + + async def test_async_success_with_tracker_exception_returns_response( + self, + ) -> None: + response = MagicMock() + model = MagicMock() + model.model_name = "ok" + model.ainvoke = AsyncMock(return_value=response) + model.with_structured_output = MagicMock(return_value=MagicMock()) + + router = BanditModelRouter( + [model], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + broken_tracker = MagicMock() + broken_tracker.track = MagicMock(side_effect=RuntimeError("telemetry exploded")) + router._tracker = broken_tracker # type: ignore[assignment] + + result = await router.ainvoke("hello") + assert result is response + + stats = router.get_bandit_stats() + assert stats["ok"]["total_pulls"] == 1 + assert stats["ok"]["window_size"] == 0 + + +# --------------------------------------------------------------------------- +# ContextVar regression: get_selected_model after bandit routing +# --------------------------------------------------------------------------- + + +class TestBanditContextVarPropagation: + """``MultiModelRouter._select`` calls ``_remember_selected_model`` so that + downstream consumers (``MutationAgent``, ``BaseAgent``) can read the + selected model via ``get_selected_model()``. ``BanditModelRouter._select`` + was wired without that call, so any agent stack that consumes + ``get_selected_model()`` would see a stale value (whatever the last + non-bandit selection left in the ContextVar, or ``None``). + """ + + async def test_get_selected_model_returns_bandit_arm(self) -> None: + from gigaevo.llm.models import get_selected_model + + models = _make_mock_models(["arm_a", "arm_b"]) + router = BanditModelRouter( + models, [0.5, 0.5], fitness_key="score", higher_is_better=True + ) + + async def _run() -> str | None: + router._bandit.select = lambda: "arm_b" # type: ignore[assignment] + router._select() + return get_selected_model() + + result = await _run() + assert result == "arm_b" + + async def test_structured_select_override_sets_context_var(self) -> None: + from gigaevo.llm.models import get_selected_model + + models = _make_mock_models(["arm_a", "arm_b"]) + router = BanditModelRouter( + models, [0.5, 0.5], fitness_key="score", higher_is_better=True + ) + + async def _run() -> str | None: + router._bandit.select = lambda: "arm_a" # type: ignore[assignment] + structured = router.with_structured_output(dict) + structured._select() + return get_selected_model() + + result = await _run() + assert result == "arm_a" + + +class TestBanditStreamingFailureDispatch: + """``stream`` and ``astream`` are inherited from ``MultiModelRouter``. + Without an override on ``BanditModelRouter`` they call ``_select`` (which + records the pull through the bandit's overridden ``_select``) but then + iterate ``model.{,a}stream`` with no try/except — a mid-stream failure + would inflate ``total_pulls`` for the failing arm without a matching + window entry, exactly the asymmetry the classifier wiring exists to + eliminate. Streaming must follow the same ledger-symmetry contract as + ``invoke``/``ainvoke``.""" + + def _flaky_streaming_router( + self, exc: BaseException + ) -> tuple[BanditModelRouter, MagicMock]: + flaky = MagicMock() + flaky.model_name = "flaky" + + def _sync_stream(*_args, **_kwargs): + raise exc + + async def _async_stream(*_args, **_kwargs): + raise exc + yield # pragma: no cover — unreachable, makes this an async generator + + flaky.stream = MagicMock(side_effect=_sync_stream) + flaky.astream = _async_stream + flaky.with_structured_output = MagicMock(return_value=MagicMock()) + + router = BanditModelRouter( + [flaky], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + return router, flaky + + def test_sync_stream_failure_records_zero_reward_and_propagates(self) -> None: + router, _flaky = self._flaky_streaming_router(RuntimeError("stream boom")) + + with pytest.raises(RuntimeError, match="stream boom"): + # ``stream`` is a generator — exhaust it to trigger the call. + for _ in router.stream("hello"): + pass + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + # Without the override the window would still be empty here, leaving + # ``total_pulls`` and ``window_size`` permanently out of step. + assert stats["flaky"]["window_size"] == 1 + + async def test_async_astream_failure_records_zero_reward_and_propagates( + self, + ) -> None: + router, _flaky = self._flaky_streaming_router(RuntimeError("astream boom")) + + with pytest.raises(RuntimeError, match="astream boom"): + async for _ in router.astream("hello"): + pass + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + +class TestInjectFailureRewardCannotMaskOriginalException: + """``BanditModelRouter.invoke``/``ainvoke`` call ``_inject_failure_reward`` + from inside an ``except`` block. If the hook itself raises (a logger + blowing up, a corrupted normalizer, a classifier regression, etc.) the + naive ``raise`` at the end would surface the *hook's* exception instead + of the original LLM failure — and the original traceback would be lost. + The ``_StructuredOutputRouter`` path already protects against this via + ``_maybe_fire_failure_hook``; the direct path must follow suit.""" + + def _router_with_broken_hook(self) -> BanditModelRouter: + model = MagicMock() + model.model_name = "m" + model.invoke = MagicMock(side_effect=RuntimeError("real failure")) + model.ainvoke = AsyncMock(side_effect=RuntimeError("real failure")) + model.with_structured_output = MagicMock(return_value=MagicMock()) + + router = BanditModelRouter( + [model], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + + def _explode(_exc, _name): + raise RuntimeError("hook bug — must not mask original") + + router._inject_failure_reward = _explode # type: ignore[assignment] + return router + + def test_sync_invoke_surfaces_original_exception_even_if_hook_explodes( + self, + ) -> None: + router = self._router_with_broken_hook() + with pytest.raises(RuntimeError, match="real failure"): + router.invoke("hi") + + async def test_ainvoke_surfaces_original_exception_even_if_hook_explodes( + self, + ) -> None: + router = self._router_with_broken_hook() + with pytest.raises(RuntimeError, match="real failure"): + await router.ainvoke("hi") + + +# --------------------------------------------------------------------------- +# Structured-output: failure_hook errors are warned (not silently swallowed) +# --------------------------------------------------------------------------- + + +class TestStructuredOutputFailureHookErrorsAreLogged: + """``_StructuredOutputRouter._maybe_fire_failure_hook`` suppresses any + exception raised by the hook so the original LLM failure still + propagates. Previously the suppression was silent (`except Exception: + pass`), losing telemetry whenever the hook itself had a bug. The + suppression must remain (the hook is observability-only and must not + mask the real failure), but the suppressed error has to be visible at + warning level so a hook regression does not vanish into the void.""" + + def test_warning_emitted_when_failure_hook_raises(self) -> None: + from gigaevo.llm.models import _StructuredOutputRouter + + flaky = MagicMock() + flaky.invoke = MagicMock(side_effect=RuntimeError("real failure")) + + def _explode_hook(_exc: BaseException, _name: str) -> None: + raise RuntimeError("hook bug") + + router = _StructuredOutputRouter( + [flaky], + ["m"], + [1.0], + None, + MagicMock(), + failure_hook=_explode_hook, + ) + + with patch("gigaevo.llm.models.logger.warning") as mock_warning: + with pytest.raises(RuntimeError, match="real failure"): + router.invoke("hi") + + # The hook exception must have been logged at warning level so a + # hook regression is visible in operator logs. + assert mock_warning.called + # The warning payload must reference the hook exception so the + # operator can identify the broken hook from logs alone. + call_args = mock_warning.call_args + payload = repr(call_args) + assert "hook bug" in payload or "RuntimeError" in payload + + +# --------------------------------------------------------------------------- +# Pre-_select exceptions: ledger invariant must hold +# --------------------------------------------------------------------------- + + +class TestPreSelectFailureLedgerInvariant: + """If ``_select`` itself raises (e.g. a corrupted bandit state where + ``self._bandit.select()`` blows up), the LLM call never happens and the + try/except inside ``invoke`` never engages. The invariant the wiring + promises is "pulls and rewards stay in step": if no pull was recorded + (``record_pull`` is called *inside* ``_select``), no reward must be + injected either. Verifying explicitly so a future refactor that moves + ``record_pull`` outside ``_select`` doesn't quietly break the + invariant.""" + + def test_select_failure_before_record_pull_leaves_ledger_clean(self) -> None: + models = _make_mock_models(["arm_a"]) + router = BanditModelRouter( + models, [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + + # Force bandit.select to raise *before* record_pull has a chance. + def _explode() -> str: + raise RuntimeError("bandit corrupted") + + router._bandit.select = _explode # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="bandit corrupted"): + router.invoke("hello") + + # No pull recorded, no reward injected — ledgers in step. + stats = router.get_bandit_stats() + assert stats["arm_a"]["total_pulls"] == 0 + assert stats["arm_a"]["window_size"] == 0 + + async def test_aselect_failure_before_record_pull_leaves_ledger_clean( + self, + ) -> None: + models = _make_mock_models(["arm_a"]) + router = BanditModelRouter( + models, [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + + def _explode() -> str: + raise RuntimeError("bandit corrupted") + + router._bandit.select = _explode # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="bandit corrupted"): + await router.ainvoke("hello") + + stats = router.get_bandit_stats() + assert stats["arm_a"]["total_pulls"] == 0 + assert stats["arm_a"]["window_size"] == 0 + + +# --------------------------------------------------------------------------- +# Engine / mutation-operator integration: no double-injection on failure +# --------------------------------------------------------------------------- + + +class TestBanditNoDoubleInjectionOnFailure: + """Audit hypothesis 3: when an LLM call inside the mutation operator + fails, the bandit must receive **exactly one** ledger entry — the + immediate ``_inject_failure_reward`` from ``ainvoke``. A failed + mutation never reaches persistence (``generate_and_persist_mutation`` + returns ``None`` on ``MutationError``), which means the engine never + fires ``_notify_hook`` and therefore never calls + ``on_program_ingested`` → ``on_mutation_outcome``. This test pins + that invariant: a future refactor that, say, persists a placeholder + program for failed mutations would break it by causing a second + reward entry to land in the window for the same pull. + """ + + async def test_failed_ainvoke_injects_exactly_one_zero_reward( + self, + ) -> None: + flaky = MagicMock() + flaky.model_name = "flaky" + flaky.ainvoke = AsyncMock(side_effect=RuntimeError("provider 500")) + flaky.with_structured_output = MagicMock(return_value=MagicMock()) + router = BanditModelRouter( + [flaky], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + + with pytest.raises(RuntimeError, match="provider 500"): + await router.ainvoke("hello") + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + # Exactly one reward entry — the failure-path zero injection. + assert stats["flaky"]["window_size"] == 1 + # And it really is a zero (normalizer is in warmup so it returns 0.5 + # for the first sample, but it appended exactly one). + assert len(router._bandit.arms["flaky"].rewards) == 1 + + async def test_on_mutation_outcome_not_invoked_when_failed_ainvoke( + self, + ) -> None: + """The engine only fires ``on_mutation_outcome`` for persisted + programs. A failed LLM call → ``MutationError`` → no persistence + → no outcome callback. We simulate the production path: the + bandit sees the failure injection, and ``on_mutation_outcome`` + is never invoked for that same call. + """ + flaky = MagicMock() + flaky.model_name = "flaky" + flaky.ainvoke = AsyncMock(side_effect=RuntimeError("timeout")) + flaky.with_structured_output = MagicMock(return_value=MagicMock()) + router = BanditModelRouter( + [flaky], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + + seen_outcomes: list[MutationOutcome] = [] + original = router.on_mutation_outcome + + def _spy(program, parents, outcome=MutationOutcome.ACCEPTED): # noqa: ANN001 + seen_outcomes.append(outcome) + return original(program, parents, outcome=outcome) + + router.on_mutation_outcome = _spy # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="timeout"): + await router.ainvoke("hello") + + # No on_mutation_outcome call — only the failure-path injection. + assert seen_outcomes == [] + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + async def test_success_then_explicit_outcome_yields_exactly_one_reward( + self, + ) -> None: + """Symmetric to the negative case: a successful ainvoke defers + reward to ``on_mutation_outcome``. After the engine calls it once + per accepted program, the window holds exactly one entry — no + accidental double-injection from a stale failure path.""" + ok = MagicMock() + ok.model_name = "ok" + ok.ainvoke = AsyncMock(return_value=MagicMock()) + ok.with_structured_output = MagicMock(return_value=MagicMock()) + router = BanditModelRouter( + [ok], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + router._tracker = MagicMock() # avoid touching real tracker.track + + await router.ainvoke("hello") + + # Deferred — nothing in window yet. + stats = router.get_bandit_stats() + assert stats["ok"]["total_pulls"] == 1 + assert stats["ok"]["window_size"] == 0 + + # Engine path: parent + accepted child with mutation_model metadata. + parent = Program(code="x=0") + parent.metrics["score"] = 5.0 + child = Program(code="x=1") + child.set_metadata("mutation_model", "ok") + child.metrics["score"] = 7.0 + router.on_mutation_outcome(child, [parent], outcome=MutationOutcome.ACCEPTED) + + stats = router.get_bandit_stats() + assert stats["ok"]["total_pulls"] == 1 + # Exactly one reward — the deferred outcome reward, not a double. + assert stats["ok"]["window_size"] == 1 + + +# --------------------------------------------------------------------------- +# Concurrent dispatch stress: success/failure mix invariants +# --------------------------------------------------------------------------- + + +class TestBanditConcurrentMixedDispatch: + """Audit hypothesis 7: under heavy concurrency, the asyncio task-id + based context propagation in ``_task_model_map`` must not blur + pulls across tasks, and the ledger invariants must hold. + + Each ``ainvoke`` records a pull at ``_select`` time. On failure the + classifier injects exactly one zero reward; on success the reward + is deferred to ``on_mutation_outcome`` (which we do *not* fire in + this test). So after N parallel calls with F failures and S=N-F + successes: + + total_pulls == N + window_size == F (only failure-path injections land in the window) + + A regression where the task-id mapping leaked one task's failure + into another task's success path would break the second assertion. + """ + + async def test_32_concurrent_mixed_success_failure_invariants( + self, + ) -> None: + import random as _rnd + + _rnd.seed(0xC0FFEE) + outcomes = [_rnd.random() < 0.5 for _ in range(32)] # True = failure + expected_failures = sum(outcomes) + + model = MagicMock() + model.model_name = "arm" + model.with_structured_output = MagicMock(return_value=MagicMock()) + + call_idx = -1 + lock = asyncio.Lock() + + async def _ainvoke(*args, **kwargs): # noqa: ANN001 + nonlocal call_idx + async with lock: + call_idx += 1 + will_fail = outcomes[call_idx] + # Yield to the event loop so calls truly interleave. + await asyncio.sleep(0) + if will_fail: + raise RuntimeError("flake") + return MagicMock() + + model.ainvoke = _ainvoke + + router = BanditModelRouter( + [model], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + router._tracker = MagicMock() + + async def _one() -> bool: + try: + await router.ainvoke("x") + return False + except RuntimeError: + return True + + results = await asyncio.wait_for( + asyncio.gather(*[_one() for _ in range(32)]), + timeout=10.0, + ) + observed_failures = sum(results) + + # Sanity: each task observed the outcome that was scheduled for it. + assert observed_failures == expected_failures + + stats = router.get_bandit_stats() + assert stats["arm"]["total_pulls"] == 32 + # Only failures inject — successes defer to on_mutation_outcome. + assert stats["arm"]["window_size"] == expected_failures + + async def test_concurrent_all_failures_keep_ledger_in_step(self) -> None: + """Pure-failure stress: total_pulls == window_size == N.""" + model = MagicMock() + model.model_name = "arm" + model.ainvoke = AsyncMock(side_effect=RuntimeError("dead")) + model.with_structured_output = MagicMock(return_value=MagicMock()) + + router = BanditModelRouter( + [model], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + + async def _one() -> None: + with pytest.raises(RuntimeError): + await router.ainvoke("x") + + await asyncio.wait_for( + asyncio.gather(*[_one() for _ in range(24)]), + timeout=10.0, + ) + + stats = router.get_bandit_stats() + assert stats["arm"]["total_pulls"] == 24 + assert stats["arm"]["window_size"] == 24 + + async def test_concurrent_all_successes_defer_all_rewards(self) -> None: + """Pure-success stress: total_pulls == N, window_size == 0 + because every reward is deferred to ``on_mutation_outcome``.""" + model = MagicMock() + model.model_name = "arm" + model.ainvoke = AsyncMock(return_value=MagicMock()) + model.with_structured_output = MagicMock(return_value=MagicMock()) + + router = BanditModelRouter( + [model], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + router._tracker = MagicMock() + + async def _one() -> None: + await router.ainvoke("x") + + await asyncio.wait_for( + asyncio.gather(*[_one() for _ in range(24)]), + timeout=10.0, + ) + + stats = router.get_bandit_stats() + assert stats["arm"]["total_pulls"] == 24 + assert stats["arm"]["window_size"] == 0 diff --git a/tests/llm/test_agent_bandit_integration.py b/tests/llm/test_agent_bandit_integration.py new file mode 100644 index 00000000..0add1680 --- /dev/null +++ b/tests/llm/test_agent_bandit_integration.py @@ -0,0 +1,418 @@ +"""Regression-lock integration tests: agent layer + ``BanditModelRouter``. + +The bandit-classifier wiring is exercised by unit tests in +``tests/evolution/test_bandit.py`` against the router in isolation. These +tests verify the *agent-to-bandit* contract end to end: when an agent +constructs a structured-output chain on top of a real +``BanditModelRouter`` and the underlying call raises, the bandit's +failure hook must still fire — i.e. no intermediate ``Runnable`` in the +agent layer swallows the exception before the bandit can record a +zero-reward injection. + +A regression in any agent that introduces an internal retry/fallback +wrapper, or that catches the LLM exception before it reaches the bandit +hook, would surface here. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +from langchain_core.messages import HumanMessage +import pytest + +from gigaevo.llm.agents.insights import InsightsAgent +from gigaevo.llm.agents.lineage import LineageAgent +from gigaevo.llm.agents.mutation import MutationAgent, MutationStructuredOutput +from gigaevo.llm.agents.scoring import ScoringAgent +from gigaevo.llm.bandit import BanditModelRouter +from gigaevo.programs.metrics.context import MetricsContext, MetricSpec +from gigaevo.programs.metrics.formatter import MetricsFormatter +from gigaevo.programs.program import Program + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _ctx() -> MetricsContext: + return MetricsContext( + specs={ + "score": MetricSpec( + description="primary", + is_primary=True, + higher_is_better=True, + lower_bound=0.0, + upper_bound=1.0, + sentinel_value=-1.0, + ) + } + ) + + +def _bandit( + *, + ainvoke_side_effect: BaseException | None = None, + invoke_side_effect: BaseException | None = None, + ainvoke_return: Any = None, +) -> tuple[BanditModelRouter, MagicMock]: + """Build a one-arm ``BanditModelRouter`` whose underlying model can be + flaky. + + Returns the router and the underlying mock so tests can override + behaviour per call. ``with_structured_output`` returns the *same* + mock so the inner ``model.ainvoke``/``invoke`` is exercised by the + structured-output dispatch path. + """ + model = MagicMock() + model.model_name = "flaky" + if ainvoke_side_effect is not None: + model.ainvoke = AsyncMock(side_effect=ainvoke_side_effect) + else: + model.ainvoke = AsyncMock(return_value=ainvoke_return) + if invoke_side_effect is not None: + model.invoke = MagicMock(side_effect=invoke_side_effect) + else: + model.invoke = MagicMock(return_value=ainvoke_return) + model.with_structured_output = MagicMock(return_value=model) + + router = BanditModelRouter( + [model], [1.0], fitness_key="score", higher_is_better=True + ) + router._langfuse = None + router._tracker = MagicMock() + return router, model + + +def _program(metrics: dict | None = None, code: str = "def f(): return 0") -> Program: + p = Program(code=code) + if metrics: + p.add_metrics(metrics) + return p + + +# --------------------------------------------------------------------------- +# MutationAgent <-> BanditModelRouter +# --------------------------------------------------------------------------- + + +class TestMutationAgentBanditWiring: + """``MutationAgent.acall_llm`` catches LLM exceptions and turns them + into ``state["error"]`` so the LangGraph chain returns an empty-code + parsed_output rather than aborting the DAG. The bandit's failure + hook must still fire *before* that catch, so the ledger stays in + step (``total_pulls == window_size``) even though the agent layer + swallows the exception for downstream sanity.""" + + @pytest.mark.asyncio + async def test_transport_failure_fires_bandit_hook_through_agent(self) -> None: + router, _model = _bandit(ainvoke_side_effect=RuntimeError("rate limit")) + agent = MutationAgent( + llm=router, + system_prompt="sys", + user_prompt_template="Mutate {count}:\n{parent_blocks}", + mutation_mode="rewrite", + ) + + state = { + "input": [_program()], + "mutation_mode": "rewrite", + "messages": [HumanMessage(content="hi")], + "llm_response": None, + "final_code": "", + "mutation_label": "", + } + result = await agent.acall_llm(state) # type: ignore[arg-type] + + # Agent-side: exception swallowed into state["error"]. + assert result["llm_response"] is None + assert "rate limit" in result["error"] + + # Bandit-side: failure hook fired exactly once → ledger in step. + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + @pytest.mark.asyncio + async def test_repeated_failures_keep_ledger_in_step_through_agent(self) -> None: + router, _model = _bandit(ainvoke_side_effect=RuntimeError("boom")) + agent = MutationAgent( + llm=router, + system_prompt="sys", + user_prompt_template="Mutate {count}:\n{parent_blocks}", + mutation_mode="rewrite", + ) + + for _ in range(5): + state = { + "input": [_program()], + "mutation_mode": "rewrite", + "messages": [HumanMessage(content="hi")], + "llm_response": None, + "final_code": "", + "mutation_label": "", + } + await agent.acall_llm(state) # type: ignore[arg-type] + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 5 + assert stats["flaky"]["window_size"] == 5 + + @pytest.mark.asyncio + async def test_success_path_defers_reward_through_agent(self) -> None: + # On the success path the bandit defers the reward to + # on_mutation_outcome, so total_pulls advances but window_size + # stays at 0. Confirms no over-injection from the agent layer. + # The bandit forces ``include_raw=True`` on the underlying + # ``with_structured_output`` call, so the mock must return the + # langchain dict shape (raw, parsed, parsing_error) — not the + # bare pydantic object. + parsed = MutationStructuredOutput( + archetype="x", + justification="y", + insights_used=[], + code="def f(): return 1", + ) + success_response = { + "raw": MagicMock(), + "parsed": parsed, + "parsing_error": None, + } + router, _model = _bandit(ainvoke_return=success_response) + agent = MutationAgent( + llm=router, + system_prompt="sys", + user_prompt_template="Mutate {count}:\n{parent_blocks}", + mutation_mode="rewrite", + ) + + state = { + "input": [_program()], + "mutation_mode": "rewrite", + "messages": [HumanMessage(content="hi")], + "llm_response": None, + "final_code": "", + "mutation_label": "", + } + await agent.acall_llm(state) # type: ignore[arg-type] + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + # Reward deferred to on_mutation_outcome → no entry in window yet. + assert stats["flaky"]["window_size"] == 0 + + +# --------------------------------------------------------------------------- +# InsightsAgent / LineageAgent / ScoringAgent <-> BanditModelRouter +# --------------------------------------------------------------------------- + + +class TestStructuredAgentsBanditWiring: + """Non-mutation agents inherit ``base.acall_llm`` which does *not* + swallow exceptions. The bandit's failure hook must fire and the + exception must propagate out of ``agent.arun`` so the DAG runner can + discard the program. This locks in that no agent has accidentally + added a retry/fallback wrapper that would swallow the failure + before the hook fires.""" + + def _insights_agent(self, router: BanditModelRouter) -> InsightsAgent: + return InsightsAgent( + llm=router, + system_prompt_template="sys", + user_prompt_template="code={code} metrics={metrics} errors={error_section} max={max_insights}", + max_insights=3, + metrics_formatter=MetricsFormatter(_ctx()), + ) + + def _lineage_agent(self, router: BanditModelRouter) -> LineageAgent: + return LineageAgent( + llm=router, + system_prompt="sys", + user_prompt_template=( + "task={task_description} m={metric_name} d={metric_description} " + "delta={delta} h={higher_is_better_text} interp={delta_interpretation} " + "pe={parent_errors} ce={child_errors} am={additional_metrics} " + "db={diff_blocks} pc={parent_code}" + ), + task_description="t", + metrics_formatter=MetricsFormatter(_ctx()), + ) + + def _scoring_agent(self, router: BanditModelRouter) -> ScoringAgent: + return ScoringAgent( + llm=router, + system_prompt="sys", + user_prompt_template="code={code} trait={trait_description} max={max_score}", + trait_description="novelty", + max_score=1.0, + ) + + @pytest.mark.asyncio + async def test_insights_agent_failure_fires_bandit_hook(self) -> None: + router, _model = _bandit(ainvoke_side_effect=RuntimeError("boom")) + agent = self._insights_agent(router) + + program = _program(metrics={"score": 0.5}) + + # Failure propagates out of agent.arun (base.acall_llm does not + # swallow, parse_response is never reached). + with pytest.raises(RuntimeError, match="boom"): + await agent.arun(program) + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + @pytest.mark.asyncio + async def test_scoring_agent_failure_fires_bandit_hook(self) -> None: + router, _model = _bandit(ainvoke_side_effect=RuntimeError("scoring boom")) + agent = self._scoring_agent(router) + + program = _program() + + with pytest.raises(RuntimeError, match="scoring boom"): + await agent.arun(program) + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + @pytest.mark.asyncio + async def test_lineage_agent_failure_fires_bandit_hook(self) -> None: + router, _model = _bandit(ainvoke_side_effect=RuntimeError("lineage boom")) + agent = self._lineage_agent(router) + + parent = _program(metrics={"score": 0.4}, code="def f(): return 0") + child = _program(metrics={"score": 0.6}, code="def f(): return 1") + + with pytest.raises(RuntimeError, match="lineage boom"): + await agent.arun(parents=[parent], program=child) + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + +# --------------------------------------------------------------------------- +# Documented gap: silent ``parsed=None`` from ``include_raw=True`` chain +# --------------------------------------------------------------------------- + + +class TestStructuredOutputParsingErrorFiresHook: + """``BanditModelRouter.with_structured_output`` passes + ``include_raw=True`` to the underlying langchain wrapper, which makes + a schema-validation failure surface as ``response['parsing_error']`` + with ``parsed=None`` instead of raising. Previously + ``_StructuredOutputRouter._process`` returned ``None`` silently and + the bandit's failure_hook never fired — the pull was recorded but + the reward window never got a matching entry. ``_process`` now + raises ``parsing_error`` so the existing ``try / except`` routes + through the failure_hook.""" + + @pytest.mark.asyncio + async def test_parsing_error_fires_hook_and_propagates(self) -> None: + silent_none = { + "raw": MagicMock(), + "parsed": None, + "parsing_error": ValueError("schema validation failed"), + } + router, _model = _bandit(ainvoke_return=silent_none) + + agent = MutationAgent( + llm=router, + system_prompt="sys", + user_prompt_template="Mutate {count}:\n{parent_blocks}", + mutation_mode="rewrite", + ) + + state = { + "input": [_program()], + "mutation_mode": "rewrite", + "messages": [HumanMessage(content="hi")], + "llm_response": None, + "final_code": "", + "mutation_label": "", + } + result = await agent.acall_llm(state) # type: ignore[arg-type] + + # MutationAgent.acall_llm now sees the raised ValueError and + # records it in state["error"]. The bandit hook fired inside + # _StructuredOutputRouter so the window has the matching entry. + assert result["llm_response"] is None + assert "schema validation failed" in result["error"] + + stats = router.get_bandit_stats() + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 1 + + @pytest.mark.asyncio + async def test_successful_parsed_does_not_raise_or_fire_hook(self) -> None: + # Positive: a clean parse passes through with no hook fire. + parsed_ok = { + "raw": MagicMock(), + "parsed": MagicMock(archetype="rewrite", code="def f(): pass"), + "parsing_error": None, + } + router, _model = _bandit(ainvoke_return=parsed_ok) + + agent = MutationAgent( + llm=router, + system_prompt="sys", + user_prompt_template="Mutate {count}:\n{parent_blocks}", + mutation_mode="rewrite", + ) + + state = { + "input": [_program()], + "mutation_mode": "rewrite", + "messages": [HumanMessage(content="hi")], + "llm_response": None, + "final_code": "", + "mutation_label": "", + } + result = await agent.acall_llm(state) # type: ignore[arg-type] + + assert result["llm_response"] is not None + + stats = router.get_bandit_stats() + # Success defers to on_mutation_outcome; no immediate reward. + assert stats["flaky"]["total_pulls"] == 1 + assert stats["flaky"]["window_size"] == 0 + + @pytest.mark.asyncio + async def test_parsed_none_without_error_passes_through(self) -> None: + # Negative: both parsed and parsing_error are None. Degenerate + # but legal langchain shape (caller asked for structured output + # but the model returned empty). Pass-through with parsed=None; + # no exception, no hook fire. The caller is responsible for + # handling None. + empty = {"raw": MagicMock(), "parsed": None, "parsing_error": None} + router, _model = _bandit(ainvoke_return=empty) + + agent = MutationAgent( + llm=router, + system_prompt="sys", + user_prompt_template="Mutate {count}:\n{parent_blocks}", + mutation_mode="rewrite", + ) + + state = { + "input": [_program()], + "mutation_mode": "rewrite", + "messages": [HumanMessage(content="hi")], + "llm_response": None, + "final_code": "", + "mutation_label": "", + } + result = await agent.acall_llm(state) # type: ignore[arg-type] + + # Without a parsing_error, _process returns None. MutationAgent + # crashes on the subsequent attribute access and records its + # own error, but the bandit ledger semantics differ: this is + # not a parse-error path so no failure_hook fires from + # _StructuredOutputRouter; the agent's own try/except records + # the downstream attribute error. + assert result["llm_response"] is None + # The agent caught AttributeError on the None-returned parsed. + assert result["error"] is not None diff --git a/tests/llm/test_call_outcome.py b/tests/llm/test_call_outcome.py new file mode 100644 index 00000000..1edbf907 --- /dev/null +++ b/tests/llm/test_call_outcome.py @@ -0,0 +1,799 @@ +"""Tests for the LLM call outcome classifier. + +Every classification case below was verified against the actually-installed +provider sources (openai 2.36.0, httpx 0.28.1, langchain-openai 1.2.1, +langchain-core 1.4.0) at the time of writing. If a future SDK release +renames a class, the offending case will fail with a clear message +pointing at the missing fingerprint. +""" + +from __future__ import annotations + +from typing import Final + +import httpx +from langchain_core.exceptions import OutputParserException +from langchain_openai.chat_models._client_utils import StreamChunkTimeoutError +from langchain_openai.chat_models.base import ( + OpenAIAPIContextOverflowError, + OpenAIContextOverflowError, + OpenAIRefusalError, +) +from langchain_openai.middleware.openai_moderation import OpenAIModerationError +import openai +from pydantic import ValidationError +import pytest + +from gigaevo.llm.call_outcome import ( + INVARIANTS, + OUTCOME_ACTION, + ZERO_REWARD_OUTCOMES, + BanditAction, + LLMCallOutcome, + LLMCallResult, + classify_call_result, +) + +_REQ: Final[httpx.Request] = httpx.Request( + "POST", "https://api.openai.com/v1/chat/completions" +) + + +def _resp(status: int, headers: dict[str, str] | None = None) -> httpx.Response: + return httpx.Response(status, headers=headers or {}, request=_REQ) + + +# --------------------------------------------------------------------------- +# Module-load invariants +# --------------------------------------------------------------------------- + + +class TestModuleInvariants: + def test_outcome_action_covers_every_outcome(self) -> None: + assert set(OUTCOME_ACTION) == set(LLMCallOutcome) + + def test_success_is_the_only_defer(self) -> None: + assert OUTCOME_ACTION[LLMCallOutcome.SUCCESS] is BanditAction.DEFER_TO_OUTCOME + for outcome in LLMCallOutcome: + if outcome is LLMCallOutcome.SUCCESS: + continue + assert OUTCOME_ACTION[outcome] is BanditAction.INJECT_ZERO_REWARD, ( + f"{outcome} unexpectedly defers" + ) + + def test_zero_reward_set_excludes_success(self) -> None: + assert LLMCallOutcome.SUCCESS not in ZERO_REWARD_OUTCOMES + assert len(ZERO_REWARD_OUTCOMES) == len(LLMCallOutcome) - 1 + + def test_invariants_documentation_present(self) -> None: + assert len(INVARIANTS) >= 6 + for line in INVARIANTS: + assert isinstance(line, str) and line + + +# --------------------------------------------------------------------------- +# Success path +# --------------------------------------------------------------------------- + + +class TestSuccess: + def test_none_is_success(self) -> None: + result = classify_call_result(None) + assert result.outcome is LLMCallOutcome.SUCCESS + assert result.is_failure is False + assert result.action is BanditAction.DEFER_TO_OUTCOME + assert result.exception_class is None + assert result.http_status is None + assert result.cause_chain == () + + def test_none_carries_model_name(self) -> None: + result = classify_call_result(None, model_name="gpt-4") + assert result.model_name == "gpt-4" + assert result.outcome is LLMCallOutcome.SUCCESS + + +# --------------------------------------------------------------------------- +# openai 2.x APIStatusError subclasses — status_code authoritative +# --------------------------------------------------------------------------- + + +class TestOpenAIStatusErrors: + def test_bad_request_400(self) -> None: + result = classify_call_result( + openai.BadRequestError("x", response=_resp(400), body=None) + ) + assert result.outcome is LLMCallOutcome.BAD_REQUEST + assert result.http_status == 400 + + def test_authentication_error_401(self) -> None: + result = classify_call_result( + openai.AuthenticationError("x", response=_resp(401), body=None) + ) + assert result.outcome is LLMCallOutcome.AUTH_FAILED + assert result.http_status == 401 + + def test_permission_denied_403(self) -> None: + result = classify_call_result( + openai.PermissionDeniedError("x", response=_resp(403), body=None) + ) + assert result.outcome is LLMCallOutcome.AUTH_FAILED + assert result.http_status == 403 + + def test_not_found_404_falls_to_bad_request(self) -> None: + result = classify_call_result( + openai.NotFoundError("x", response=_resp(404), body=None) + ) + assert result.outcome is LLMCallOutcome.BAD_REQUEST + + def test_conflict_409_falls_to_bad_request(self) -> None: + result = classify_call_result( + openai.ConflictError("x", response=_resp(409), body=None) + ) + assert result.outcome is LLMCallOutcome.BAD_REQUEST + + def test_unprocessable_entity_422_falls_to_bad_request(self) -> None: + result = classify_call_result( + openai.UnprocessableEntityError("x", response=_resp(422), body=None) + ) + assert result.outcome is LLMCallOutcome.BAD_REQUEST + + def test_rate_limit_429(self) -> None: + result = classify_call_result( + openai.RateLimitError("x", response=_resp(429), body=None) + ) + assert result.outcome is LLMCallOutcome.RATE_LIMITED + assert result.http_status == 429 + + def test_internal_server_error_500(self) -> None: + result = classify_call_result( + openai.InternalServerError("x", response=_resp(500), body=None) + ) + assert result.outcome is LLMCallOutcome.SERVER_5XX + + def test_unknown_5xx_classifies_as_server(self) -> None: + result = classify_call_result( + openai.APIStatusError("x", response=_resp(503), body=None) + ) + assert result.outcome is LLMCallOutcome.SERVER_5XX + + +# --------------------------------------------------------------------------- +# openai connection/timeout (no status_code) +# --------------------------------------------------------------------------- + + +class TestOpenAIConnectionAndTimeout: + def test_api_connection_error(self) -> None: + result = classify_call_result(openai.APIConnectionError(request=_REQ)) + assert result.outcome is LLMCallOutcome.NETWORK_ERROR + assert result.http_status is None + + def test_api_timeout_error(self) -> None: + result = classify_call_result(openai.APITimeoutError(request=_REQ)) + assert result.outcome is LLMCallOutcome.TIMEOUT + assert result.http_status is None + + +# --------------------------------------------------------------------------- +# httpx layer +# --------------------------------------------------------------------------- + + +class TestHttpxExceptions: + @pytest.mark.parametrize( + "exc_factory", + [ + lambda: httpx.ConnectTimeout("boom", request=_REQ), + lambda: httpx.ReadTimeout("boom", request=_REQ), + lambda: httpx.WriteTimeout("boom", request=_REQ), + lambda: httpx.PoolTimeout("boom", request=_REQ), + ], + ) + def test_timeouts(self, exc_factory) -> None: + result = classify_call_result(exc_factory()) + assert result.outcome is LLMCallOutcome.TIMEOUT + + @pytest.mark.parametrize( + "exc_factory", + [ + lambda: httpx.ConnectError("x", request=_REQ), + lambda: httpx.ReadError("x", request=_REQ), + lambda: httpx.WriteError("x", request=_REQ), + lambda: httpx.CloseError("x", request=_REQ), + lambda: httpx.RemoteProtocolError("x", request=_REQ), + lambda: httpx.LocalProtocolError("x", request=_REQ), + lambda: httpx.ProxyError("x", request=_REQ), + lambda: httpx.UnsupportedProtocol("x", request=_REQ), + ], + ) + def test_network_errors(self, exc_factory) -> None: + result = classify_call_result(exc_factory()) + assert result.outcome is LLMCallOutcome.NETWORK_ERROR + + def test_decoding_error_is_parse_failed(self) -> None: + result = classify_call_result(httpx.DecodingError("bad json", request=_REQ)) + assert result.outcome is LLMCallOutcome.PARSE_FAILED + + +# --------------------------------------------------------------------------- +# Context overflow (langchain-core base + langchain-openai specializations) +# --------------------------------------------------------------------------- + + +class TestContextOverflow: + def test_openai_context_overflow_error(self) -> None: + # OpenAIContextOverflowError inherits from BadRequestError(400) AND + # ContextOverflowError; MRO match must override the 400→BAD_REQUEST + # classification. + result = classify_call_result( + OpenAIContextOverflowError("overflow", response=_resp(400), body=None) + ) + assert result.outcome is LLMCallOutcome.CONTEXT_OVERFLOW + + def test_openai_api_context_overflow_error(self) -> None: + class _Fake(OpenAIAPIContextOverflowError): + def __init__(self) -> None: + pass + + assert classify_call_result(_Fake()).outcome is LLMCallOutcome.CONTEXT_OVERFLOW + + def test_text_marker_fallback(self) -> None: + # A generic ValueError whose message advertises context overflow + # still classifies correctly — covers providers that raise raw + # ValueError without a dedicated class. + result = classify_call_result( + ValueError("Your prompt is too long for this model") + ) + assert result.outcome is LLMCallOutcome.CONTEXT_OVERFLOW + + def test_text_marker_on_400_overrides_bad_request(self) -> None: + # A real 400 whose message names overflow should classify as + # CONTEXT_OVERFLOW, not BAD_REQUEST. + result = classify_call_result( + openai.BadRequestError( + "This model's maximum context length is 8192 tokens", + response=_resp(400), + body=None, + ) + ) + assert result.outcome is LLMCallOutcome.CONTEXT_OVERFLOW + + +# --------------------------------------------------------------------------- +# Output truncated / content filter +# --------------------------------------------------------------------------- + + +class TestOutputAndContentSignals: + def test_openai_refusal_error(self) -> None: + result = classify_call_result(OpenAIRefusalError("refused")) + assert result.outcome is LLMCallOutcome.CONTENT_FILTERED + + def test_openai_moderation_error(self) -> None: + # The real constructor takes keyword args we can't easily fake; + # subclass to bypass. + class _FakeModeration(OpenAIModerationError): + def __init__(self) -> None: + pass + + result = classify_call_result(_FakeModeration()) + assert result.outcome is LLMCallOutcome.CONTENT_FILTERED + + def test_length_finish_reason(self) -> None: + # LengthFinishReasonError needs a ChatCompletion to construct; subclass. + from openai import LengthFinishReasonError + + class _FakeLength(LengthFinishReasonError): + def __init__(self) -> None: + pass + + result = classify_call_result(_FakeLength()) + assert result.outcome is LLMCallOutcome.OUTPUT_TRUNCATED + + def test_content_filter_finish_reason(self) -> None: + from openai import ContentFilterFinishReasonError + + result = classify_call_result(ContentFilterFinishReasonError()) + assert result.outcome is LLMCallOutcome.CONTENT_FILTERED + + +# --------------------------------------------------------------------------- +# Built-in / asyncio fallbacks +# --------------------------------------------------------------------------- + + +class TestBuiltinFallbacks: + def test_asyncio_timeout_error_is_builtin_timeout(self) -> None: + # Since Python 3.11, asyncio.TimeoutError is an alias for TimeoutError; + # the classifier handles either route via the builtin hierarchy. + import asyncio + + assert asyncio.TimeoutError is TimeoutError + result = classify_call_result(TimeoutError()) + assert result.outcome is LLMCallOutcome.TIMEOUT + + def test_stream_chunk_timeout_via_builtin_hierarchy(self) -> None: + # StreamChunkTimeoutError extends builtin TimeoutError. + result = classify_call_result(StreamChunkTimeoutError(30.0)) + assert result.outcome is LLMCallOutcome.TIMEOUT + + def test_builtin_connection_refused(self) -> None: + result = classify_call_result(ConnectionRefusedError("refused")) + assert result.outcome is LLMCallOutcome.NETWORK_ERROR + + def test_builtin_connection_reset(self) -> None: + result = classify_call_result(ConnectionResetError("reset")) + assert result.outcome is LLMCallOutcome.NETWORK_ERROR + + +# --------------------------------------------------------------------------- +# Parser failures +# --------------------------------------------------------------------------- + + +class TestParseFailures: + def test_output_parser_exception_alone(self) -> None: + result = classify_call_result(OutputParserException("bad parse")) + assert result.outcome is LLMCallOutcome.PARSE_FAILED + + def test_pydantic_validation_error(self) -> None: + from pydantic import BaseModel + + class _M(BaseModel): + x: int + + with pytest.raises(ValidationError) as exc_info: + _M(x="not an int") # type: ignore[arg-type] + result = classify_call_result(exc_info.value) + assert result.outcome is LLMCallOutcome.PARSE_FAILED + + def test_json_decode_error(self) -> None: + import json + + with pytest.raises(json.JSONDecodeError) as exc_info: + json.loads("not json") + result = classify_call_result(exc_info.value) + assert result.outcome is LLMCallOutcome.PARSE_FAILED + + +# --------------------------------------------------------------------------- +# Cause-chain priority +# --------------------------------------------------------------------------- + + +class TestCauseChainPriority: + def test_parser_wrapping_rate_limit_classifies_as_rate_limited(self) -> None: + # langchain wraps API errors in OutputParserException; bandit cares + # about the root cause, not the wrapper. + inner = openai.RateLimitError("rate", response=_resp(429), body=None) + outer = OutputParserException("parser saw bad output") + outer.__cause__ = inner + result = classify_call_result(outer) + assert result.outcome is LLMCallOutcome.RATE_LIMITED + assert "OutputParserException" in result.cause_chain + assert "RateLimitError" in result.cause_chain + + def test_chain_uses_context_when_cause_missing(self) -> None: + # Implicit context (during-handling-of) is followed when __cause__ is None. + inner = openai.AuthenticationError("no key", response=_resp(401), body=None) + outer = RuntimeError("wrap") + outer.__context__ = inner + result = classify_call_result(outer) + assert result.outcome is LLMCallOutcome.AUTH_FAILED + + def test_cause_priority_picks_more_actionable_outcome(self) -> None: + # If chain has both TIMEOUT and RATE_LIMITED, RATE_LIMITED wins + # because back-off is more actionable than retry. + rl = openai.RateLimitError("rate", response=_resp(429), body=None) + to = openai.APITimeoutError(request=_REQ) + to.__cause__ = rl + result = classify_call_result(to) + assert result.outcome is LLMCallOutcome.RATE_LIMITED + + def test_self_referential_cycle_terminates(self) -> None: + e1 = ValueError("a") + e2 = ValueError("b") + e1.__cause__ = e2 + e2.__cause__ = e1 + result = classify_call_result(e1) + # No crash, chain length bounded. + assert len(result.cause_chain) <= 8 + assert result.outcome is LLMCallOutcome.OTHER_EXCEPTION + + +# --------------------------------------------------------------------------- +# Retry-After extraction +# --------------------------------------------------------------------------- + + +class TestRetryAfter: + def test_retry_after_integer_seconds(self) -> None: + result = classify_call_result( + openai.RateLimitError( + "rate", response=_resp(429, {"retry-after": "17"}), body=None + ) + ) + assert result.retry_after_seconds == 17.0 + + def test_retry_after_float_seconds(self) -> None: + result = classify_call_result( + openai.RateLimitError( + "rate", response=_resp(429, {"retry-after": "1.5"}), body=None + ) + ) + assert result.retry_after_seconds == 1.5 + + def test_retry_after_missing(self) -> None: + result = classify_call_result( + openai.RateLimitError("rate", response=_resp(429), body=None) + ) + assert result.retry_after_seconds is None + + def test_retry_after_http_date_not_parsed(self) -> None: + # HTTP-date form is not supported — caller treats None as "no hint". + result = classify_call_result( + openai.RateLimitError( + "rate", + response=_resp(429, {"retry-after": "Wed, 21 Oct 2026 07:28:00 GMT"}), + body=None, + ) + ) + assert result.retry_after_seconds is None + + def test_retry_after_negative_rejected(self) -> None: + result = classify_call_result( + openai.RateLimitError( + "rate", response=_resp(429, {"retry-after": "-5"}), body=None + ) + ) + assert result.retry_after_seconds is None + + +# --------------------------------------------------------------------------- +# Schema invariants +# --------------------------------------------------------------------------- + + +class TestSchemaInvariants: + def test_result_is_frozen(self) -> None: + result = classify_call_result(None) + with pytest.raises(ValidationError): + result.outcome = LLMCallOutcome.TIMEOUT # type: ignore[misc] + + def test_extra_fields_forbidden(self) -> None: + with pytest.raises(ValidationError): + LLMCallResult(outcome=LLMCallOutcome.SUCCESS, bogus="x") # type: ignore[call-arg] + + def test_http_status_range_enforced(self) -> None: + with pytest.raises(ValidationError): + LLMCallResult(outcome=LLMCallOutcome.SERVER_5XX, http_status=42) + + def test_action_property_returns_mapping_value(self) -> None: + assert classify_call_result(None).action is BanditAction.DEFER_TO_OUTCOME + assert ( + classify_call_result(RuntimeError("x")).action + is BanditAction.INJECT_ZERO_REWARD + ) + + def test_is_failure_property(self) -> None: + assert classify_call_result(None).is_failure is False + assert classify_call_result(RuntimeError("x")).is_failure is True + + +# --------------------------------------------------------------------------- +# Defensive / hostile inputs +# --------------------------------------------------------------------------- + + +class TestHostileInputs: + def test_bool_masquerading_as_status_code_rejected(self) -> None: + # ``True`` is an int in Python; the extractor must refuse it. + class _SneakyBool(Exception): + status_code = True # type: ignore[assignment] + + result = classify_call_result(_SneakyBool("x")) + assert result.http_status is None + + def test_negative_status_code_rejected(self) -> None: + class _Negative(Exception): + status_code = -1 + + result = classify_call_result(_Negative("x")) + assert result.http_status is None + + def test_str_exception_that_raises(self) -> None: + class _BadStr(Exception): + def __str__(self) -> str: + raise RuntimeError("__str__ blew up") + + # Must not crash; message ends up None. + result = classify_call_result(_BadStr()) + assert result.message is None + assert result.outcome is LLMCallOutcome.OTHER_EXCEPTION + + def test_message_pass_through_full_length(self) -> None: + # Provider-side response is already bounded by max_tokens; we never + # second-guess it. Whatever ``str(exc)`` returns lands verbatim. + long = "y" * 5000 + result = classify_call_result(RuntimeError(long)) + assert result.message == long + + def test_unknown_runtime_error_is_other(self) -> None: + assert ( + classify_call_result(RuntimeError("mystery")).outcome + is LLMCallOutcome.OTHER_EXCEPTION + ) + + +# --------------------------------------------------------------------------- +# Audit-discovered defects: regression guards +# --------------------------------------------------------------------------- + + +class TestSurrogateInMessage: + """The classifier returns ``LLMCallResult.message`` as ``str(exc)``. A + provider exception text containing a UTF-16 surrogate code point would + otherwise survive into the result and crash any downstream UTF-8 + encoder or JSON serializer (``orjson``, ``model_dump_json``, asyncpg + TEXT write). ``_safe_message`` replaces surrogates with U+FFFD.""" + + def test_lone_high_surrogate_replaced(self) -> None: + result = classify_call_result(RuntimeError("a\ud83dz")) + assert result.message is not None + assert "\ud83d" not in result.message + result.message.encode("utf-8") # must not raise + + def test_lone_low_surrogate_replaced(self) -> None: + result = classify_call_result(RuntimeError("a\udc00z")) + assert result.message is not None + result.message.encode("utf-8") + + def test_adjacent_surrogate_codepoints_both_replaced(self) -> None: + # Two distinct codepoints, not a valid astral pair in Python str. + result = classify_call_result(RuntimeError(chr(0xD800) + chr(0xDC00))) + assert result.message is not None + for ch in result.message: + cp = ord(ch) + assert not (0xD800 <= cp <= 0xDFFF), f"surrogate U+{cp:04X} survived" + result.message.encode("utf-8") + + def test_real_astral_emoji_preserved(self) -> None: + result = classify_call_result(RuntimeError("a😀z")) + assert result.message == "a😀z" + + def test_model_dump_json_succeeds_after_classify(self) -> None: + result = classify_call_result(RuntimeError("a\ud83dz")) + # Must not raise PydanticSerializationError / UnicodeEncodeError. + result.model_dump_json() + + +class TestRetryAfterNonFiniteRejected: + """``Retry-After: inf`` or any non-finite parse poisons the bandit's + sleep budget and diverges from telemetry (JSON emits ``null`` for + non-finite floats, in-memory readers see ``math.inf``). + ``_extract_retry_after`` must return ``None``.""" + + @pytest.mark.parametrize( + "header_value", + ["inf", "Infinity", "-inf", "nan", "NaN", "1" + "0" * 400], + ) + def test_non_finite_header_rejected(self, header_value: str) -> None: + result = classify_call_result( + openai.RateLimitError( + "rate", response=_resp(429, {"retry-after": header_value}), body=None + ) + ) + assert result.retry_after_seconds is None + assert result.outcome is LLMCallOutcome.RATE_LIMITED # outcome unaffected + + def test_finite_header_still_accepted(self) -> None: + result = classify_call_result( + openai.RateLimitError( + "rate", response=_resp(429, {"retry-after": "30"}), body=None + ) + ) + assert result.retry_after_seconds == 30.0 + + def test_large_finite_header_passes_through(self) -> None: + # The classifier parses; capping is the consumer's policy choice. + result = classify_call_result( + openai.RateLimitError( + "rate", + response=_resp(429, {"retry-after": "100000"}), + body=None, + ) + ) + assert result.retry_after_seconds == 100000.0 + + +class TestHostileClassMetadata: + """``type(exc).__module__`` and ``__name__`` can be set to non-str + values via metaclass shenanigans (legal in Python). The classifier + must coerce defensively and never raise ``ValidationError`` from + inside its own body — that would violate the documented totality + contract.""" + + def test_int_module_does_not_break_totality(self) -> None: + class _Weird(Exception): + pass + + _Weird.__module__ = 42 # type: ignore[assignment] + result = classify_call_result(_Weird("x")) + # Coerced via str() — exception_module is "42" (string), not int. + assert result.exception_module == "42" + assert result.exception_class == "_Weird" + assert result.outcome is LLMCallOutcome.OTHER_EXCEPTION + + def test_none_module_handled(self) -> None: + class _Weird(Exception): + pass + + _Weird.__module__ = None # type: ignore[assignment] + result = classify_call_result(_Weird("x")) + assert result.exception_module is None + assert result.exception_class == "_Weird" + + # The metaclass-raises-on-__module__ scenario is covered by + # TestTotalityWithHostileExceptionPlumbing.test_hostile_mro_via_metaclass_does_not_propagate. + + +class TestTotalityWithHostileExceptionPlumbing: + """``classify_call_result`` is documented total: every constructible + ``BaseException`` returns an ``LLMCallResult`` without raising. Python + permits class authors to override ``__cause__`` / ``__context__`` / + ``__mro__`` with descriptors that raise non-AttributeError. The + classifier touches these attributes on whatever object the caller + hands us, so it must defend against the hostile path.""" + + def test_hostile_cause_property_does_not_propagate(self) -> None: + class _HostileCause(Exception): + @property + def __cause__(self): # type: ignore[override] + raise RuntimeError("evil cause descriptor") + + # Must not raise — totality contract. + result = classify_call_result(_HostileCause("x")) + assert result.outcome is LLMCallOutcome.OTHER_EXCEPTION + assert result.exception_class == "_HostileCause" + + def test_hostile_context_property_does_not_propagate(self) -> None: + class _HostileContext(Exception): + @property + def __context__(self): # type: ignore[override] + raise RuntimeError("evil context descriptor") + + result = classify_call_result(_HostileContext("x")) + assert result.outcome is LLMCallOutcome.OTHER_EXCEPTION + assert result.exception_class == "_HostileContext" + + def test_hostile_mro_via_metaclass_does_not_propagate(self) -> None: + class _BadMeta(type): + @property + def __mro__(cls): # type: ignore[override] + raise RuntimeError("evil mro descriptor") + + class _E(Exception, metaclass=_BadMeta): + pass + + result = classify_call_result(_E("x")) + # No MRO available -> all class-fingerprint branches skipped; falls + # through to OTHER_EXCEPTION (or a built-in isinstance match). + assert result.outcome is LLMCallOutcome.OTHER_EXCEPTION + + def test_base_exception_subclasses_total(self) -> None: + # The signature is ``BaseException | None``; KeyboardInterrupt / + # SystemExit / GeneratorExit are constructible BaseExceptions and + # must not crash the classifier. + for cls in (KeyboardInterrupt, SystemExit, GeneratorExit): + result = classify_call_result(cls("x")) + assert result.outcome is LLMCallOutcome.OTHER_EXCEPTION + assert result.exception_class == cls.__name__ + + +# --------------------------------------------------------------------------- +# litellm fingerprint coverage +# --------------------------------------------------------------------------- +# Rather than depend on the heavy litellm import in the test process, we +# synthesize exception classes whose name + MRO shape mirror the real +# litellm exceptions verified at audit time: +# litellm.ContextWindowExceededError → BadRequestError (status_code=400) +# litellm.ContentPolicyViolationError → BadRequestError (status_code=400) +# Without the rule-table additions in this commit, both classify as +# BAD_REQUEST because the 400 status wins. + + +class _LiteLLMContextWindowExceededError(Exception): + """Shape-equivalent stand-in for litellm.ContextWindowExceededError. + + The classifier matches on the leaf class name via the MRO walk, so the + real class name on the synthesized type is what counts for the test. + Status code 400 is set on the instance to mirror the litellm class-level + default.""" + + +_LiteLLMContextWindowExceededError.__name__ = "ContextWindowExceededError" + + +class _LiteLLMContentPolicyViolationError(Exception): + pass + + +_LiteLLMContentPolicyViolationError.__name__ = "ContentPolicyViolationError" + + +class TestLiteLLMRuleCoverage: + def test_context_window_exceeded_classifies_as_context_overflow(self) -> None: + exc = _LiteLLMContextWindowExceededError("ctx window exceeded") + exc.status_code = 400 # type: ignore[attr-defined] + result = classify_call_result(exc) + # The MRO-based override beats the 400 status fallback. + assert result.outcome is LLMCallOutcome.CONTEXT_OVERFLOW + assert result.http_status == 400 + + def test_content_policy_violation_classifies_as_content_filtered(self) -> None: + exc = _LiteLLMContentPolicyViolationError("policy violation") + exc.status_code = 400 # type: ignore[attr-defined] + result = classify_call_result(exc) + assert result.outcome is LLMCallOutcome.CONTENT_FILTERED + assert result.http_status == 400 + + +# --------------------------------------------------------------------------- +# PR-B documented invariants +# --------------------------------------------------------------------------- +# These pin the contract PR-B will rely on (cause_chain ordering, +# retry_after_seconds tri-state, action lookup for every outcome). + + +class TestPRBContract: + def test_cause_chain_starts_with_surface_exception(self) -> None: + try: + try: + raise ValueError("inner") + except ValueError as inner: + raise RuntimeError("outer") from inner + except RuntimeError as surface: + result = classify_call_result(surface) + # Index 0 must be the SURFACE exception's class name; the cause + # walk extends outward toward the root cause. + assert result.cause_chain[0] == "RuntimeError" + assert result.cause_chain[-1] == "ValueError" + + def test_retry_after_zero_is_distinct_from_none(self) -> None: + zero = openai.RateLimitError( + "x", response=_resp(429, {"retry-after": "0"}), body=None + ) + absent = openai.RateLimitError("x", response=_resp(429), body=None) + zero_result = classify_call_result(zero) + absent_result = classify_call_result(absent) + # ``0.0`` means "retry immediately"; ``None`` means "no hint". + # PR-B must distinguish these two. + assert zero_result.retry_after_seconds == 0.0 + assert absent_result.retry_after_seconds is None + + def test_action_lookup_for_every_outcome(self) -> None: + # PR-B will call ``result.action`` after ``classify_call_result``; + # the property must resolve for every enum member. + for outcome in LLMCallOutcome: + stub = LLMCallResult(outcome=outcome) + assert stub.action in { + BanditAction.DEFER_TO_OUTCOME, + BanditAction.INJECT_ZERO_REWARD, + } + + def test_json_roundtrip_preserves_pr_b_relevant_fields(self) -> None: + exc = openai.RateLimitError( + "rate limit hit", + response=_resp(429, {"retry-after": "42"}), + body=None, + ) + result = classify_call_result(exc, model_name="gpt-4o-2024-08-06") + # Round-trip via the wire format PR-B telemetry will use. + encoded = result.model_dump_json() + decoded = LLMCallResult.model_validate_json(encoded) + assert decoded == result + assert decoded.outcome is LLMCallOutcome.RATE_LIMITED + assert decoded.http_status == 429 + assert decoded.retry_after_seconds == 42.0 + assert decoded.model_name == "gpt-4o-2024-08-06" + assert decoded.action is BanditAction.INJECT_ZERO_REWARD diff --git a/tests/llm/test_llm_routing.py b/tests/llm/test_llm_routing.py index a6714bb6..25eba543 100644 --- a/tests/llm/test_llm_routing.py +++ b/tests/llm/test_llm_routing.py @@ -111,6 +111,107 @@ def test_from_response_usage_key_fallback(self): assert usage.total == 75 +# --------------------------------------------------------------------------- +# Hostile-payload hardening — provider returns non-dict / non-int values +# +# The bandit's ``_safe_track`` wrapper swallows AttributeError from +# malformed payloads, but direct callers in ``gigaevo/llm/models.py`` +# (``MultiModelRouter.invoke``) have no such guard. Harden the extractor +# itself so it always returns either a valid ``TokenUsage`` or ``None``. +# --------------------------------------------------------------------------- + + +class TestTokenUsageHostilePayloads: + def test_completion_tokens_details_as_string(self): + # Positive: pre-fix this raised AttributeError ('str' has no get). + resp = MagicMock() + resp.response_metadata = { + "token_usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "completion_tokens_details": "NOT_A_DICT", + } + } + usage = TokenUsage.from_response(resp) + assert usage is not None + assert usage.context == 100 + assert usage.reasoning == 0 # ignored — wrong shape + + def test_usage_as_string(self): + # Negative: hostile provider stuffs a JSON-encoded string in place + # of the dict. Should return None, not crash. + resp = MagicMock() + resp.response_metadata = {"token_usage": "broken"} + assert TokenUsage.from_response(resp) is None + + def test_response_metadata_as_string(self): + # Negative: response_metadata itself wrong type. + resp = MagicMock() + resp.response_metadata = "not-a-dict" + assert TokenUsage.from_response(resp) is None + + def test_string_token_counts_coerced(self): + # Positive: provider returns string-encoded counts (some HTTP + # clients leave numbers as text). Coerced to int. + resp = MagicMock() + resp.response_metadata = { + "token_usage": { + "prompt_tokens": "100", + "completion_tokens": "50", + "total_tokens": "150", + } + } + usage = TokenUsage.from_response(resp) + assert usage.context == 100 + assert usage.generated == 50 + assert usage.total == 150 + + def test_none_token_counts_default_to_zero(self): + # Negative: missing/None counts must not raise on the int field + # validator; default to 0. + resp = MagicMock() + resp.response_metadata = { + "token_usage": { + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None, + } + } + usage = TokenUsage.from_response(resp) + assert usage.context == 0 + assert usage.generated == 0 + assert usage.total == 0 + + def test_float_token_counts_truncated(self): + # Positive: some providers return floats (rare). + resp = MagicMock() + resp.response_metadata = { + "token_usage": { + "prompt_tokens": 100.5, + "completion_tokens": 50.0, + "total_tokens": 150.5, + } + } + usage = TokenUsage.from_response(resp) + assert usage.context == 100 + assert usage.generated == 50 + + def test_garbage_string_count_defaults_to_zero(self): + # Negative: a string that can't parse as int defaults to 0. + resp = MagicMock() + resp.response_metadata = { + "token_usage": { + "prompt_tokens": "not-a-number", + "completion_tokens": 50, + "total_tokens": 150, + } + } + usage = TokenUsage.from_response(resp) + assert usage.context == 0 + assert usage.generated == 50 + + # --------------------------------------------------------------------------- # TokenTracker # ---------------------------------------------------------------------------