Skip to content

Commit cd62982

Browse files
authored
Merge pull request #743 from PlanExeOrg/napkin-math/compress-second-pass
napkin-math(compress): second-pass shifts the variance failure from emission to ranking
2 parents 5206bf9 + 5f44c6f commit cd62982

2 files changed

Lines changed: 194 additions & 7 deletions

File tree

worker_plan/worker_plan_internal/parameter_extraction/compress_report_section.py

Lines changed: 142 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,52 @@ def format_scored_item_line(item: PublicScoredItem) -> str:
10191019
return f"{item.line_english} {tag}"
10201020

10211021

1022+
SECOND_PASS_USER_PROMPT_TEMPLATE = (
1023+
"Review the {field_name} items you just produced above. "
1024+
"Identify items present in the section that you missed. "
1025+
"Emit only NEW items not already covered above; do not repeat or "
1026+
"rephrase items you already produced. "
1027+
"Apply the same bucket rules. "
1028+
"Up to 8 new items. "
1029+
"If you captured everything important on the first pass, return an "
1030+
"empty list."
1031+
)
1032+
1033+
1034+
def merge_second_pass_items(
1035+
first_pass: list[ScoredItem],
1036+
second_pass: list[ScoredItem],
1037+
) -> tuple[list[ScoredItem], int]:
1038+
"""Merge two batches of ScoredItem, de-duplicating second-pass items
1039+
whose normalised source_quote already appears in the first pass.
1040+
1041+
The second pass is gated by a "what did you miss?" prompt that asks the
1042+
LLM to surface candidates absent from the first batch. Smaller models
1043+
can mis-count near the per-bucket cap; the two-batch protocol keeps each
1044+
call's cognitive load comparable to the original single-batch flow, and
1045+
leaves the deterministic top-N filter (``annotate_scored_items``) to
1046+
pick the survivors from the combined pool.
1047+
1048+
Sometimes the model re-emits a first-pass item anyway; this merger
1049+
drops those duplicates while preserving order: first-pass items come
1050+
first, second-pass items follow in their emitted order, and any
1051+
duplicate from the second pass is silently skipped.
1052+
1053+
Returns ``(merged_list, newly_added_count)``.
1054+
"""
1055+
seen = {normalise_for_quote_match(item.source_quote) for item in first_pass}
1056+
merged = list(first_pass)
1057+
newly_added = 0
1058+
for item in second_pass:
1059+
key = normalise_for_quote_match(item.source_quote)
1060+
if key in seen:
1061+
continue
1062+
seen.add(key)
1063+
merged.append(item)
1064+
newly_added += 1
1065+
return merged, newly_added
1066+
1067+
10221068
def infer_section_type_from_path(file_path: str | Path) -> str:
10231069
"""Infer the section type from a filename whose stem matches one of the
10241070
known section names. Returns ``"unknown"`` if the stem is not recognised.
@@ -1162,6 +1208,102 @@ def execute(
11621208
"response_byte_count": bucket_byte_count,
11631209
"user_prompt": user_content,
11641210
}
1211+
1212+
# Append the first-pass assistant turn so the second pass (and
1213+
# subsequent buckets) can see what was already produced and avoid
1214+
# duplicating it.
1215+
assistant_content_first = json.dumps(obj.model_dump(), separators=(",", ":"))
1216+
accumulated_chat.append(
1217+
ChatMessage(role=MessageRole.ASSISTANT, content=assistant_content_first)
1218+
)
1219+
1220+
# Second pass: for scored-list buckets only, ask the LLM what it
1221+
# missed. Smaller models can mis-count near the per-bucket cap and
1222+
# silently drop high-signal items on a single pass; the two-batch
1223+
# protocol keeps each call's cognitive load comparable to the
1224+
# original flow, with the deterministic scorer (annotate_scored_items
1225+
# below) picking the survivors from the combined pool.
1226+
if spec.field_name in SCORED_LIST_FIELDS:
1227+
second_pass_user_content = SECOND_PASS_USER_PROMPT_TEMPLATE.format(
1228+
field_name=spec.field_name
1229+
)
1230+
accumulated_chat.append(
1231+
ChatMessage(role=MessageRole.USER, content=second_pass_user_content)
1232+
)
1233+
1234+
second_pass_start = time.perf_counter()
1235+
second_pass_obj = None
1236+
second_pass_chat_response = None
1237+
second_pass_last_error: Optional[Exception] = None
1238+
for retry in range(PER_BUCKET_MAX_ATTEMPTS):
1239+
logger.debug(
1240+
f"Bucket {spec.field_name} second pass: starting LLM call "
1241+
f"(attempt {retry + 1}/{PER_BUCKET_MAX_ATTEMPTS})"
1242+
)
1243+
try:
1244+
second_pass_chat_response = sllm.chat(accumulated_chat)
1245+
second_pass_obj = second_pass_chat_response.raw
1246+
if second_pass_obj is None:
1247+
raise ValueError(
1248+
f"Structured LLM returned None for bucket "
1249+
f"{spec.field_name!r} (second pass)."
1250+
)
1251+
break
1252+
except Exception as e:
1253+
second_pass_last_error = e
1254+
logger.warning(
1255+
f"Bucket {spec.field_name} second pass attempt "
1256+
f"{retry + 1} failed: {type(e).__name__}: "
1257+
f"{str(e)[:160]}"
1258+
)
1259+
if second_pass_obj is None:
1260+
raise ValueError(
1261+
f"Bucket {spec.field_name!r} second pass failed after "
1262+
f"{PER_BUCKET_MAX_ATTEMPTS} attempts. Last error: "
1263+
f"{type(second_pass_last_error).__name__}: "
1264+
f"{second_pass_last_error}"
1265+
) from second_pass_last_error
1266+
second_pass_duration = int(ceil(time.perf_counter() - second_pass_start))
1267+
second_pass_byte_count = len(
1268+
second_pass_chat_response.message.content.encode("utf-8")
1269+
)
1270+
logger.info(
1271+
f"Bucket {spec.field_name} second pass: completed in "
1272+
f"{second_pass_duration}s, {second_pass_byte_count} bytes"
1273+
)
1274+
1275+
first_pass_items = list(raw_field_value or [])
1276+
second_pass_items = list(
1277+
getattr(second_pass_obj, spec.field_name) or []
1278+
)
1279+
merged_items, newly_added_count = merge_second_pass_items(
1280+
first_pass_items, second_pass_items
1281+
)
1282+
raw_field_value = merged_items
1283+
1284+
# Append the second-pass assistant turn so subsequent buckets
1285+
# see the full pool the LLM produced.
1286+
assistant_content_second = json.dumps(
1287+
second_pass_obj.model_dump(), separators=(",", ":")
1288+
)
1289+
accumulated_chat.append(
1290+
ChatMessage(
1291+
role=MessageRole.ASSISTANT,
1292+
content=assistant_content_second,
1293+
)
1294+
)
1295+
1296+
bucket_metadata.update(
1297+
{
1298+
"second_pass_duration": second_pass_duration,
1299+
"second_pass_response_byte_count": second_pass_byte_count,
1300+
"second_pass_user_prompt": second_pass_user_content,
1301+
"first_pass_item_count": len(first_pass_items),
1302+
"second_pass_item_count": len(second_pass_items),
1303+
"newly_added_count": newly_added_count,
1304+
}
1305+
)
1306+
11651307
if spec.field_name in SCORED_LIST_FIELDS:
11661308
bucket_values[spec.field_name], scored_items = annotate_scored_items(
11671309
raw_field_value, section_markdown, spec.field_name
@@ -1170,13 +1312,6 @@ def execute(
11701312
else:
11711313
bucket_values[spec.field_name] = raw_field_value
11721314

1173-
# Append the assistant turn as compact JSON so the next bucket call
1174-
# can see what has already been produced and avoid duplicating it.
1175-
assistant_content = json.dumps(obj.model_dump(), separators=(",", ":"))
1176-
accumulated_chat.append(
1177-
ChatMessage(role=MessageRole.ASSISTANT, content=assistant_content)
1178-
)
1179-
11801315
per_bucket_metadata[spec.field_name] = bucket_metadata
11811316

11821317
total_duration = int(ceil(time.perf_counter() - total_start))

worker_plan/worker_plan_internal/parameter_extraction/tests/test_compress_report_section.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,55 @@ def test_composite_score_prefers_quantified_over_prose() -> None:
393393
# numeric-density bonus alone.
394394
assert kept[0].line_english.startswith("Reserve buffer is 15%")
395395
assert kept[1].line_english.startswith("The plan emphasizes community")
396+
397+
398+
def test_merge_second_pass_items_keeps_first_pass_when_second_pass_empty() -> None:
399+
from worker_plan_internal.parameter_extraction.compress_report_section import (
400+
merge_second_pass_items,
401+
)
402+
403+
first = [_si("alpha", quote="alpha quote"), _si("beta", quote="beta quote")]
404+
merged, added = merge_second_pass_items(first, [])
405+
assert added == 0
406+
assert [item.line_english for item in merged] == ["alpha", "beta"]
407+
408+
409+
def test_merge_second_pass_items_appends_genuinely_new_items() -> None:
410+
from worker_plan_internal.parameter_extraction.compress_report_section import (
411+
merge_second_pass_items,
412+
)
413+
414+
first = [_si("alpha", quote="alpha quote")]
415+
second = [_si("gamma", quote="gamma quote"), _si("delta", quote="delta quote")]
416+
merged, added = merge_second_pass_items(first, second)
417+
assert added == 2
418+
assert [item.line_english for item in merged] == ["alpha", "gamma", "delta"]
419+
420+
421+
def test_merge_second_pass_items_dedupes_by_normalised_source_quote() -> None:
422+
"""Second pass occasionally re-emits a first-pass item with surface
423+
differences (case, whitespace, punctuation). Normalisation should catch
424+
these so the merged pool does not double-count."""
425+
from worker_plan_internal.parameter_extraction.compress_report_section import (
426+
merge_second_pass_items,
427+
)
428+
429+
first = [_si("alpha", quote="Threshold: 100 units")]
430+
second = [
431+
_si("alpha-variant", quote="threshold: 100 UNITS"), # same quote, different casing/punct
432+
_si("beta", quote="some other 50 metric"),
433+
]
434+
merged, added = merge_second_pass_items(first, second)
435+
assert added == 1
436+
assert [item.line_english for item in merged] == ["alpha", "beta"]
437+
438+
439+
def test_merge_second_pass_items_preserves_emit_order() -> None:
440+
from worker_plan_internal.parameter_extraction.compress_report_section import (
441+
merge_second_pass_items,
442+
)
443+
444+
first = [_si("a", quote="q1"), _si("b", quote="q2")]
445+
second = [_si("c", quote="q3"), _si("d", quote="q4"), _si("e", quote="q5")]
446+
merged, _ = merge_second_pass_items(first, second)
447+
assert [item.line_english for item in merged] == ["a", "b", "c", "d", "e"]

0 commit comments

Comments
 (0)