Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions gigaevo/evolution/bus/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from abc import ABC, abstractmethod
import json
import re
from typing import TYPE_CHECKING, Any

from loguru import logger
Expand All @@ -18,6 +19,45 @@
from gigaevo.evolution.bus.topology import Topology


# Lone UTF-16 surrogate code points (U+D800..U+DFFF) survive ``str(exc)`` and
# ``json.dumps`` (stdlib escapes them as ``\uD800`` literals) but the receiver
# eventually round-trips the payload back through ``gigaevo.utils.json.dumps``
# (= orjson) when it persists the migrated Program to Redis. orjson rejects
# surrogates with ``TypeError: str is not valid UTF-8: surrogates not allowed``,
# which would crash the migration handler mid-restore. We scrub them at the
# publish boundary so every consumer downstream sees an orjson-safe payload.
_SURROGATE_RE = re.compile(r"[\ud800-\udfff]")


def _scrub_str(value: str) -> str:
"""Replace lone UTF-16 surrogates with U+FFFD in a single ``str`` leaf.

Identity on surrogate-free strings; idempotent (``f(f(x)) == f(x)``).
"""
if not _SURROGATE_RE.search(value):
return value
return _SURROGATE_RE.sub("�", value)


def _scrub_surrogates(value: Any) -> Any:
"""Recursively replace lone UTF-16 surrogate code points with U+FFFD.

Walks ``dict``/``list``/``tuple`` containers and rewrites every ``str``
leaf. Non-string leaves pass through unchanged (identity). The transform
is idempotent: ``f(f(x)) == f(x)`` for every input. Cycle-free by
construction — ``MigrantEnvelope.program_data`` is the output of
``Program.to_dict()`` which is always a JSON-shaped tree.
"""
if isinstance(value, str):
return _scrub_str(value)
if isinstance(value, dict):
return {k: _scrub_surrogates(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
scrubbed = [_scrub_surrogates(v) for v in value]
return type(value)(scrubbed) if isinstance(value, tuple) else scrubbed
return value


class MigrantEnvelope(BaseModel):
"""Wire format for a program migrating between runs."""

Expand All @@ -28,10 +68,16 @@ class MigrantEnvelope(BaseModel):
generation: int

def to_stream_fields(self) -> dict[str, str]:
# Scrub UTF-16 surrogates from every str leaf before serialization.
# stdlib ``json.dumps`` would otherwise escape them as ``\uD800``
# literals — legal in JSON but rejected by orjson when the receiver
# restores the Program and tries to persist it through
# ``gigaevo.utils.json.dumps`` (TypeError: surrogates not allowed).
# See module docstring for the failure-mode rationale.
return {
"source_run_id": self.source_run_id,
"program_id": self.program_id,
"program_data": json.dumps(self.program_data),
"source_run_id": _scrub_str(self.source_run_id),
"program_id": _scrub_str(self.program_id),
"program_data": json.dumps(_scrub_surrogates(self.program_data)),
"published_at": str(self.published_at),
"generation": str(self.generation),
}
Expand Down
184 changes: 182 additions & 2 deletions gigaevo/llm/bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,25 @@
from __future__ import annotations

from collections import deque
from collections.abc import AsyncIterator, Iterator
from dataclasses import dataclass, field
from enum import Enum
import math
from typing import TYPE_CHECKING, Any

from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from loguru import logger
import numpy as np

from gigaevo.llm.models import MultiModelRouter, _StructuredOutputRouter
from gigaevo.llm.call_outcome import BanditAction, classify_call_result
from gigaevo.llm.models import (
MultiModelRouter,
_remember_selected_model,
_StructuredOutputRouter,
)
from gigaevo.utils.trackers.base import LogWriter

if TYPE_CHECKING:
Expand Down Expand Up @@ -263,15 +272,181 @@ def __init__(
# -- selection ----------------------------------------------------------

def _select(self) -> tuple[ChatOpenAI, str]:
"""Select a model via UCB1 and record the pull."""
"""Select a model via UCB1 and record the pull.

Also writes the selected name into the ``_selected_model`` ContextVar
so downstream consumers (``BaseAgent.acall_llm``, ``MutationAgent``,
``InsightsAgent``, ...) can read it via ``get_selected_model()``.
``MultiModelRouter._select`` does this; the bandit override must
preserve the contract or ``state['metadata']['model_used']`` would
carry whatever value the previous non-bandit selection happened to
leave behind in the ContextVar.
"""
name = self._bandit.select()
self._bandit.record_pull(name)
_remember_selected_model(name)
tid = self._current_task_id()
if tid is not None:
self._task_model_map[tid] = name
idx = self.model_names.index(name)
return self.models[idx], name

# -- dispatch -----------------------------------------------------------

def _inject_failure_reward(self, exc: BaseException, arm_name: str) -> None:
"""Classify a failed LLM call and dispatch on its bandit action.

``_select`` records the pull before the LLM call, so without this
hook a failure would inflate ``total_pulls`` for the arm with no
matching window entry — the UCB1 confidence term shrinks for that
arm and the bandit underexplores flaky models. The classifier maps
the exception to an outcome whose ``OUTCOME_ACTION`` is currently
``INJECT_ZERO_REWARD`` for every failure variant; we normalize a
zero reward and append it to the arm's window.
"""
result = classify_call_result(exc, model_name=arm_name)
if result.action is BanditAction.DEFER_TO_OUTCOME:
# SUCCESS — never reachable here (we are inside the except),
# but the action lookup keeps the contract honest.
return
if arm_name not in self._bandit.arms:
# Mirrors the on_mutation_outcome guard: an unknown arm name
# would otherwise raise KeyError inside update_reward.
logger.debug(
"[BanditModelRouter] Skipping zero-reward injection for "
"unknown arm {!r} (outcome={})",
arm_name,
result.outcome.value,
)
return
normalized = self._reward_normalizer.normalize(0.0)
self._bandit.update_reward(arm_name, normalized)
logger.debug(
"[BanditModelRouter] Zero reward injected for {} | outcome={} exception={}",
arm_name,
result.outcome.value,
result.exception_class,
)

def _safe_inject_failure_reward(self, exc: BaseException, arm_name: str) -> None:
"""Best-effort wrapper around ``_inject_failure_reward``.

Mirrors ``_StructuredOutputRouter._maybe_fire_failure_hook``: the
ledger-symmetry update is observability-only and must never mask the
original LLM exception (or its traceback). Any error inside the hook
is swallowed with a debug log so ``raise`` at the call site still
re-raises the real failure.
"""
try:
self._inject_failure_reward(exc, arm_name)
except Exception as hook_exc: # noqa: BLE001 — observability-only
logger.debug(
"[BanditModelRouter] Failure-reward injection itself raised "
"for arm {!r}: {!r}. Original exception preserved.",
arm_name,
hook_exc,
)

def _safe_track(self, response: Any, name: str) -> None:
"""Best-effort token-tracker call.

``self._tracker.track`` reads ``response.usage_metadata`` and writes
to a ``LogWriter``. A telemetry-side bug (malformed token_usage from
a hostile provider, broken writer, etc.) must not propagate to the
caller — the LLM call already succeeded, and on the bandit success
path the reward is *deferred* to ``on_mutation_outcome``, so an
exception here would both lose the response AND leave the bandit
unable to associate the deferred reward (the caller never returns).
"""
try:
self._tracker.track(response, name)
except Exception as track_exc: # noqa: BLE001 — telemetry only
logger.debug(
"[BanditModelRouter] Token tracking failed for arm {!r}: {!r}. "
"LLM response preserved; reward still deferred to "
"on_mutation_outcome.",
name,
track_exc,
)

def invoke(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> BaseMessage:
model, name = self._select()
try:
response = model.invoke(input, self._config(config, name), **kwargs)
except BaseException as exc:
self._safe_inject_failure_reward(exc, name)
raise
self._safe_track(response, name)
return response

async def ainvoke(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> BaseMessage:
model, name = self._select()
try:
response = await model.ainvoke(input, self._config(config, name), **kwargs)
except BaseException as exc:
self._safe_inject_failure_reward(exc, name)
raise
self._safe_track(response, name)
return response

def stream(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Iterator[BaseMessage]:
"""Streaming counterpart to :meth:`invoke` with ledger-symmetry guard.

``MultiModelRouter.stream`` records the pull (via ``_select``) and
then yields from ``model.stream`` without try/except. A mid-stream
failure would leave ``total_pulls`` and the reward window out of
step exactly as the unwrapped ``invoke`` path used to. We mirror
the ``invoke`` contract: classify any exception via the bandit
hook, then re-raise.
"""
model, name = self._select()
last = None
try:
for chunk in model.stream(input, self._config(config, name), **kwargs):
last = chunk
yield chunk
except BaseException as exc:
self._safe_inject_failure_reward(exc, name)
raise
if last is not None:
self._safe_track(last, name)

async def astream(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessage]:
"""Async streaming counterpart to :meth:`ainvoke`. See :meth:`stream`."""
model, name = self._select()
last = None
try:
async for chunk in model.astream(
input, self._config(config, name), **kwargs
):
last = chunk
yield chunk
except BaseException as exc:
self._safe_inject_failure_reward(exc, name)
raise
if last is not None:
self._safe_track(last, name)

# -- mutation outcome ---------------------------------------------------

def on_mutation_outcome(
Expand Down Expand Up @@ -350,6 +525,10 @@ def with_structured_output(self, schema: Any, **kwargs) -> _StructuredOutputRout
def _bandit_select() -> tuple[Any, str]:
name = self._bandit.select()
self._bandit.record_pull(name)
# Mirror MultiModelRouter._select: publish the selection to the
# ContextVar so ``get_selected_model()`` reads the actual arm
# name during structured-output dispatch.
_remember_selected_model(name)
tid = self._current_task_id()
if tid is not None:
self._task_model_map[tid] = name
Expand All @@ -364,4 +543,5 @@ def _bandit_select() -> tuple[Any, str]:
self._tracker,
task_model_map=self._task_model_map,
select_override=_bandit_select,
failure_hook=self._inject_failure_reward,
)
Loading