diff --git a/docs/concepts/detection.md b/docs/concepts/detection.md index 9fdb7a5..f9f7e4e 100644 --- a/docs/concepts/detection.md +++ b/docs/concepts/detection.md @@ -45,7 +45,35 @@ config = AnonymizerConfig( |-------|---------|-------------| | `entity_labels` | `None` (all defaults) | List of labels to detect. Leave unset (or pass `None`) to use the full default set. | | `gliner_threshold` | `0.3` | GLiNER confidence threshold (0.0--1.0). Lower values detect more entities but may increase false positives. | +| `validation_max_entities_per_call` | `100` | Maximum candidate entities per validator LLM call. Rows with more candidates are split into chunks. See [Chunked validation](#chunked-validation). | +| `validation_excerpt_window_chars` | `500` | Characters of context included before and after a chunk's entity spans in the validator prompt. Bounds per-chunk prompt size; not the model's context-window limit. | +--- + +## Chunked validation + +When a row yields many entity candidates, validating them in a single LLM call can often exceed the model's context window or the provider's rate limits (tokens-per-minute or requests-per-minute quotas that many hosted models enforce). Anonymizer automatically splits validation for such rows: candidates are grouped in position order into chunks of at most `validation_max_entities_per_call`, and each chunk is validated independently with its own bounded text excerpt (`validation_excerpt_window_chars` before and after the chunk's span). Decisions are merged back into a single per-row set. + +The chunked path is always on; if a row has fewer candidates than the limit, it runs as a single call and is exactly equivalent to the unchunked behavior. Tuning guidance: + +- **Raise `validation_max_entities_per_call`** if your validator has a large context window and you want fewer, larger calls. +- **Lower it** if you hit provider rate limits or want more uniform per-call latency. +- **Raise `validation_excerpt_window_chars`** when short windows hide the context needed to disambiguate entities (e.g., `"John"` as first name vs. last name depends on surrounding text). +- **Lower it** to reduce per-chunk prompt tokens, at the risk of lower validation quality on context-sensitive labels. + +### Validator pools + +`entity_validator` can be a single alias (the default) or a list of aliases — a **pool**. When multiple aliases are configured, each chunk in a row is dispatched to the next alias in round-robin order, which lets you work around per-alias rate limits by spreading requests across equivalent endpoints. + +Pools also act as **failover**. If a chunk's assigned alias can't complete the call (an unrecoverable rate limit, a 5xx that didn't clear on retry, a malformed response), the same chunk is automatically retried against the other aliases in your pool before the row is given up on. A chunk only fails once every alias in the pool has failed for it. This is a cheap way to harden validation against any one endpoint having a bad day, on top of the load-spreading role. + +#### What happens when a row can't be validated + +If validation can't get a complete answer for a row — every alias in the pool has failed on at least one of that row's chunks — the row is **dropped from the output** rather than passed through with some entities unvalidated. This is deliberate: the alternative would be writing the original text back out with those entities still un-scrubbed, which is exactly the outcome you're trying to avoid. + +Dropped rows show up on `result.failed_records` with `step="detection"`, so you can tell which inputs didn't make it through by comparing input IDs against output IDs and reprocess those on a follow-up pass. + +See [Validator pools](models.md#validator-pools) for the YAML syntax and caveats. ## Entity labels diff --git a/docs/concepts/models.md b/docs/concepts/models.md index 77a9439..c7adc8d 100644 --- a/docs/concepts/models.md +++ b/docs/concepts/models.md @@ -109,6 +109,33 @@ Roles you don't override keep their default alias selections, but those aliases Use [`anonymizer.validate_config(config)`](../reference/anonymizer/interface/anonymizer.md) (or [`anonymizer validate`](../reference/anonymizer/interface/cli/main.md) from the CLI) after changing model configs to catch alias mismatches before processing data. +### Validator pools + +`entity_validator` accepts either a single alias (shown above) or a list of aliases. A list forms a **validator pool** with two jobs: + +1. **Load spreading.** [Chunked validation](detection.md#chunked-validation) dispatches each chunk to the next alias in round-robin order, aggregating quota across equivalent endpoints when a single alias would hit the provider's rate limits (tokens-per-minute or requests-per-minute quotas). +2. **Failover.** If a chunk's assigned alias can't complete the call (an unrecoverable rate limit, a 5xx that didn't clear on retry, a malformed response), the same chunk is automatically retried against the other aliases in your pool before the row is given up on. A row is only dropped when *every* alias in the pool has failed for the same chunk. Single-alias pools have nothing to fall back to, so they behave the same as not using a pool. + +```yaml +selected_models: + detection: + entity_detector: gliner-pii-detector + entity_validator: + - gpt5-primary + - gpt5-secondary + entity_augmenter: gpt5-primary + latent_detector: claude-sonnet +``` + +Every alias in the pool must also appear in `model_configs`; `anonymizer validate` flags unknown aliases by index. A scalar value remains valid and is equivalent to a one-element list. + +!!! warning "`max_parallel_requests` is enforced per alias" + + A pool with N aliases effectively allows up to `sum(max_parallel_requests for alias in pool)` concurrent validator calls per row when chunks exist. Budget your provider rate limits accordingly — the whole point of pooling is to multiply in-flight requests, but the multiplication is real. + + Pool aliases should target **equivalent models** (same model family, similar quality). Mixing heterogeneous models produces inconsistent validation across chunks in the same row and is almost always a misconfiguration. + + ### Choosing custom models For Anonymizer, the best overall leaderboard model is not always the best default for every role. diff --git a/src/anonymizer/config/anonymizer_config.py b/src/anonymizer/config/anonymizer_config.py index 2ecd072..2c8c4d1 100644 --- a/src/anonymizer/config/anonymizer_config.py +++ b/src/anonymizer/config/anonymizer_config.py @@ -82,6 +82,24 @@ class Detect(BaseModel): gliner_threshold: float = Field( default=0.3, ge=0.0, le=1.0, description="GLiNER detection confidence threshold (0.0-1.0)." ) + validation_max_entities_per_call: int = Field( + default=100, + gt=0, + description=( + "Maximum number of candidate entities included in a single validator LLM call. " + "When a row has more candidates than this, validation is split into chunks that " + "are dispatched (round-robin) across the validator pool." + ), + ) + validation_excerpt_window_chars: int = Field( + default=500, + gt=0, + description=( + "Number of characters to include before and after a chunk's entity span when " + "building the text excerpt sent to the validator. Bounds the prompt context the " + "validator sees per chunk; it is NOT the LLM's context window limit." + ), + ) @field_validator("entity_labels") @classmethod diff --git a/src/anonymizer/config/models.py b/src/anonymizer/config/models.py index f01983f..5979219 100644 --- a/src/anonymizer/config/models.py +++ b/src/anonymizer/config/models.py @@ -3,17 +3,76 @@ from __future__ import annotations -from pydantic import BaseModel +import logging +from typing import Any + +from pydantic import BaseModel, field_validator + +logger = logging.getLogger(__name__) class DetectionModelSelection(BaseModel): - """Model aliases for the entity detection pipeline.""" + """Model aliases for the entity detection pipeline. + + ``entity_validator`` accepts either a single alias or a list of aliases. + A list forms a validator *pool*: chunked validation rotates calls + across the pool in round-robin order, which is useful for bypassing + per-alias TPM/RPM limits. A single scalar is normalized to a + one-element list. + """ entity_detector: str - entity_validator: str + entity_validator: list[str] entity_augmenter: str latent_detector: str + @field_validator("entity_validator", mode="before") + @classmethod + def normalize_entity_validator(cls, value: Any) -> list[str]: + """Accept a scalar alias, a list of aliases, or a tuple of aliases; return a non-empty deduplicated list. + + Normalizing at parse time keeps every downstream consumer on the + same shape (``list[str]``) regardless of whether the user wrote + ``entity_validator: some-alias`` or + ``entity_validator: [alias-a, alias-b]``. Tuples are accepted for + parity with Pydantic v2's default coercion for ``list[str]`` fields, + which lets programmatic callers pass either + ``DetectionModelSelection(entity_validator=["a", "b"])`` or + ``DetectionModelSelection(entity_validator=("a", "b"))`` without + caring about the concrete sequence type. Any other input type + raises ``TypeError``. + + Duplicate aliases are collapsed to the first occurrence (order + preserved) and a warning is logged. A duplicate in the pool would + burn a failover attempt on an already-exhausted endpoint, which + almost certainly isn't what the user wants. + """ + if isinstance(value, str): + aliases: list[str] = [value] + elif isinstance(value, (list, tuple)): + aliases = [str(item) for item in value] + else: + raise TypeError(f"entity_validator must be a string or list of strings, got {type(value).__name__}") + cleaned = [alias.strip() for alias in aliases if alias.strip()] + if not cleaned: + raise ValueError("entity_validator must name at least one model alias.") + seen: set[str] = set() + deduped: list[str] = [] + for alias in cleaned: + if alias in seen: + continue + seen.add(alias) + deduped.append(alias) + if len(deduped) != len(cleaned): + removed = [alias for alias in cleaned if cleaned.count(alias) > 1] + logger.warning( + "entity_validator pool contained duplicate aliases %s; collapsing to %s. " + "Duplicates burn a failover attempt on an already-exhausted endpoint.", + sorted(set(removed)), + deduped, + ) + return deduped + class ReplaceModelSelection(BaseModel): """Model aliases for the replacement pipeline.""" diff --git a/src/anonymizer/engine/detection/chunked_validation.py b/src/anonymizer/engine/detection/chunked_validation.py new file mode 100644 index 0000000..50601ca --- /dev/null +++ b/src/anonymizer/engine/detection/chunked_validation.py @@ -0,0 +1,545 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Chunked LLM validation for the entity detection pipeline. + +Partition a row's validation candidates into chunks, build a small tagged +excerpt around each chunk, render the validation prompt per chunk, and +dispatch each chunk to an alias selected round-robin from a configured +validator pool. The per-chunk decisions are merged into a +``ValidationDecisionsSchema``-shaped payload consumed by +``enrich_validation_decisions`. + +Public entry point: :func:`make_chunked_validation_generator`, which +produces a ``@custom_column_generator``-decorated function bound to a +concrete pool. The helpers below are exposed for unit testing. + +Failure contract. Each chunk attempts its round-robin primary first and +fails over sequentially to the rest of the pool; a chunk only fails when +every pool member has raised. The first failing chunk re-raises out of +the generator, DataDesigner drops the row, and +``NddAdapter._detect_missing_records`` surfaces it as a ``FailedRecord``. +Raw text never silently leaks through as unscrubbed output. + +Concurrency. Chunks dispatch through a ``ThreadPoolExecutor``. Per-alias +concurrency is already enforced downstream by each facade's +``ThrottledModelClient`` (AIMD on 429), so there is no row-level cap +here; the pool exists purely to overlap this row's chunks. + +TODO(async-native): once DataDesigner's async engine becomes the default +(``DATA_DESIGNER_ASYNC_ENGINE`` flips off), replace the +``ThreadPoolExecutor`` + sync ``facade.generate()`` pattern with +``async def`` functions calling ``facade.agenerate()`` under +``asyncio.gather``. +""" + +from __future__ import annotations + +import functools +import logging +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from data_designer.config import custom_column_generator +from data_designer.engine.models.recipes.response_recipes import PydanticResponseRecipe +from jinja2 import BaseLoader, Environment, StrictUndefined +from pydantic import BaseModel, Field + +from anonymizer.engine.constants import ( + COL_SEED_ENTITIES, + COL_SEED_TAGGED_TEXT, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAG_NOTATION, + COL_TEXT, + COL_VALIDATION_DECISIONS, + COL_VALIDATION_SKELETON, +) +from anonymizer.engine.detection.postprocess import ( + EntitySpan, + TagNotation, + build_tagged_text, +) +from anonymizer.engine.schemas import ( + EntitiesSchema, + RawValidationDecisionsSchema, + ValidationCandidatesSchema, + ValidationDecisionsSchema, + ValidationSkeletonDecisionSchema, + ValidationSkeletonSchema, +) + +logger = logging.getLogger("anonymizer.detection.chunked_validation") + +# Jinja2 environment used to render the per-chunk validation prompt. +# The template mirrors the production prompt exactly: we substitute the same +# placeholders (``_seed_tagged_text``, ``_validation_skeleton``, +# ``_tag_notation``) but with per-chunk values. +_PROMPT_ENV = Environment( + loader=BaseLoader(), + autoescape=False, + undefined=StrictUndefined, + keep_trailing_newline=True, +) + + +@functools.lru_cache(maxsize=4) +def _compile_template(template: str) -> Any: + """Return a compiled Jinja2 template, cached by source string.""" + return _PROMPT_ENV.from_string(template) + + +class ChunkedValidationParams(BaseModel): + """Parameters supplied to :func:`chunked_validate_row` via DD's ``generator_params``. + + Attributes: + pool: Ordered list of validator model aliases. Chunk ``i`` is dispatched + to ``pool[i % len(pool)]`` as its primary; on any terminal exception + from that alias the chunk fails over through the rest of the pool + (starting from the next position, wrapping around). Must be + non-empty and every alias must also be present in the decorator's + ``model_aliases`` so DataDesigner materialises the facade. + max_entities_per_call: Upper bound on candidates per chunk. + excerpt_window_chars: Chars of surrounding raw text included in each + chunk's excerpt on either side of the chunk span. + prompt_template: Jinja2 source for the validation prompt (with + ``_seed_tagged_text``, ``_validation_skeleton``, ``_tag_notation`` + placeholders). Typically produced by ``_get_validation_prompt``. + system_prompt: Optional system prompt forwarded to each chunk call. + + ``prompt_template`` and ``system_prompt`` are marked ``repr=False`` because + DataDesigner's pre-generation logger f-strings this model + (``generator_params: {params}``) and our validation prompt is multi-kB of + entity rules; a non-trivial system prompt would compound that. Hiding + them from ``__str__``/``__repr__`` keeps setup logs readable without + touching serialization — ``model_dump()`` still carries both, so the + generator receives them unchanged. + """ + + pool: list[str] = Field(min_length=1) + max_entities_per_call: int = Field(gt=0) + excerpt_window_chars: int = Field(gt=0) + prompt_template: str = Field(repr=False) + system_prompt: str | None = Field(default=None, repr=False) + + +# --------------------------------------------------------------------------- +# Pure helpers (no DataDesigner, no LLM). Tested directly. +# --------------------------------------------------------------------------- + + +def order_candidates_by_position( + candidates: ValidationCandidatesSchema, + seed_entities: list[EntitySpan], +) -> list[tuple[Any, EntitySpan]]: + """Pair each candidate with its matching seed entity and sort by text position. + + Every candidate id must resolve to a seed entity. A missing id indicates an + upstream bug in ``merge_and_build_candidates`` or ``prepare_validation_inputs`` + (both produce candidates whose ids come from ``EntitySpan.entity_id``). We + raise early with the offending id so the failure is easy to triage. + """ + seed_by_id = {span.entity_id: span for span in seed_entities} + paired: list[tuple[Any, EntitySpan]] = [] + for candidate in candidates.candidates: + seed = seed_by_id.get(candidate.id) + if seed is None: + raise ValueError( + f"Validation candidate id {candidate.id!r} has no matching seed entity. " + "Every candidate produced by merge_and_build_candidates or " + "prepare_validation_inputs must correspond to a seed entity with " + "start_position and end_position populated; this inconsistency " + "indicates a bug in one of those upstream generators." + ) + paired.append((candidate, seed)) + paired.sort(key=lambda pair: (pair[1].start_position, pair[1].end_position, pair[1].entity_id)) + return paired + + +def chunk_candidates( + ordered: Sequence[tuple[Any, EntitySpan]], + max_entities_per_call: int, +) -> list[list[tuple[Any, EntitySpan]]]: + """Partition the ordered (candidate, seed) pairs into chunks of at most ``max_entities_per_call``. + + Assumes ``max_entities_per_call > 0``; positivity is enforced upstream at + ``ChunkedValidationParams.max_entities_per_call`` and + ``AnonymizerDetectConfig.validation_max_entities_per_call`` (both + ``Field(gt=0)``). + """ + return [list(ordered[i : i + max_entities_per_call]) for i in range(0, len(ordered), max_entities_per_call)] + + +def build_chunk_excerpt( + *, + text: str, + chunk_spans: list[EntitySpan], + all_spans: list[EntitySpan], + window_chars: int, + notation: TagNotation, +) -> str: + """Build a tagged text excerpt wide enough to give the LLM context around ``chunk_spans``. + + The excerpt spans ``[min(chunk.start) - window, max(chunk.end) + window]`` + clamped to the text bounds. Any entity from ``all_spans`` fully contained + in that window is re-tagged inside the excerpt so the surrounding context + matches the full-document view. The forced ``notation`` keeps tags stable + across chunks of the same row even when a local slice would otherwise + pick a different heuristic. + """ + if not chunk_spans: + return "" + chunk_start = min(span.start_position for span in chunk_spans) + chunk_end = max(span.end_position for span in chunk_spans) + excerpt_start = max(0, chunk_start - window_chars) + excerpt_end = min(len(text), chunk_end + window_chars) + excerpt_raw = text[excerpt_start:excerpt_end] + in_window = [ + EntitySpan( + entity_id=span.entity_id, + value=span.value, + label=span.label, + start_position=span.start_position - excerpt_start, + end_position=span.end_position - excerpt_start, + score=span.score, + source=span.source, + ) + for span in all_spans + if span.start_position >= excerpt_start and span.end_position <= excerpt_end + ] + return build_tagged_text(excerpt_raw, in_window, notation=notation) + + +def build_chunk_skeleton(chunk_candidates_: list[Any]) -> dict[str, Any]: + """Build the validation skeleton (``ValidationSkeletonSchema``) for a chunk.""" + skeleton = ValidationSkeletonSchema( + decisions=[ValidationSkeletonDecisionSchema(id=c.id, value=c.value, label=c.label) for c in chunk_candidates_] + ) + return skeleton.model_dump(mode="json") + + +def render_chunk_prompt( + *, + template: str, + excerpt: str, + skeleton: dict[str, Any], + notation: TagNotation, +) -> str: + """Render the validation prompt for a single chunk via Jinja2. + + The template and context match the production ``LLMStructuredColumnConfig`` + call: dicts are rendered with Python ``str()`` (Jinja2 default), which is + how the existing prompt has always served ``{{ _validation_skeleton }}``. + """ + compiled = _compile_template(template) + return compiled.render( + **{ + COL_SEED_TAGGED_TEXT: excerpt, + COL_VALIDATION_SKELETON: skeleton, + COL_TAG_NOTATION: notation.value, + } + ) + + +def merge_chunk_decisions( + chunk_results: list[RawValidationDecisionsSchema], + candidates: ValidationCandidatesSchema, +) -> dict[str, Any]: + """Flatten chunk decisions into a single ``ValidationDecisionsSchema`` payload. + + Mirrors the single-call contract: + - Only decisions whose ids match a known candidate are retained. This is + consistent with ``enrich_validation_decisions``, which also filters to + valid ids; doing it here too keeps COL_VALIDATION_DECISIONS minimal. + - Null-decision entries are treated as "no answer" and do NOT reserve + the id, so if a later chunk yields a real verdict for the same id, + that verdict wins. The null entry itself never leaks through: downstream + ``apply_validation_decisions`` relies on candidate-not-in-output to + mean "keep unchanged", which would break if we emitted ``decision=null``. + - Among multiple real verdicts for the same id (shouldn't happen because + candidates partition cleanly, but kept as defence-in-depth), the first + wins. + """ + candidate_lookup = {c.id: c for c in candidates.candidates} + valid_ids = set(candidate_lookup) + seen: set[str] = set() + merged_decisions: list[dict[str, Any]] = [] + for result in chunk_results: + for decision in result.decisions: + if decision.id not in valid_ids or decision.id in seen: + continue + # Skip null-decision entries without marking the id as seen, so a + # later chunk with a real verdict for the same id can still win. + if decision.decision is None: + continue + cand = candidate_lookup[decision.id] + merged_decisions.append( + { + "id": decision.id, + "value": cand.value, + "label": cand.label, + "decision": decision.decision.value, + "proposed_label": decision.proposed_label or "", + "reason": decision.reason, + } + ) + seen.add(decision.id) + return ValidationDecisionsSchema.model_validate({"decisions": merged_decisions}).model_dump(mode="json") + + +# --------------------------------------------------------------------------- +# Chunk dispatch. Testable by passing fake ``models``. +# --------------------------------------------------------------------------- + + +def _dispatch_chunk( + *, + facades: list[tuple[str, Any]], + prompt: str, + system_prompt: str | None, + chunk_index: int, +) -> RawValidationDecisionsSchema: + """Dispatch a single chunk with cross-alias failover across the pool. + + ``facades`` is an ordered list of ``(alias, facade)`` pairs. The first + entry is the chunk's round-robin-assigned primary; subsequent entries + are the rest of the pool, tried in order on any terminal exception from + the primary. Each facade carries its own transport-level retry policy + (``RetryConfig.max_retries`` + exponential backoff on 5xx and connection + errors) and its own AIMD throttling on 429, so by the time an exception + escapes the facade call we consider that alias exhausted for this chunk. + + We use ``PydanticResponseRecipe`` so the facade appends JSON task + instructions and parses the response into ``RawValidationDecisionsSchema``. + + Single-alias pools run the loop exactly once and re-raise the original + exception (no alternate alias to try). Multi-alias pools get + ``len(pool)`` total attempts. If every pool member raises, the *last* + exception propagates so DataDesigner records the row as a + ``FailedRecord`` via ``NddAdapter._detect_missing_records``. + + Each failover attempt is logged at WARNING so operators can correlate + degraded pool members with run-level failure-rate spikes. + """ + recipe = PydanticResponseRecipe(data_type=RawValidationDecisionsSchema) + final_prompt = recipe.apply_recipe_to_user_prompt(prompt) + final_system = recipe.apply_recipe_to_system_prompt(system_prompt) + + last_exc: BaseException | None = None + for attempt_index, (alias, facade) in enumerate(facades): + try: + output, _messages = facade.generate( + prompt=final_prompt, + parser=recipe.parse, + system_prompt=final_system, + purpose=f"entity-validation-chunk-{chunk_index}-attempt-{attempt_index}", + ) + if attempt_index > 0: + logger.info( + "validator chunk %d: recovered on failover alias=%s (attempt %d of %d)", + chunk_index, + alias, + attempt_index + 1, + len(facades), + ) + return output + except Exception as exc: # noqa: BLE001 — we classify by failover position, not type + last_exc = exc + remaining = len(facades) - attempt_index - 1 + if remaining > 0: + logger.warning( + "validator chunk %d: alias=%s raised %s (%s); failing over to next pool member (%d remaining)", + chunk_index, + alias, + type(exc).__name__, + exc, + remaining, + ) + else: + logger.error( + "validator chunk %d: alias=%s raised %s (%s); pool exhausted — row will be dropped", + chunk_index, + alias, + type(exc).__name__, + exc, + ) + + # ``facades`` is non-empty by caller contract: ``chunked_validate_row`` + # builds it from the configured pool, and the config validator requires + # a non-empty pool. After the loop, ``last_exc`` is therefore set and + # we re-raise it. The ``None`` branch exists only to give a loud, + # named error if that precondition is ever violated (rather than + # ``raise None``, which would surface as ``TypeError: exceptions must + # derive from BaseException``) and to keep the guard live under + # ``python -O``, which strips ``assert``. + if last_exc is None: + raise RuntimeError( + "_dispatch_chunk was called with an empty facades list; " + "this violates the caller contract (a non-empty validator pool)." + ) + raise last_exc + + +def chunked_validate_row( + row: dict[str, Any], + params: ChunkedValidationParams, + models: dict[str, Any], +) -> dict[str, Any]: + """Run chunked validation for a single row and write ``COL_VALIDATION_DECISIONS``. + + This is the workhorse. Call it directly in tests with fake ``models``; + the DataDesigner-decorated wrapper produced by + :func:`make_chunked_validation_generator` just forwards to it. + """ + missing_aliases = [alias for alias in params.pool if alias not in models] + if missing_aliases: + raise KeyError( + f"Validator pool aliases {missing_aliases} not present in models dict. " + f"Ensure make_chunked_validation_generator was invoked with the same pool " + f"passed in ChunkedValidationParams.pool." + ) + + text = str(row.get(COL_TEXT, "")) + candidates = ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})) + seed_entities_schema = EntitiesSchema.from_raw(row.get(COL_SEED_ENTITIES, {})) + notation_raw = row.get(COL_TAG_NOTATION) or TagNotation.sentinel.value + notation = TagNotation(str(notation_raw)) + + # Short-circuit: a row with no candidates has no decisions to make. + if not candidates.candidates: + row[COL_VALIDATION_DECISIONS] = ValidationDecisionsSchema().model_dump(mode="json") + return row + + all_spans = [ + EntitySpan( + entity_id=entity.id, + value=entity.value, + label=entity.label, + start_position=entity.start_position, + end_position=entity.end_position, + score=entity.score, + source=entity.source, + ) + for entity in seed_entities_schema.entities + ] + + ordered = order_candidates_by_position(candidates, all_spans) + chunks = chunk_candidates(ordered, params.max_entities_per_call) + + if len(chunks) == 1: + logger.debug( + "chunked validation: %d candidate(s) in 1 chunk (full-text excerpt), pool=%s", + len(ordered), + params.pool, + ) + else: + logger.debug( + "chunked validation: %d candidate(s) in %d chunks (max=%d per chunk, window=%d chars), pool=%s", + len(ordered), + len(chunks), + params.max_entities_per_call, + params.excerpt_window_chars, + params.pool, + ) + + # Single-chunk rows preserve parity with the pre-chunking + # ``LLMStructuredColumnConfig`` path by sending the fully tagged + # document. The excerpt window is strictly a cost-control lever for + # multi-chunk dispatch (it bounds per-chunk input tokens); when we're + # only making one call there's no cost reason to clip, and clipping + # would silently narrow the context the validator sees. Computed once + # here because ``len(chunks) == 1`` is loop-invariant. + single_chunk_tagged_text = build_tagged_text(text, all_spans, notation=notation) if len(chunks) == 1 else None + + dispatch_kwargs_per_chunk: list[dict[str, Any]] = [] + for chunk_index, chunk in enumerate(chunks): + chunk_candidates_ = [pair[0] for pair in chunk] + chunk_spans = [pair[1] for pair in chunk] + excerpt = ( + single_chunk_tagged_text + if single_chunk_tagged_text is not None + else build_chunk_excerpt( + text=text, + chunk_spans=chunk_spans, + all_spans=all_spans, + window_chars=params.excerpt_window_chars, + notation=notation, + ) + ) + skeleton = build_chunk_skeleton(chunk_candidates_) + prompt = render_chunk_prompt( + template=params.prompt_template, + excerpt=excerpt, + skeleton=skeleton, + notation=notation, + ) + # Round-robin across the validator pool. ``ChunkedValidationParams`` + # guarantees ``pool`` is non-empty; ``chunk_index`` comes from + # ``enumerate`` so it's non-negative by construction. The rotated + # order (primary first, then the rest of the pool) is what + # ``_dispatch_chunk`` walks on cross-alias failover. + start = chunk_index % len(params.pool) + rotated_aliases = [params.pool[(start + offset) % len(params.pool)] for offset in range(len(params.pool))] + chunk_facades = [(alias, models[alias]) for alias in rotated_aliases] + dispatch_kwargs_per_chunk.append( + { + "facades": chunk_facades, + "prompt": prompt, + "system_prompt": params.system_prompt, + "chunk_index": chunk_index, + } + ) + + # Dispatch all chunks concurrently via a ThreadPoolExecutor. Per-alias + # concurrency is still capped downstream by each facade's + # ``ThrottledModelClient`` (AIMD on 429), so the pool's only job here is + # to overlap one row's chunks. ``f.result()`` re-raises the first chunk + # exception, which is what we want: a single terminal chunk failure + # fails the row. Pending workers finish naturally as the ``with`` block + # exits — we just stop observing their results once we re-raise. + if not chunks: + chunk_results: list[RawValidationDecisionsSchema] = [] + else: + with ThreadPoolExecutor(max_workers=len(chunks)) as executor: + futures = [executor.submit(_dispatch_chunk, **kwargs) for kwargs in dispatch_kwargs_per_chunk] + chunk_results = [f.result() for f in futures] + + row[COL_VALIDATION_DECISIONS] = merge_chunk_decisions(chunk_results, candidates) + return row + + +# --------------------------------------------------------------------------- +# DataDesigner wiring factory. +# --------------------------------------------------------------------------- + + +def make_chunked_validation_generator(pool: list[str]) -> Any: + """Build a ``@custom_column_generator``-decorated function bound to ``pool``. + + ``model_aliases`` must be declared statically on the decorator so + DataDesigner knows which facades to materialise for the generator. Since + the pool is config-driven (per-run), we generate the function dynamically. + The required_columns are exhaustive for DataDesigner's DAG ordering: the + generator reads the raw text, seed entities (for positions), the candidate + list (what to decide), and the tag notation (for excerpt tagging). + """ + if not pool: + raise ValueError("Cannot build chunked validation generator: pool is empty.") + + @custom_column_generator( + required_columns=[ + COL_TEXT, + COL_SEED_ENTITIES, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAG_NOTATION, + ], + model_aliases=list(pool), + ) + def chunked_validate( + row: dict[str, Any], + generator_params: ChunkedValidationParams, + models: dict[str, Any], + ) -> dict[str, Any]: + return chunked_validate_row(row, generator_params, models) + + return chunked_validate diff --git a/src/anonymizer/engine/detection/custom_columns.py b/src/anonymizer/engine/detection/custom_columns.py index bad41d7..059d82a 100644 --- a/src/anonymizer/engine/detection/custom_columns.py +++ b/src/anonymizer/engine/detection/custom_columns.py @@ -34,7 +34,6 @@ COL_VALIDATED_SEED_ENTITIES, COL_VALIDATION_CANDIDATES, COL_VALIDATION_DECISIONS, - COL_VALIDATION_SKELETON, ) from anonymizer.engine.detection.postprocess import ( EntitySpan, @@ -52,8 +51,6 @@ ValidatedDecisionSchema, ValidatedDecisionsSchema, ValidationCandidatesSchema, - ValidationSkeletonDecisionSchema, - ValidationSkeletonSchema, ) @@ -135,19 +132,6 @@ def prepare_validation_inputs(row: dict[str, Any]) -> dict[str, Any]: return row -@custom_column_generator(required_columns=[COL_SEED_VALIDATION_CANDIDATES]) -def build_validation_skeleton(row: dict[str, Any]) -> dict[str, Any]: - """Pre-populate the decisions template with candidate IDs so the LLM only fills in decision/reason.""" - candidates = ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})) - skeleton = ValidationSkeletonSchema( - decisions=[ - ValidationSkeletonDecisionSchema(id=c.id, value=c.value, label=c.label) for c in candidates.candidates - ] - ) - row[COL_VALIDATION_SKELETON] = skeleton.model_dump(mode="json") - return row - - @custom_column_generator( required_columns=[COL_VALIDATION_DECISIONS, COL_SEED_VALIDATION_CANDIDATES], ) diff --git a/src/anonymizer/engine/detection/detection_workflow.py b/src/anonymizer/engine/detection/detection_workflow.py index d633041..24e4ef9 100644 --- a/src/anonymizer/engine/detection/detection_workflow.py +++ b/src/anonymizer/engine/detection/detection_workflow.py @@ -11,6 +11,7 @@ from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig from data_designer.config.models import ModelConfig +from anonymizer.config.anonymizer_config import Detect as AnonymizerDetectConfig from anonymizer.config.models import DetectionModelSelection from anonymizer.config.rewrite import PrivacyGoal from anonymizer.engine.constants import ( @@ -36,10 +37,13 @@ ENTITY_LABEL_EXAMPLES, _jinja, ) +from anonymizer.engine.detection.chunked_validation import ( + ChunkedValidationParams, + make_chunked_validation_generator, +) from anonymizer.engine.detection.custom_columns import ( apply_validation_and_finalize, apply_validation_to_seed_entities, - build_validation_skeleton, enrich_validation_decisions, merge_and_build_candidates, parse_detected_entities, @@ -47,18 +51,27 @@ ) from anonymizer.engine.detection.postprocess import EntitySpan, group_entities_by_value from anonymizer.engine.ndd.adapter import FailedRecord, NddAdapter -from anonymizer.engine.ndd.model_loader import resolve_model_alias +from anonymizer.engine.ndd.model_loader import resolve_model_alias, resolve_model_aliases from anonymizer.engine.prompt_utils import substitute_placeholders from anonymizer.engine.schemas import ( AugmentedEntitiesSchema, EntitiesByValueSchema, EntitiesSchema, LatentEntitiesSchema, - ValidationDecisionsSchema, ) logger = logging.getLogger("anonymizer.detection") +# Defaults for the two chunked-validation knobs. Sourced from the Detect config +# so there is a single source of truth; the workflow method defaults exist so +# internal tests and ad-hoc callers do not have to wire plumbing by hand. +_DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL: int = AnonymizerDetectConfig.model_fields[ + "validation_max_entities_per_call" +].default +_DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS: int = AnonymizerDetectConfig.model_fields[ + "validation_excerpt_window_chars" +].default + @dataclass(frozen=True) class EntityDetectionResult: @@ -79,6 +92,8 @@ def detect_and_validate_entities( model_configs: list[ModelConfig], selected_models: DetectionModelSelection, gliner_detection_threshold: float, + validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, entity_labels: list[str] | None = None, data_summary: str | None = None, preview_num_records: int | None = None, @@ -86,9 +101,10 @@ def detect_and_validate_entities( """Run the core detection pipeline: GLiNER NER, LLM validation, LLM augmentation, and finalization. This is the primary detection workflow. It detects entities via GLiNER, - validates/reclassifies them with an LLM, augments with additional - entities the detector may have missed, and produces final standoff - entity spans with overlap resolution. + validates/reclassifies them with an LLM (chunked across a pool of + validator aliases), augments with additional entities the detector may + have missed, and produces final standoff entity spans with overlap + resolution. """ labels = _resolve_detection_labels(entity_labels) workflow_model_configs = self._inject_detector_params( @@ -99,14 +115,36 @@ def detect_and_validate_entities( ) detection_alias = resolve_model_alias("entity_detector", selected_models) - validator_alias = resolve_model_alias("entity_validator", selected_models) + validator_aliases = resolve_model_aliases("entity_validator", selected_models) augmenter_alias = resolve_model_alias("entity_augmenter", selected_models) logger.debug( - "detection aliases: detector=%s, validator=%s, augmenter=%s", + "detection aliases: detector=%s, validator_pool=%s, augmenter=%s", detection_alias, - validator_alias, + validator_aliases, augmenter_alias, ) + # ModelConfig.max_parallel_requests caps concurrency *per alias*. When + # the pool has multiple validators each gets its own cap, so total + # in-flight validator calls can reach sum(per-alias caps). Operators + # provisioning rate budgets for downstream providers should size each + # alias's cap accordingly. + if len(validator_aliases) > 1: + logger.warning( + "entity validation runs across a pool of %d aliases (%s). " + "max_parallel_requests is enforced per alias, so the pool " + "multiplies total in-flight validator calls; budget provider " + "TPM/RPM accordingly.", + len(validator_aliases), + validator_aliases, + ) + + validator_generator = make_chunked_validation_generator(validator_aliases) + validator_params = ChunkedValidationParams( + pool=list(validator_aliases), + max_entities_per_call=validation_max_entities_per_call, + excerpt_window_chars=validation_excerpt_window_chars, + prompt_template=_get_validation_prompt(data_summary=data_summary, labels=labels), + ) detection_result = self._adapter.run_workflow( dataframe, @@ -126,13 +164,9 @@ def detect_and_validate_entities( generator_function=prepare_validation_inputs, ), CustomColumnConfig( - name=COL_VALIDATION_SKELETON, generator_function=build_validation_skeleton, drop=True - ), - LLMStructuredColumnConfig( name=COL_VALIDATION_DECISIONS, - prompt=_get_validation_prompt(data_summary=data_summary, labels=labels), - model_alias=validator_alias, - output_format=ValidationDecisionsSchema, + generator_function=validator_generator, + generator_params=validator_params, drop=True, ), CustomColumnConfig( @@ -217,6 +251,8 @@ def run( model_configs: list[ModelConfig], selected_models: DetectionModelSelection, gliner_detection_threshold: float, + validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, entity_labels: list[str] | None = None, privacy_goal: PrivacyGoal | None = None, data_summary: str | None = None, @@ -239,6 +275,8 @@ def run( model_configs=model_configs, selected_models=selected_models, gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, entity_labels=entity_labels, data_summary=data_summary, preview_num_records=preview_num_records, diff --git a/src/anonymizer/engine/detection/postprocess.py b/src/anonymizer/engine/detection/postprocess.py index 1433808..9e3557e 100644 --- a/src/anonymizer/engine/detection/postprocess.py +++ b/src/anonymizer/engine/detection/postprocess.py @@ -255,11 +255,30 @@ def resolve_overlaps(entities: list[EntitySpan]) -> list[EntitySpan]: return sorted(accepted, key=lambda item: (item.start_position, item.end_position, item.label)) -def build_tagged_text(text: str, entities: list[EntitySpan]) -> str: - """Render human-readable tagged text for downstream LLM prompts.""" +def build_tagged_text( + text: str, + entities: list[EntitySpan], + *, + notation: TagNotation | str | None = None, +) -> str: + """Render human-readable tagged text for downstream LLM prompts. + + Args: + text: Source text to annotate. + entities: Entities to tag within ``text``; positions are relative to ``text``. + notation: Optional override of the tag notation. When ``None`` the + notation is chosen heuristically from ``text`` (default behaviour). + Callers that tag a substring of a larger document should pass the + parent document's notation so tags remain stable across excerpts. + """ if not entities: return text - notation = _choose_tag_notation(text) + if notation is None: + resolved_notation = _choose_tag_notation(text) + elif isinstance(notation, TagNotation): + resolved_notation = notation + else: + resolved_notation = TagNotation(notation) cursor = 0 parts: list[str] = [] for entity in sorted(entities, key=lambda item: (item.start_position, item.end_position)): @@ -270,7 +289,7 @@ def build_tagged_text(text: str, entities: list[EntitySpan]) -> str: _format_entity_tag( value=text[entity.start_position : entity.end_position], label=entity.label, - notation=notation, + notation=resolved_notation, ) ) cursor = entity.end_position diff --git a/src/anonymizer/engine/ndd/model_loader.py b/src/anonymizer/engine/ndd/model_loader.py index b567002..d74cbac 100644 --- a/src/anonymizer/engine/ndd/model_loader.py +++ b/src/anonymizer/engine/ndd/model_loader.py @@ -10,6 +10,7 @@ from data_designer.config.models import ModelConfig, load_model_configs from data_designer.config.utils.io_helpers import load_config_file +from pydantic import BaseModel from anonymizer.config.models import ( DetectionModelSelection, @@ -88,15 +89,32 @@ def load_models_config(config_dir: Path | None = None) -> dict[str, Any]: return _load_yaml_dict(resolved_dir / "models.yaml") -def load_workflow_selections(workflow_name: WorkflowName, config_dir: Path | None = None) -> dict[str, str]: - """Load selected model aliases for a workflow.""" +def load_workflow_selections( + workflow_name: WorkflowName, + config_dir: Path | None = None, +) -> dict[str, str | list[str]]: + """Load selected model aliases for a workflow. + + Scalar roles (e.g. ``entity_detector``) come back as strings and + list-valued roles (e.g. ``entity_validator``, which accepts a pool) come + back as ``list[str]``. Native types are preserved rather than stringified + so downstream Pydantic selection models see the shape the YAML actually + declared — stringifying a list would silently collapse the pool to a + single garbled alias. + """ resolved_dir = config_dir or DEFAULT_CONFIG_DIR workflow_file = resolved_dir / f"{workflow_name.value}.yaml" workflow_config = _load_yaml_dict(workflow_file) selected_models = workflow_config.get("selected_models", {}) if not isinstance(selected_models, dict): raise ValueError(f"{workflow_file} must define a top-level 'selected_models' mapping.") - return {str(key): str(value) for key, value in selected_models.items()} + normalized: dict[str, str | list[str]] = {} + for key, value in selected_models.items(): + if isinstance(value, list): + normalized[str(key)] = [str(item) for item in value] + else: + normalized[str(key)] = str(value) + return normalized def load_workflow_config(workflow_name: WorkflowName, config_dir: Path | None = None) -> dict[str, Any]: @@ -114,42 +132,88 @@ def load_workflow_config(workflow_name: WorkflowName, config_dir: Path | None = def get_model_alias(workflow_name: WorkflowName, role: str, config_dir: Path | None = None) -> str: - """Return the model alias assigned to a workflow role.""" + """Return the scalar model alias assigned to a workflow role. + + Raises ``TypeError`` if the role is list-valued in the YAML (e.g. a + validator pool). Callers that need the full pool should read the + populated selection model via ``load_default_model_selection()`` and + ``resolve_model_aliases`` instead. + """ selected_models = load_workflow_selections(workflow_name=workflow_name, config_dir=config_dir) if role not in selected_models: available = ", ".join(sorted(selected_models.keys())) raise ValueError(f"Role '{role}' not found in workflow '{workflow_name.value}'. Available roles: {available}") - return selected_models[role] + value = selected_models[role] + if isinstance(value, list): + raise TypeError( + f"Role '{role}' in workflow '{workflow_name.value}' is list-valued (a pool); " + f"use resolve_model_aliases() on the populated selection model instead." + ) + return value def resolve_model_alias( role: str, selection_model: DetectionModelSelection | ReplaceModelSelection | RewriteModelSelection, ) -> str: - """Read model alias directly from the selection model. + """Read a scalar model alias directly from the selection model. The selection model is already populated with defaults from YAML (via ``load_default_model_selection``) or user overrides. + + For list-valued roles (e.g. ``entity_validator``), use + ``resolve_model_aliases`` instead. """ - return getattr(selection_model, role) + value = getattr(selection_model, role) + if isinstance(value, list): + raise TypeError(f"Role {role!r} is list-valued; use resolve_model_aliases() to read it.") + return value + + +def resolve_model_aliases( + role: str, + selection_model: DetectionModelSelection | ReplaceModelSelection | RewriteModelSelection, +) -> list[str]: + """Read model aliases from the selection model as a list. + + Returns the stored list for list-valued roles (e.g. ``entity_validator``) + or a one-element list wrapping the scalar for scalar roles. Callers + that need to iterate a possible model pool should prefer this helper. + """ + value = getattr(selection_model, role) + if isinstance(value, list): + return list(value) + return [value] def _merge_selections(user_selections: dict[str, dict[str, str]] | None) -> ModelSelection: - """Merge user-provided role selections onto YAML defaults.""" + """Merge user-provided role selections onto YAML defaults. + + Re-validates via ``type(section).model_validate(merged)`` rather than + ``model_copy(update=...)``. Pydantic v2's ``model_copy`` silently skips + field validators, which would let invalid pool configs (empty list, + duplicate aliases, whitespace-only entries) bypass + ``DetectionModelSelection.normalize_entity_validator`` at parse time + and surface as opaque runtime failures later. + """ defaults = load_default_model_selection() if not user_selections or not isinstance(user_selections, dict): return defaults + def _merge(section: BaseModel, overrides: dict[str, Any]) -> BaseModel: + if not overrides: + return section + merged = {**section.model_dump(), **overrides} + return type(section).model_validate(merged) + detection_overrides = user_selections.get(WorkflowName.detection.value, {}) replace_overrides = user_selections.get(WorkflowName.replace.value, {}) rewrite_overrides = user_selections.get(WorkflowName.rewrite.value, {}) return ModelSelection( - detection=defaults.detection.model_copy(update=detection_overrides) - if detection_overrides - else defaults.detection, - replace=defaults.replace.model_copy(update=replace_overrides) if replace_overrides else defaults.replace, - rewrite=defaults.rewrite.model_copy(update=rewrite_overrides) if rewrite_overrides else defaults.rewrite, + detection=_merge(defaults.detection, detection_overrides), + replace=_merge(defaults.replace, replace_overrides), + rewrite=_merge(defaults.rewrite, rewrite_overrides), ) @@ -164,21 +228,16 @@ def validate_model_alias_references( known_aliases = {model_config.alias for model_config in model_configs} detection_roles = selected_models.detection.model_dump() - roles_to_check: dict[str, str] = { - f"detection.{role}": detection_roles[role] - for role in ("entity_detector", "entity_validator", "entity_augmenter") - } + roles_to_check: dict[str, str] = {} + for role in ("entity_detector", "entity_validator", "entity_augmenter"): + _collect_role(roles_to_check, f"detection.{role}", detection_roles[role]) if check_rewrite: - roles_to_check.update( - { - "detection.latent_detector": detection_roles["latent_detector"], - **{f"rewrite.{role}": alias for role, alias in selected_models.rewrite.model_dump().items()}, - } - ) + _collect_role(roles_to_check, "detection.latent_detector", detection_roles["latent_detector"]) + for role, alias in selected_models.rewrite.model_dump().items(): + _collect_role(roles_to_check, f"rewrite.{role}", alias) if check_substitute: - roles_to_check.update( - {f"replace.{role}": alias for role, alias in selected_models.replace.model_dump().items()} - ) + for role, alias in selected_models.replace.model_dump().items(): + _collect_role(roles_to_check, f"replace.{role}", alias) unknown = {path: alias for path, alias in roles_to_check.items() if alias not in known_aliases} if unknown: @@ -188,14 +247,34 @@ def validate_model_alias_references( ) +def _collect_role(bucket: dict[str, str], path: str, value: object) -> None: + """Flatten a role entry into ``bucket`` so list-valued roles produce one entry per alias.""" + if isinstance(value, list): + for idx, alias in enumerate(value): + bucket[f"{path}[{idx}]"] = str(alias) + else: + bucket[path] = str(value) + + def _validate_alias_references( models_config: dict[str, Any], - selections: dict[str, str], + selections: dict[str, str | list[str]], workflow_name: str, ) -> None: - """Validate that bundled workflow YAMLs reference aliases defined in models.yaml.""" + """Validate that bundled workflow YAMLs reference aliases defined in models.yaml. + + Handles both scalar-valued roles and list-valued roles (e.g. a validator + pool). A plain ``set(selections.values())`` would raise + ``TypeError: unhashable type: 'list'`` once any role is list-valued. + """ known_aliases = {m["alias"] for m in models_config.get("model_configs", [])} - unknown = set(selections.values()) - known_aliases + referenced: set[str] = set() + for value in selections.values(): + if isinstance(value, list): + referenced.update(value) + else: + referenced.add(value) + unknown = referenced - known_aliases if unknown: raise ValueError( f"Workflow '{workflow_name}' references unknown model aliases: {unknown}. Known aliases: {known_aliases}" diff --git a/src/anonymizer/interface/anonymizer.py b/src/anonymizer/interface/anonymizer.py index cca83f7..8cdbb04 100644 --- a/src/anonymizer/interface/anonymizer.py +++ b/src/anonymizer/interface/anonymizer.py @@ -95,7 +95,7 @@ def __init__( logger.info("🔧 Anonymizer initialized with %d model configs", len(self._model_configs)) det = self._selected_models.detection logger.info(LOG_INDENT + "🔎 detector: %s", det.entity_detector) - logger.info(LOG_INDENT + "✅ validator: %s", det.entity_validator) + logger.info(LOG_INDENT + "✅ validator: %s", ", ".join(det.entity_validator)) logger.info(LOG_INDENT + "🧩 augmenter: %s", det.entity_augmenter) if data_designer is not None: @@ -210,6 +210,8 @@ def _run_internal( model_configs=self._model_configs, selected_models=self._selected_models.detection, gliner_detection_threshold=config.detect.gliner_threshold, + validation_max_entities_per_call=config.detect.validation_max_entities_per_call, + validation_excerpt_window_chars=config.detect.validation_excerpt_window_chars, entity_labels=config.detect.entity_labels, privacy_goal=config.rewrite.privacy_goal if config.rewrite else None, data_summary=data.data_summary, diff --git a/tests/config/test_anonymizer_config.py b/tests/config/test_anonymizer_config.py index 37d2b6e..ee8c859 100644 --- a/tests/config/test_anonymizer_config.py +++ b/tests/config/test_anonymizer_config.py @@ -117,3 +117,31 @@ def test_both_modes_set_exits() -> None: """Setting both replace and rewrite on AnonymizerConfig violates the model_validator.""" with pytest.raises(ValidationError): AnonymizerConfig(replace=Redact(), rewrite=Rewrite()) + + +def test_detect_chunked_validation_defaults() -> None: + config = AnonymizerConfig(replace=Redact()) + assert config.detect.validation_max_entities_per_call == 100 + assert config.detect.validation_excerpt_window_chars == 500 + + +def test_detect_chunked_validation_accepts_overrides() -> None: + config = AnonymizerConfig( + detect={ + "validation_max_entities_per_call": 25, + "validation_excerpt_window_chars": 1000, + }, + replace=Redact(), + ) + assert config.detect.validation_max_entities_per_call == 25 + assert config.detect.validation_excerpt_window_chars == 1000 + + +def test_detect_validation_max_entities_per_call_must_be_positive() -> None: + with pytest.raises(ValidationError): + AnonymizerConfig(detect={"validation_max_entities_per_call": 0}, replace=Redact()) + + +def test_detect_validation_excerpt_window_chars_must_be_positive() -> None: + with pytest.raises(ValidationError): + AnonymizerConfig(detect={"validation_excerpt_window_chars": 0}, replace=Redact()) diff --git a/tests/engine/test_chunked_validation.py b/tests/engine/test_chunked_validation.py new file mode 100644 index 0000000..f9b402a --- /dev/null +++ b/tests/engine/test_chunked_validation.py @@ -0,0 +1,975 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for chunked LLM validation. + +The module is layered: pure helpers (ordering, chunking, excerpts, prompt +rendering, merging) are tested directly; the chunk dispatch is tested via a +fake ``ModelFacade`` that records calls and returns preconfigured responses. +""" + +from __future__ import annotations + +import json +import re +from typing import Any, Callable + +import pytest + +from anonymizer.engine.constants import ( + COL_MERGED_TAGGED_TEXT, + COL_SEED_ENTITIES, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAG_NOTATION, + COL_TEXT, + COL_VALIDATION_CANDIDATES, + COL_VALIDATION_DECISIONS, + COL_VALIDATION_SKELETON, +) +from anonymizer.engine.detection.chunked_validation import ( + ChunkedValidationParams, + build_chunk_excerpt, + build_chunk_skeleton, + chunk_candidates, + chunked_validate_row, + make_chunked_validation_generator, + merge_chunk_decisions, + order_candidates_by_position, + render_chunk_prompt, +) +from anonymizer.engine.detection.postprocess import EntitySpan, TagNotation, apply_validation_decisions +from anonymizer.engine.schemas import ( + EntitiesSchema, + RawValidationDecisionsSchema, + ValidationCandidateSchema, + ValidationCandidatesSchema, +) + +# --------------------------------------------------------------------------- +# Test fixtures +# --------------------------------------------------------------------------- + + +class FakeFacade: + """Test double for ``ModelFacade`` recording invocations and replaying canned responses. + + Exposes a ``generate()`` method because ``_dispatch_chunk`` calls the + sync facade primitive directly; per-row chunk concurrency comes from + the ``ThreadPoolExecutor`` inside ``chunked_validate_row``. Under DD's + async engine the sync ``.generate()`` call is transparently bridged to + ``agenerate`` by the DD runtime. + + A canned response can be a ``dict`` (auto-wrapped in a ```json fence), a + raw string, or a callable that receives the prompt and returns either. + Set ``raise_on_call=True`` to simulate a terminal LLM failure. + """ + + def __init__( + self, + alias: str, + response: dict | str | Callable[[str], dict | str] | None = None, + *, + raise_on_call: bool = False, + ) -> None: + self.alias = alias + self._response = response + self._raise = raise_on_call + self.calls: list[dict[str, Any]] = [] + + def generate(self, *, prompt, parser, system_prompt=None, purpose=None, **kwargs): + self.calls.append( + { + "prompt": prompt, + "system_prompt": system_prompt, + "purpose": purpose, + "kwargs": kwargs, + } + ) + if self._raise: + raise RuntimeError(f"forced failure from {self.alias}") + response = self._response + if callable(response): + response = response(prompt) + raw = response if isinstance(response, str) else f"```json\n{json.dumps(response)}\n```" + return parser(raw), [] + + +def _entity_span(entity_id: str, value: str, label: str, start: int, end: int) -> EntitySpan: + return EntitySpan( + entity_id=entity_id, + value=value, + label=label, + start_position=start, + end_position=end, + score=1.0, + source="detector", + ) + + +def _candidates_schema(*candidates: tuple[str, str, str]) -> ValidationCandidatesSchema: + return ValidationCandidatesSchema( + candidates=[ValidationCandidateSchema(id=cid, value=val, label=lab) for cid, val, lab in candidates] + ) + + +# --------------------------------------------------------------------------- +# Pure helpers: ordering / chunking / excerpt / skeleton / prompt / merge +# --------------------------------------------------------------------------- + + +class TestOrderCandidatesByPosition: + def test_orders_by_start_then_end_then_id(self) -> None: + candidates = _candidates_schema( + ("a_10_13", "foo", "first_name"), + ("b_0_5", "bar", "email"), + ("c_10_12", "baz", "city"), + ) + spans = [ + _entity_span("a_10_13", "foo", "first_name", 10, 13), + _entity_span("b_0_5", "bar", "email", 0, 5), + _entity_span("c_10_12", "baz", "city", 10, 12), + ] + ordered = order_candidates_by_position(candidates, spans) + assert [pair[0].id for pair in ordered] == ["b_0_5", "c_10_12", "a_10_13"] + + def test_missing_seed_raises_with_triage_hint(self) -> None: + candidates = _candidates_schema(("missing", "x", "y")) + with pytest.raises(ValueError, match="merge_and_build_candidates or prepare_validation_inputs"): + order_candidates_by_position(candidates, []) + + +class TestChunkCandidates: + def test_splits_into_chunks_of_at_most_n(self) -> None: + ordered = [(i, None) for i in range(5)] # type: ignore[misc] + chunks = chunk_candidates(ordered, max_entities_per_call=2) + assert [len(c) for c in chunks] == [2, 2, 1] + + def test_empty_input_yields_no_chunks(self) -> None: + assert chunk_candidates([], max_entities_per_call=10) == [] + + def test_tuple_input_returns_list_of_lists(self) -> None: + # Input-type tolerance: a sequence (not only a list) should work, and + # the declared ``list[list[...]]`` return contract must hold even + # when the caller hands in a tuple. + ordered: tuple[tuple[int, None], ...] = tuple((i, None) for i in range(3)) # type: ignore[misc] + chunks = chunk_candidates(ordered, max_entities_per_call=2) + assert chunks == [[(0, None), (1, None)], [(2, None)]] + assert all(isinstance(c, list) for c in chunks) + + +class TestBuildChunkExcerpt: + def test_includes_fully_contained_neighbors_and_drops_outside_spans(self) -> None: + text = "Alice met Bob at Acme HQ in Seattle yesterday." + spans = [ + _entity_span("alice_0_5", "Alice", "first_name", 0, 5), + _entity_span("bob_10_13", "Bob", "first_name", 10, 13), + _entity_span("acme_17_21", "Acme", "organization", 17, 21), + _entity_span("seattle_28_35", "Seattle", "city", 28, 35), + ] + # Window 8 around Bob (10..13) => excerpt [2, 21]; contains Acme fully (17..21) + # but truncates 'Alice' (0..5) and excludes 'Seattle' (28..35). + excerpt = build_chunk_excerpt( + text=text, + chunk_spans=[spans[1]], + all_spans=spans, + window_chars=8, + notation=TagNotation.xml, + ) + assert "Bob" in excerpt + assert "Acme" in excerpt + # Alice is partially in the window (end=5 inside, but start=0 before) → excluded + assert "first_name>Alice" not in excerpt + assert "Seattle" not in excerpt + + def test_partially_contained_neighbor_is_excluded(self) -> None: + """A neighbor that only partially overlaps the excerpt window must not be re-tagged. + + Tagging a truncated span would emit text that doesn't match the + entity's actual value, which is worse than omitting the tag entirely. + """ + text = "PREFIX Alice SUFFIX" # Alice at 7..12 + spans = [_entity_span("a", "Alice", "first_name", 7, 12)] + # Excerpt window [8, 12] cuts off 'A' from Alice → partial, must be excluded. + excerpt = build_chunk_excerpt( + text=text, + chunk_spans=spans, + all_spans=spans, + window_chars=0, + notation=TagNotation.xml, + ) + # chunk_spans IS Alice (7..12), so its own span is within. This test + # instead builds the partial case via a distinct chunk entity: + bob = _entity_span("b", "lice", "first_name", 8, 12) + chunk = [bob] + excerpt2 = build_chunk_excerpt( + text=text, + chunk_spans=chunk, + all_spans=[spans[0], bob], + window_chars=0, + notation=TagNotation.xml, + ) + # 'Alice' at 7..12 starts before excerpt_start=8 → excluded. + # 'bob' (the chunk entity) has positions 8..12 which are within the window → included. + assert "lice" in excerpt2 + assert "Alice" not in excerpt2 + _ = excerpt # suppress "unused" lint + + def test_forces_requested_notation_over_heuristic(self) -> None: + text = "Alice met Bob at HQ" + spans = [_entity_span("alice_0_5", "Alice", "first_name", 0, 5)] + excerpt = build_chunk_excerpt( + text=text, + chunk_spans=spans, + all_spans=spans, + window_chars=100, + notation=TagNotation.paren, + ) + assert "((SENSITIVE:first_name|Alice))" in excerpt + assert "" not in excerpt + + def test_empty_chunk_returns_empty_string(self) -> None: + assert ( + build_chunk_excerpt(text="x", chunk_spans=[], all_spans=[], window_chars=5, notation=TagNotation.xml) == "" + ) + + +class TestBuildChunkSkeleton: + def test_skeleton_matches_chunk_only(self) -> None: + candidates = _candidates_schema(("a", "Alice", "first_name"), ("b", "Bob", "first_name")) + skeleton = build_chunk_skeleton([candidates.candidates[0]]) + assert skeleton == { + "decisions": [ + { + "id": "a", + "value": "Alice", + "label": "first_name", + "decision": None, + "proposed_label": None, + "reason": None, + } + ] + } + + +class TestRenderChunkPrompt: + def test_substitutes_excerpt_skeleton_and_notation(self) -> None: + template = ( + "Input: {{ _seed_tagged_text }}\n" + "Skeleton: {{ _validation_skeleton }}\n" + '{%- if _tag_notation == "xml" -%}notation-is-xml{%- endif -%}' + ) + rendered = render_chunk_prompt( + template=template, + excerpt="hello Alice", + skeleton={"decisions": [{"id": "a"}]}, + notation=TagNotation.xml, + ) + assert "Input: hello Alice" in rendered + assert "notation-is-xml" in rendered + # Dict rendered via Python str(); this matches the existing production prompt path. + assert "'id': 'a'" in rendered + + +class TestMergeChunkDecisions: + def test_filters_unknown_ids_and_deduplicates(self) -> None: + candidates = _candidates_schema(("a", "Alice", "first_name"), ("b", "Bob", "first_name")) + chunk_one = RawValidationDecisionsSchema.model_validate( + {"decisions": [{"id": "a", "decision": "keep"}, {"id": "ghost", "decision": "drop"}]} + ) + chunk_two = RawValidationDecisionsSchema.model_validate( + {"decisions": [{"id": "b", "decision": "drop"}, {"id": "a", "decision": "reclass"}]} + ) + merged = merge_chunk_decisions([chunk_one, chunk_two], candidates) + ids = [d["id"] for d in merged["decisions"]] + assert ids == ["a", "b"] # 'ghost' dropped; duplicate 'a' from chunk_two ignored + by_id = {d["id"]: d for d in merged["decisions"]} + assert by_id["a"]["decision"] == "keep" + assert by_id["b"]["decision"] == "drop" + assert by_id["a"]["value"] == "Alice" # enriched from candidate + assert by_id["a"]["label"] == "first_name" + + def test_drops_decisions_without_verdict(self) -> None: + """A decision with ``decision=None`` is equivalent to 'no answer' and must not leak through. + + Downstream ``apply_validation_decisions`` interprets missing-id as + 'keep unchanged'. Emitting a null-decision entry would break that. + """ + candidates = _candidates_schema(("a", "Alice", "first_name")) + chunk = RawValidationDecisionsSchema.model_validate({"decisions": [{"id": "a", "decision": None}]}) + merged = merge_chunk_decisions([chunk], candidates) + assert merged == {"decisions": []} + + def test_later_real_verdict_wins_over_earlier_null_duplicate(self) -> None: + """A null-decision entry must not reserve the id against a later real verdict. + + If an earlier chunk returns ``decision=None`` for id 'a' and a later + chunk returns ``decision='keep'`` for 'a', the merged payload should + retain the real verdict. Otherwise a transient "no answer" in one + chunk would suppress a valid answer from another. + """ + candidates = _candidates_schema(("a", "Alice", "first_name")) + chunk_one = RawValidationDecisionsSchema.model_validate({"decisions": [{"id": "a", "decision": None}]}) + chunk_two = RawValidationDecisionsSchema.model_validate({"decisions": [{"id": "a", "decision": "keep"}]}) + merged = merge_chunk_decisions([chunk_one, chunk_two], candidates) + assert merged["decisions"] == [ + { + "id": "a", + "value": "Alice", + "label": "first_name", + "decision": "keep", + "proposed_label": "", + "reason": None, + } + ] + + +# --------------------------------------------------------------------------- +# Async dispatch: chunked_validate_row end-to-end with fake facades +# --------------------------------------------------------------------------- + + +def _build_row( + *, + text: str, + seed_entities: list[EntitySpan], + candidates: ValidationCandidatesSchema, + notation: TagNotation = TagNotation.xml, +) -> dict[str, Any]: + return { + COL_TEXT: text, + COL_SEED_ENTITIES: EntitiesSchema( + entities=[ + { + "id": span.entity_id, + "value": span.value, + "label": span.label, + "start_position": span.start_position, + "end_position": span.end_position, + "score": span.score, + "source": span.source, + } + for span in seed_entities + ] + ).model_dump(mode="json"), + COL_SEED_VALIDATION_CANDIDATES: candidates.model_dump(mode="json"), + COL_TAG_NOTATION: notation.value, + } + + +_MINIMAL_TEMPLATE = "TAGGED:{{ _seed_tagged_text }}|SKELETON:{{ _validation_skeleton }}|NOTATION:{{ _tag_notation }}" + + +class TestChunkedValidateRowPoolOfOne: + def test_single_chunk_single_alias_dispatches_once_and_merges(self) -> None: + text = "Alice and Bob met." + spans = [ + _entity_span("a", "Alice", "first_name", 0, 5), + _entity_span("b", "Bob", "first_name", 10, 13), + ] + candidates = _candidates_schema(("a", "Alice", "first_name"), ("b", "Bob", "first_name")) + row = _build_row(text=text, seed_entities=spans, candidates=candidates) + + facade = FakeFacade( + "v0", + response={ + "decisions": [ + {"id": "a", "decision": "keep", "proposed_label": "", "reason": "real"}, + {"id": "b", "decision": "drop", "proposed_label": "", "reason": "placeholder"}, + ] + }, + ) + params = ChunkedValidationParams( + pool=["v0"], + max_entities_per_call=10, + excerpt_window_chars=100, + prompt_template=_MINIMAL_TEMPLATE, + ) + + out = chunked_validate_row(row, params, {"v0": facade}) + + assert len(facade.calls) == 1 + decisions = out[COL_VALIDATION_DECISIONS]["decisions"] + assert {d["id"]: d["decision"] for d in decisions} == {"a": "keep", "b": "drop"} + + def test_single_chunk_sends_single_chunk_tagged_text_not_windowed_excerpt(self) -> None: + """Single-chunk rows must receive the fully tagged document, not a windowed excerpt. + + Regression: before this fix, every chunk — including a lone single + chunk — was routed through ``build_chunk_excerpt`` with the configured + window. That silently narrowed the validator's context relative to the + pre-chunking ``LLMStructuredColumnConfig`` path, which always served + the full ``_seed_tagged_text``. The excerpt window is a multi-chunk + cost-control lever; applying it to the single-call path is incorrect. + """ + prefix = "HEADER_MARKER_ALPHA " * 80 + suffix = " FOOTER_MARKER_OMEGA" * 80 + middle = "Alice met Bob." + text = prefix + middle + suffix + alice_start = len(prefix) + bob_start = alice_start + 10 + + spans = [ + _entity_span("a", "Alice", "first_name", alice_start, alice_start + 5), + _entity_span("b", "Bob", "first_name", bob_start, bob_start + 3), + ] + candidates = _candidates_schema( + ("a", "Alice", "first_name"), + ("b", "Bob", "first_name"), + ) + row = _build_row(text=text, seed_entities=spans, candidates=candidates) + + facade = FakeFacade( + "v0", + response={ + "decisions": [ + {"id": "a", "decision": "keep"}, + {"id": "b", "decision": "keep"}, + ] + }, + ) + params = ChunkedValidationParams( + pool=["v0"], + max_entities_per_call=10, + excerpt_window_chars=50, + prompt_template=_MINIMAL_TEMPLATE, + ) + + chunked_validate_row(row, params, {"v0": facade}) + + assert len(facade.calls) == 1 + prompt = facade.calls[0]["prompt"] + assert "HEADER_MARKER_ALPHA" in prompt, ( + "single-chunk row should receive the full text prefix; a 50-char window " + "around Alice/Bob would clip the prefix entirely." + ) + assert "FOOTER_MARKER_OMEGA" in prompt, ( + "single-chunk row should receive the full text suffix; a 50-char window " + "around Alice/Bob would clip the suffix entirely." + ) + + def test_empty_candidates_short_circuits_without_calls(self) -> None: + row = _build_row(text="hello", seed_entities=[], candidates=_candidates_schema()) + facade = FakeFacade("v0", response={"decisions": []}) + params = ChunkedValidationParams( + pool=["v0"], max_entities_per_call=10, excerpt_window_chars=50, prompt_template=_MINIMAL_TEMPLATE + ) + + out = chunked_validate_row(row, params, {"v0": facade}) + + assert facade.calls == [] + assert out[COL_VALIDATION_DECISIONS] == {"decisions": []} + + def test_system_prompt_is_forwarded_to_facade(self) -> None: + # ``ChunkedValidationParams.system_prompt`` must reach ``facade.generate``. + # The recipe appends JSON task instructions before dispatch, so we assert + # substring containment with a distinctive sentinel rather than equality. + text = "Alice spoke." + spans = [_entity_span("a", "Alice", "first_name", 0, 5)] + candidates = _candidates_schema(("a", "Alice", "first_name")) + row = _build_row(text=text, seed_entities=spans, candidates=candidates) + + facade = FakeFacade( + "v0", + response={"decisions": [{"id": "a", "decision": "keep", "proposed_label": "", "reason": "x"}]}, + ) + sentinel = "SYSPROMPT_SENTINEL_CHUNKED_VALIDATION_TEST" + params = ChunkedValidationParams( + pool=["v0"], + max_entities_per_call=10, + excerpt_window_chars=50, + prompt_template=_MINIMAL_TEMPLATE, + system_prompt=f"You are a validator. {sentinel}", + ) + + chunked_validate_row(row, params, {"v0": facade}) + + assert len(facade.calls) == 1 + forwarded = facade.calls[0]["system_prompt"] + assert forwarded is not None + assert sentinel in forwarded + + def test_system_prompt_default_none_is_forwarded_untouched(self) -> None: + # The recipe maps ``None`` input to ``None`` output; this test pins + # that no intermediate layer replaces it with a placeholder string + # when ``ChunkedValidationParams.system_prompt`` is left at its + # default. + text = "Alice spoke." + spans = [_entity_span("a", "Alice", "first_name", 0, 5)] + candidates = _candidates_schema(("a", "Alice", "first_name")) + row = _build_row(text=text, seed_entities=spans, candidates=candidates) + + facade = FakeFacade( + "v0", + response={"decisions": [{"id": "a", "decision": "keep", "proposed_label": "", "reason": "x"}]}, + ) + params = ChunkedValidationParams( + pool=["v0"], + max_entities_per_call=10, + excerpt_window_chars=50, + prompt_template=_MINIMAL_TEMPLATE, + # system_prompt intentionally omitted (default None) + ) + + chunked_validate_row(row, params, {"v0": facade}) + + assert len(facade.calls) == 1 + # The recipe maps ``None`` input to ``None`` output, so the facade + # should receive no system prompt. + assert facade.calls[0]["system_prompt"] is None + + +class TestChunkedValidateRowPoolOfTwoRoundRobin: + def test_chunks_assigned_round_robin_across_pool(self) -> None: + text = "A B C D E F" + " " * 50 # pad so excerpts don't overlap + spans = [_entity_span(f"e{i}", chr(ord("A") + i), "first_name", i * 2, i * 2 + 1) for i in range(6)] + candidates = _candidates_schema(*[(f"e{i}", chr(ord("A") + i), "first_name") for i in range(6)]) + row = _build_row(text=text, seed_entities=spans, candidates=candidates) + + def make_response(chunk_size: int) -> Callable[[str], dict]: + def _respond(prompt: str) -> dict: + # Parse out which entity ids are in this chunk from the skeleton. + # We can't easily, so just return empty decisions — the dispatch + # order assertion is about which alias was called, not contents. + return {"decisions": []} + + return _respond + + v0 = FakeFacade("v0", response=make_response(2)) + v1 = FakeFacade("v1", response=make_response(2)) + params = ChunkedValidationParams( + pool=["v0", "v1"], + max_entities_per_call=2, + excerpt_window_chars=50, + prompt_template=_MINIMAL_TEMPLATE, + ) + chunked_validate_row(row, params, {"v0": v0, "v1": v1}) + # 6 candidates / 2 per chunk = 3 chunks; round-robin → v0,v1,v0 + assert len(v0.calls) == 2 + assert len(v1.calls) == 1 + + +class TestChunkedValidateRowMultiChunkReassembly: + def test_decisions_merged_across_chunks(self) -> None: + text = "one two three four five" + spans = [ + _entity_span("c1", "one", "first_name", 0, 3), + _entity_span("c2", "two", "first_name", 4, 7), + _entity_span("c3", "three", "first_name", 8, 13), + _entity_span("c4", "four", "first_name", 14, 18), + ] + candidates = _candidates_schema( + ("c1", "one", "first_name"), + ("c2", "two", "first_name"), + ("c3", "three", "first_name"), + ("c4", "four", "first_name"), + ) + row = _build_row(text=text, seed_entities=spans, candidates=candidates) + + def responder_for(alias: str) -> Callable[[str], dict]: + def respond(prompt: str) -> dict: + # Return a decision for every id mentioned in the skeleton portion of this prompt. + # Use alias-encoded decisions so we can verify which chunk decided which id. + ids_here = [cid for cid in ("c1", "c2", "c3", "c4") if f"'id': '{cid}'" in prompt] + return { + "decisions": [ + {"id": cid, "decision": "keep", "proposed_label": "", "reason": alias} for cid in ids_here + ] + } + + return respond + + v0 = FakeFacade("v0", response=responder_for("v0")) + v1 = FakeFacade("v1", response=responder_for("v1")) + params = ChunkedValidationParams( + pool=["v0", "v1"], + max_entities_per_call=2, + excerpt_window_chars=50, + prompt_template=_MINIMAL_TEMPLATE, + ) + out = chunked_validate_row(row, params, {"v0": v0, "v1": v1}) + decisions = {d["id"]: d for d in out[COL_VALIDATION_DECISIONS]["decisions"]} + assert set(decisions) == {"c1", "c2", "c3", "c4"} + # Chunk 0 (c1,c2) → v0; chunk 1 (c3,c4) → v1. + assert decisions["c1"]["reason"] == "v0" + assert decisions["c2"]["reason"] == "v0" + assert decisions["c3"]["reason"] == "v1" + assert decisions["c4"]["reason"] == "v1" + + +class TestChunkedValidateRowFailurePropagation: + def test_row_fails_only_when_every_pool_member_raises_for_a_chunk(self) -> None: + """With failover enabled, a chunk only fails when *every* pool member raises. + + Downstream DD reporting then turns that row into a FailedRecord via + ``NddAdapter._detect_missing_records`` — no unscrubbed passthrough. + """ + spans = [ + _entity_span("a", "Alice", "first_name", 0, 5), + _entity_span("b", "Bob", "first_name", 10, 13), + ] + candidates = _candidates_schema(("a", "Alice", "first_name"), ("b", "Bob", "first_name")) + row = _build_row(text="Alice and Bob", seed_entities=spans, candidates=candidates) + + v0 = FakeFacade("v0", raise_on_call=True) + v1 = FakeFacade("v1", raise_on_call=True) + params = ChunkedValidationParams( + pool=["v0", "v1"], + max_entities_per_call=1, + excerpt_window_chars=50, + prompt_template=_MINIMAL_TEMPLATE, + ) + # The last alias tried propagates; order is round-robin + wrap, so + # chunk 0 starts at v0 → fails → v1 (fails, last) → raises "from v1". + with pytest.raises(RuntimeError, match="forced failure from v1"): + chunked_validate_row(row, params, {"v0": v0, "v1": v1}) + + +class TestChunkedValidateRowCrossAliasFailover: + def test_primary_alias_failure_falls_over_to_secondary(self) -> None: + """Primary raises, secondary in pool returns a valid response → chunk succeeds, row does not fail.""" + spans = [_entity_span("a", "Alice", "first_name", 0, 5)] + candidates = _candidates_schema(("a", "Alice", "first_name")) + row = _build_row(text="Alice", seed_entities=spans, candidates=candidates) + + v0 = FakeFacade("v0", raise_on_call=True) + v1 = FakeFacade("v1", response={"decisions": [{"id": "a", "decision": "keep"}]}) + params = ChunkedValidationParams( + pool=["v0", "v1"], + max_entities_per_call=5, + excerpt_window_chars=50, + prompt_template=_MINIMAL_TEMPLATE, + ) + out = chunked_validate_row(row, params, {"v0": v0, "v1": v1}) + decisions = out[COL_VALIDATION_DECISIONS]["decisions"] + assert len(decisions) == 1 + assert decisions[0]["id"] == "a" + assert decisions[0]["decision"] == "keep" + # Each alias was tried exactly once for this single-chunk row. + assert len(v0.calls) == 1 + assert len(v1.calls) == 1 + + def test_single_alias_pool_does_not_failover(self) -> None: + """A one-alias pool has no fallback — one attempt, propagate immediately. + + This keeps the behavioural guarantee that pools of size 1 behave + exactly as the pre-failover dispatch did: no hidden extra attempts. + """ + spans = [_entity_span("a", "Alice", "first_name", 0, 5)] + candidates = _candidates_schema(("a", "Alice", "first_name")) + row = _build_row(text="Alice", seed_entities=spans, candidates=candidates) + + v0 = FakeFacade("v0", raise_on_call=True) + params = ChunkedValidationParams( + pool=["v0"], + max_entities_per_call=5, + excerpt_window_chars=50, + prompt_template=_MINIMAL_TEMPLATE, + ) + with pytest.raises(RuntimeError, match="forced failure from v0"): + chunked_validate_row(row, params, {"v0": v0}) + assert len(v0.calls) == 1 + + def test_failover_wraps_starting_from_round_robin_primary(self) -> None: + """With pool [v0,v1] and 3 chunks, chunk 1's primary is v1. + + If v1 fails, failover wraps to v0 (next position mod len(pool)). + We verify by ensuring chunk 1 succeeds on v0's response, while + chunk 0 and chunk 2 succeed on v0 (their primary) directly. + """ + spans = [ + _entity_span("a", "Alice", "first_name", 0, 5), + _entity_span("b", "Bob", "first_name", 10, 13), + _entity_span("c", "Carol", "first_name", 20, 25), + ] + candidates = _candidates_schema( + ("a", "Alice", "first_name"), + ("b", "Bob", "first_name"), + ("c", "Carol", "first_name"), + ) + row = _build_row(text="Alice and Bob and Carol", seed_entities=spans, candidates=candidates) + + def v0_response(prompt: str) -> dict: + # The chunk's skeleton is serialized into the prompt; pick the id + # from there. We can't use the raw text excerpt to distinguish + # chunks because the text is short enough that every chunk's + # excerpt window covers the whole string. + for candidate_id in ("a", "b", "c"): + if f"'id': '{candidate_id}'" in prompt: + return {"decisions": [{"id": candidate_id, "decision": "keep"}]} + raise AssertionError(f"no known candidate id found in prompt: {prompt!r}") + + v0 = FakeFacade("v0", response=v0_response) + v1 = FakeFacade("v1", raise_on_call=True) + params = ChunkedValidationParams( + pool=["v0", "v1"], + max_entities_per_call=1, + excerpt_window_chars=50, + prompt_template=_MINIMAL_TEMPLATE, + ) + out = chunked_validate_row(row, params, {"v0": v0, "v1": v1}) + decisions = {d["id"]: d["decision"] for d in out[COL_VALIDATION_DECISIONS]["decisions"]} + assert decisions == {"a": "keep", "b": "keep", "c": "keep"} + # v0 serviced all three chunks: chunk 0 + chunk 2 directly, chunk 1 via failover. + assert len(v0.calls) == 3 + # v1 saw exactly one call — chunk 1's primary attempt that raised. + assert len(v1.calls) == 1 + + +class TestChunkedValidateRowMissingIdMirrorsSingleCall: + def test_decision_for_non_candidate_id_is_dropped(self) -> None: + """Matches single-call contract: ``enrich_validation_decisions`` filters to candidate ids.""" + spans = [_entity_span("a", "Alice", "first_name", 0, 5)] + candidates = _candidates_schema(("a", "Alice", "first_name")) + row = _build_row(text="Alice", seed_entities=spans, candidates=candidates) + facade = FakeFacade( + "v0", + response={ + "decisions": [ + {"id": "a", "decision": "keep"}, + {"id": "unknown", "decision": "drop"}, + ] + }, + ) + params = ChunkedValidationParams( + pool=["v0"], max_entities_per_call=5, excerpt_window_chars=20, prompt_template=_MINIMAL_TEMPLATE + ) + out = chunked_validate_row(row, params, {"v0": facade}) + ids = [d["id"] for d in out[COL_VALIDATION_DECISIONS]["decisions"]] + assert ids == ["a"] + + +class TestChunkedValidateRowGuardsBadConfig: + def test_pool_alias_missing_from_models_raises_helpful_error(self) -> None: + spans = [_entity_span("a", "Alice", "first_name", 0, 5)] + candidates = _candidates_schema(("a", "Alice", "first_name")) + row = _build_row(text="Alice", seed_entities=spans, candidates=candidates) + params = ChunkedValidationParams( + pool=["missing_alias"], + max_entities_per_call=5, + excerpt_window_chars=20, + prompt_template=_MINIMAL_TEMPLATE, + ) + with pytest.raises(KeyError, match="missing_alias"): + chunked_validate_row(row, params, {"v0": FakeFacade("v0", response={"decisions": []})}) + + +# --------------------------------------------------------------------------- +# Factory: make_chunked_validation_generator +# --------------------------------------------------------------------------- + + +class TestMakeChunkedValidationGenerator: + def test_decorator_metadata_encodes_pool_and_required_columns(self) -> None: + fn = make_chunked_validation_generator(["v0", "v1"]) + meta = fn.custom_column_metadata + assert meta["model_aliases"] == ["v0", "v1"] + assert set(meta["required_columns"]) == { + COL_TEXT, + COL_SEED_ENTITIES, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAG_NOTATION, + } + # Must not declare columns we deliberately don't read; an over-broad + # required_columns would distort DAG ordering elsewhere. In particular + # the post-augmentation views (COL_MERGED_TAGGED_TEXT, COL_VALIDATION_CANDIDATES) + # are downstream of this step and would create a cycle if declared. + assert COL_VALIDATION_SKELETON not in meta["required_columns"] + assert COL_MERGED_TAGGED_TEXT not in meta["required_columns"] + assert COL_VALIDATION_CANDIDATES not in meta["required_columns"] + + def test_factory_rejects_empty_pool(self) -> None: + with pytest.raises(ValueError, match="pool is empty"): + make_chunked_validation_generator([]) + + def test_generator_forwards_to_chunked_validate_row(self) -> None: + """The DD-exposed wrapper is sync; calling it directly must return a + dict populated by the row logic (not a coroutine). + """ + spans = [_entity_span("a", "Alice", "first_name", 0, 5)] + candidates = _candidates_schema(("a", "Alice", "first_name")) + row = _build_row(text="Alice", seed_entities=spans, candidates=candidates) + fn = make_chunked_validation_generator(["v0"]) + facade = FakeFacade("v0", response={"decisions": [{"id": "a", "decision": "keep"}]}) + params = ChunkedValidationParams( + pool=["v0"], max_entities_per_call=5, excerpt_window_chars=20, prompt_template=_MINIMAL_TEMPLATE + ) + out = fn(row, params, {"v0": facade}) + assert isinstance(out, dict) + assert out[COL_VALIDATION_DECISIONS]["decisions"][0]["id"] == "a" + + def test_generator_is_sync_callable_returning_dict(self) -> None: + """Regression: DD's default thread-pool engine calls the wrapper + synchronously and rejects coroutine returns with "must return a dict, + got coroutine". Guard against accidentally re-introducing an async + outer wrapper.""" + import asyncio as _asyncio + import inspect as _inspect + + fn = make_chunked_validation_generator(["v0"]) + inner = _inspect.unwrap(fn) + assert not _asyncio.iscoroutinefunction(inner), ( + "DD-exposed generator must be sync; an async outer wrapper breaks the default thread-pool engine." + ) + + +# --------------------------------------------------------------------------- +# Behavioral regression: single-chunk vs multi-chunk parity (pool-of-one). +# --------------------------------------------------------------------------- + + +def _selective_facade(alias: str, decisions_by_id: dict[str, dict[str, Any]]) -> FakeFacade: + """Return a facade whose response depends on the ids embedded in each chunk's prompt. + + The prompt renders ``_validation_skeleton`` as Python ``str(dict)``; we + parse ``'id': 'X'`` tokens back out to select which decisions to return. + This makes the LLM behaviour a pure function of the *candidate ids in + the chunk*, not of chunk shape or sequencing -- which is exactly the + assumption a real deterministic validator would satisfy when re-run. + """ + + def respond(prompt: str) -> dict[str, Any]: + ids = re.findall(r"'id':\s*'([^']+)'", prompt) + return {"decisions": [decisions_by_id[i] for i in ids if i in decisions_by_id]} + + return FakeFacade(alias, response=respond) + + +def _summarize(validated: list[EntitySpan]) -> list[tuple[str, str]]: + return [(e.entity_id, e.label) for e in validated] + + +def _tally(before: list[EntitySpan], after: list[EntitySpan], decisions: dict) -> dict[str, int]: + """Count keep/reclass/drop/untouched relative to the decisions the LLM returned.""" + before_labels = {e.entity_id: e.label for e in before} + after_ids = {e.entity_id for e in after} + decided_ids: dict[str, dict[str, str]] = {d["id"]: d for d in decisions["decisions"]} + counts = {"keep": 0, "reclass": 0, "drop": 0, "untouched": 0} + for entity_id, original_label in before_labels.items(): + decision = decided_ids.get(entity_id) + if decision is None: + counts["untouched"] += 1 + continue + verdict = decision.get("decision") + if verdict == "drop": + assert entity_id not in after_ids, f"drop decision for {entity_id} did not remove it" + counts["drop"] += 1 + elif verdict == "reclass": + counts["reclass"] += 1 + elif verdict == "keep": + counts["keep"] += 1 + return counts + + +def _normalize_decisions(doc: dict) -> list[tuple[str, str, str]]: + return sorted((d["id"], d.get("decision") or "", d.get("proposed_label") or "") for d in doc["decisions"]) + + +class TestChunkedValidationRegression: + """Partitioning must not change outcomes when the LLM is deterministic per candidate. + + Guards the most important property we promised when we switched from a + single ``LLMStructuredColumnConfig`` to chunked ``CustomColumnConfig`` + dispatch: given the same set of per-id decisions, chunk sizing is an + implementation detail of *how* we talk to the validator, not *what* + entities survive validation. + """ + + SCENARIO_TEXT = ( + # Positions referenced below are into this exact string. + "Alice met Bob in Chicago at Acme HQ; Doe introduced Eve to the team later." + # 0 5 10 15 20 28 34 43 53 56 + ) + + @pytest.fixture + def scenario(self) -> tuple[list[EntitySpan], ValidationCandidatesSchema, dict[str, dict[str, Any]]]: + spans = [ + _entity_span("a", "Alice", "first_name", 0, 5), + _entity_span("b", "Bob", "first_name", 10, 13), + _entity_span("c", "Chicago", "city", 17, 24), + _entity_span("d", "Acme", "organization", 28, 32), + _entity_span("e", "Doe", "last_name", 37, 40), + _entity_span("f", "Eve", "first_name", 54, 57), + ] + candidates = _candidates_schema( + ("a", "Alice", "first_name"), + ("b", "Bob", "first_name"), + ("c", "Chicago", "city"), + ("d", "Acme", "organization"), + ("e", "Doe", "last_name"), + ("f", "Eve", "first_name"), + ) + # Deterministic per-id decisions covering all branches. + # ``f`` intentionally has no decision: downstream must keep it as-is, + # regardless of whether it lands in its own chunk or shares one. + decisions_by_id: dict[str, dict[str, Any]] = { + "a": {"id": "a", "decision": "keep"}, + "b": {"id": "b", "decision": "drop"}, + "c": {"id": "c", "decision": "reclass", "proposed_label": "location"}, + "d": {"id": "d", "decision": "keep"}, + "e": {"id": "e", "decision": "reclass", "proposed_label": "surname"}, + } + return spans, candidates, decisions_by_id + + def _run( + self, + *, + spans: list[EntitySpan], + candidates: ValidationCandidatesSchema, + decisions_by_id: dict[str, dict[str, Any]], + max_per_call: int, + ) -> tuple[dict, list[EntitySpan], int]: + row = _build_row(text=self.SCENARIO_TEXT, seed_entities=spans, candidates=candidates) + facade = _selective_facade("solo", decisions_by_id) + params = ChunkedValidationParams( + pool=["solo"], + max_entities_per_call=max_per_call, + excerpt_window_chars=200, + prompt_template=_MINIMAL_TEMPLATE, + ) + out = chunked_validate_row(row, params, {"solo": facade}) + decisions_doc = out[COL_VALIDATION_DECISIONS] + validated = apply_validation_decisions(spans, decisions_doc) + return decisions_doc, validated, len(facade.calls) + + def test_multi_chunk_matches_single_chunk(self, scenario) -> None: + spans, candidates, decisions_by_id = scenario + + # 1 chunk -- stands in for the "legacy single-call" path: all + # candidates fit into one validator call, no partitioning. + single_doc, single_validated, single_calls = self._run( + spans=spans, candidates=candidates, decisions_by_id=decisions_by_id, max_per_call=10 + ) + # 3 chunks of size 2 -- pool-of-one with real partitioning. + multi_doc, multi_validated, multi_calls = self._run( + spans=spans, candidates=candidates, decisions_by_id=decisions_by_id, max_per_call=2 + ) + + # Sanity: the two configurations actually differ in *how* they call + # the validator. Without this the test is trivially satisfiable. + assert single_calls == 1 + assert multi_calls == 3 + + # Decisions merged back together are identical (order-insensitive). + assert _normalize_decisions(single_doc) == _normalize_decisions(multi_doc) + + # Final per-entity outcomes (surviving ids + their post-validation + # labels) are identical. This is the ``COL_DETECTED_ENTITIES`` parity + # claim: downstream stages cannot tell the two runs apart. + assert _summarize(single_validated) == _summarize(multi_validated) + + # Keep/reclass/drop/untouched tallies match the fixed decision set. + expected_tally = {"keep": 2, "reclass": 2, "drop": 1, "untouched": 1} + assert _tally(spans, single_validated, single_doc) == expected_tally + assert _tally(spans, multi_validated, multi_doc) == expected_tally + + # Concrete post-validation outcome, pinned so a regression in + # ``apply_validation_decisions`` or chunk merging is caught + # precisely, not just "something changed". + assert _summarize(multi_validated) == [ + ("a", "first_name"), # keep + ("c", "location"), # reclass + ("d", "organization"), # keep + ("e", "surname"), # reclass + ("f", "first_name"), # untouched (no decision) + # "b" dropped + ] diff --git a/tests/engine/test_detection_custom_columns.py b/tests/engine/test_detection_custom_columns.py index bd2aa6c..1a42f22 100644 --- a/tests/engine/test_detection_custom_columns.py +++ b/tests/engine/test_detection_custom_columns.py @@ -26,12 +26,10 @@ COL_VALIDATED_SEED_ENTITIES, COL_VALIDATION_CANDIDATES, COL_VALIDATION_DECISIONS, - COL_VALIDATION_SKELETON, ) from anonymizer.engine.detection.custom_columns import ( _parse_entity_spans, apply_validation_and_finalize, - build_validation_skeleton, enrich_validation_decisions, merge_and_build_candidates, parse_detected_entities, @@ -144,57 +142,6 @@ def test_enrich_validation_decisions_ignores_non_dict_validation_payload() -> No assert result[COL_VALIDATED_ENTITIES] == {"decisions": []} -def test_build_validation_skeleton_produces_null_decisions() -> None: - row: dict[str, Any] = { - COL_SEED_VALIDATION_CANDIDATES: { - "candidates": [ - { - "id": "first_name_0_5", - "value": "Alice", - "label": "first_name", - "context_before": "", - "context_after": " works", - }, - { - "id": "org_15_19", - "value": "Acme", - "label": "organization", - "context_before": "at ", - "context_after": "", - }, - ] - }, - } - result = build_validation_skeleton(row) - skeleton = result[COL_VALIDATION_SKELETON] - assert len(skeleton["decisions"]) == 2 - assert skeleton["decisions"][0]["id"] == "first_name_0_5" - assert skeleton["decisions"][0]["value"] == "Alice" - assert skeleton["decisions"][0]["label"] == "first_name" - assert skeleton["decisions"][0]["decision"] is None - assert skeleton["decisions"][0]["proposed_label"] is None - assert skeleton["decisions"][1]["id"] == "org_15_19" - - -def test_build_validation_skeleton_handles_candidates_with_missing_keys() -> None: - row: dict[str, Any] = { - COL_SEED_VALIDATION_CANDIDATES: { - "candidates": [ - {"id": "x"}, - {"value": "Alice"}, - {}, - ] - }, - } - result = build_validation_skeleton(row) - skeleton = result[COL_VALIDATION_SKELETON] - assert len(skeleton["decisions"]) == 3 - assert skeleton["decisions"][0]["id"] == "x" - assert skeleton["decisions"][0]["value"] == "" - assert skeleton["decisions"][1]["id"] == "" - assert skeleton["decisions"][1]["value"] == "Alice" - - def test_apply_validation_and_finalize_handles_malformed_merged_entities() -> None: row: dict[str, Any] = { COL_TEXT: "Alice works at Acme.", diff --git a/tests/engine/test_detection_workflow.py b/tests/engine/test_detection_workflow.py index a6cb0c4..3783dd3 100644 --- a/tests/engine/test_detection_workflow.py +++ b/tests/engine/test_detection_workflow.py @@ -7,7 +7,7 @@ import pandas as pd import pytest -from data_designer.config.column_configs import LLMStructuredColumnConfig +from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig from data_designer.config.models import ModelConfig from anonymizer.config.models import DetectionModelSelection @@ -17,10 +17,15 @@ COL_ENTITIES_BY_VALUE, COL_FINAL_ENTITIES, COL_LATENT_ENTITIES, + COL_SEED_ENTITIES, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAG_NOTATION, COL_TAGGED_TEXT, COL_TEXT, + COL_VALIDATION_DECISIONS, DEFAULT_ENTITY_LABELS, ) +from anonymizer.engine.detection.chunked_validation import ChunkedValidationParams from anonymizer.engine.detection.detection_workflow import ( EntityDetectionWorkflow, _format_label_examples, @@ -30,7 +35,11 @@ _resolve_detection_labels, ) from anonymizer.engine.ndd.adapter import FailedRecord, WorkflowRunResult -from anonymizer.engine.ndd.model_loader import load_default_model_selection, resolve_model_alias +from anonymizer.engine.ndd.model_loader import ( + load_default_model_selection, + resolve_model_alias, + resolve_model_aliases, +) from anonymizer.engine.schemas import EntitiesSchema @@ -326,7 +335,18 @@ def test_resolve_model_alias_reads_from_selection_model() -> None: defaults = load_default_model_selection().detection selection = defaults.model_copy(update={"entity_detector": "custom-model"}) assert resolve_model_alias("entity_detector", selection) == "custom-model" - assert resolve_model_alias("entity_validator", selection) == defaults.entity_validator + assert resolve_model_aliases("entity_validator", selection) == defaults.entity_validator + + +def test_resolve_model_alias_raises_for_list_valued_role() -> None: + selection = load_default_model_selection().detection + with pytest.raises(TypeError, match="list-valued"): + resolve_model_alias("entity_validator", selection) + + +def test_resolve_model_aliases_wraps_scalar_roles() -> None: + selection = load_default_model_selection().detection + assert resolve_model_aliases("entity_detector", selection) == [selection.entity_detector] def test_resolve_detection_labels_none_uses_defaults() -> None: @@ -458,3 +478,171 @@ def test_default_entity_labels_preserves_novel_augmented_entities( assert "server_name" in final_labels assert "hostname" in final_labels assert "ipv4" in final_labels + + +# --------------------------------------------------------------------------- +# Chunked validation wiring (Commit 2) +# --------------------------------------------------------------------------- + + +def _find_column(columns: list, name: str): + for col in columns: + if getattr(col, "name", None) == name: + return col + raise AssertionError(f"Column {name!r} not found in workflow columns: {[getattr(c, 'name', c) for c in columns]}") + + +def test_validation_column_is_custom_chunked_generator( + stub_detector_model_configs: list[ModelConfig], + stub_detection_model_selection: DetectionModelSelection, +) -> None: + """COL_VALIDATION_DECISIONS is now a CustomColumnConfig bound to the chunked generator, + not an LLMStructuredColumnConfig.""" + adapter = Mock() + adapter.run_workflow.return_value = WorkflowRunResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice"], + COL_DETECTED_ENTITIES: [{"entities": [{"value": "Alice", "label": "first_name"}]}], + } + ), + failed_records=[], + ) + workflow = EntityDetectionWorkflow(adapter=adapter) + workflow.run( + pd.DataFrame({COL_TEXT: ["Alice"]}), + model_configs=stub_detector_model_configs, + selected_models=stub_detection_model_selection, + gliner_detection_threshold=0.5, + tag_latent_entities=False, + ) + columns = adapter.run_workflow.call_args.kwargs["columns"] + validation_col = _find_column(columns, COL_VALIDATION_DECISIONS) + assert isinstance(validation_col, CustomColumnConfig) + # Must NOT be the old structured-output LLM column. + assert not isinstance(validation_col, LLMStructuredColumnConfig) + assert validation_col.drop is True + # generator_params must match the Detect config defaults that flow through. + assert isinstance(validation_col.generator_params, ChunkedValidationParams) + assert validation_col.generator_params.pool == stub_detection_model_selection.entity_validator + assert validation_col.generator_params.max_entities_per_call > 0 + assert validation_col.generator_params.excerpt_window_chars > 0 + # The decorated generator's metadata must expose the pool and the exact + # set of columns it reads, so DataDesigner resolves facades and DAG ordering. + metadata = validation_col.generator_function.custom_column_metadata + assert metadata["model_aliases"] == list(stub_detection_model_selection.entity_validator) + assert set(metadata["required_columns"]) == { + COL_TEXT, + COL_SEED_ENTITIES, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAG_NOTATION, + } + + +def test_validator_pool_kwargs_thread_through_to_generator_params( + stub_detector_model_configs: list[ModelConfig], + stub_detection_model_selection: DetectionModelSelection, +) -> None: + """Explicit ``validation_max_entities_per_call`` and ``validation_excerpt_window_chars`` + propagate from ``run()`` all the way to ``ChunkedValidationParams``.""" + adapter = Mock() + adapter.run_workflow.return_value = WorkflowRunResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice"], + COL_DETECTED_ENTITIES: [{"entities": [{"value": "Alice", "label": "first_name"}]}], + } + ), + failed_records=[], + ) + workflow = EntityDetectionWorkflow(adapter=adapter) + workflow.run( + pd.DataFrame({COL_TEXT: ["Alice"]}), + model_configs=stub_detector_model_configs, + selected_models=stub_detection_model_selection, + gliner_detection_threshold=0.5, + validation_max_entities_per_call=17, + validation_excerpt_window_chars=42, + tag_latent_entities=False, + ) + columns = adapter.run_workflow.call_args.kwargs["columns"] + params = _find_column(columns, COL_VALIDATION_DECISIONS).generator_params + assert params.max_entities_per_call == 17 + assert params.excerpt_window_chars == 42 + + +def test_pool_size_greater_than_one_emits_warning( + stub_detector_model_configs: list[ModelConfig], + stub_detection_model_selection: DetectionModelSelection, + caplog: pytest.LogCaptureFixture, +) -> None: + """Operators with multiple validator aliases must be alerted that + ``max_parallel_requests`` is enforced per alias (pool multiplies in-flight).""" + selection = stub_detection_model_selection.model_copy( + update={"entity_validator": [*stub_detection_model_selection.entity_validator, "extra-validator"]} + ) + adapter = Mock() + adapter.run_workflow.return_value = WorkflowRunResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice"], + COL_DETECTED_ENTITIES: [{"entities": [{"value": "Alice", "label": "first_name"}]}], + } + ), + failed_records=[], + ) + workflow = EntityDetectionWorkflow(adapter=adapter) + with caplog.at_level("WARNING"): + workflow.run( + pd.DataFrame({COL_TEXT: ["Alice"]}), + model_configs=stub_detector_model_configs, + selected_models=selection, + gliner_detection_threshold=0.5, + tag_latent_entities=False, + ) + # caplog can attach handlers at both the target logger and root, so the + # same record may appear twice in ``records``; dedupe by identity. + pool_warnings = { + id(r): r + for r in caplog.records + if r.name == "anonymizer.detection" and "pool of" in r.getMessage() and "aliases" in r.getMessage() + } + assert len(pool_warnings) == 1 + (only,) = pool_warnings.values() + assert "multiplies total in-flight" in only.getMessage() + + +def test_pool_size_one_does_not_emit_warning( + stub_detector_model_configs: list[ModelConfig], + stub_detection_model_selection: DetectionModelSelection, + caplog: pytest.LogCaptureFixture, +) -> None: + """Default single-alias configurations must not spam the warning: it's a pool caveat, not advice for everyone.""" + adapter = Mock() + adapter.run_workflow.return_value = WorkflowRunResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice"], + COL_DETECTED_ENTITIES: [{"entities": [{"value": "Alice", "label": "first_name"}]}], + } + ), + failed_records=[], + ) + assert len(stub_detection_model_selection.entity_validator) == 1, ( + "baseline default must be a single validator for this test to be meaningful" + ) + workflow = EntityDetectionWorkflow(adapter=adapter) + with caplog.at_level("WARNING"): + workflow.run( + pd.DataFrame({COL_TEXT: ["Alice"]}), + model_configs=stub_detector_model_configs, + selected_models=stub_detection_model_selection, + gliner_detection_threshold=0.5, + tag_latent_entities=False, + ) + pool_warnings = [ + r + for r in caplog.records + if r.name == "anonymizer.detection" and "pool of" in r.getMessage() and "aliases" in r.getMessage() + ] + assert pool_warnings == [] diff --git a/tests/engine/test_model_loader.py b/tests/engine/test_model_loader.py index 29ff299..14e9471 100644 --- a/tests/engine/test_model_loader.py +++ b/tests/engine/test_model_loader.py @@ -32,11 +32,98 @@ def test_load_workflow_config_contains_selected_models() -> None: assert isinstance(entity_detector, str) and entity_detector +def test_load_workflow_config_accepts_list_valued_entity_validator(tmp_path) -> None: + """A list-valued role in detection.yaml must not blow up the alias check. + + Regression test for ``set(selections.values())`` raising + ``TypeError: unhashable type: 'list'`` once any role is list-valued. + """ + (tmp_path / "models.yaml").write_text( + "model_configs:\n" + " - alias: gliner\n" + " model: test/gliner\n" + " - alias: v1\n" + " model: test/v1\n" + " - alias: v2\n" + " model: test/v2\n" + " - alias: a\n" + " model: test/a\n" + ) + (tmp_path / "detection.yaml").write_text( + "selected_models:\n" + " entity_detector: gliner\n" + " entity_validator: [v1, v2]\n" + " entity_augmenter: a\n" + " latent_detector: gliner\n" + ) + config = load_workflow_config(WorkflowName.detection, tmp_path) + assert config["selected_models"]["entity_validator"] == ["v1", "v2"] + + +def test_load_workflow_config_raises_on_unknown_alias_in_validator_pool(tmp_path) -> None: + """An unknown alias inside a list-valued pool should surface by name.""" + (tmp_path / "models.yaml").write_text( + "model_configs:\n" + " - alias: gliner\n" + " model: test/gliner\n" + " - alias: v1\n" + " model: test/v1\n" + " - alias: a\n" + " model: test/a\n" + ) + (tmp_path / "detection.yaml").write_text( + "selected_models:\n" + " entity_detector: gliner\n" + " entity_validator: [v1, does-not-exist]\n" + " entity_augmenter: a\n" + " latent_detector: gliner\n" + ) + with pytest.raises(ValueError, match="does-not-exist"): + load_workflow_config(WorkflowName.detection, tmp_path) + + def test_get_model_alias_reads_workflow_mapping() -> None: alias = get_model_alias(workflow_name=WorkflowName.detection, role="entity_validator") assert isinstance(alias, str) and alias +def test_load_workflow_selections_preserves_list_values(tmp_path) -> None: + """A YAML pool under ``selected_models`` must round-trip as ``list[str]``. + + Stringifying would silently collapse the pool to a single garbled alias + ("['v1', 'v2']"), and Pydantic's ``normalize_entity_validator`` would + then treat that repr as one alias. Pinning the native-type preservation + here keeps that trap closed. + """ + config_dir = tmp_path + (config_dir / "detection.yaml").write_text( + "selected_models:\n" + " entity_detector: d\n" + " entity_validator:\n" + " - v1\n" + " - v2\n" + " entity_augmenter: a\n" + " latent_detector: l\n" + ) + selections = load_workflow_selections(WorkflowName.detection, config_dir) + assert selections["entity_validator"] == ["v1", "v2"] + assert selections["entity_detector"] == "d" + + +def test_get_model_alias_rejects_list_valued_role(tmp_path) -> None: + """Calling the scalar accessor on a pool-valued role raises ``TypeError``.""" + config_dir = tmp_path + (config_dir / "detection.yaml").write_text( + "selected_models:\n" + " entity_detector: d\n" + " entity_validator: [v1, v2]\n" + " entity_augmenter: a\n" + " latent_detector: l\n" + ) + with pytest.raises(TypeError, match="list-valued"): + get_model_alias(WorkflowName.detection, "entity_validator", config_dir) + + WORKFLOW_YAMLS = [p.stem for p in DEFAULT_CONFIG_DIR.glob("*.yaml") if p.stem != "models"] @@ -46,7 +133,13 @@ def test_default_workflow_aliases_exist_in_models(workflow_name: str) -> None: models_config = load_models_config() known_aliases = {m["alias"] for m in models_config.get("model_configs", [])} selections = load_workflow_selections(WorkflowName(workflow_name)) - unknown = set(selections.values()) - known_aliases + referenced: set[str] = set() + for value in selections.values(): + if isinstance(value, list): + referenced.update(value) + else: + referenced.add(value) + unknown = referenced - known_aliases assert not unknown, f"Workflow '{workflow_name}' references unknown aliases: {unknown}. Known: {known_aliases}" @@ -54,7 +147,9 @@ def test_load_default_model_selection_populates_all_workflows() -> None: selection = load_default_model_selection() # Detection assert selection.detection.entity_detector - assert selection.detection.entity_validator + assert selection.detection.entity_validator # list[str] + assert isinstance(selection.detection.entity_validator, list) + assert all(isinstance(alias, str) and alias for alias in selection.detection.entity_validator) assert selection.detection.entity_augmenter assert selection.detection.latent_detector # Replace @@ -104,6 +199,66 @@ def test_parse_model_configs_yaml_without_selections_uses_defaults() -> None: assert result.selected_models.detection.entity_detector == "gliner-pii-detector" +# parse_model_configs regression tests: user overrides in selected_models must +# rerun field validators. model_copy(update=...) in Pydantic v2 silently skips +# them, so the three DetectionModelSelection.normalize_entity_validator checks +# (non-empty, deduped, whitespace-stripped) would be bypassed on override +# unless _merge_selections re-validates. + + +def test_parse_model_configs_rejects_empty_entity_validator_override() -> None: + yaml_str = """ +selected_models: + detection: + entity_validator: [] +model_configs: + - alias: gliner-pii-detector + model: test/gliner +""" + with pytest.raises(ValueError, match="at least one model alias"): + parse_model_configs(yaml_str) + + +def test_parse_model_configs_dedupes_duplicate_override_aliases_with_warning( + caplog: pytest.LogCaptureFixture, +) -> None: + yaml_str = """ +selected_models: + detection: + entity_validator: [v1, v2, v1, v3, v2] +model_configs: + - alias: gliner-pii-detector + model: test/gliner + - alias: v1 + model: test/v1 + - alias: v2 + model: test/v2 + - alias: v3 + model: test/v3 +""" + with caplog.at_level("WARNING", logger="anonymizer.config.models"): + result = parse_model_configs(yaml_str) + assert result.selected_models.detection.entity_validator == ["v1", "v2", "v3"] + assert any("duplicate aliases" in r.getMessage() for r in caplog.records) + + +def test_parse_model_configs_strips_whitespace_only_override_entries() -> None: + yaml_str = """ +selected_models: + detection: + entity_validator: [" v1 ", " ", "v2"] +model_configs: + - alias: gliner-pii-detector + model: test/gliner + - alias: v1 + model: test/v1 + - alias: v2 + model: test/v2 +""" + result = parse_model_configs(yaml_str) + assert result.selected_models.detection.entity_validator == ["v1", "v2"] + + def test_validate_model_alias_references_accepts_valid_detection_aliases( stub_known_model_configs: list[ModelConfig], stub_slim_model_selection: ModelSelection, @@ -228,3 +383,149 @@ def test_validate_model_alias_references_skips_rewrite_alias_when_not_enabled( stub_known_model_configs, selected_models, ) + + +class TestEntityValidatorNormalization: + """``DetectionModelSelection.entity_validator`` accepts scalar or list input. + + Scalars normalize to single-item lists so every downstream consumer + sees ``list[str]``. + """ + + def test_scalar_normalizes_to_single_item_list(self) -> None: + selection = DetectionModelSelection( + entity_detector="d", + entity_validator="v", + entity_augmenter="a", + latent_detector="l", + ) + assert selection.entity_validator == ["v"] + + def test_list_preserved(self) -> None: + selection = DetectionModelSelection( + entity_detector="d", + entity_validator=["v1", "v2", "v3"], + entity_augmenter="a", + latent_detector="l", + ) + assert selection.entity_validator == ["v1", "v2", "v3"] + + def test_tuple_coerced_to_list(self) -> None: + # Tuples are accepted for parity with Pydantic v2's default coercion + # for ``list[str]`` fields; programmatic callers should not need to + # care about the concrete sequence type. The normalizer must return + # a real ``list`` so downstream ``isinstance(value, list)`` branches + # (e.g. in ``resolve_model_alias``) behave consistently. + selection = DetectionModelSelection( + entity_detector="d", + entity_validator=("v1", "v2"), # type: ignore[arg-type] + entity_augmenter="a", + latent_detector="l", + ) + assert selection.entity_validator == ["v1", "v2"] + assert isinstance(selection.entity_validator, list) + + def test_empty_list_rejected(self) -> None: + with pytest.raises(ValueError, match="at least one model alias"): + DetectionModelSelection( + entity_detector="d", + entity_validator=[], + entity_augmenter="a", + latent_detector="l", + ) + + def test_whitespace_only_rejected(self) -> None: + with pytest.raises(ValueError, match="at least one model alias"): + DetectionModelSelection( + entity_detector="d", + entity_validator=[" ", ""], + entity_augmenter="a", + latent_detector="l", + ) + + def test_non_string_non_list_rejected(self) -> None: + with pytest.raises((ValueError, TypeError)): + DetectionModelSelection( + entity_detector="d", + entity_validator=42, # type: ignore[arg-type] + entity_augmenter="a", + latent_detector="l", + ) + + def test_duplicate_aliases_are_deduped_with_warning( + self, + caplog: pytest.LogCaptureFixture, + ) -> None: + # A duplicate alias in the pool would burn a failover attempt on an + # already-exhausted endpoint. The normalizer collapses duplicates to + # the first occurrence (preserving order) and logs a warning so the + # user can see their config wasn't applied exactly as written. + with caplog.at_level("WARNING", logger="anonymizer.config.models"): + selection = DetectionModelSelection( + entity_detector="d", + entity_validator=["v1", "v2", "v1", "v3", "v2"], + entity_augmenter="a", + latent_detector="l", + ) + assert selection.entity_validator == ["v1", "v2", "v3"] + # caplog may double-capture when pytest-caplog and the root logger + # both propagate the record; dedupe on message content instead of + # asserting a raw count. + dedupe_messages = { + r.getMessage() for r in caplog.records if r.levelname == "WARNING" and "duplicate aliases" in r.getMessage() + } + assert len(dedupe_messages) == 1 + + def test_no_warning_when_all_aliases_unique( + self, + caplog: pytest.LogCaptureFixture, + ) -> None: + with caplog.at_level("WARNING", logger="anonymizer.config.models"): + selection = DetectionModelSelection( + entity_detector="d", + entity_validator=["v1", "v2", "v3"], + entity_augmenter="a", + latent_detector="l", + ) + assert selection.entity_validator == ["v1", "v2", "v3"] + dedupe_warnings = [ + r for r in caplog.records if r.levelname == "WARNING" and "duplicate aliases" in r.getMessage() + ] + assert dedupe_warnings == [] + + +class TestValidateAliasReferencesHandlesValidatorPool: + """``validate_model_alias_references`` must expand list-valued roles to one check per alias.""" + + def test_accepts_all_pool_aliases_present( + self, + stub_slim_model_selection: ModelSelection, + ) -> None: + """Pool of aliases all present in the model pool — passes.""" + configs = [ + ModelConfig(alias="v1", model="test/v1"), + ModelConfig(alias="v2", model="test/v2"), + ModelConfig(alias="known", model="some/model"), + ] + selected_models = stub_slim_model_selection.model_copy( + update={ + "detection": stub_slim_model_selection.detection.model_copy(update={"entity_validator": ["v1", "v2"]}) + } + ) + validate_model_alias_references(configs, selected_models) + + def test_raises_on_any_pool_alias_missing( + self, + stub_known_model_configs: list[ModelConfig], + stub_slim_model_selection: ModelSelection, + ) -> None: + """If any alias in the validator pool is unknown, error names that alias by index.""" + selected_models = stub_slim_model_selection.model_copy( + update={ + "detection": stub_slim_model_selection.detection.model_copy( + update={"entity_validator": ["known", "missing-one"]} + ) + } + ) + with pytest.raises(ValueError, match=r"entity_validator\[1\].*missing-one"): + validate_model_alias_references(stub_known_model_configs, selected_models)