Skip to content

Commit f093b99

Browse files
refactor(detect): replace asyncio fan-out with ThreadPoolExecutor; cache Jinja template
Drops asyncio.run / asyncio.to_thread / asyncio.gather from the chunked validation path. _dispatch_chunk and chunked_validate_row become sync defs; chunked_validate_row dispatches chunks via a ThreadPoolExecutor and calls facade.generate() directly. Per-alias concurrency is still enforced downstream by each facade's ThrottledModelClient, so the pool exists purely to overlap this row's chunks. Under DataDesigner's async engine the sync calls are transparently bridged to agenerate by the DD runtime (DD#545), so the code path stays engine-agnostic. Test call sites lose their asyncio.run(...) wrappers. A TODO(async-native) comment in the module docstring flags the follow-up migration once the async engine becomes the DD default. Wraps _compile_template in functools.lru_cache(maxsize=4) so a row with N chunks parses the Jinja source once instead of N times. Folds in the post-#119 column rename for this module (COL_MERGED_TAGGED_TEXT -> COL_SEED_TAGGED_TEXT, COL_VALIDATION_CANDIDATES -> COL_SEED_VALIDATION_CANDIDATES, prompt placeholder _merged_tagged_text -> _seed_tagged_text) that the rebase resolved in neighbouring modules but never applied here. Module docstring gains two follow-up-deferred paragraphs explaining (a) why per-instance validation is intentional -- dedup by (value, label) was considered and rejected because it conflates surface form with meaning -- and (b) how peak prompt memory scales with chunk count, with pointers to the orthogonal cost levers to pull if pressure shows up in a real workload. Both are tracked as separate follow-up issues. Made-with: Cursor
1 parent 2fe252e commit f093b99

3 files changed

Lines changed: 134 additions & 103 deletions

File tree

src/anonymizer/engine/detection/chunked_validation.py

