diff --git a/gigaevo/database/redis_program_storage.py b/gigaevo/database/redis_program_storage.py index 68694a94..78cae2c9 100644 --- a/gigaevo/database/redis_program_storage.py +++ b/gigaevo/database/redis_program_storage.py @@ -25,6 +25,7 @@ from gigaevo.programs.program_state import ProgramState, validate_transition from gigaevo.utils.json import dumps as _dumps from gigaevo.utils.json import loads as _loads +from gigaevo.utils.text_sanitize import sanitize_for_log from gigaevo.utils.trackers.base import LogWriter T = TypeVar("T") @@ -109,7 +110,11 @@ def _safe_deserialize( try: return Program.from_dict(_loads(raw), exclude=exclude) except Exception as e: - logger.warning("[RedisProgramStorage] Corrupt data in {}: {}", ctx, e) + logger.warning( + "[RedisProgramStorage] Corrupt data in {}: {}", + ctx, + sanitize_for_log(str(e)), + ) return None async def _mget_by_keys( diff --git a/gigaevo/database/state_manager.py b/gigaevo/database/state_manager.py index b8ee2fd1..3ebdf945 100644 --- a/gigaevo/database/state_manager.py +++ b/gigaevo/database/state_manager.py @@ -7,6 +7,7 @@ from gigaevo.programs.core_types import ProgramStageResult, StageState from gigaevo.programs.program import Program from gigaevo.programs.program_state import ProgramState, validate_transition +from gigaevo.utils.text_sanitize import sanitize_for_log # States after which the DagRunner never accesses the program again. # Evict per-program locks for these states to prevent unbounded memory growth. @@ -100,7 +101,7 @@ async def set_program_state( logger.error( "[ProgramStateManager] Invalid state transition for {}: {}", program.short_id, - e, + sanitize_for_log(str(e)), ) raise diff --git a/gigaevo/evolution/bus/transport.py b/gigaevo/evolution/bus/transport.py index 786cb172..960746d2 100644 --- a/gigaevo/evolution/bus/transport.py +++ b/gigaevo/evolution/bus/transport.py @@ -8,12 +8,14 @@ from abc import ABC, abstractmethod import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from loguru import logger from pydantic import BaseModel import redis.asyncio as aioredis +from gigaevo.utils.text_sanitize import deep_sanitize_for_json + if TYPE_CHECKING: from gigaevo.evolution.bus.topology import Topology @@ -28,10 +30,19 @@ class MigrantEnvelope(BaseModel): generation: int def to_stream_fields(self) -> dict[str, str]: + # Belt-and-suspenders: program_data carries LLM-generated code plus + # stage errors whose origins span Python / Triton / CUDA C++ / + # CUTLASS / Mojo / Pallas / CuTe. Any one of those toolchains can + # emit text that contains a lone UTF-16 surrogate; json.dumps then + # raises UnicodeEncodeError and the migration write aborts. Scrub + # surrogates at the boundary. + safe_program_data = cast( + dict[str, Any], deep_sanitize_for_json(self.program_data) + ) return { "source_run_id": self.source_run_id, "program_id": self.program_id, - "program_data": json.dumps(self.program_data), + "program_data": json.dumps(safe_program_data), "published_at": str(self.published_at), "generation": str(self.generation), } diff --git a/gigaevo/evolution/mutation/mutation_operator.py b/gigaevo/evolution/mutation/mutation_operator.py index 42a119c7..b0109dd7 100644 --- a/gigaevo/evolution/mutation/mutation_operator.py +++ b/gigaevo/evolution/mutation/mutation_operator.py @@ -17,6 +17,7 @@ from gigaevo.problems.context import ProblemContext from gigaevo.programs.metrics.formatter import MetricsFormatter from gigaevo.programs.program import Program +from gigaevo.utils.text_sanitize import sanitize_for_log if TYPE_CHECKING: from gigaevo.database.program_storage import ProgramStorage @@ -95,7 +96,7 @@ def _canonicalize_code(code: str) -> str: logger.warning( "[LLMMutationOperator] Failed to canonicalize code due to syntax error: {}. " "Returning original code.", - e, + sanitize_for_log(str(e)), ) return code @@ -156,7 +157,10 @@ async def mutate_single( if structured_output: mutation_metadata[MutationSpec.META_OUTPUT] = structured_output archetype = result.get("archetype", "unknown") - logger.debug("[LLMMutationOperator] Mutation archetype: {}", archetype) + logger.debug( + "[LLMMutationOperator] Mutation archetype: {}", + sanitize_for_log(str(archetype)), + ) if result.get("changes"): logger.debug( "[LLMMutationOperator] Mutation returned {} tracked change(s)", diff --git a/gigaevo/llm/agents/insights.py b/gigaevo/llm/agents/insights.py index 96565050..299fbac6 100644 --- a/gigaevo/llm/agents/insights.py +++ b/gigaevo/llm/agents/insights.py @@ -8,12 +8,13 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_openai import ChatOpenAI -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from gigaevo.llm.agents.base import LangGraphAgent from gigaevo.llm.models import MultiModelRouter from gigaevo.programs.metrics.formatter import MetricsFormatter from gigaevo.programs.program import OPTIMIZATION_STAGES, Program +from gigaevo.utils.text_sanitize import sanitize_for_log class ProgramInsight(BaseModel): @@ -24,6 +25,11 @@ class ProgramInsight(BaseModel): tag: str = Field(description="Tag for the insight") severity: str = Field(description="Severity of the insight") + @field_validator("type", "insight", "tag", "severity", mode="after") + @classmethod + def _scrub_text(cls, value: str) -> str: + return sanitize_for_log(value) + class ProgramInsights(BaseModel): """Collection of program insights.""" diff --git a/gigaevo/llm/agents/lineage.py b/gigaevo/llm/agents/lineage.py index c3ae277c..f2e76cd6 100644 --- a/gigaevo/llm/agents/lineage.py +++ b/gigaevo/llm/agents/lineage.py @@ -9,12 +9,13 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_openai import ChatOpenAI -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from gigaevo.llm.agents.base import LangGraphAgent from gigaevo.llm.models import MultiModelRouter from gigaevo.programs.metrics.formatter import MetricsFormatter from gigaevo.programs.program import OPTIMIZATION_STAGES, Program +from gigaevo.utils.text_sanitize import sanitize_for_log class TransitionInsight(BaseModel): @@ -27,6 +28,11 @@ class TransitionInsight(BaseModel): description="Specific explanation with evidence (≤30 words)" ) + @field_validator("strategy", "description", mode="after") + @classmethod + def _scrub_text(cls, value: str) -> str: + return sanitize_for_log(value) + class TransitionInsights(BaseModel): """Collection of transition insights.""" diff --git a/gigaevo/llm/agents/memory_selector.py b/gigaevo/llm/agents/memory_selector.py index 1592d43e..298f6c3a 100644 --- a/gigaevo/llm/agents/memory_selector.py +++ b/gigaevo/llm/agents/memory_selector.py @@ -14,6 +14,7 @@ from gigaevo.evolution.mutation.constants import MUTATION_CONTEXT_METADATA_KEY from gigaevo.programs.program import Program +from gigaevo.utils.text_sanitize import sanitize_for_log try: from gigaevo.memory.runtime_config import ( @@ -71,7 +72,7 @@ def _resolve_memory_backend_class(use_api: bool) -> type[Any]: def _create_memory_backend(self) -> Any | None: if _RUNTIME_IMPORT_ERROR is not None: - message = ( + message = sanitize_for_log( "gigaevo.memory is unavailable" f"{': ' + str(_RUNTIME_IMPORT_ERROR) if _RUNTIME_IMPORT_ERROR else ''}" ) @@ -210,17 +211,19 @@ def _create_memory_backend(self) -> Any | None: logger.info( "[MemorySelectorAgent] Using memory backend " "(class={}, use_api={}, namespace={}, channel={}, checkpoint={})", - type(memory).__module__, + sanitize_for_log(type(memory).__module__), use_api, - namespace, - channel, - memory_dir, + sanitize_for_log(str(namespace)), + sanitize_for_log(str(channel)), + sanitize_for_log(str(memory_dir)), ) return memory except Exception as exc: - self._backend_error = str(exc) + safe_exc = sanitize_for_log(str(exc)) + self._backend_error = safe_exc logger.warning( - "[MemorySelectorAgent] Failed to initialize red memory backend: {}", exc + "[MemorySelectorAgent] Failed to initialize red memory backend: {}", + safe_exc, ) return None @@ -259,7 +262,7 @@ async def select( if self.memory is None: logger.warning( "[MemorySelectorAgent] Memory backend unavailable: {}", - self._backend_error or "unknown error", + sanitize_for_log(self._backend_error or "unknown error"), ) return MemorySelection(cards=[], card_ids=[]) @@ -280,7 +283,10 @@ async def select( self._search_with_ids, query ) except Exception as exc: - logger.warning("[MemorySelectorAgent] Red memory search failed: {}", exc) + logger.warning( + "[MemorySelectorAgent] Red memory search failed: {}", + sanitize_for_log(str(exc)), + ) return MemorySelection(cards=[], card_ids=[]) cards = self._parse_search_result(result_text, max_cards=max_cards) @@ -295,7 +301,7 @@ async def select( logger.debug( "[MemorySelectorAgent] Selected {} memory idea(s) via red agent (ids={})", len(cards), - card_ids, + [sanitize_for_log(cid) for cid in card_ids], ) else: logger.debug( @@ -358,7 +364,7 @@ def _search_with_ids(self, query: str) -> tuple[str, list[str]]: except Exception as exc: logger.warning( "[MemorySelectorAgent] Direct GAM research failed, falling back to plain search: {}", - exc, + sanitize_for_log(str(exc)), ) assert self.memory is not None # caller checks self.memory before calling diff --git a/gigaevo/llm/agents/mutation.py b/gigaevo/llm/agents/mutation.py index 8a8ee40f..371011a4 100644 --- a/gigaevo/llm/agents/mutation.py +++ b/gigaevo/llm/agents/mutation.py @@ -8,7 +8,7 @@ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_openai import ChatOpenAI from loguru import logger -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from gigaevo.evolution.mutation.base import MutationSpec from gigaevo.evolution.mutation.constants import ( @@ -18,6 +18,7 @@ from gigaevo.llm.agents.base import LangGraphAgent from gigaevo.llm.models import MultiModelRouter, get_selected_model from gigaevo.programs.program import Program +from gigaevo.utils.text_sanitize import sanitize_for_log if TYPE_CHECKING: from gigaevo.programs.metrics.context import MetricsContext @@ -42,6 +43,14 @@ class MutationChange(BaseModel): ) ) + @field_validator("description", "explanation", mode="after") + @classmethod + def _scrub_text(cls, value: str) -> str: + # LLM-generated free-form text; sanitize so downstream log sinks, + # JSON encoders, and asyncpg TEXT columns never see ANSI escape + # sequences, BIDI overrides, lone UTF-16 surrogates, or NUL bytes. + return sanitize_for_log(value) + class MutationStructuredOutput(BaseModel): """Structured output from the mutation LLM. @@ -76,6 +85,20 @@ class MutationStructuredOutput(BaseModel): ) ) + @field_validator("archetype", "justification", "code", mode="after") + @classmethod + def _scrub_text(cls, value: str) -> str: + # ``code`` is sanitized too — Python source has no legitimate use + # for ANSI/BIDI/C0-other-than-TAB-LF or NUL, and an LLM injecting + # one would otherwise break ast.parse() error formatting, log + # rendering, and asyncpg storage. + return sanitize_for_log(value) + + @field_validator("insights_used", mode="after") + @classmethod + def _scrub_insights(cls, value: list[str]) -> list[str]: + return [sanitize_for_log(v) for v in value] + # Re-export from canonical location for backward compatibility MUTATION_OUTPUT_METADATA_KEY = MutationSpec.META_OUTPUT @@ -202,7 +225,9 @@ def _dump_prompt_to_file( f.write(user) f.write("\n") except Exception as exc: - logger.debug(f"[MutationAgent] prompt dump failed: {exc}") + logger.debug( + "[MutationAgent] prompt dump failed: {}", sanitize_for_log(str(exc)) + ) async def arun(self, input: list[Program], mutation_mode: str) -> dict: """Execute mutation agent. @@ -258,8 +283,9 @@ async def acall_llm(self, state: MutationState) -> MutationState: ) except Exception as e: - logger.error(f"[MutationAgent] Structured LLM call failed: {e}") - state["error"] = str(e) + safe_msg = sanitize_for_log(str(e)) + logger.error("[MutationAgent] Structured LLM call failed: {}", safe_msg) + state["error"] = safe_msg state["llm_response"] = None return state @@ -387,8 +413,10 @@ def parse_response(self, state: MutationState) -> MutationState: model_used = state.get("metadata", {}).get("model_used") if structured_output is None: - error_msg = state.get("error", "No structured output received") - logger.error(f"[MutationAgent] No structured output: {error_msg}") + error_msg = sanitize_for_log( + state.get("error", "No structured output received") + ) + logger.error("[MutationAgent] No structured output: {}", error_msg) state["parsed_output"] = { "code": "", "structured_output": None, @@ -450,14 +478,17 @@ def parse_response(self, state: MutationState) -> MutationState: ) except Exception as e: - logger.error(f"[MutationAgent] Failed to parse structured response: {e}") - state["error"] = str(e) + safe_msg = sanitize_for_log(str(e)) + logger.error( + "[MutationAgent] Failed to parse structured response: {}", safe_msg + ) + state["error"] = safe_msg state["parsed_output"] = { "code": "", "structured_output": ( structured_output.model_dump() if structured_output else None ), - "error": str(e), + "error": safe_msg, "model_used": model_used, } diff --git a/gigaevo/llm/bandit.py b/gigaevo/llm/bandit.py index b63fe5c0..446aeb65 100644 --- a/gigaevo/llm/bandit.py +++ b/gigaevo/llm/bandit.py @@ -19,6 +19,7 @@ import numpy as np from gigaevo.llm.models import MultiModelRouter, _StructuredOutputRouter +from gigaevo.utils.text_sanitize import sanitize_for_log from gigaevo.utils.trackers.base import LogWriter if TYPE_CHECKING: @@ -296,7 +297,7 @@ def on_mutation_outcome( self._bandit.update_reward(model_name, normalized) logger.debug( "[BanditModelRouter] Reward for {} ({}): raw=0.0 norm={:.4f}", - model_name, + sanitize_for_log(str(model_name)), outcome.value, normalized, ) @@ -326,7 +327,7 @@ def on_mutation_outcome( self._bandit.update_reward(model_name, normalized) logger.debug( "[BanditModelRouter] Reward for {} ({}): raw={:.4f} norm={:.4f}", - model_name, + sanitize_for_log(str(model_name)), outcome.value, raw, normalized, diff --git a/gigaevo/llm/models.py b/gigaevo/llm/models.py index 7a810372..0f977797 100644 --- a/gigaevo/llm/models.py +++ b/gigaevo/llm/models.py @@ -6,6 +6,7 @@ import os import random from typing import TYPE_CHECKING, Any, cast +from urllib.parse import urlparse, urlunparse from langchain_core.language_models import LanguageModelInput from langchain_core.messages import BaseMessage @@ -15,6 +16,7 @@ from loguru import logger from gigaevo.llm.token_tracking import TokenTracker +from gigaevo.utils.text_sanitize import clean_identifier, sanitize_for_log from gigaevo.utils.trackers.base import LogWriter if TYPE_CHECKING: @@ -24,6 +26,37 @@ _selected_model_var: ContextVar[str | None] = ContextVar("selected_model", default=None) +def _redact_url(url: str) -> str: + """Strip userinfo (user:password@) from a URL before logging. Other + URL components are preserved verbatim. Returns the input unchanged + on parse failure.""" + try: + parsed = urlparse(url) + except Exception: + return url + if not parsed.hostname: + return url + netloc = parsed.hostname + if parsed.port is not None: + netloc = f"{netloc}:{parsed.port}" + return urlunparse(parsed._replace(netloc=netloc)) + + +def _safe_model_name(raw: object) -> str: + """Validate a model name read off a ChatOpenAI instance. Strips control + characters and ANSI; logs a one-shot WARNING if the cleaning changes + the input so operators notice a misconfigured identifier.""" + raw_str = str(raw) if raw is not None else "" + cleaned = clean_identifier(raw_str, max_len=128) + if cleaned != raw_str: + logger.warning( + "[MultiModelRouter] model_name sanitized: {!r} -> {!r}", + sanitize_for_log(raw_str), + cleaned, + ) + return cleaned + + def get_selected_model() -> str | None: """Return the last selected model name for the current async context.""" return _selected_model_var.get() @@ -95,7 +128,12 @@ def __init__( raise ValueError("All probabilities must be positive") self.models = models - self.model_names = [m.model_name for m in models] + # ChatOpenAI.model_name comes from operator config / env interpolation + # / occasionally LLM-generated overrides; control characters there + # would propagate into every loguru ``{}`` substitution in this + # module and into langfuse trace identifiers. Validate once at + # construction. + self.model_names = [_safe_model_name(m.model_name) for m in models] self.probabilities = [p / sum(probabilities) for p in probabilities] self._task_model_map: dict[int, str] = {} self._name = name @@ -116,12 +154,15 @@ def __init__( model_desc, ) # Log base URLs for debugging server connectivity - for m in models: + for m, safe_name in zip(models, self.model_names): # ChatOpenAI exposes base_url as a property (langchain 0.1+) base_url = getattr(m, "base_url", None) if base_url: logger.info( - "[MultiModelRouter:{}] Model {} at {}", name, m.model_name, base_url + "[MultiModelRouter:{}] Model {} at {}", + name, + safe_name, + _redact_url(sanitize_for_log(str(base_url))), ) self._verify_models() @@ -139,38 +180,45 @@ def _verify_models(self) -> None: if not base_url or base_url in checked: continue checked.add(base_url) + # Redacted view of base_url is what enters log messages. The raw + # value still drives the HTTP GET below — operators need that + # for connectivity debugging, but it must never reach loguru. + safe_base_url = _redact_url(sanitize_for_log(str(base_url))) try: url = f"{base_url}/models" req = urllib.request.Request(url, method="GET") with urllib.request.urlopen(req, timeout=10) as resp: # noqa: S310 data = _json.loads(resp.read()) - available = [d["id"] for d in data.get("data", [])] - for m in self.models: + # Server-returned model ids are LLM-provider-controlled text; + # treat them as untrusted before logging or comparing. + available_raw = [d["id"] for d in data.get("data", [])] + available = [sanitize_for_log(str(x)) for x in available_raw] + for m, safe_name in zip(self.models, self.model_names): m_url = getattr(m, "base_url", None) or getattr( m, "openai_api_base", None ) if m_url == base_url: - if m.model_name in available: + if m.model_name in available_raw: logger.info( "[MultiModelRouter:{}] Model {} verified on {}", self._name, - m.model_name, - base_url, + safe_name, + safe_base_url, ) else: logger.warning( "[MultiModelRouter:{}] Model {} NOT FOUND on {}. Available: {}", self._name, - m.model_name, - base_url, + safe_name, + safe_base_url, available, ) except Exception as exc: logger.warning( "[MultiModelRouter:{}] Cannot verify models at {}: {}", self._name, - base_url, - exc, + safe_base_url, + sanitize_for_log(str(exc)), ) @staticmethod diff --git a/gigaevo/llm/token_tracking.py b/gigaevo/llm/token_tracking.py index 2e912852..98015ad0 100644 --- a/gigaevo/llm/token_tracking.py +++ b/gigaevo/llm/token_tracking.py @@ -2,8 +2,9 @@ from typing import Annotated, Any from loguru import logger -from pydantic import BaseModel, Field, SkipValidation +from pydantic import BaseModel, Field, SkipValidation, ValidationError +from gigaevo.utils.text_sanitize import clean_identifier, sanitize_for_log from gigaevo.utils.trackers.base import LogWriter @@ -60,16 +61,39 @@ class TokenTracker(BaseModel): ) def track(self, response: Any, model_name: str) -> None: - """Track token usage from LLM response. Thread-safe.""" + """Track token usage from LLM response. Thread-safe. + + ``model_name`` is cleaned through ``clean_identifier`` because it + flows into metric path components (``self._write_metrics`` joins it + with slashes and colons replaced) and into loguru ``{}`` slots. The + ``TokenUsage.from_response`` call is wrapped against ``ValidationError`` + — a hostile provider returning ``"prompt_tokens": "lots"`` would + otherwise abort the whole ``ainvoke`` boundary. + """ if self.writer is None: return - usage = TokenUsage.from_response(response) + safe_model_name = clean_identifier(str(model_name), max_len=128) or "unknown" + try: + usage = TokenUsage.from_response(response) + except (ValidationError, TypeError, ValueError) as exc: + logger.debug( + "[TokenTracker:{}] Token usage extraction failed for {}: {}", + self.name, + safe_model_name, + sanitize_for_log(str(exc)), + ) + return if usage is None: logger.debug( - "[TokenTracker:{}] No token usage for {}", self.name, model_name + "[TokenTracker:{}] No token usage for {}", + self.name, + safe_model_name, ) return + # Shadow original variable so the rest of the function uses the + # cleaned name everywhere (cumulative dict key, _write_metrics path). + model_name = safe_model_name with self.lock: if model_name not in self.cumulative: diff --git a/gigaevo/programs/core_types.py b/gigaevo/programs/core_types.py index 417d2b3f..4d411012 100644 --- a/gigaevo/programs/core_types.py +++ b/gigaevo/programs/core_types.py @@ -7,9 +7,10 @@ from typing import Any import cloudpickle -from pydantic import BaseModel, Field, field_serializer +from pydantic import BaseModel, Field, field_serializer, field_validator from gigaevo.programs.utils import pickle_b64_deserialize, pickle_b64_serialize +from gigaevo.utils.text_sanitize import sanitize_for_log class StageIO(BaseModel): @@ -39,6 +40,21 @@ class StageError(BaseModel): stage: str | None = Field(default=None, description="Stage class name, if known") traceback: str | None = Field(default=None, description="Formatted traceback") + @field_validator("type", "message", mode="after") + @classmethod + def _scrub_text(cls, value: str) -> str: + # LLM-generated code triggers compiler errors whose stderr (from + # nvcc / clang / ptxas / Triton / Mojo / Pallas / CUTLASS) may carry + # ANSI colorization, BIDI overrides, lone surrogates, NUL bytes. + # Sanitize at construction so every downstream (loguru, orjson, + # Redis storage, re-injection back into LLM prompts) sees safe text. + return sanitize_for_log(value) + + @field_validator("traceback", mode="after") + @classmethod + def _scrub_traceback(cls, value: str | None) -> str | None: + return None if value is None else sanitize_for_log(value) + @classmethod def from_exception( cls, diff --git a/gigaevo/programs/dag/dag.py b/gigaevo/programs/dag/dag.py index b0f85cd4..7a230318 100644 --- a/gigaevo/programs/dag/dag.py +++ b/gigaevo/programs/dag/dag.py @@ -21,6 +21,7 @@ ) from gigaevo.programs.program import Program from gigaevo.programs.stages.base import Stage +from gigaevo.utils.text_sanitize import sanitize_for_log from gigaevo.utils.trackers.base import LogWriter DEFAULT_STALL_GRACE_SECONDS = 120.0 @@ -409,11 +410,17 @@ async def _process_finished_task( result = cast(ProgramStageResult, outcome) if result.status == StageState.FAILED and result.error is not None: + # ``StageError`` field_validators already sanitize ``type``, + # ``message``, and ``traceback`` at construction, so + # ``pretty()`` returns text whose interpolated leaves are + # safe. The outer literal in ``pretty()`` is hardcoded. + # Defensive wrap guards against future fields being added to + # the format string without a corresponding validator. logger.exception( "[DAG][{}] Stage '{}' FAILED with exception.\n### ERROR SUMMARY ###:\n{}", pid, stage_name, - result.error.pretty(include_traceback=True), + sanitize_for_log(result.error.pretty(include_traceback=True)), ) await self._persist_stage_result(program, stage_name, result) diff --git a/gigaevo/programs/stages/optimization/optuna/stage.py b/gigaevo/programs/stages/optimization/optuna/stage.py index b8218632..2c7fcd8e 100644 --- a/gigaevo/programs/stages/optimization/optuna/stage.py +++ b/gigaevo/programs/stages/optimization/optuna/stage.py @@ -12,6 +12,7 @@ import asyncio import math from pathlib import Path +import re import time from typing import Any, cast import warnings @@ -63,9 +64,71 @@ run_exec_runner, ) from gigaevo.programs.stages.stage_registry import StageRegistry +from gigaevo.utils.text_sanitize import clean_identifier, sanitize_for_log + +# Conservative cap on Optuna parameter names. Long enough to admit +# any reasonable snake_case identifier the LLM proposes; short enough +# that a hostile payload of megabytes of garbage cannot bloat the +# trial dict / optuna storage key. Matches typical SQL identifier +# limits and leaves headroom for compound names like +# "block_size_x_inner_loop_unroll". +_MAX_PARAM_NAME_LEN: int = 64 _DEADLINE_GRACE_S: int = 10 # post-eval margin before hard stage timeout +_OPTUNA_PARAM_REF_RE = re.compile( + rf"(?P{re.escape(_OPTUNA_PARAMS_NAME)}\s*\[\s*)" + r"(?P'(?P(?:\\.|[^'\\])*)'|\"(?P(?:\\.|[^\"\\])*)\")" + r"(?P\s*\])" +) + + +def _dedupe_param_name(base_name: str, used_names: set[str]) -> str: + """Return a unique Optuna parameter name within the configured length cap.""" + + if base_name not in used_names: + return base_name + + counter = 1 + while True: + suffix = f"_{counter}" + prefix_len = max(1, _MAX_PARAM_NAME_LEN - len(suffix)) + candidate = f"{base_name[:prefix_len]}{suffix}" + if candidate not in used_names: + return candidate + counter += 1 + + +def _string_literal_key(match: re.Match[str]) -> str | None: + """Decode the string key from an ``_optuna_params[...]`` regex match.""" + + literal = match.group("literal") + try: + value = ast.literal_eval(literal) + except (SyntaxError, ValueError): + value = match.group("single") + if value is None: + value = match.group("double") + return value if isinstance(value, str) else None + + +def _rewrite_optuna_param_refs(snippet: str, name_map: dict[str, str]) -> str: + """Rewrite string-literal ``_optuna_params`` keys after name sanitization.""" + + if not name_map: + return snippet + + def replace(match: re.Match[str]) -> str: + old_key = _string_literal_key(match) + if old_key is None: + return match.group(0) + new_key = name_map.get(old_key) + if new_key is None or new_key == old_key: + return match.group(0) + return f"{match.group('prefix')}{new_key!r}{match.group('suffix')}" + + return _OPTUNA_PARAM_REF_RE.sub(replace, snippet) + @StageRegistry.register( description="LLM-guided hyperparameter optimization using Optuna" @@ -206,6 +269,38 @@ def _apply_modifications( ValueError If line ranges are invalid or if the resulting code has syntax errors. """ + # LLM-derived parameter names flow into ``trial.suggest_*`` calls, + # which embed the name in Optuna's internal storage key and in + # log records. A name like ``"\x00../"`` or ``"abc‮"`` would + # otherwise corrupt the trial dict and any downstream sink. + # Clean the identifier at the boundary; if the LLM hands us a + # name with no surviving characters, fall back to a stable + # positional id so optimization can still proceed. + name_map: dict[str, str] = {} + assigned_by_original: dict[str, str] = {} + used_names: set[str] = set() + for idx, p in enumerate(search_space.parameters): + original_name = p.name + if original_name in assigned_by_original: + cleaned = assigned_by_original[original_name] + else: + cleaned = clean_identifier(original_name, max_len=_MAX_PARAM_NAME_LEN) + if not cleaned: + cleaned = f"param_{idx}" + cleaned = _dedupe_param_name(cleaned, used_names) + assigned_by_original[original_name] = cleaned + used_names.add(cleaned) + + if cleaned != original_name: + name_map[original_name] = cleaned + if cleaned != p.name: + logger.debug( + "[Optuna] ParamSpec name sanitized: {!r} -> {!r}", + sanitize_for_log(p.name), + cleaned, + ) + p.name = cleaned + lines = original_code.splitlines() num_lines = len(lines) mods = sorted(search_space.modifications, key=lambda x: x.start_line) @@ -230,7 +325,8 @@ def _apply_modifications( for mod in reversed(mods): start_idx = mod.start_line - 1 end_idx = mod.end_line - replacement_lines = mod.parameterized_snippet.splitlines() + snippet = _rewrite_optuna_param_refs(mod.parameterized_snippet, name_map) + replacement_lines = snippet.splitlines() # Defensive: strip any "N | " prefix if the LLM copied the numbered format replacement_lines = strip_line_number_prefix(replacement_lines) # Re-indent to match the original block so we never get "unexpected indent" @@ -251,14 +347,22 @@ def _apply_modifications( try: ast.parse(code) except SyntaxError as e: - logger.error( - "[Optuna] Parameterized code has syntax error: {}\nCode snippet around error:\n{}", - e, + # ``e.msg`` and the surrounding code lines originate from + # LLM output and may carry ANSI / BIDI / control bytes; + # sanitize each interpolation independently. + snippet = ( "\n".join(code.splitlines()[max(0, e.lineno - 5) : e.lineno + 5]) if e.lineno - else "Unknown location", + else "Unknown location" ) - raise ValueError(f"Parameterized code syntax error: {e}") + logger.error( + "[Optuna] Parameterized code has syntax error: {}\nCode snippet around error:\n{}", + sanitize_for_log(str(e)), + sanitize_for_log(snippet), + ) + raise ValueError( + f"Parameterized code syntax error: {sanitize_for_log(str(e))}" + ) from e return code @@ -492,8 +596,16 @@ async def _evaluate_single( except TimeoutError: return None, None, "Timeout" except ExecRunnerError as exc: + # Sanitize compiler-stderr-derived text before it flows into + # StageError.message / loguru sinks downstream. The returned + # error string ends up in failure_reasons and ultimately in + # log lines aggregated for the LLM. last_line = (exc.stderr or "").strip().rsplit("\n", 1)[-1] - return None, None, f"{exc} | {last_line}" + return ( + None, + None, + f"{sanitize_for_log(str(exc))} | {sanitize_for_log(last_line)}", + ) async def _run_optuna( self, @@ -1152,7 +1264,11 @@ async def compute(self, program: Program) -> OptunaOptimizationOutput: n, [p.name for p in param_specs], ) - logger.debug("[Optuna][{}] LLM reasoning: {}", pid, search_space.reasoning) + logger.debug( + "[Optuna][{}] LLM reasoning: {}", + pid, + sanitize_for_log(search_space.reasoning), + ) # 3. Run Optuna ( diff --git a/gigaevo/programs/stages/optimization/utils.py b/gigaevo/programs/stages/optimization/utils.py index e58f209b..f0b1f33b 100644 --- a/gigaevo/programs/stages/optimization/utils.py +++ b/gigaevo/programs/stages/optimization/utils.py @@ -25,6 +25,7 @@ ExecRunnerError, run_exec_runner, ) +from gigaevo.utils.text_sanitize import sanitize_for_log # --------------------------------------------------------------------------- # Shared numeric / AST helpers @@ -244,14 +245,25 @@ async def evaluate_single( return result, None msg = f"Unexpected result type: {type(result).__name__} (expected dict with key '{score_key}')" - logger.warning("[{}] {}", log_tag, msg) + logger.warning("[{}] {}", log_tag, sanitize_for_log(msg)) return None, msg except TimeoutError: logger.trace("[{}] single evaluation timed out", log_tag) return None, "Timeout" except ExecRunnerError as exc: + # exc.stderr may carry ANSI / NUL / BIDI from heterogeneous + # compiler stacks; the str(exc) message is sanitized too because + # ExecRunnerError stores arbitrary text. last_line = (exc.stderr or "").strip().rsplit("\n", 1)[-1] - logger.trace("[{}] eval failed: {} | {}", log_tag, exc, last_line) - # Return the actual error message so the caller can log it if critical - return None, f"{exc} | {last_line}" + safe_exc = sanitize_for_log(str(exc)) + safe_last = sanitize_for_log(last_line) + logger.trace( + "[{}] eval failed: {} | {}", + log_tag, + safe_exc, + safe_last, + ) + # Return the sanitized error message so callers can log / store + # it without re-introducing terminal control bytes. + return None, f"{safe_exc} | {safe_last}" diff --git a/gigaevo/programs/stages/python_executors/execution.py b/gigaevo/programs/stages/python_executors/execution.py index d4a28acc..3f1a1c8f 100644 --- a/gigaevo/programs/stages/python_executors/execution.py +++ b/gigaevo/programs/stages/python_executors/execution.py @@ -22,6 +22,7 @@ ) from gigaevo.programs.stages.stage_registry import StageRegistry from gigaevo.programs.utils import dedent_code +from gigaevo.utils.text_sanitize import sanitize_for_log T = TypeVar("T") @@ -133,12 +134,17 @@ async def compute(self, program: Program) -> ProgramStageResult | Box[Any]: else "Process ran out of memory" ) + # Subprocess stderr may contain ANSI / NUL / BIDI / lone + # surrogates from heterogeneous compiler toolchains (nvcc, + # ptxas, Triton, Mojo). Sanitize the log interpolation; the + # StageError construction below is already covered by + # field_validators on type/message/traceback. logger.warning( "[{}] {} FAILED for {}: {}", stage_name, error_type, program.id[:8], - error_msg[:200], + sanitize_for_log(error_msg[:200]), ) return ProgramStageResult.failure( error=StageError( @@ -153,7 +159,7 @@ async def compute(self, program: Program) -> ProgramStageResult | Box[Any]: "[{}] Exception for {}: {}", stage_name, program.id[:8], - str(e)[:200], + sanitize_for_log(str(e)[:200]), ) return ProgramStageResult.failure( error=StageError.from_exception(e, stage=stage_name) diff --git a/gigaevo/programs/stages/validation.py b/gigaevo/programs/stages/validation.py index f32f83e9..26bcd4d9 100644 --- a/gigaevo/programs/stages/validation.py +++ b/gigaevo/programs/stages/validation.py @@ -12,6 +12,7 @@ from gigaevo.programs.program import Program from gigaevo.programs.stages.base import Stage from gigaevo.programs.stages.stage_registry import StageRegistry +from gigaevo.utils.text_sanitize import sanitize_for_log class CodeValidationOutput(StageIO): @@ -73,9 +74,18 @@ async def compute(self, program: Program) -> StageIO: try: compile(code, "", "exec") except SyntaxError as e: + # ``e.msg`` and ``e.text`` come from the parser operating on + # LLM output, which can embed control bytes that would slip + # into the re-raised exception's args and from there into + # every downstream log / serialization path. Sanitize before + # interpolation; the StageError validators below would catch + # the final log line, but the exception text itself is + # consumed elsewhere (e.g. ``__cause__`` chains in tests). code_line = (e.text or "").strip() or "" raise SyntaxError( - f"SyntaxError at line {e.lineno}, offset {e.offset}: {e.msg}. Line: `{code_line}`" + f"SyntaxError at line {e.lineno}, offset {e.offset}: " + f"{sanitize_for_log(e.msg or '')}. " + f"Line: `{sanitize_for_log(code_line)}`" ) from e if self.safe_mode: diff --git a/gigaevo/prompts/coevolution/stages.py b/gigaevo/prompts/coevolution/stages.py index 88b6fc98..08fcc65c 100644 --- a/gigaevo/prompts/coevolution/stages.py +++ b/gigaevo/prompts/coevolution/stages.py @@ -21,6 +21,7 @@ from gigaevo.programs.stages.insights_lineage import LineageAnalysesOutput, LineageStage from gigaevo.programs.stages.stage_registry import StageRegistry from gigaevo.prompts.coevolution.stats import PromptStatsProvider, prompt_text_to_id +from gigaevo.utils.text_sanitize import sanitize_for_log class PromptExecutionOutput(StageIO): @@ -58,15 +59,19 @@ async def compute(self, program: Program) -> PromptExecutionOutput: # type: ign raise ValueError( "Prompt program must contain 'def entrypoint()'. " "Got non-Python content (possibly JSON template). " - f"Code starts with: {code[:80]!r}" + f"Code starts with: {sanitize_for_log(code[:80])!r}" ) namespace: dict[str, Any] = {} try: exec(compile(code, "", "exec"), namespace) # noqa: S102 except SyntaxError as exc: - raise ValueError(f"Prompt program has syntax error: {exc}") from exc + raise ValueError( + f"Prompt program has syntax error: {sanitize_for_log(str(exc))}" + ) from exc except Exception as exc: - raise ValueError(f"Prompt program failed to compile/exec: {exc}") from exc + raise ValueError( + f"Prompt program failed to compile/exec: {sanitize_for_log(str(exc))}" + ) from exc entrypoint_fn = namespace.get("entrypoint") if not callable(entrypoint_fn): @@ -75,7 +80,9 @@ async def compute(self, program: Program) -> PromptExecutionOutput: # type: ign try: result = entrypoint_fn() except Exception as exc: - raise ValueError(f"entrypoint() raised an exception: {exc}") from exc + raise ValueError( + f"entrypoint() raised an exception: {sanitize_for_log(str(exc))}" + ) from exc if isinstance(result, str): if not result.strip(): @@ -100,6 +107,16 @@ async def compute(self, program: Program) -> PromptExecutionOutput: # type: ign f"entrypoint() must return str or dict, got {type(result).__name__}" ) + system_text = sanitize_for_log(system_text) + if not system_text.strip(): + raise ValueError("entrypoint() returned empty string after sanitization") + if user_text is not None: + user_text = sanitize_for_log(user_text) + if not user_text.strip(): + raise ValueError( + "dict entrypoint() 'user' key became empty after sanitization" + ) + prompt_id = prompt_text_to_id(system_text, user_text=user_text) logger.debug( f"[PromptExecutionStage] Executed entrypoint(): " diff --git a/gigaevo/prompts/coevolution/stats.py b/gigaevo/prompts/coevolution/stats.py index 6ae1738c..70db5b97 100644 --- a/gigaevo/prompts/coevolution/stats.py +++ b/gigaevo/prompts/coevolution/stats.py @@ -14,6 +14,8 @@ from loguru import logger from redis import asyncio as aioredis +from gigaevo.utils.text_sanitize import sanitize_for_log + @dataclass class PromptMutationStats: @@ -131,8 +133,11 @@ async def get_stats(self, prompt_id: str) -> PromptMutationStats: ) + float(v) except Exception as exc: logger.warning( - f"[RedisPromptStatsProvider] Error reading stats from " - f"db={db} for {prompt_id}: {exc}" + "[RedisPromptStatsProvider] Error reading stats from " + "db={} for {}: {}", + db, + sanitize_for_log(str(prompt_id)), + sanitize_for_log(str(exc)), ) if total_trials < self._min_trials: @@ -178,7 +183,9 @@ def prompt_text_to_id(prompt_text: str, user_text: str | None = None) -> str: Returns: 16-char hex string (sha256[:16]) """ + prompt_text = sanitize_for_log(prompt_text) blob = prompt_text if user_text is not None: + user_text = sanitize_for_log(user_text) blob = prompt_text + "\x00" + user_text return hashlib.sha256(blob.encode()).hexdigest()[:16] diff --git a/gigaevo/prompts/fetcher.py b/gigaevo/prompts/fetcher.py index eadcf9a1..dd6c67b4 100644 --- a/gigaevo/prompts/fetcher.py +++ b/gigaevo/prompts/fetcher.py @@ -18,6 +18,7 @@ from gigaevo.prompts import load_prompt from gigaevo.prompts.coevolution.stats import prompt_text_to_id +from gigaevo.utils.text_sanitize import sanitize_for_log if TYPE_CHECKING: from gigaevo.database.program_storage import ProgramStorage @@ -292,7 +293,9 @@ def _refresh_candidates(self) -> list[tuple[str, float, str]] | None: candidates.append((pid, fitness, code)) except Exception as exc: logger.debug( - f"[GigaEvoArchivePromptFetcher] Error parsing program {pid}: {exc}" + "[GigaEvoArchivePromptFetcher] Error parsing program {}: {}", + sanitize_for_log(str(pid)), + sanitize_for_log(str(exc)), ) continue @@ -301,7 +304,9 @@ def _refresh_candidates(self) -> list[tuple[str, float, str]] | None: except Exception as exc: self._fetch_errors += 1 logger.warning( - f"[GigaEvoArchivePromptFetcher] Archive read error (#{self._fetch_errors}): {exc}" + "[GigaEvoArchivePromptFetcher] Archive read error (#{}): {}", + self._fetch_errors, + sanitize_for_log(str(exc)), ) return None @@ -327,16 +332,26 @@ def _sample_prompt(self) -> _PromptPack | None: if pack is None: return None - user_preview = repr(pack.user[:300]) if pack.user else "None" + user_preview = ( + repr(sanitize_for_log(pack.user[:300])) if pack.user else "None" + ) + system_preview = repr(sanitize_for_log(pack.system[:300])) logger.info( - f"[GigaEvoArchivePromptFetcher] Sampled: {chosen_pid[:8]} " - f"fitness={chosen_fitness:.4f} prompt_id={pack.prompt_id} " - f"has_user={pack.user is not None} " - f"(from {len(candidates)} candidates)\n" - f" SYSTEM[:{min(300, len(pack.system))}]: " - f"{pack.system[:300]!r}\n" - f" USER[:{min(300, len(pack.user)) if pack.user else 0}]: " - f"{user_preview}" + "[GigaEvoArchivePromptFetcher] Sampled: {} " + "fitness={:.4f} prompt_id={} " + "has_user={} " + "(from {} candidates)\n" + " SYSTEM[:{}]: {}\n" + " USER[:{}]: {}", + sanitize_for_log(str(chosen_pid[:8])), + chosen_fitness, + sanitize_for_log(str(pack.prompt_id)), + pack.user is not None, + len(candidates), + min(300, len(pack.system)), + system_preview, + min(300, len(pack.user)) if pack.user else 0, + user_preview, ) return pack @@ -365,13 +380,14 @@ def _execute_entrypoint(self, code: str) -> _PromptPack | None: return None result = entrypoint_fn() if isinstance(result, str): - if not result.strip(): + system = sanitize_for_log(result) + if not system.strip(): logger.warning( "[GigaEvoArchivePromptFetcher] entrypoint() returned empty string" ) return None - pid = prompt_text_to_id(result) - return _PromptPack(system=result, user=None, prompt_id=pid) + pid = prompt_text_to_id(system) + return _PromptPack(system=system, user=None, prompt_id=pid) elif isinstance(result, dict): system = result.get("system", "") if not isinstance(system, str) or not system.strip(): @@ -379,23 +395,38 @@ def _execute_entrypoint(self, code: str) -> _PromptPack | None: "[GigaEvoArchivePromptFetcher] dict entrypoint() missing valid 'system' key" ) return None + system = sanitize_for_log(system) + if not system.strip(): + logger.warning( + "[GigaEvoArchivePromptFetcher] dict entrypoint() system became empty after sanitization" + ) + return None user = result.get("user") if user is not None and (not isinstance(user, str) or not user.strip()): logger.warning( "[GigaEvoArchivePromptFetcher] dict entrypoint() has invalid 'user' key — ignoring" ) user = None + if user is not None: + user = sanitize_for_log(user) + if not user.strip(): + logger.warning( + "[GigaEvoArchivePromptFetcher] dict entrypoint() user became empty after sanitization" + ) + user = None pid = prompt_text_to_id(system, user_text=user) return _PromptPack(system=system, user=user, prompt_id=pid) else: logger.warning( - f"[GigaEvoArchivePromptFetcher] entrypoint() returned {type(result)}, " - f"expected str or dict" + "[GigaEvoArchivePromptFetcher] entrypoint() returned {}, " + "expected str or dict", + type(result), ) return None except Exception as exc: logger.warning( - f"[GigaEvoArchivePromptFetcher] entrypoint() execution error: {exc}" + "[GigaEvoArchivePromptFetcher] entrypoint() execution error: {}", + sanitize_for_log(str(exc)), ) return None @@ -552,12 +583,18 @@ def record_outcome( self._redis_main_sync.set(stats_key, _json.dumps(stats)) logger.debug( - f"[GigaEvoArchivePromptFetcher] Stats updated for {prompt_id}: " - f"trials={stats['trials']} successes={stats['successes']} " - f"child_fitness={child_fitness:.4f}" + "[GigaEvoArchivePromptFetcher] Stats updated for {}: " + "trials={} successes={} child_fitness={:.4f}", + sanitize_for_log(str(prompt_id)), + stats["trials"], + stats["successes"], + child_fitness, ) except Exception as exc: - logger.warning(f"[GigaEvoArchivePromptFetcher] Stats write error: {exc}") + logger.warning( + "[GigaEvoArchivePromptFetcher] Stats write error: {}", + sanitize_for_log(str(exc)), + ) def get_stats(self) -> dict[str, Any]: return { diff --git a/gigaevo/runner/dag_runner.py b/gigaevo/runner/dag_runner.py index cd8f570f..f8aabdfb 100644 --- a/gigaevo/runner/dag_runner.py +++ b/gigaevo/runner/dag_runner.py @@ -20,6 +20,7 @@ from gigaevo.programs.program_state import ProgramState from gigaevo.runner.dag_blueprint import DAGBlueprint from gigaevo.utils.metrics_collector import start_metrics_collector +from gigaevo.utils.text_sanitize import sanitize_for_log from gigaevo.utils.trackers.base import LogWriter @@ -290,10 +291,14 @@ async def _maintain(self) -> None: self._metrics.record_timeout() logger.error("[DagScheduler] program {} timed out", info.program_id[:8]) except Exception as e: + # Exception ``__str__`` from downstream subprocess / + # compiler stacks may carry ANSI / control bytes; wrap + # before loguru interpolation to keep log records inert + # for parsers and terminal renderers. logger.error( "[DagScheduler] discard after timeout failed for {}: {}", info.program_id[:8], - e, + sanitize_for_log(str(e)), ) for info in finished: @@ -308,7 +313,9 @@ async def _maintain(self) -> None: except Exception as e: self._metrics.increment_dag_errors() logger.error( - "[DagScheduler] program {} failed: {}", info.program_id[:8], e + "[DagScheduler] program {} failed: {}", + info.program_id[:8], + sanitize_for_log(str(e)), ) finally: del info @@ -340,7 +347,10 @@ async def _launch(self) -> None: self._storage.get_ids_by_status(ProgramState.RUNNING.value), ) except Exception as e: - logger.error("[DagScheduler] fetch-by-status failed: {}", e) + logger.error( + "[DagScheduler] fetch-by-status failed: {}", + sanitize_for_log(str(e)), + ) return # Phase 2: handle orphaned RUNNING programs (fetch full data only for these) @@ -362,10 +372,13 @@ async def _launch(self) -> None: logger.error( "[DagScheduler] orphan discard failed for {}: {}", p.short_id, - se, + sanitize_for_log(str(se)), ) except Exception as e: - logger.error("[DagScheduler] orphan fetch failed: {}", e) + logger.error( + "[DagScheduler] orphan fetch failed: {}", + sanitize_for_log(str(e)), + ) # Phase 3: launch fresh programs up to capacity (fetch only what we need) # Prefetch: create up to max_concurrent_dags * prefetch_factor tasks. @@ -385,7 +398,10 @@ async def _launch(self) -> None: try: fresh = await self._storage.mget(to_launch_ids) except Exception as e: - logger.error("[DagScheduler] mget for launch failed: {}", e) + logger.error( + "[DagScheduler] mget for launch failed: {}", + sanitize_for_log(str(e)), + ) return launched: list[Program] = [] @@ -406,9 +422,14 @@ async def _launch(self) -> None: import traceback logger.error( - "[DagScheduler] DAG build failed for {}: {}", program.short_id, e + "[DagScheduler] DAG build failed for {}: {}", + program.short_id, + sanitize_for_log(str(e)), + ) + logger.error( + "[DagScheduler] Traceback:\n{}", + sanitize_for_log(traceback.format_exc()), ) - logger.error("[DagScheduler] Traceback:\n{}", traceback.format_exc()) self._metrics.record_build_failure() try: await self._state_manager.set_program_state( @@ -418,7 +439,7 @@ async def _launch(self) -> None: logger.error( "[DagScheduler] state update failed for {}: {}", program.short_id, - se, + sanitize_for_log(str(se)), ) self._metrics.record_state_update_failure() continue @@ -447,7 +468,10 @@ async def _run_one(prog: Program = program, dag_inst: DAG = dag) -> None: self._metrics.dag_runs_started += count logger.info("[DagScheduler] launched {} programs", count) except Exception as e: - logger.error("[DagScheduler] batch mark-started failed: {}", e) + logger.error( + "[DagScheduler] batch mark-started failed: {}", + sanitize_for_log(str(e)), + ) # Cancel tasks whose state transition failed for pid in launched_ids: info = self._active.pop(pid, None) @@ -461,8 +485,13 @@ async def _execute_dag(self, dag: DAG, program: Program) -> None: await dag.run(program) except Exception as exc: ok = False + # ``exc`` may originate from Triton / CUDA / Mojo / nvcc + # subprocess stderr propagated as an exception message; + # sanitize before loguru emits the record. logger.error( - "[DagScheduler] DAG run failed for {}: {}", program.short_id, exc + "[DagScheduler] DAG run failed for {}: {}", + program.short_id, + sanitize_for_log(str(exc)), ) finally: # Eagerly release references to allow GC of heavy objects. @@ -497,7 +526,7 @@ async def _execute_dag(self, dag: DAG, program: Program) -> None: logger.error( "[DagScheduler] state update failed for {}: {}", program.short_id, - se, + sanitize_for_log(str(se)), ) async def _flush_done_queue(self) -> None: @@ -519,7 +548,7 @@ async def _flush_done_queue(self) -> None: logger.error( "[DagScheduler] batch RUNNING→DONE failed for {} programs: {}", len(batch), - e, + sanitize_for_log(str(e)), ) async def _cancel_task(self, info: TaskInfo) -> None: diff --git a/gigaevo/utils/json.py b/gigaevo/utils/json.py index cb771d60..3ad0d2ae 100644 --- a/gigaevo/utils/json.py +++ b/gigaevo/utils/json.py @@ -4,6 +4,8 @@ import types from typing import Any +from gigaevo.utils.text_sanitize import deep_sanitize_for_json + __all__ = ["dumps", "loads", "json"] json: types.ModuleType @@ -12,8 +14,15 @@ import orjson as _orjson def dumps(obj: Any) -> str: - """Serialize *obj* to a ``str`` using orjson (bytes -> str).""" - return _orjson.dumps(obj).decode() + """Serialize *obj* to a ``str`` using orjson (bytes -> str). + + Walks ``obj`` first to replace lone UTF-16 surrogates with U+FFFD. + orjson (and stdlib ``json``) raise ``UnicodeEncodeError`` on lone + surrogates; LLM-derived text frequently carries them on the path + from Triton / CUDA / CUTLASS / Mojo error formatters. The walk is + a no-op for clean structures. + """ + return _orjson.dumps(deep_sanitize_for_json(obj)).decode() def loads(data: str | bytes | bytearray) -> Any: """Deserialize *data* using orjson.""" @@ -24,8 +33,11 @@ def loads(data: str | bytes | bytearray) -> Any: except ModuleNotFoundError: # pragma: no cover -- dev/test envs without orjson def dumps(obj: Any) -> str: # type: ignore[misc] # redefinition for fallback branch - """Serialize *obj* to a ``str`` using the stdlib *json* module.""" - return _stdlib_json.dumps(obj) + """Serialize *obj* to a ``str`` using the stdlib *json* module. + + See the orjson branch for the surrogate-scrub rationale. + """ + return _stdlib_json.dumps(deep_sanitize_for_json(obj)) def loads(data: str | bytes | bytearray) -> Any: # type: ignore[misc] # redefinition for fallback branch """Deserialize *data* using the stdlib *json* module.""" diff --git a/gigaevo/utils/text_sanitize.py b/gigaevo/utils/text_sanitize.py new file mode 100644 index 00000000..c67d107b --- /dev/null +++ b/gigaevo/utils/text_sanitize.py @@ -0,0 +1,198 @@ +"""Pure text sanitization for LLM-derived strings. + +Three increasingly strict modes that compose freely. Most callers want +``sanitize_for_log``; the other two are minimal helpers for cases where +the destination only rejects a narrower set of bytes. + +All functions are pure ``str -> str``, idempotent, and preserve printable +Unicode (CJK, Greek letters, math symbols, emoji, directional arrows). +Lone UTF-16 surrogates always collapse to U+FFFD. + +The threat surface this module guards covers compiler error text from +heterogeneous LLM targets (Python tracebacks, Triton MLIR diagnostics, +nvcc / ptxas / CUTLASS template explosions with embedded ANSI from gcc / +clang colorization, Mojo error formatter output with Unicode arrows, +Pallas / JAX jaxpr traces with ASCII art, CuTe layout errors with +``Layout,Stride<...>>`` template syntax). Each toolchain emits +its own conventions; the sanitizers reject what would break log files, +JSON encoders, or Postgres TEXT columns, and preserve everything else. +""" + +from __future__ import annotations + +import re +from typing import Final + +# ============================================================================ +# Compiled patterns (private) +# ============================================================================ + + +# ANSI escape sequences. Covers CSI (the common ``\x1b[...m`` colorization +# from gcc / clang / nvcc), OSC (xterm title-setting and similar), DCS / SOS +# / PM / APC string sequences, and single-character Fe escapes. Reference: +# ECMA-48 + xterm Ctrl Sequences. +_ANSI_RE: Final[re.Pattern[str]] = re.compile( + r"\x1b\[[0-?]*[ -/]*[@-~]" # CSI: ESC [ params intermediates final + r"|\x1b\][^\x07\x1b]*(?:\x07|\x1b\\)" # OSC: ESC ] ... BEL or ST + r"|\x1b[PX^_][^\x1b]*\x1b\\" # DCS / SOS / PM / APC: ESC X ... ST + r"|\x1b[@-Z\\-_]" # Fe single-char: ESC +) + +# C0 controls (0x00-0x1F + 0x7F DEL) except TAB (0x09) and LF (0x0A). CR +# (0x0D) is included so log-line forgery via "\r\nFAKE LINE" is defused +# while real multi-line tracebacks (which use LF only) pass through. +_C0_RE: Final[re.Pattern[str]] = re.compile(r"[\x00-\x08\x0b-\x1f\x7f]") + +# C1 controls (0x80-0x9F). Rare in modern output but legal in Python str and +# capable of confusing terminal parsers. +_C1_RE: Final[re.Pattern[str]] = re.compile(r"[\x80-\x9f]") + +# Unicode BIDI overrides and isolates (U+202A..U+202E plus U+2066..U+2069). +# Invisible characters used to spoof text directionality; have no place in +# machine-readable log records. +_BIDI_RE: Final[re.Pattern[str]] = re.compile(r"[‪-‮⁦-⁩]") + +# UTF-16 surrogate code points (U+D800-U+DFFF). Python ``str`` is sequence-of- +# code-points, not UTF-16, so a "high then low" pair is NOT decoded as one +# astral character; both halves remain independent code points that the +# UTF-8 encoder refuses (``surrogates not allowed``). Match every surrogate +# unconditionally and replace with U+FFFD. Real astral characters in Python +# ``str`` are single code points above U+FFFF, never surrogate pairs. +_LONE_SURROGATE_RE: Final[re.Pattern[str]] = re.compile(r"[\ud800-\udfff]") + +# Identifier strip pattern. Anything outside the conservative charset is +# removed. Slash and colon are permitted because model names sometimes +# carry path-like or provider-prefix forms (``openai:gpt-4o-mini``, +# ``/local/path/to/model.gguf``); ``@`` is permitted because some configs +# route via ``model@host:port`` against local vllm / sglang / tgi servers. +_IDENTIFIER_STRIP_RE: Final[re.Pattern[str]] = re.compile(r"[^A-Za-z0-9._:/+@\-]") + + +def _escape_control(match: re.Match[str]) -> str: + return f"\\x{ord(match.group()):02x}" + + +# ============================================================================ +# Public API +# ============================================================================ + + +def sanitize_for_log(text: str) -> str: + """Make ``text`` safe for log sinks, JSON encoders, and Postgres TEXT. + + Strips ANSI escape sequences (no terminal control on log readers), strips + BIDI overrides (no directional spoofing in log records), escapes C0 and + C1 control characters except TAB and LF (visible but inert; for example + ``\\x07`` BEL becomes the four-character literal ``\\x07``, and + ``\\x0d`` CR is escaped so single-line log records cannot be forged via + carriage-return overwriting), and replaces lone UTF-16 surrogates with + U+FFFD. Result is valid UTF-8, safe for ``json.dumps`` / + ``model_dump_json``, and safe for ``asyncpg`` TEXT columns. + + LF (``\\n``) is preserved so legitimate multi-line tracebacks survive. + A consequence is that a hostile string containing ``"\\n[FORGED]"`` + will produce a second line in plain-text log sinks. Reasonable log + parsers should recognize that authentic entries begin with a timestamp + prefix and treat untimestamped continuations accordingly. Callers that + cannot accept this residual risk (for example, plain-text sinks read + by line-naive consumers) should pipeline through a stricter step that + escapes ``\\n`` as well. + """ + text = _ANSI_RE.sub("", text) + text = _BIDI_RE.sub("", text) + text = _C0_RE.sub(_escape_control, text) + text = _C1_RE.sub(_escape_control, text) + text = _LONE_SURROGATE_RE.sub("�", text) + return text + + +def sanitize_for_json(text: str) -> str: + """Minimum-viable fix for JSON encoders. Replaces lone UTF-16 surrogates + with U+FFFD so ``json.dumps`` and ``pydantic.model_dump_json`` succeed. + Preserves every other byte verbatim — control characters, ANSI escape + sequences, BIDI overrides all pass through. Compose with + ``sanitize_for_log`` when the destination also forbids those. + """ + return _LONE_SURROGATE_RE.sub("�", text) + + +def sanitize_for_dbtext(text: str) -> str: + """Make ``text`` safe for Postgres TEXT columns through asyncpg. + + Handles two failure modes asyncpg surfaces on LLM-derived text. First, + the driver rejects literal NUL bytes outright (``A string literal cannot + contain NUL (0x00) characters``) — NUL is replaced with U+FFFD. Second, + asyncpg UTF-8 encodes string values before sending them on the wire, + and Python's UTF-8 encoder refuses lone UTF-16 surrogates; those are + also replaced with U+FFFD. ANSI escape sequences and other control + bytes pass through verbatim because Postgres TEXT accepts them; compose + with ``sanitize_for_log`` when the column is also displayed to humans + or consumed by log readers. + """ + text = text.replace("\x00", "�") + text = _LONE_SURROGATE_RE.sub("�", text) + return text + + +def deep_sanitize_for_json(value: object) -> object: + """Walk a JSON-shaped structure and apply ``sanitize_for_json`` to every + string leaf. Lists / tuples / dicts are rebuilt recursively; primitive + non-string values (``int`` / ``float`` / ``bool`` / ``None``) pass + through unchanged. + + Use at the boundary where an arbitrary container of LLM-derived text + will be handed to ``json.dumps`` / ``orjson.dumps`` / pydantic + ``model_dump_json``. A single lone surrogate buried anywhere in the + structure would otherwise raise ``UnicodeEncodeError`` from inside the + serializer and abort the write. This is the cheap belt that stops + those aborts without rewriting every producer. + + Returns ``object`` rather than a tight type because the input shape is + arbitrary JSON. Callers narrow at the call site via ``cast`` or by + knowing the input shape; the function preserves the outermost + container type (``dict`` stays ``dict``, ``list`` stays ``list``, + ``tuple`` stays ``tuple``). + """ + if isinstance(value, str): + return sanitize_for_json(value) + if isinstance(value, dict): + return { + deep_sanitize_for_json(k): deep_sanitize_for_json(v) + for k, v in value.items() + } + if isinstance(value, list): + return [deep_sanitize_for_json(v) for v in value] + if isinstance(value, tuple): + return tuple(deep_sanitize_for_json(v) for v in value) + return value + + +def clean_identifier(text: str, *, max_len: int | None = None) -> str: + """Strip every character outside the conservative identifier charset + ``[A-Za-z0-9._:/+@-]`` and optionally truncate to ``max_len`` characters. + + Returns an empty string if nothing in ``text`` survives. Callers decide + whether to reject, fall back to a default, or warn. Intended for places + where a string must be a stable, displayable, file-system-safe + identifier (model names, cache keys, log tags). + + Raises ``ValueError`` if ``max_len`` is negative; the slice ``[:negative]`` + would silently drop trailing characters without warning, which is almost + never what the caller meant. + """ + if max_len is not None and max_len < 0: + raise ValueError(f"max_len must be non-negative, got {max_len}") + cleaned = _IDENTIFIER_STRIP_RE.sub("", text) + if max_len is not None and len(cleaned) > max_len: + cleaned = cleaned[:max_len] + return cleaned + + +__all__ = ( + "clean_identifier", + "deep_sanitize_for_json", + "sanitize_for_dbtext", + "sanitize_for_json", + "sanitize_for_log", +) diff --git a/gigaevo/utils/trackers/backends/redis.py b/gigaevo/utils/trackers/backends/redis.py index f5ed5bd9..04056558 100644 --- a/gigaevo/utils/trackers/backends/redis.py +++ b/gigaevo/utils/trackers/backends/redis.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib import json import threading import time @@ -8,6 +9,12 @@ from loguru import logger import redis +from gigaevo.utils.text_sanitize import ( + clean_identifier, + deep_sanitize_for_json, + sanitize_for_dbtext, + sanitize_for_log, +) from gigaevo.utils.trackers.configs import RedisMetricsConfig from gigaevo.utils.trackers.core import LoggerBackend @@ -31,10 +38,21 @@ def __init__(self, cfg: RedisMetricsConfig): def _k_latest(self) -> str: return f"{self.cfg.key_prefix}:latest" + def _field_tag(self, tag: str) -> str: + """Return the stable Redis hash/list field for a metric tag.""" + + safe_tag = clean_identifier(tag, max_len=128) + if safe_tag: + return safe_tag + + digest = hashlib.sha256(sanitize_for_log(tag).encode()).hexdigest()[:12] + return f"metric_{digest}" + def _k_history(self, tag: str) -> str: - # Sanitize tag for Redis key - safe_tag = tag.replace("/", ":").replace(" ", "_") - return f"{self.cfg.key_prefix}:history:{safe_tag}" + # Sanitize tag for Redis key: strict identifier charset only. + # Defends against ANSI / BIDI / control bytes in LLM-derived tags + # that the previous ad-hoc replace() missed. + return f"{self.cfg.key_prefix}:history:{self._field_tag(tag)}" def _k_meta(self) -> str: return f"{self.cfg.key_prefix}:meta" @@ -82,10 +100,13 @@ def write_hist(self, tag: str, values: Any, step: int, wall_time: float) -> None self._buffer.append(entry) def write_text(self, tag: str, text: str, step: int, wall_time: float) -> None: + # Sanitize the text payload at the boundary so LLM-derived strings + # with NUL bytes or lone surrogates do not poison the latest-hash + # value or the JSON history entry. entry = { "kind": "text", "tag": tag, - "value": text, + "value": sanitize_for_dbtext(text), "step": step, "wall_time": wall_time, } @@ -105,7 +126,12 @@ def flush(self) -> None: pipe = self._client.pipeline(transaction=False) for entry in buf: - tag = entry["tag"] + # Sanitize the tag at the Redis boundary. The wire encoder + # rejects lone UTF-16 surrogates and is also unhappy with + # NUL inside a field name on some clients; clean_identifier + # gives a stable, displayable field name regardless of what + # an LLM-derived tag carried. + tag = self._field_tag(str(entry["tag"])) step = entry["step"] wall_time = entry["wall_time"] kind = entry["kind"] @@ -119,7 +145,10 @@ def flush(self) -> None: # Store history if enabled if self.cfg.store_history: - history_entry = json.dumps( + # deep_sanitize_for_json defuses lone surrogates buried + # in histogram value lists or text strings before they + # reach json.dumps, which would otherwise raise. + payload = deep_sanitize_for_json( { "s": step, "t": wall_time, @@ -127,6 +156,7 @@ def flush(self) -> None: "k": kind, } ) + history_entry = json.dumps(payload) history_key = self._k_history(tag) pipe.rpush(history_key, history_entry) # Trim to max size (FIFO) @@ -137,7 +167,9 @@ def flush(self) -> None: pipe.execute() except Exception as e: - logger.warning("[RedisMetricsBackend] Flush failed: {}", e) + logger.warning( + "[RedisMetricsBackend] Flush failed: {}", sanitize_for_log(str(e)) + ) def clear_series(self, tag: str) -> None: """Delete the history list for *tag* so it can be rewritten.""" @@ -148,7 +180,9 @@ def clear_series(self, tag: str) -> None: self._client.delete(history_key) except Exception as e: logger.warning( - "[RedisMetricsBackend] clear_series failed for {}: {}", tag, e + "[RedisMetricsBackend] clear_series failed for {}: {}", + sanitize_for_log(tag), + sanitize_for_log(str(e)), ) # --------------------- Query Methods --------------------- @@ -160,15 +194,18 @@ def get_latest(self, tag: str | None = None) -> dict[str, Any]: return {} try: if tag: - val = client.hget(self._k_latest(), tag) + field = self._field_tag(tag) + val = client.hget(self._k_latest(), field) if val is None: return {} - return {tag: float(str(val))} + return {field: self._parse_value(str(val))} else: data = client.hgetall(self._k_latest()) return {k: self._parse_value(str(v)) for k, v in data.items()} except Exception as e: - logger.warning("[RedisMetricsBackend] get_latest failed: {}", e) + logger.warning( + "[RedisMetricsBackend] get_latest failed: {}", sanitize_for_log(str(e)) + ) return {} def get_history( @@ -182,7 +219,9 @@ def get_history( entries = client.lrange(self._k_history(tag), start, end) return [json.loads(str(e)) for e in entries] except Exception as e: - logger.warning("[RedisMetricsBackend] get_history failed: {}", e) + logger.warning( + "[RedisMetricsBackend] get_history failed: {}", sanitize_for_log(str(e)) + ) return [] def list_metrics(self) -> list[str]: @@ -193,7 +232,9 @@ def list_metrics(self) -> list[str]: try: return [str(k) for k in client.hkeys(self._k_latest())] except Exception as e: - logger.warning("[RedisMetricsBackend] list_metrics failed: {}", e) + logger.warning( + "[RedisMetricsBackend] list_metrics failed: {}", sanitize_for_log(str(e)) + ) return [] @staticmethod diff --git a/tests/dag/test_sanitize_integration.py b/tests/dag/test_sanitize_integration.py new file mode 100644 index 00000000..c1415701 --- /dev/null +++ b/tests/dag/test_sanitize_integration.py @@ -0,0 +1,117 @@ +"""Hostile-input integration tests for the sanitizer wiring in +``gigaevo/programs/dag/dag.py`` and ``gigaevo/runner/dag_runner.py``. + +We exercise the exact log lines wrapped with ``sanitize_for_log`` by +constructing exceptions whose ``__str__`` carries ANSI / NUL / BIDI and +asserting the captured loguru output contains no raw control bytes. +""" + +from __future__ import annotations + +import re + +from loguru import logger + +from gigaevo.programs.core_types import StageError + +# Patterns the sanitizer must strip. +_ANSI_RE = re.compile(r"\x1b\[") +_BIDI_RE = re.compile(r"[‪-‮⁦-⁩]") +_C0_RAW_RE = re.compile(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]") + + +def _capture(): + """Attach a list-sink to loguru. Returns (list, sink_id).""" + captured: list[str] = [] + sink_id = logger.add(captured.append, format="{message}", level="TRACE") + return captured, sink_id + + +def _assert_clean(captured: list[str]) -> None: + joined = "\n".join(captured) + assert not _ANSI_RE.search(joined), f"ANSI escape leaked: {joined!r}" + assert not _BIDI_RE.search(joined), f"BIDI override leaked: {joined!r}" + assert not _C0_RAW_RE.search(joined), f"Raw C0 control leaked: {joined!r}" + + +class TestDagRunnerErrorLogSanitization: + """The ``DagRunner`` interpolates the ``__str__`` of caught + exceptions into ``logger.error`` lines. After the sanitizer wiring, + a hostile exception message must arrive at the sink in escaped form. + """ + + def test_sanitize_for_log_module_is_wired(self): + """Smoke test: the dag_runner module imports the sanitizer.""" + from gigaevo.runner import dag_runner + + assert hasattr(dag_runner, "sanitize_for_log") + + def test_hostile_exception_string_is_escaped(self): + """Simulate the wrapped log call directly. This is the exact + idiom every error path in dag_runner.py now uses.""" + from gigaevo.runner.dag_runner import sanitize_for_log + + captured, sink_id = _capture() + try: + exc = RuntimeError("\x1b[31mCUDA OOM\x1b[0m\nat \x00addr ‮swap") + logger.error( + "[DagScheduler] program {} failed: {}", + "deadbeef", + sanitize_for_log(str(exc)), + ) + _assert_clean(captured) + # The visible text "CUDA OOM" must survive — sanitizer + # preserves printable content. + assert "CUDA OOM" in "\n".join(captured) + finally: + logger.remove(sink_id) + + def test_traceback_format_exc_is_escaped(self): + """``traceback.format_exc()`` from a thrown hostile exception + is also wrapped by the new logging path.""" + from gigaevo.runner.dag_runner import sanitize_for_log + + captured, sink_id = _capture() + try: + try: + raise ValueError("\x1b[31mbad\x1b[0m\x00") + except ValueError: + import traceback as tb + + logger.error( + "[DagScheduler] Traceback:\n{}", + sanitize_for_log(tb.format_exc()), + ) + _assert_clean(captured) + finally: + logger.remove(sink_id) + + +class TestDagStageErrorPretty: + """``DAG._process_finished_task`` now wraps ``error.pretty()``. + Verify the StageError model + the wrap together yield clean text + on hostile input.""" + + def test_stage_error_pretty_is_safe(self): + # StageError validators scrub ``type`` / ``message`` / + # ``traceback`` at construction. ``stage`` is supplied + # internally (canonical class names — no LLM influence), so it + # is not scrubbed. The dag-side ``sanitize_for_log`` wrap is + # the defensive belt for the whole ``pretty()`` string. + from gigaevo.utils.text_sanitize import sanitize_for_log + + err = StageError( + type="\x1b[31mError\x1b[0m", + message="\x00invalid\x1b[0m", + stage="HostileStage", + traceback="line1\n\x1b[31mline2\x1b[0m\n‮line3", + ) + rendered = sanitize_for_log(err.pretty(include_traceback=True)) + assert "\x1b[" not in rendered + assert "\x00" not in rendered + assert "‮" not in rendered + + def test_dag_module_imports_sanitizer(self): + from gigaevo.programs.dag import dag as dag_mod + + assert hasattr(dag_mod, "sanitize_for_log") diff --git a/tests/llm/test_insights_scoring_agents.py b/tests/llm/test_insights_scoring_agents.py index 38600d4d..de0b19e1 100644 --- a/tests/llm/test_insights_scoring_agents.py +++ b/tests/llm/test_insights_scoring_agents.py @@ -359,3 +359,69 @@ def test_non_program_score_raises(self): } with pytest.raises(ValueError, match="Expected ProgramScore"): agent.parse_response(state) + + +# --------------------------------------------------------------------------- +# ProgramInsight field validators — schema-layer sanitization +# --------------------------------------------------------------------------- + + +class TestProgramInsightFieldSanitization: + """ProgramInsight str fields receive LLM output verbatim and must scrub + ANSI / BIDI overrides / lone surrogates / control bytes before the value + flows into reports, JSON dumps, Postgres TEXT columns, or re-injection + back into LLM prompts as part of a multi-agent loop.""" + + def test_clean_input_passes_through(self): + insight = ProgramInsight( + type="performance", + insight="The inner loop dominates wall time on N>1000.", + tag="hotspot", + severity="medium", + ) + assert insight.type == "performance" + assert insight.insight == "The inner loop dominates wall time on N>1000." + assert insight.tag == "hotspot" + assert insight.severity == "medium" + + def test_ansi_escape_in_insight_field_stripped(self): + insight = ProgramInsight( + type="perf", + insight="\x1b[31mred\x1b[0m: cache miss on hot path", + tag="cache", + severity="high", + ) + assert "\x1b" not in insight.insight + assert "red: cache miss on hot path" in insight.insight + + def test_lone_surrogate_replaced_in_type(self): + insight = ProgramInsight( + type="perf\ud83d", + insight="ok", + tag="t", + severity="low", + ) + assert "\ud83d" not in insight.type + # Result must be UTF-8 encodable and JSON-serializable. + insight.type.encode("utf-8") + insight.model_dump_json() + + def test_cr_in_severity_does_not_forge_log_line(self): + insight = ProgramInsight( + type="t", + insight="ok", + tag="t", + severity="high\r\n[FAKE LINE]", + ) + assert "\r" not in insight.severity + + def test_unicode_arrows_and_math_symbols_preserved(self): + # Mojo and Pallas error formatters carry these legitimately; the + # validator must not strip printable Unicode the operator wants to see. + insight = ProgramInsight( + type="formal", + insight="∀x ∈ ℝ → ℂ holds after the rewrite", + tag="t", + severity="low", + ) + assert insight.insight == "∀x ∈ ℝ → ℂ holds after the rewrite" diff --git a/tests/llm/test_lineage_agent.py b/tests/llm/test_lineage_agent.py index e811930f..4cae487b 100644 --- a/tests/llm/test_lineage_agent.py +++ b/tests/llm/test_lineage_agent.py @@ -616,3 +616,58 @@ async def _capture(state): assert captured_state["child"] is child assert captured_state["metadata"]["parent_id"] == parent.id assert captured_state["metadata"]["child_id"] == child.id + + +# --------------------------------------------------------------------------- +# TransitionInsight field validators — schema-layer sanitization +# --------------------------------------------------------------------------- + + +class TestTransitionInsightFieldSanitization: + """TransitionInsight str fields receive LLM output verbatim and must scrub + ANSI / BIDI overrides / lone surrogates / control bytes before the value + flows into reports, JSON dumps, Postgres TEXT columns, or re-injection + back into LLM prompts.""" + + def test_clean_input_passes_through(self): + ti = TransitionInsight( + strategy="imitation", + description="Parent's loop tiling improved cache locality 1.8x.", + ) + assert ti.strategy == "imitation" + assert ti.description == "Parent's loop tiling improved cache locality 1.8x." + + def test_ansi_escape_in_description_stripped(self): + ti = TransitionInsight( + strategy="avoidance", + description="\x1b[31mreverted\x1b[0m fused-mul-add — slower on this GPU", + ) + assert "\x1b" not in ti.description + assert "reverted" in ti.description + + def test_lone_surrogate_replaced_in_strategy(self): + ti = TransitionInsight( + strategy="exploration\ud83d", + description="ok", + ) + assert "\ud83d" not in ti.strategy + ti.strategy.encode("utf-8") + ti.model_dump_json() + + def test_cr_in_description_does_not_forge_log_line(self): + ti = TransitionInsight( + strategy="generalization", + description="moved threshold higher\r\n[FAKE LINE]", + ) + assert "\r" not in ti.description + + def test_unicode_arrows_preserved(self): + # Mojo / Pallas error formatters carry U+2192; CUTLASS-style template + # syntax in descriptions must survive verbatim. + ti = TransitionInsight( + strategy="imitation", + description="Shape<_32,_128> → Shape<_64,_64> halves register pressure", + ) + assert ti.description == ( + "Shape<_32,_128> → Shape<_64,_64> halves register pressure" + ) diff --git a/tests/llm/test_sanitize_wiring.py b/tests/llm/test_sanitize_wiring.py new file mode 100644 index 00000000..1c000934 --- /dev/null +++ b/tests/llm/test_sanitize_wiring.py @@ -0,0 +1,464 @@ +"""Integration tests that prove the sanitizer is wired into LLM call sites. + +Each test drives a hostile string (ANSI escape, NUL, lone surrogate, BIDI +override, CR carriage-return) through the production call site and asserts +the destination — a real loguru sink, a JSON dump, or pydantic field state — +never sees the raw hostile bytes. + +The tests intentionally use the production logging path (``loguru.logger`` +with a captured ``StringIO`` sink) so that a regression that silently drops +a ``sanitize_for_log`` wrap surfaces as a hostile byte reappearing in the +captured output. They are organized one class per modified file under +``gigaevo/llm/``. +""" + +from __future__ import annotations + +import io +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +from loguru import logger +import pytest + +from gigaevo.llm.agents.mutation import ( + MutationAgent, + MutationChange, + MutationState, + MutationStructuredOutput, +) +from gigaevo.llm.bandit import BanditModelRouter, MutationOutcome +from gigaevo.llm.models import MultiModelRouter, _redact_url +from gigaevo.llm.token_tracking import TokenTracker +from gigaevo.programs.program import Program +from tests.conftest import NullWriter + + +# --------------------------------------------------------------------------- +# Shared hostile-input fixtures (kept consistent with tests/utils/...) +# --------------------------------------------------------------------------- + +LONE_HIGH = "\ud83d" +HOSTILE = ( + "\x1b[31merr\x1b[0m" # ANSI red + "\x00NUL" # NUL + "\rCR" # CR forgery + "\x07BEL" # bell + f"{LONE_HIGH}LS" # lone surrogate + "‮RLO" # RLO BIDI override +) + + +@pytest.fixture +def loguru_sink(): + """Add a string-buffer loguru sink, yield it, tear down.""" + buf = io.StringIO() + sink_id = logger.add(buf, format="{message}", level="DEBUG") + yield buf + logger.remove(sink_id) + + +def _assert_no_raw_hostile(captured: str) -> None: + assert "\x1b" not in captured, "raw ANSI ESC survived" + assert "\x00" not in captured, "raw NUL survived" + assert "\x07" not in captured, "raw BEL survived" + assert LONE_HIGH not in captured, "lone surrogate survived" + assert "‮" not in captured, "BIDI RLO survived" + # Captured string must encode cleanly as UTF-8 (loguru already wrote it). + captured.encode("utf-8") + + +def _mock_model(name: str) -> MagicMock: + m = MagicMock() + m.model_name = name + m.with_structured_output = MagicMock(return_value=MagicMock()) + return m + + +# --------------------------------------------------------------------------- +# gigaevo/llm/models.py — MultiModelRouter init + _verify_models +# --------------------------------------------------------------------------- + + +class TestModelRouterLogSanitization: + """Init banner and _verify_models warnings must never emit hostile bytes.""" + + def test_init_log_with_hostile_model_name( + self, loguru_sink, monkeypatch + ) -> None: + # Hostile bytes in model_name should be stripped by _safe_model_name + # before reaching the init INFO line. Patch _verify_models out — this + # test is about the init banner, not the server probe; the real probe + # would otherwise spend ~10s timing out against the fake host. + monkeypatch.setattr( + MultiModelRouter, "_verify_models", lambda self: None + ) + models = [_mock_model(f"gpt-4{HOSTILE}"), _mock_model("gpt-3.5-turbo")] + # base_url must be present so the second loop also fires; we use one + # that contains userinfo, exercising _redact_url alongside sanitizing. + models[0].base_url = "http://user:pwd@host:8000/v1" + models[1].base_url = "http://host:8000/v1" + MultiModelRouter(models, [0.5, 0.5], writer=NullWriter(), name="san") + captured = loguru_sink.getvalue() + _assert_no_raw_hostile(captured) + # Cleaned form survives — sanity-check the prefix is recognizable. + assert "[MultiModelRouter:san]" in captured + # Userinfo from base_url must be redacted in the log. + assert "user:pwd" not in captured + assert "pwd@" not in captured + + def test_verify_models_failure_log_sanitized( + self, loguru_sink, monkeypatch + ) -> None: + """When the server probe raises, the exception message is sanitized.""" + import urllib.request + + def boom(*_a, **_kw): + raise OSError(f"connect failed: {HOSTILE}") + + monkeypatch.setattr(urllib.request, "urlopen", boom) + models = [_mock_model("gpt-4")] + models[0].base_url = "http://host:8000/v1" + MultiModelRouter(models, [1.0], writer=NullWriter(), name="probe") + captured = loguru_sink.getvalue() + assert "Cannot verify models" in captured + _assert_no_raw_hostile(captured) + + def test_verify_models_not_found_log_sanitized( + self, loguru_sink, monkeypatch + ) -> None: + """Server-returned model ids with hostile bytes are sanitized in WARN.""" + import urllib.request + + class FakeResp: + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def read(self): + # The server claims to host a different model than the + # configured one, with hostile bytes in its id. + return json.dumps( + {"data": [{"id": f"other-model{HOSTILE}"}]} + ).encode("utf-8") + + def fake_urlopen(*_a, **_kw): + return FakeResp() + + monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen) + models = [_mock_model("gpt-4")] + models[0].base_url = "http://host:8000/v1" + MultiModelRouter(models, [1.0], writer=NullWriter(), name="probe2") + captured = loguru_sink.getvalue() + assert "NOT FOUND" in captured + _assert_no_raw_hostile(captured) + + +class TestRedactUrl: + """Strip userinfo, keep everything else.""" + + def test_userinfo_stripped(self) -> None: + assert _redact_url("http://u:p@h:8000/x") == "http://h:8000/x" + + def test_no_userinfo_preserved(self) -> None: + assert _redact_url("http://h:8000/v1") == "http://h:8000/v1" + + def test_parse_failure_returns_input(self) -> None: + # http://[ is an unparseable URL on stricter parsers — at minimum + # the helper must not raise and must yield a str. + result = _redact_url("not a url at all") + assert isinstance(result, str) + + +# --------------------------------------------------------------------------- +# gigaevo/llm/agents/mutation.py — structured-output validators + log calls +# --------------------------------------------------------------------------- + + +class TestMutationStructuredOutputValidators: + """Field validators must scrub LLM-supplied text on construction.""" + + def test_hostile_archetype_scrubbed_at_validation(self) -> None: + out = MutationStructuredOutput( + archetype=f"clever{HOSTILE}archetype", + justification=f"because{HOSTILE}", + insights_used=[f"insight{HOSTILE}1", "clean"], + changes=[ + MutationChange( + description=f"swap loop{HOSTILE}", + explanation=f"why{HOSTILE}", + ) + ], + code=f"def f():\n return 1{HOSTILE}", + ) + # No raw hostile bytes survive in any string field. + assert "\x1b" not in out.archetype + assert "\x00" not in out.archetype + assert LONE_HIGH not in out.archetype + assert "\x00" not in out.justification + assert all("\x00" not in s for s in out.insights_used) + assert "\x1b" not in out.changes[0].description + assert "\x07" not in out.changes[0].explanation + # The code field is also scrubbed (ANSI/BIDI/C0-non-LF have no + # legitimate place in Python source). LF must still survive. + assert "\x00" not in out.code + assert "\n" in out.code # legitimate newline preserved + + def test_model_dump_json_succeeds_after_validation(self) -> None: + """Lone surrogate inside any field would otherwise abort orjson; the + validator pre-scrubs so JSON serialization is total.""" + out = MutationStructuredOutput( + archetype=f"a{LONE_HIGH}", + justification="j", + insights_used=[f"i{LONE_HIGH}"], + changes=[], + code=f"x = 1{LONE_HIGH}", + ) + blob = out.model_dump_json() + # Round-trip parse must succeed — no encoder failure. + json.loads(blob) + + +class TestMutationAgentLogSanitization: + """Direct logger calls inside MutationAgent must scrub LLM-derived text.""" + + def _make_agent(self) -> MutationAgent: + mock_llm = MagicMock() + mock_llm.with_structured_output = MagicMock(return_value=MagicMock()) + return MutationAgent( + llm=mock_llm, + system_prompt="sys", + user_prompt_template="Mutate {count}:\n{parent_blocks}", + mutation_mode="rewrite", + ) + + async def test_acall_llm_failure_log_sanitized(self, loguru_sink) -> None: + agent = self._make_agent() + agent.structured_llm = MagicMock() + agent.structured_llm.ainvoke = AsyncMock( + side_effect=RuntimeError(f"oops {HOSTILE}") + ) + state: MutationState = { + "input": [], + "mutation_mode": "rewrite", + "messages": [], + "llm_response": None, + "final_code": "", + "mutation_label": "", + } + await agent.acall_llm(state) + captured = loguru_sink.getvalue() + assert "Structured LLM call failed" in captured + _assert_no_raw_hostile(captured) + # State["error"] must also be scrubbed so callers can stash it. + assert "\x00" not in state.get("error", "") + + def test_parse_response_no_output_log_sanitized(self, loguru_sink) -> None: + agent = self._make_agent() + state: MutationState = { + "input": [], + "mutation_mode": "rewrite", + "messages": [], + "llm_response": None, + "final_code": "", + "mutation_label": "", + "error": f"upstream said: {HOSTILE}", + } + agent.parse_response(state) + captured = loguru_sink.getvalue() + assert "No structured output" in captured + _assert_no_raw_hostile(captured) + + def test_parse_response_failure_stores_sanitized_error( + self, loguru_sink, monkeypatch + ) -> None: + agent = self._make_agent() + structured_output = MutationStructuredOutput( + archetype="Rewrite", + justification="test", + insights_used=[], + changes=[], + code="def run_code():\n return 1\n", + ) + + def boom(_code: str) -> str: + raise RuntimeError(f"parse failed: {HOSTILE}") + + monkeypatch.setattr(agent, "_extract_code_block", boom) + state: MutationState = { + "input": [], + "mutation_mode": "rewrite", + "messages": [], + "llm_response": None, + "structured_output": structured_output, + "final_code": "", + "mutation_label": "", + } + + agent.parse_response(state) + + captured = loguru_sink.getvalue() + assert "Failed to parse structured response" in captured + _assert_no_raw_hostile(captured) + parsed = state["parsed_output"] + _assert_no_raw_hostile(parsed["error"]) + assert state["error"] == parsed["error"] + + +# --------------------------------------------------------------------------- +# gigaevo/llm/token_tracking.py — track() error path +# --------------------------------------------------------------------------- + + +class TestTokenTrackerWiring: + def test_validation_error_caught_and_logged_sanitized( + self, loguru_sink + ) -> None: + """A provider returning garbage token-usage types must not raise out + of TokenTracker.track; the failure is logged with sanitized text.""" + tracker = TokenTracker(name="t", writer=NullWriter()) + + class BadResponse: + @property + def response_metadata(self): + # ``prompt_tokens`` is a string — pydantic coerces it, but if + # we use a clearly non-coercible value we exercise the + # try/except path. + return {"token_usage": {"prompt_tokens": object()}} + + # Should not raise even though TokenUsage.from_response will hit + # validation/type errors. + tracker.track(BadResponse(), model_name=f"model{HOSTILE}") + captured = loguru_sink.getvalue() + _assert_no_raw_hostile(captured) + + def test_model_name_cleaned_in_no_usage_branch(self, loguru_sink) -> None: + """When response carries no usage, the debug log must show a cleaned + model name (control chars in ``model_name`` would otherwise reach + loguru via the ``{}`` slot).""" + tracker = TokenTracker(name="t", writer=NullWriter()) + + class EmptyResponse: + response_metadata: dict = {} + + tracker.track(EmptyResponse(), model_name=f"model{HOSTILE}") + captured = loguru_sink.getvalue() + assert "No token usage" in captured + _assert_no_raw_hostile(captured) + + +# --------------------------------------------------------------------------- +# gigaevo/llm/agents/memory_selector.py — backend init + search log paths +# --------------------------------------------------------------------------- + + +class TestMemorySelectorLogSanitization: + """The memory selector logs over backend errors; those strings are + backend-supplied and must be scrubbed before reaching loguru.""" + + async def test_search_unavailable_log_sanitized(self, loguru_sink) -> None: + """When ``self.memory`` is None, ``select()`` logs the cached backend + error verbatim — that path must run the value through sanitize.""" + from gigaevo.llm.agents.memory_selector import ( + MemorySelection, + MemorySelectorAgent, + ) + + agent = MemorySelectorAgent.__new__(MemorySelectorAgent) + agent.memory = None + agent._backend_error = f"backend died: {HOSTILE}" + import asyncio as _aio + + agent._search_lock = _aio.Lock() + + result = await agent.select( + input=[], + mutation_mode="rewrite", + task_description="t", + metrics_description="m", + memory_text="", + max_cards=1, + ) + assert isinstance(result, MemorySelection) + assert result.cards == [] + captured = loguru_sink.getvalue() + assert "Memory backend unavailable" in captured + _assert_no_raw_hostile(captured) + + async def test_search_failure_log_sanitized(self, loguru_sink) -> None: + """When the underlying GAM search raises with hostile bytes in the + exception message, the WARN line must show the scrubbed form.""" + from gigaevo.llm.agents.memory_selector import ( + MemorySelection, + MemorySelectorAgent, + ) + + agent = MemorySelectorAgent.__new__(MemorySelectorAgent) + + # Memory backend that raises from research and from search. + class BadMem: + research_agent = None + + def search(self, query: str) -> str: + raise RuntimeError(f"search exploded {HOSTILE}") + + agent.memory = BadMem() + agent._backend_error = None + import asyncio as _aio + + agent._search_lock = _aio.Lock() + + result = await agent.select( + input=[], + mutation_mode="rewrite", + task_description="t", + metrics_description="m", + memory_text="", + max_cards=1, + ) + assert isinstance(result, MemorySelection) + assert result.cards == [] + captured = loguru_sink.getvalue() + assert "Red memory search failed" in captured + _assert_no_raw_hostile(captured) + + +# --------------------------------------------------------------------------- +# gigaevo/llm/bandit.py — on_mutation_outcome debug log +# --------------------------------------------------------------------------- + + +class TestBanditRouterLogSanitization: + def _make_router(self) -> BanditModelRouter: + models = [_mock_model("m1"), _mock_model("m2")] + return BanditModelRouter( + models, + [0.5, 0.5], + writer=NullWriter(), + name="bandit", + fitness_key="fitness", + higher_is_better=True, + ) + + def test_outcome_log_with_hostile_model_metadata(self, loguru_sink) -> None: + """``program.get_metadata("mutation_model")`` could carry hostile text + if any upstream stashes it raw; the bandit log must sanitize.""" + router = self._make_router() + # Patch the bandit update so its KeyError on the hostile key (a + # separate pre-existing concern queued as follow-up) does not mask + # the wiring assertion we care about here. + router._bandit.update_reward = lambda *a, **kw: None # type: ignore[method-assign] + program = Program(code="x = 1") + program.metadata["mutation_model"] = f"m1{HOSTILE}" + # No parent metrics → REJECTED_ACCEPTOR path that logs raw=0.0 line. + router.on_mutation_outcome( + program, + parents=[], + outcome=MutationOutcome.REJECTED_ACCEPTOR, + ) + captured = loguru_sink.getvalue() + assert "Reward for" in captured + _assert_no_raw_hostile(captured) diff --git a/tests/prompts/test_coevolution_pipeline.py b/tests/prompts/test_coevolution_pipeline.py index 0b7ab926..56bb4b76 100644 --- a/tests/prompts/test_coevolution_pipeline.py +++ b/tests/prompts/test_coevolution_pipeline.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import hashlib import json from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch @@ -434,6 +435,12 @@ def capture_set(key, value): class TestPromptTextToIdUserText: + def test_clean_system_only_hash_matches_previous_sha256(self): + """Clean prompt text keeps the historical sha256[:16] ID.""" + expected = hashlib.sha256("system".encode()).hexdigest()[:16] + + assert prompt_text_to_id("system") == expected + def test_same_system_different_user_different_ids(self): """M4: Two prompts with same system but different user get different IDs.""" id1 = prompt_text_to_id("system", user_text="user1") @@ -452,6 +459,13 @@ def test_user_text_changes_id(self): id_with_user = prompt_text_to_id("system", user_text="user") assert id_system_only != id_with_user + def test_lone_surrogate_prompt_text_is_sanitized_before_hashing(self): + """LLM-returned prompt text with a surrogate must not crash ID generation.""" + expected_blob = "system�\x00user�" + expected = hashlib.sha256(expected_blob.encode()).hexdigest()[:16] + + assert prompt_text_to_id("system\ud83d", user_text="user\udc00") == expected + # =================================================================== # Amendment #11: PromptInsightsStage / PromptLineageStage diff --git a/tests/prompts/test_coevolution_stages.py b/tests/prompts/test_coevolution_stages.py index 55a71dab..d27c968c 100644 --- a/tests/prompts/test_coevolution_stages.py +++ b/tests/prompts/test_coevolution_stages.py @@ -12,7 +12,11 @@ PromptExecutionStage, PromptFitnessStage, ) -from gigaevo.prompts.coevolution.stats import PromptMutationStats, PromptStatsProvider +from gigaevo.prompts.coevolution.stats import ( + PromptMutationStats, + PromptStatsProvider, + prompt_text_to_id, +) # --------------------------------------------------------------------------- # Fixtures @@ -87,6 +91,32 @@ async def test_execute_valid_program(self, simple_prompt_program: Program): ) assert len(result.prompt_id) == 16 # SHA256[:16] + @pytest.mark.asyncio + async def test_execute_sanitizes_returned_prompt_text(self): + """compute() sanitizes prompt text before returning and hashing it.""" + code = ( + "def entrypoint() -> dict:\n" + " return {\n" + " 'system': 'system\\ud83d\\x00\\x1b[31mred\\x1b[0m',\n" + " 'user': 'user\\udc00\\r\\u202e',\n" + " }\n" + ) + stage = PromptExecutionStage(timeout=30.0) + stage.attach_inputs({}) + + result = await stage.compute(Program(code=code)) + + assert result.prompt_text == "system�\\x00red" + assert result.user_text == "user�\\x0d" + assert result.prompt_id == prompt_text_to_id( + result.prompt_text, user_text=result.user_text + ) + assert "\ud83d" not in result.prompt_text + assert "\x1b" not in result.prompt_text + assert "\x00" not in result.prompt_text + assert "\udc00" not in result.user_text + assert "‮" not in result.user_text + @pytest.mark.asyncio async def test_execute_no_entrypoint(self, broken_prompt_program: Program): """compute() raises when program has no entrypoint().""" diff --git a/tests/prompts/test_fetcher.py b/tests/prompts/test_fetcher.py index a2823926..86d326b1 100644 --- a/tests/prompts/test_fetcher.py +++ b/tests/prompts/test_fetcher.py @@ -232,6 +232,35 @@ def test_execute_entrypoint_str_return(self, tmp_prompts_dir: Path): assert pack.prompt_id == prompt_text_to_id("Hello system.") + def test_execute_entrypoint_sanitizes_prompt_pack(self, tmp_prompts_dir: Path): + """_execute_entrypoint() sanitizes returned system/user prompt text.""" + fetcher = GigaEvoArchivePromptFetcher( + prompt_redis_db=6, + main_redis_prefix="prefix", + fallback_prompts_dir=tmp_prompts_dir, + ) + code = ( + "def entrypoint() -> dict:\n" + " return {\n" + " 'system': 'System\\ud83d\\x00\\x1b[31mred\\x1b[0m',\n" + " 'user': 'User\\udc00\\r\\u202e',\n" + " }\n" + ) + + from gigaevo.prompts.coevolution.stats import prompt_text_to_id + + pack = fetcher._execute_entrypoint(code) + + assert pack is not None + assert pack.system == "System�\\x00red" + assert pack.user == "User�\\x0d" + assert pack.prompt_id == prompt_text_to_id(pack.system, user_text=pack.user) + assert "\ud83d" not in pack.system + assert "\x1b" not in pack.system + assert "\x00" not in pack.system + assert "\udc00" not in pack.user + assert "‮" not in pack.user + def test_execute_entrypoint_dict_return_with_user(self, tmp_prompts_dir: Path): """_execute_entrypoint() handles dict-returning entrypoint with user key.""" fetcher = GigaEvoArchivePromptFetcher( diff --git a/tests/stages/test_sanitize_integration.py b/tests/stages/test_sanitize_integration.py new file mode 100644 index 00000000..bd4f9b31 --- /dev/null +++ b/tests/stages/test_sanitize_integration.py @@ -0,0 +1,365 @@ +"""Hostile-input integration tests for the sanitizer wiring in +``gigaevo/programs/stages/`` and ``gigaevo/programs/dag/``. + +Each test feeds ANSI escape sequences, NUL bytes, BIDI overrides, or +lone UTF-16 surrogates into one of the call sites surgically wired with +``sanitize_for_log`` / ``clean_identifier`` and asserts that the +relevant downstream surface (loguru sink contents, optuna trial keys, +SyntaxError args) is free of the offending bytes — without disturbing +the surrounding logic. +""" + +from __future__ import annotations + +from pathlib import Path +import re + +from loguru import logger +import pytest + +from gigaevo.programs.core_types import StageError +from gigaevo.programs.program import Program +from gigaevo.programs.program_state import ProgramState +from gigaevo.programs.stages.optimization.optuna.models import ( + CodeModification, + OptunaSearchSpace, + ParamSpec, +) +from gigaevo.programs.stages.optimization.optuna.stage import OptunaOptimizationStage +from gigaevo.programs.stages.optimization.utils import evaluate_single +from gigaevo.programs.stages.python_executors.execution import PythonCodeExecutor +from gigaevo.programs.stages.python_executors.wrapper import ExecRunnerError +from gigaevo.programs.stages.validation import ValidateCodeStage + +# Byte patterns we never want to see on log sinks or as optuna trial keys. +_ANSI_RE = re.compile(r"\x1b\[") +_BIDI_RE = re.compile(r"[‪-‮⁦-⁩]") +_C0_RAW_RE = re.compile(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]") + + +def _attach_sink() -> list[str]: + """Add a memory loguru sink for the duration of a single test. + + Returns the underlying list that captures every emitted log message; + each test removes the sink at the end via ``logger.remove``. + """ + messages: list[str] = [] + sink_id = logger.add(messages.append, format="{message}", level="TRACE") + # Stash the sink id on the list so the caller can remove it. + messages.append(f"__sink_id__={sink_id}") + return messages + + +def _detach_sink(messages: list[str]) -> None: + for entry in messages: + if isinstance(entry, str) and entry.startswith("__sink_id__="): + sink_id = int(entry.split("=", 1)[1]) + logger.remove(sink_id) + return + + +def _assert_sink_clean(messages: list[str]) -> None: + for line in messages: + if line.startswith("__sink_id__="): + continue + assert not _ANSI_RE.search(line), f"ANSI escape leaked into log: {line!r}" + assert not _BIDI_RE.search(line), f"BIDI override leaked into log: {line!r}" + assert not _C0_RAW_RE.search(line), f"Raw C0 control leaked into log: {line!r}" + + +# --------------------------------------------------------------------------- +# validation.py — SyntaxError text scrubbing +# --------------------------------------------------------------------------- + + +class TestValidationSyntaxErrorSanitized: + """``ValidateCodeStage`` interpolates ``e.msg`` / ``e.text`` into a + re-raised ``SyntaxError``. With a hostile compiler message that + text would propagate verbatim into every downstream consumer.""" + + async def test_syntax_error_message_strips_ansi(self): + # Inject a raw \x1b sequence into the code as a comment so the + # parser quotes it back in e.text. + code = "def foo(\x1b[31m # bad\n" + stage = ValidateCodeStage(timeout=30.0) + prog = Program(code=code, state=ProgramState.RUNNING) + with pytest.raises(SyntaxError) as ei: + await stage.compute(prog) + rendered = str(ei.value) + assert "\x1b[" not in rendered + assert "\x1b" not in rendered or "\\x1b" in rendered + + async def test_syntax_error_message_escapes_nul(self): + code = "def bar(\x00\n" + stage = ValidateCodeStage(timeout=30.0) + prog = Program(code=code, state=ProgramState.RUNNING) + with pytest.raises(SyntaxError) as ei: + await stage.compute(prog) + rendered = str(ei.value) + assert "\x00" not in rendered + + +# --------------------------------------------------------------------------- +# optuna/stage.py — ParamSpec.name as an Optuna trial key +# --------------------------------------------------------------------------- + + +class TestOptunaParamNameCleaning: + """``OptunaOptimizationStage._apply_modifications`` must scrub + ``ParamSpec.name`` before the name is handed to ``trial.suggest_*`` + where it becomes an optuna storage key.""" + + @staticmethod + def _make_stage(tmp_path: Path) -> OptunaOptimizationStage: + # Minimal validator file so __init__ succeeds. + validator = tmp_path / "validator.py" + validator.write_text( + "def validate(_):\n return {'score': 0.0}\n", + encoding="utf-8", + ) + return OptunaOptimizationStage( + llm=None, # type: ignore[arg-type] # not used here + validator_path=validator, + score_key="score", + timeout=30.0, + ) + + def test_clean_paramspec_name_and_reference_stay_unchanged(self, tmp_path: Path): + stage = self._make_stage(tmp_path) + param = ParamSpec( + name="learning_rate", + initial_value=0.1, + param_type="float", + low=0.01, + high=1.0, + reason="clean", + ) + mod = CodeModification( + start_line=1, + end_line=1, + parameterized_snippet="x = _optuna_params['learning_rate']", + ) + ss = OptunaSearchSpace( + parameters=[param], + modifications=[mod], + reasoning="clean", + ) + + code = stage._apply_modifications("x = 1\n", ss) + + assert param.name == "learning_rate" + assert code == "x = _optuna_params['learning_rate']\n" + + def test_nul_in_paramspec_name_cleaned(self, tmp_path: Path): + stage = self._make_stage(tmp_path) + param = ParamSpec( + name="\x00../etc/passwd", + initial_value=1.0, + param_type="float", + low=0.0, + high=10.0, + reason="hostile", + ) + mod = CodeModification( + start_line=1, + end_line=1, + parameterized_snippet="x = _optuna_params['\x00../etc/passwd']", + ) + ss = OptunaSearchSpace( + parameters=[param], + modifications=[mod], + reasoning="hostile", + ) + + code = stage._apply_modifications("x = 1\n", ss) + + assert "\x00" not in param.name + assert ".." in param.name or "etc" in param.name # printable survived + # And the name is now safe for use as an optuna key (no control bytes). + assert _C0_RAW_RE.search(param.name) is None + assert code == "x = _optuna_params['../etc/passwd']\n" + namespace = {"_optuna_params": {param.name: 3.0}} + exec(code, namespace) # noqa: S102 + assert namespace["x"] == 3.0 + + def test_ansi_in_paramspec_name_cleaned(self, tmp_path: Path): + stage = self._make_stage(tmp_path) + param = ParamSpec( + name="\x1b[31mred_param\x1b[0m", + initial_value=1, + param_type="int", + low=0, + high=10, + reason="hostile", + ) + ss = OptunaSearchSpace( + parameters=[param], + modifications=[ + CodeModification( + start_line=1, end_line=1, parameterized_snippet="x = 1" + ) + ], + reasoning="hostile", + ) + stage._apply_modifications("x = 1\n", ss) + assert "\x1b" not in param.name + assert "red_param" in param.name + + def test_only_bad_chars_falls_back_to_positional(self, tmp_path: Path): + stage = self._make_stage(tmp_path) + # Every character outside the identifier charset. + param = ParamSpec( + name="\x00\x1b‮", + initial_value=1.0, + param_type="float", + low=0.0, + high=1.0, + reason="hostile", + ) + ss = OptunaSearchSpace( + parameters=[param], + modifications=[ + CodeModification( + start_line=1, end_line=1, parameterized_snippet="x = 1" + ) + ], + reasoning="hostile", + ) + stage._apply_modifications("x = 1\n", ss) + assert param.name == "param_0" + + def test_sanitized_param_name_rewrites_double_quoted_reference( + self, tmp_path: Path + ): + stage = self._make_stage(tmp_path) + param = ParamSpec( + name="batch size\r", + initial_value=16, + param_type="int", + low=1, + high=64, + reason="hostile", + ) + ss = OptunaSearchSpace( + parameters=[param], + modifications=[ + CodeModification( + start_line=1, + end_line=1, + parameterized_snippet='x = _optuna_params["batch size\\r"]', + ) + ], + reasoning="hostile", + ) + + code = stage._apply_modifications("x = 16\n", ss) + + assert param.name == "batchsize" + assert code == "x = _optuna_params['batchsize']\n" + namespace = {"_optuna_params": {"batchsize": 32}} + exec(code, namespace) # noqa: S102 + assert namespace["x"] == 32 + + +# --------------------------------------------------------------------------- +# execution.py — subprocess stderr through the warning log +# --------------------------------------------------------------------------- + + +class _HostilePythonCodeExecutor(PythonCodeExecutor): + """Bypass the subprocess pool entirely; raise a precomposed + ``ExecRunnerError`` so we can drive the ``except`` branch + deterministically without spinning a real worker.""" + + def __init__(self, hostile_stderr: str, **kwargs): + super().__init__(**kwargs) + self._hostile_stderr = hostile_stderr + + async def compute(self, program): # noqa: D401 — override + # Re-run the parent implementation but force the error path. + # Easiest: call the parent's exception block directly by invoking + # the same logger pattern via the parent's compute, but inject the + # error. We mimic the structure manually for stability. + raise ExecRunnerError( + returncode=1, + stderr=self._hostile_stderr, + stdout_bytes=b"", + ) + + +class TestExecutionWarningLogSanitized: + async def test_hostile_stderr_does_not_leak_to_loguru(self): + # Build a real PythonCodeExecutor with a stub that raises + # ExecRunnerError on the inner await — exercises the exact + # logger.warning line we wrapped. + hostile = "\x1b[31mCUDA error\x1b[0m: \x00 invalid value ‮malicious" + stage = PythonCodeExecutor(timeout=30.0) + + async def _fake_runner(**_kw): + raise ExecRunnerError( + returncode=1, stderr=hostile, stdout_bytes=b"" + ) + + # Monkey-patch the bound name imported inside execution.py. + from gigaevo.programs.stages.python_executors import execution as exec_mod + + original = exec_mod.run_exec_runner + exec_mod.run_exec_runner = _fake_runner # type: ignore[assignment] + messages = _attach_sink() + try: + prog = Program(code="def run_code(): return 1", state=ProgramState.RUNNING) + result = await stage.compute(prog) + # The stage should return a FAILED result, not raise. + assert hasattr(result, "status") + _assert_sink_clean(messages) + finally: + _detach_sink(messages) + exec_mod.run_exec_runner = original # type: ignore[assignment] + + async def test_stage_error_traceback_scrubbed_by_validator(self): + # The StageError validator must convert any control bytes in + # the constructed StageError.traceback into escaped form. + hostile = "Traceback:\n\x1b[31mline\x1b[0m\nFinal\x00" + err = StageError(type="X", message="m", traceback=hostile) + assert "\x1b[" not in (err.traceback or "") + assert "\x00" not in (err.traceback or "") + + +# --------------------------------------------------------------------------- +# optimization/utils.py — evaluate_single ExecRunnerError path +# --------------------------------------------------------------------------- + + +class TestEvaluateSingleSanitization: + async def test_exec_runner_error_returned_message_is_sanitized( + self, monkeypatch + ): + hostile_stderr = "\x1b[31mcompile error\x1b[0m\nlast: \x00bad" + + async def _fake_runner(**_kw): + raise ExecRunnerError( + returncode=1, stderr=hostile_stderr, stdout_bytes=b"" + ) + + from gigaevo.programs.stages.optimization import utils as utils_mod + + monkeypatch.setattr(utils_mod, "run_exec_runner", _fake_runner) + messages = _attach_sink() + try: + scores, err = await evaluate_single( + eval_code="def _opt(): return 1", + eval_fn_name="_opt", + context=None, + score_key="score", + python_path=[], + timeout=5, + max_memory_mb=None, + log_tag="Unit", + ) + assert scores is None + assert err is not None + assert "\x1b" not in err + assert "\x00" not in err + _assert_sink_clean(messages) + finally: + _detach_sink(messages) diff --git a/tests/utils/test_text_sanitize.py b/tests/utils/test_text_sanitize.py new file mode 100644 index 00000000..56f0f9d5 --- /dev/null +++ b/tests/utils/test_text_sanitize.py @@ -0,0 +1,533 @@ +"""Tests for gigaevo/utils/text_sanitize.py. + +Coverage axes: + * sanitize_for_log: ANSI families, C0, C1, BIDI, lone surrogates, + composition, idempotence, identity on safe input. + * sanitize_for_json: minimal lone-surrogate replacement, identity on + everything else. + * sanitize_for_dbtext: NUL replacement, identity on everything else. + * clean_identifier: charset strip, max_len, empty input. + * multi-language preservation: Greek (Mojo identifiers), Unicode arrows + (Mojo / Pallas formatters), math symbols, CJK, emoji, box-drawing, + CUTLASS-style template syntax. + * composability and idempotence guarantees. +""" + +from __future__ import annotations + +import pytest + +from gigaevo.utils.text_sanitize import ( + clean_identifier, + deep_sanitize_for_json, + sanitize_for_dbtext, + sanitize_for_json, + sanitize_for_log, +) + +# --------------------------------------------------------------------------- +# sanitize_for_log — ANSI escape sequences +# --------------------------------------------------------------------------- + + +class TestSanitizeForLogAnsi: + @pytest.mark.parametrize( + "src,expected", + [ + ("\x1b[2J\x1b[H", ""), # clear screen + home + ("\x1b[31mred\x1b[0m text", "red text"), # SGR colorization + ("\x1b[1m\x1b[31merror\x1b[0m", "error"), # bold red + ("plain", "plain"), # identity + ("\x1b[?25h", ""), # private mode (cursor show) + ("\x1b]0;window title\x07after", "after"), # OSC + BEL terminator + ("\x1b]2;title\x1b\\after", "after"), # OSC + ST terminator + ("\x1b[1;31;42mtext\x1b[0m", "text"), # multi-param CSI + ], + ) + def test_csi_and_osc_stripped(self, src: str, expected: str) -> None: + assert sanitize_for_log(src) == expected + + def test_single_char_fe_escape_stripped(self) -> None: + # ESC followed by a single Fe byte (e.g. ESC M reverse-index). + assert sanitize_for_log("a\x1bMb") == "ab" + + def test_compiler_style_error_block(self) -> None: + # Mimics a gcc / clang / nvcc colorized error line. + src = "\x1b[1m\x1b[31merror:\x1b[0m\x1b[1m undefined reference\x1b[0m" + assert sanitize_for_log(src) == "error: undefined reference" + + def test_dcs_sequence_stripped(self) -> None: + # DCS (Device Control String): ESC P ... ST. Used by e.g. terminal + # sixel image transfer; rare but legal in stderr. + assert sanitize_for_log("a\x1bPpayload\x1b\\b") == "ab" + + def test_apc_sequence_stripped(self) -> None: + # APC (Application Program Command): ESC _ ... ST. + assert sanitize_for_log("a\x1b_data\x1b\\b") == "ab" + + def test_csi_with_intermediate_byte_stripped(self) -> None: + # CSI with intermediate byte: ESC [ params SP final. ``\x1b[1 q`` + # is a cursor-shape control with space as intermediate. + assert sanitize_for_log("a\x1b[1 qb") == "ab" + + def test_consecutive_ansi_sequences_all_stripped(self) -> None: + src = "\x1b[31m\x1b[1m\x1b[4mtext\x1b[0m\x1b[0m\x1b[0m" + assert sanitize_for_log(src) == "text" + + +# --------------------------------------------------------------------------- +# sanitize_for_log — C0 and C1 controls +# --------------------------------------------------------------------------- + + +class TestSanitizeForLogControlChars: + def test_nul_escaped(self) -> None: + assert sanitize_for_log("a\x00b") == "a\\x00b" + + def test_bel_escaped(self) -> None: + assert sanitize_for_log("a\x07b") == "a\\x07b" + + def test_backspace_escaped(self) -> None: + assert sanitize_for_log("a\x08b") == "a\\x08b" + + def test_cr_escaped_so_no_line_forgery(self) -> None: + # The critical case: a forged log entry attempt becomes inert. + assert sanitize_for_log("real\r\nFORGED") == "real\\x0d\nFORGED" + + def test_tab_preserved(self) -> None: + assert sanitize_for_log("a\tb") == "a\tb" + + def test_lf_preserved(self) -> None: + # Real multi-line tracebacks must survive. + assert sanitize_for_log("line1\nline2") == "line1\nline2" + + def test_del_escaped(self) -> None: + assert sanitize_for_log("a\x7fb") == "a\\x7fb" + + @pytest.mark.parametrize("byte", [0x80, 0x9F, 0x85]) + def test_c1_controls_escaped(self, byte: int) -> None: + src = f"a{chr(byte)}b" + assert sanitize_for_log(src) == f"a\\x{byte:02x}b" + + +# --------------------------------------------------------------------------- +# sanitize_for_log — BIDI overrides +# --------------------------------------------------------------------------- + + +class TestSanitizeForLogBidi: + @pytest.mark.parametrize( + "codepoint", + [ + 0x202A, # LRE + 0x202B, # RLE + 0x202C, # PDF + 0x202D, # LRO + 0x202E, # RLO + 0x2066, # LRI + 0x2067, # RLI + 0x2068, # FSI + 0x2069, # PDI + ], + ) + def test_each_bidi_override_stripped(self, codepoint: int) -> None: + src = f"a{chr(codepoint)}b" + assert sanitize_for_log(src) == "ab" + + +# --------------------------------------------------------------------------- +# sanitize_for_log — lone surrogates +# --------------------------------------------------------------------------- + + +class TestSanitizeForLogSurrogates: + def test_lone_high_surrogate_replaced(self) -> None: + assert sanitize_for_log("a\ud83dz") == "a�z" + + def test_lone_low_surrogate_replaced(self) -> None: + assert sanitize_for_log("a\udc00z") == "a�z" + + def test_multiple_lone_surrogates_replaced(self) -> None: + assert sanitize_for_log("\ud800𐀀\ud801") == "�𐀀�" + + def test_valid_surrogate_pair_preserved(self) -> None: + # A paired high+low surrogate represents a single code point above the + # BMP. Python str doesn't typically use this form, but if constructed + # by hand it must be preserved as a valid pair. + src = "a😀z" # paired -> U+1F600 emoji + assert sanitize_for_log(src) == src + + def test_real_astral_emoji_preserved(self) -> None: + # The natural single-code-point form must of course pass through. + assert sanitize_for_log("a😀z") == "a😀z" + + +# --------------------------------------------------------------------------- +# sanitize_for_log — multi-language preservation +# --------------------------------------------------------------------------- + + +class TestSanitizeForLogPreservation: + def test_greek_letters_preserved(self) -> None: + # Mojo permits Greek identifiers; classifier and logs must preserve. + assert sanitize_for_log("αβγ_kernel") == "αβγ_kernel" + + def test_unicode_arrows_preserved(self) -> None: + # Mojo / Pallas error formatters use U+2192 / U+21D2 to highlight + # source positions. + assert sanitize_for_log("a → b ⇒ c") == "a → b ⇒ c" + + def test_math_symbols_preserved(self) -> None: + assert sanitize_for_log("∀x ∈ ℝ → ℂ") == "∀x ∈ ℝ → ℂ" + + def test_cjk_preserved(self) -> None: + assert sanitize_for_log("配置错误") == "配置错误" + + def test_emoji_preserved(self) -> None: + assert sanitize_for_log("done ✅ 🎉") == "done ✅ 🎉" + + def test_box_drawing_preserved(self) -> None: + # clang / rustc carets use box-drawing for source-position pointers. + assert sanitize_for_log("│ ╰─ here") == "│ ╰─ here" + + def test_cutlass_template_syntax_preserved(self) -> None: + # CUTLASS / CuTe error messages routinely contain dense template + # syntax that must survive verbatim. + src = "Layout,Stride<_128,_1>>" + assert sanitize_for_log(src) == src + + def test_compile_error_block_with_mixed_content(self) -> None: + # Realistic: ANSI from compiler colorization wrapping a source line + # that itself contains template syntax, Greek identifier, arrow. + src = "\x1b[31merror\x1b[0m: cannot convert α → β in Layout>" + assert ( + sanitize_for_log(src) == "error: cannot convert α → β in Layout>" + ) + + +# --------------------------------------------------------------------------- +# sanitize_for_log — invariants +# --------------------------------------------------------------------------- + + +class TestSanitizeForLogInvariants: + def test_identity_on_safe_input(self) -> None: + safe = "Plain ASCII with\ttabs and\nnewlines." + assert sanitize_for_log(safe) == safe + + def test_idempotence_on_safe_input(self) -> None: + safe = "plain text" + assert sanitize_for_log(sanitize_for_log(safe)) == safe + + def test_idempotence_on_hostile_input(self) -> None: + hostile = "a\x1b[2J\x07‮b\ud83dc\x00" + once = sanitize_for_log(hostile) + twice = sanitize_for_log(once) + assert twice == once + + def test_empty_input(self) -> None: + assert sanitize_for_log("") == "" + + def test_only_controls_collapses(self) -> None: + # All-controls input becomes a string of escape forms (still safe). + result = sanitize_for_log("\x00\x01\x02\x03") + assert result == "\\x00\\x01\\x02\\x03" + + def test_output_is_utf8_encodable(self) -> None: + # After sanitization the result must round-trip through UTF-8. + hostile = "a\ud83d\x00\x1b[2J‮b" + cleaned = sanitize_for_log(hostile) + cleaned.encode("utf-8") # must not raise + + def test_output_is_json_encodable(self) -> None: + import json + + hostile = "a\ud83d\x00\x1b[2J‮b" + json.dumps(sanitize_for_log(hostile)) # must not raise + + def test_output_contains_no_escape_byte(self) -> None: + # Byte-level invariant: no raw \x1b ever survives sanitize_for_log. + # Catches regex regressions that JSON / UTF-8 checks would not. + hostile = "\x1b[2J\x1bM\x1b]title\x07\x1b_apc\x1b\\plain" + assert "\x1b" not in sanitize_for_log(hostile) + + def test_output_contains_no_lone_surrogate(self) -> None: + hostile = "a\ud800b\udc00c\ud83d" + cleaned = sanitize_for_log(hostile) + for ch in cleaned: + cp = ord(ch) + assert not (0xD800 <= cp <= 0xDFFF), ( + f"surrogate U+{cp:04X} survived" + ) + + def test_output_contains_no_raw_c0_except_tab_lf(self) -> None: + hostile = "".join(chr(c) for c in range(0x20)) + "\x7f" + cleaned = sanitize_for_log(hostile) + for ch in cleaned: + cp = ord(ch) + assert cp in (0x09, 0x0A) or cp >= 0x20, ( + f"raw C0 char U+{cp:04X} survived" + ) + # And no 0x7F either. + assert "\x7f" not in cleaned + + def test_output_contains_no_bidi_overrides(self) -> None: + hostile = "a‪b‫c‬d‭e‮f⁦g⁧h⁨i⁩" + cleaned = sanitize_for_log(hostile) + for ch in cleaned: + cp = ord(ch) + assert not (0x202A <= cp <= 0x202E), f"BIDI U+{cp:04X} survived" + assert not (0x2066 <= cp <= 0x2069), f"BIDI U+{cp:04X} survived" + + +# --------------------------------------------------------------------------- +# sanitize_for_json +# --------------------------------------------------------------------------- + + +class TestSanitizeForJson: + def test_lone_surrogate_replaced(self) -> None: + assert sanitize_for_json("a\ud83dz") == "a�z" + + def test_paired_surrogates_preserved(self) -> None: + assert sanitize_for_json("a😀z") == "a😀z" + + def test_ansi_passes_through(self) -> None: + # JSON encoder handles ANSI fine; this function does not strip it. + assert sanitize_for_json("\x1b[2Jhello") == "\x1b[2Jhello" + + def test_controls_pass_through(self) -> None: + # NUL / BEL / CR survive — sanitize_for_dbtext or _log are stricter. + assert sanitize_for_json("a\x00\x07\rb") == "a\x00\x07\rb" + + def test_bidi_passes_through(self) -> None: + assert sanitize_for_json("a‮b") == "a‮b" + + def test_idempotence(self) -> None: + hostile = "a\ud83dz" + assert sanitize_for_json(sanitize_for_json(hostile)) == sanitize_for_json( + hostile + ) + + def test_json_encodes_after_sanitize(self) -> None: + import json + + hostile = "msg with \ud83d lone" + json.dumps(sanitize_for_json(hostile)) # must not raise + + def test_identity_on_safe_input(self) -> None: + safe = "completely normal text 😀 αβγ" + assert sanitize_for_json(safe) == safe + + +# --------------------------------------------------------------------------- +# sanitize_for_dbtext +# --------------------------------------------------------------------------- + + +class TestSanitizeForDbtext: + def test_nul_replaced(self) -> None: + assert sanitize_for_dbtext("a\x00b") == "a�b" + + def test_multiple_nul_replaced(self) -> None: + assert sanitize_for_dbtext("\x00\x00x\x00") == "��x�" + + def test_no_nul_identity(self) -> None: + assert sanitize_for_dbtext("plain") == "plain" + + def test_ansi_passes_through(self) -> None: + # The dbtext variant is intentionally minimal; ANSI passes. + assert sanitize_for_dbtext("\x1b[2Jhello") == "\x1b[2Jhello" + + def test_other_controls_pass_through(self) -> None: + assert sanitize_for_dbtext("a\x07\r\nb") == "a\x07\r\nb" + + def test_unicode_preserved(self) -> None: + assert sanitize_for_dbtext("a😀b") == "a😀b" + + def test_idempotence(self) -> None: + src = "a\x00b\x00c" + assert sanitize_for_dbtext(sanitize_for_dbtext(src)) == sanitize_for_dbtext(src) + + def test_lone_surrogate_replaced(self) -> None: + # asyncpg UTF-8 encodes; lone surrogates fail there even after NUL + # handling. dbtext variant must cover both. + assert sanitize_for_dbtext("a\ud83dz") == "a�z" + + def test_combined_nul_and_surrogate(self) -> None: + assert sanitize_for_dbtext("a\x00b\ud83dc") == "a�b�c" + + def test_paired_surrogates_preserved(self) -> None: + # Valid emoji must not be mangled by the surrogate fix. + assert sanitize_for_dbtext("a😀b") == "a😀b" + + def test_output_is_utf8_encodable(self) -> None: + # The whole point: after dbtext, the result encodes cleanly. + hostile = "a\x00\ud83db\udc00c" + sanitize_for_dbtext(hostile).encode("utf-8") # must not raise + + +# --------------------------------------------------------------------------- +# Composability +# --------------------------------------------------------------------------- + + +class TestComposability: + def test_dbtext_then_log_handles_everything(self) -> None: + # Pipelined: first replace NUL, then strip ANSI / BIDI / controls / + # surrogates. End result is safe for log, JSON, and DB. + src = "a\x00\x1b[2J‮\ud83db\r" + intermediate = sanitize_for_dbtext(src) + final = sanitize_for_log(intermediate) + assert "\x00" not in final + assert "\x1b" not in final + assert "‮" not in final + assert "\r" not in final + # Validate output safety. + final.encode("utf-8") + import json + + json.dumps(final) + + def test_log_alone_covers_db_safety(self) -> None: + # sanitize_for_log strips C0 (incl NUL) — the escaped \\x00 string + # is safe for Postgres TEXT. + result = sanitize_for_log("a\x00b") + assert "\x00" not in result + + +# --------------------------------------------------------------------------- +# clean_identifier +# --------------------------------------------------------------------------- + + +class TestCleanIdentifier: + @pytest.mark.parametrize( + "src,expected", + [ + ("gpt-4o-mini", "gpt-4o-mini"), + ("claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022"), + ("meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct"), + ("openai:gpt-4o", "openai:gpt-4o"), + ("model@my-host:8080", "model@my-host:8080"), + ("model_v2+experimental", "model_v2+experimental"), + ("/local/path/to/model.gguf", "/local/path/to/model.gguf"), + ], + ) + def test_safe_identifiers_pass_through(self, src: str, expected: str) -> None: + assert clean_identifier(src) == expected + + @pytest.mark.parametrize( + "src,expected", + [ + ("gpt\x00admin", "gptadmin"), # NUL + ("real\nFAKE", "realFAKE"), # LF + ("a‮b", "ab"), # RLO + # ANSI: clean_identifier strips byte-by-byte, so the digits and + # letters inside the escape survive. Callers wanting ANSI removed + # as a unit must pipeline sanitize_for_log first. + ("model\x1b[2J", "model2J"), + ("model with spaces", "modelwithspaces"), # spaces + ("model