Lines changed: 102 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
Failure contract. Each chunk attempts its round-robin-assigned alias first
1616
and, on any terminal exception from that alias (5xx after transport retries,
1717
connection errors, parser errors), fails over sequentially to the rest of
18-
the pool. A chunk only fails when every pool member has raised. A failed
19-
chunk cancels sibling chunks (``asyncio.gather`` semantics) and raises out
20-
of the custom column generator, at which point DataDesigner drops the row
21-
from the output DataFrame and ``NddAdapter._detect_missing_records``
22-
surfaces it as a ``FailedRecord``. This is the "best effort on first pass,
23-
identify and reprocess misses" contract: raw text never silently leaks
24-
through as unscrubbed output.
18+
the pool. A chunk only fails when every pool member has raised. The first
19+
failing chunk re-raises out of the custom column generator, at which point
20+
DataDesigner drops the row from the output DataFrame and
21+
``NddAdapter._detect_missing_records`` surfaces it as a ``FailedRecord``.
22+
This is the "best effort on first pass, identify and reprocess misses"
23+
contract: raw text never silently leaks through as unscrubbed output.
2524
2625
The per-chunk decisions are merged into a ``ValidationDecisionsSchema``-shaped
2726
payload so the downstream ``enrich_validation_decisions`` column keeps
@@ -32,25 +31,58 @@
3231
decorated with ``@custom_column_generator`` and bound to a concrete pool.
3332
The pure helpers below are exposed for unit testing.
3433
35-
Sync/async boundary with DataDesigner. The DD-exposed wrapper is sync
36-
(DD's default thread-pool engine calls ``generator.generate(record)``
37-
directly; an async wrapper would return a coroutine and DD would reject it
38-
with ``"must return a dict, got coroutine"``). The wrapper runs the
39-
``chunked_validate_row`` coroutine on a fresh event loop per row via
40-
``asyncio.run``; inside, each chunk dispatch calls the *sync*
41-
``facade.generate()`` through ``asyncio.to_thread`` because the facades DD
42-
hands us under the sync engine are sync-mode ``HttpModelClient``s whose
43-
``agenerate()`` raises. DD#545 (shipped in ``data-designer>=0.5.7``)
44-
bridges sync ``.generate()`` calls to ``agenerate`` transparently under
45-
the async engine, so this same code path works both before and after
46-
Anonymizer#119 turns the async engine on for detection.
34+
Per-instance validation is intentional. Every candidate goes to the LLM in
35+
its own excerpt, even when the same ``(value, label)`` pair appears many
36+
times in a row. Deduping by ``(value, label)`` and broadcasting one
37+
canonical decision to every duplicate was considered (it is the single
38+
largest cost lever on this path) and rejected: it conflates surface form
39+
with meaning, produces silently-wrong answers when the detector's labels
40+
are context-dependent, and gives us no signal when it does. See the
41+
"Validator chunking" section of ``docs/concepts/detection.md`` and the
42+
``context`` block on the C5 review thread in PR #126 for the full
43+
reasoning. If cost becomes pressing before we have data on how often
44+
duplicates genuinely are semantically equivalent, prefer orthogonal
45+
levers: tighter excerpt windows, larger ``max_entities_per_call`` for
46+
cheap validators, or token-budget-aware chunking (tracked as C7 in the
47+
same review thread).
48+
49+
Concurrency model. ``chunked_validate_row`` dispatches its chunks through a
50+
``ThreadPoolExecutor``. Per-alias concurrency is already enforced downstream
51+
by each facade's ``ThrottledModelClient`` (AIMD on 429), so we intentionally
52+
do not impose a row-level cap here; the pool exists purely to overlap the
53+
chunks for this single row. Under DataDesigner's opt-in async engine the
54+
sync ``facade.generate()`` calls we make are transparently bridged to
55+
``agenerate`` by the DD runtime, so this code path is agnostic to which
56+
engine is active.
57+
58+
TODO(async-native): once DataDesigner's async engine becomes the default
59+
(tracking: DATA_DESIGNER_ASYNC_ENGINE opt-in flag flips off), drop the
60+
``ThreadPoolExecutor`` + sync ``facade.generate()`` pattern here in favour
61+
of ``async def`` functions calling ``facade.agenerate()`` directly, with
62+
``asyncio.gather`` replacing the executor. That removes one thread hop per
63+
chunk and lets DD's scheduler see per-chunk dispatch as first-class async
64+
work. See the PR #126 review thread (step 2 of Andre's async
65+
simplification suggestion) for context.
66+
67+
Eager prompt construction. ``chunked_validate_row`` builds the excerpt,
68+
skeleton, and rendered prompt for every chunk before submitting any worker.
69+
Peak prompt memory is ``len(chunks) * (2 * excerpt_window_chars + skeleton
70+
+ ~3KB template overhead)``, i.e. low-MB even at 1000 entities with
71+
``max_entities_per_call`` in the single digits. At workloads where this
72+
becomes observable (very small ``max_entities_per_call`` on multi-thousand-
73+
entity rows, or large ``excerpt_window_chars``), the fix is to move prompt
74+
construction inside each worker and pair it with a row-level concurrency
75+
cap (otherwise all workers race the construction phase in parallel and the
76+
bound is unchanged). Deferred because we have no evidence of memory
77+
pressure at realistic scale; tracked as C6 in the PR #126 review thread.
4778
"""
4879

4980
from __future__ import annotations
5081

51-
import asyncio
82+
import functools
5283
import logging
5384
from collections.abc import Sequence
85+
from concurrent.futures import ThreadPoolExecutor
5486
from typing import Any
5587

5688
from data_designer.config import custom_column_generator
@@ -59,11 +91,11 @@
5991
from pydantic import BaseModel, Field
6092

6193
from anonymizer.engine.constants import (
62-
COL_MERGED_TAGGED_TEXT,
6394
COL_SEED_ENTITIES,
95+
COL_SEED_TAGGED_TEXT,
96+
COL_SEED_VALIDATION_CANDIDATES,
6497
COL_TAG_NOTATION,
6598
COL_TEXT,
66-
COL_VALIDATION_CANDIDATES,
6799
COL_VALIDATION_DECISIONS,
68100
COL_VALIDATION_SKELETON,
69101
)
@@ -85,7 +117,7 @@
85117

86118
# Jinja2 environment used to render the per-chunk validation prompt.
87119
# The template mirrors the production prompt exactly: we substitute the same
88-
# placeholders (``_merged_tagged_text``, ``_validation_skeleton``,
120+
# placeholders (``_seed_tagged_text``, ``_validation_skeleton``,
89121
# ``_tag_notation``) but with per-chunk values.
90122
_PROMPT_ENV = Environment(
91123
loader=BaseLoader(),
@@ -95,6 +127,20 @@
95127
)
96128

97129

130+
@functools.lru_cache(maxsize=4)
131+
def _compile_template(template: str) -> Any:
132+
"""Return a compiled Jinja2 template, cached by source string.
133+
134+
A row with ``N`` chunks would otherwise re-parse the same template ``N``
135+
times, which is wasteful for the exact workload this module targets
136+
(high-entity-count rows). ``maxsize=4`` is deliberately tiny: in practice
137+
there is one prompt string per ``EntityDetectionWorkflow`` instance, but
138+
tests may instantiate a handful of distinct templates in a single
139+
process so we keep a small LRU rather than an unbounded cache.
140+
"""
141+
return _PROMPT_ENV.from_string(template)
142+
143+
98144
class ChunkedValidationParams(BaseModel):
99145
"""Parameters supplied to :func:`chunked_validate_row` via DD's ``generator_params``.
100146
@@ -109,7 +155,7 @@ class ChunkedValidationParams(BaseModel):
109155
excerpt_window_chars: Chars of surrounding raw text included in each
110156
chunk's excerpt on either side of the chunk span.
111157
prompt_template: Jinja2 source for the validation prompt (with
112-
``_merged_tagged_text``, ``_validation_skeleton``, ``_tag_notation``
158+
``_seed_tagged_text``, ``_validation_skeleton``, ``_tag_notation``
113159
placeholders). Typically produced by ``_get_validation_prompt``.
114160
system_prompt: Optional system prompt forwarded to each chunk call.
115161
"""
@@ -229,10 +275,10 @@ def render_chunk_prompt(
229275
call: dicts are rendered with Python ``str()`` (Jinja2 default), which is
230276
how the existing prompt has always served ``{{ _validation_skeleton }}``.
231277
"""
232-
compiled = _PROMPT_ENV.from_string(template)
278+
compiled = _compile_template(template)
233279
return compiled.render(
234280
**{
235-
COL_MERGED_TAGGED_TEXT: excerpt,
281+
COL_SEED_TAGGED_TEXT: excerpt,
236282
COL_VALIDATION_SKELETON: skeleton,
237283
COL_TAG_NOTATION: notation.value,
238284
}
@@ -286,11 +332,11 @@ def merge_chunk_decisions(
286332

287333

288334
# ---------------------------------------------------------------------------
289-
# Async dispatch. Testable by passing fake ``models``.
335+
# Chunk dispatch. Testable by passing fake ``models``.
290336
# ---------------------------------------------------------------------------
291337

292338

293-
async def _dispatch_chunk(
339+
def _dispatch_chunk(
294340
*,
295341
facades: list[tuple[str, Any]],
296342
prompt: str,
@@ -307,17 +353,6 @@ async def _dispatch_chunk(
307353
errors) and its own AIMD throttling on 429, so by the time an exception
308354
escapes the facade call we consider that alias exhausted for this chunk.
309355
310-
We call the *sync* ``facade.generate()`` inside ``asyncio.to_thread`` on
311-
purpose. Under DataDesigner's default thread-pool engine the facades
312-
DD hands us are sync-mode ``HttpModelClient``s, and calling
313-
``facade.agenerate()`` on them raises
314-
``"Async methods are not available on a sync-mode HttpModelClient"``.
315-
The sync call works under both engines: DD#545 (shipped in
316-
``data-designer>=0.5.7``) bridges sync ``.generate()`` calls from
317-
custom column generators to ``agenerate`` under the async engine, so
318-
wrapping in ``asyncio.to_thread`` gives us per-row chunk concurrency
319-
today and forward-compatibility after Anonymizer#119 lands.
320-
321356
We use ``PydanticResponseRecipe`` so the facade appends JSON task
322357
instructions and parses the response into ``RawValidationDecisionsSchema``.
323358
@@ -337,8 +372,7 @@ async def _dispatch_chunk(
337372
last_exc: BaseException | None = None
338373
for attempt_index, (alias, facade) in enumerate(facades):
339374
try:
340-
output, _messages = await asyncio.to_thread(
341-
facade.generate,
375+
output, _messages = facade.generate(
342376
prompt=final_prompt,
343377
parser=recipe.parse,
344378
system_prompt=final_system,
@@ -390,14 +424,14 @@ async def _dispatch_chunk(
390424
raise last_exc
391425

392426

393-
async def chunked_validate_row(
427+
def chunked_validate_row(
394428
row: dict[str, Any],
395429
params: ChunkedValidationParams,
396430
models: dict[str, Any],
397431
) -> dict[str, Any]:
398432
"""Run chunked validation for a single row and write ``COL_VALIDATION_DECISIONS``.
399433
400-
This is the async workhorse. Call it directly in tests with fake ``models``;
434+
This is the workhorse. Call it directly in tests with fake ``models``;
401435
the DataDesigner-decorated wrapper produced by
402436
:func:`make_chunked_validation_generator` just forwards to it.
403437
"""
@@ -410,7 +444,7 @@ async def chunked_validate_row(
410444
)
411445

412446
text = str(row.get(COL_TEXT, ""))
413-
candidates = ValidationCandidatesSchema.from_raw(row.get(COL_VALIDATION_CANDIDATES, {}))
447+
candidates = ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {}))
414448
seed_entities_schema = EntitiesSchema.from_raw(row.get(COL_SEED_ENTITIES, {}))
415449
notation_raw = row.get(COL_TAG_NOTATION) or TagNotation.sentinel.value
416450
notation = TagNotation(str(notation_raw))
@@ -436,7 +470,7 @@ async def chunked_validate_row(
436470
ordered = order_candidates_by_position(candidates, all_spans)
437471
chunks = chunk_candidates(ordered, params.max_entities_per_call)
438472

439-
tasks: list[Any] = []
473+
dispatch_kwargs_per_chunk: list[dict[str, Any]] = []
440474
for chunk_index, chunk in enumerate(chunks):
441475
chunk_candidates_ = [pair[0] for pair in chunk]
442476
chunk_spans = [pair[1] for pair in chunk]
@@ -462,18 +496,28 @@ async def chunked_validate_row(
462496
start = chunk_index % len(params.pool)
463497
rotated_aliases = [params.pool[(start + offset) % len(params.pool)] for offset in range(len(params.pool))]
464498
chunk_facades = [(alias, models[alias]) for alias in rotated_aliases]
465-
tasks.append(
466-
_dispatch_chunk(
467-
facades=chunk_facades,
468-
prompt=prompt,
469-
system_prompt=params.system_prompt,
470-
chunk_index=chunk_index,
471-
)
499+
dispatch_kwargs_per_chunk.append(
500+
{
501+
"facades": chunk_facades,
502+
"prompt": prompt,
503+
"system_prompt": params.system_prompt,
504+
"chunk_index": chunk_index,
505+
}
472506
)
473507

474-
# gather() propagates the first exception, cancelling siblings. That's the
475-
# all-or-nothing row contract: a single terminal chunk failure fails the row.
476-
chunk_results = await asyncio.gather(*tasks)
508+
# Dispatch all chunks concurrently via a ThreadPoolExecutor. Per-alias
509+
# concurrency is still capped downstream by each facade's
510+
# ``ThrottledModelClient`` (AIMD on 429), so the pool's only job here is
511+
# to overlap one row's chunks. ``f.result()`` re-raises the first chunk
512+
# exception, which is what we want: a single terminal chunk failure
513+
# fails the row. Pending workers finish naturally as the ``with`` block
514+
# exits — we just stop observing their results once we re-raise.
515+
if not chunks:
516+
chunk_results: list[RawValidationDecisionsSchema] = []
517+
else:
518+
with ThreadPoolExecutor(max_workers=len(chunks)) as executor:
519+
futures = [executor.submit(_dispatch_chunk, **kwargs) for kwargs in dispatch_kwargs_per_chunk]
520+
chunk_results = [f.result() for f in futures]
477521

478522
row[COL_VALIDATION_DECISIONS] = merge_chunk_decisions(chunk_results, candidates)
479523
return row
@@ -485,27 +529,14 @@ async def chunked_validate_row(
485529

486530

487531
def make_chunked_validation_generator(pool: list[str]) -> Any:
488-
"""Build a ``@custom_column_generator``-decorated sync function bound to ``pool``.
532+
"""Build a ``@custom_column_generator``-decorated function bound to ``pool``.
489533
490534
``model_aliases`` must be declared statically on the decorator so
491535
DataDesigner knows which facades to materialise for the generator. Since
492536
the pool is config-driven (per-run), we generate the function dynamically.
493537
The required_columns are exhaustive for DataDesigner's DAG ordering: the
494538
generator reads the raw text, seed entities (for positions), the candidate
495539
list (what to decide), and the tag notation (for excerpt tagging).
496-
497-
Why the outer wrapper is sync. DataDesigner routes cell-by-cell custom
498-
generators through a ThreadPoolExecutor by default (the async engine is
499-
opt-in via ``DATA_DESIGNER_ASYNC_ENGINE=1``). The thread-pool path calls
500-
``generator.generate(record)`` which synchronously invokes our function
501-
and passes its return value straight into ``_postprocess_result``. If the
502-
outer function were ``async``, the sync caller would receive a coroutine
503-
object and DD would reject it with ``"must return a dict, got coroutine"``.
504-
The sync wrapper here runs the async row logic on a fresh event loop
505-
(safe because each DD worker thread has no ambient loop) and returns the
506-
resolved dict, so this works under both the default thread engine and
507-
the opt-in async engine (which wraps sync generators in
508-
``asyncio.to_thread``).
509540
"""
510541
if not pool:
511542
raise ValueError("Cannot build chunked validation generator: pool is empty.")
@@ -514,7 +545,7 @@ def make_chunked_validation_generator(pool: list[str]) -> Any:
514545
required_columns=[
515546
COL_TEXT,
516547
COL_SEED_ENTITIES,
517-
COL_VALIDATION_CANDIDATES,
548+
COL_SEED_VALIDATION_CANDIDATES,
518549
COL_TAG_NOTATION,
519550
],
520551
model_aliases=list(pool),
@@ -524,6 +555,6 @@ def chunked_validate(
524555
generator_params: ChunkedValidationParams,
525556
models: dict[str, Any],
526557
) -> dict[str, Any]:
527-
return asyncio.run(chunked_validate_row(row, generator_params, models))
558+
return chunked_validate_row(row, generator_params, models)
528559

529560
return chunked_validate

0 commit comments

Comments
 (0)