Skip to content

Commit 0f1c122

Browse files
authored
Merge pull request #79 from FluffyAIcode/AgentMemory/v04-pr-k1h-attention-window-metric-8e7f
PR-K1.H: structural effective attention-window metric for K1.E NIAH harness
2 parents b272765 + 793cd15 commit 0f1c122

10 files changed

Lines changed: 2125 additions & 47 deletions

inference_engine/v04/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@
4747
DEFAULT_NEEDLE_PREFIXES,
4848
NIAHEvalResult,
4949
NIAHSample,
50+
aggregate_attention_window_metrics,
5051
aggregate_recall,
52+
compute_effective_attention_window,
5153
evaluate,
54+
format_attention_window_summary,
5255
format_memory_summary,
5356
greedy_decode_oracle,
5457
greedy_decode_sink_window,
@@ -90,4 +93,8 @@
9093
"format_memory_summary",
9194
"record_memory",
9295
"reset_memory_peak",
96+
# K1.H — effective attention-window metric
97+
"aggregate_attention_window_metrics",
98+
"compute_effective_attention_window",
99+
"format_attention_window_summary",
93100
]

inference_engine/v04/niah_eval.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,197 @@ def record_memory(device: torch.device) -> Dict[str, Any]:
632632
}
633633

634634

635+
# ---------------------------------------------------------------------------
636+
# Effective attention-window metric
637+
# ---------------------------------------------------------------------------
638+
#
639+
# ADR 0008 §11.5 §"Five properties" item 2 — "approximates full
640+
# attention intelligence" — turns on a structural property: how
641+
# many of the prompt's preceding key positions can the verifier's
642+
# last query *actually* attend to? In v0.3 sink+window, that
643+
# number is bounded at ``sink + window`` (≈ 68 for sink=4 +
644+
# window=64) regardless of context length T, so the verifier sees
645+
# ~5 % of context at T=1.4k and ~0.07 % at T=100k — a direct
646+
# intelligence cap. v0.4's dLM K/V Restoration design fills the
647+
# evicted positions with reconstructed K/V, so the structural
648+
# attention range is the full preceding context T regardless of
649+
# the verifier's local cache size.
650+
#
651+
# This metric is *structural*, not behavioural. Behavioural
652+
# attention-mass measurement (count keys whose post-softmax weight
653+
# exceeds ε) requires materialising the [B, H, T, T] attention
654+
# matrix and is incompatible with the SDPA path K1.F enabled for
655+
# long-context runs (SDPA fuses softmax inside the kernel and
656+
# does not return weights). The structural metric is sufficient
657+
# to answer the user-facing question "did the inference engine
658+
# reduce the verifier's intelligence by capping its attention
659+
# range?" — and it composes cleanly with the recall metric: if
660+
# v0.4 restores recall to oracle parity *and* preserves full
661+
# structural attention range, then the architecture really does
662+
# satisfy the "no intelligence loss" claim of ADR 0008 §11.5.
663+
#
664+
# Knowing the metric is derived from the configuration alone (no
665+
# instrumentation required, no SDPA incompatibility) lets it be
666+
# computed at any context length, including the canonical 100k
667+
# rung that K1.F unlocked.
668+
669+
670+
def compute_effective_attention_window(
671+
config_name: str,
672+
*,
673+
seq_len: int,
674+
sink_size: int,
675+
window_size: int,
676+
) -> Dict[str, Any]:
677+
"""Compute the structural effective attention window for one
678+
sample under one verifier configuration.
679+
680+
Parameters
681+
----------
682+
config_name
683+
One of ``"oracle_full_attention"``, ``"v03_sink_window"``, or
684+
``"v04_dlm_restored"``. Other values raise ``ValueError``.
685+
seq_len
686+
Prompt token length T (i.e. the number of preceding keys
687+
available to the last query at the first decode step).
688+
The metric for later decode steps differs by at most
689+
``max_new_tokens``, which is negligible for the long-context
690+
runs this metric targets.
691+
sink_size, window_size
692+
v0.3 cache shape. Required for ``v03_sink_window``; ignored
693+
for ``oracle_full_attention`` and ``v04_dlm_restored`` (kept
694+
in the dict for self-describing JSON evidence).
695+
696+
Returns
697+
-------
698+
dict with keys
699+
700+
* ``config``: echoed ``config_name``.
701+
* ``seq_len``: echoed ``seq_len``.
702+
* ``effective_keys_at_last_query``: number of preceding key
703+
positions the last query can structurally attend to. Equals
704+
``seq_len`` for oracle and v0.4; equals
705+
``min(sink + window, seq_len)`` for v0.3.
706+
* ``effective_attention_fraction``: that count divided by
707+
``seq_len`` — a unit-free intelligence-coverage metric. ≈ 1.0
708+
for oracle and v0.4; bounded < 1 for v0.3 once
709+
``seq_len > sink + window``.
710+
* ``structural_constraint``: human-readable description of the
711+
constraint (used by the run summary).
712+
713+
The ``v04_dlm_restored`` entry assumes the architecture's claim
714+
holds — that ``prepare_restored_attention_kv`` fills evicted
715+
positions with proposer K/V so the verifier's attention can
716+
reach all preceding tokens. If that contract ever regresses,
717+
the recall metric (in the same JSON) will diverge from oracle,
718+
making the failure visible. The two metrics together form a
719+
cross-check.
720+
"""
721+
if seq_len < 0:
722+
raise ValueError(f"seq_len must be non-negative; got {seq_len}")
723+
if sink_size < 0 or window_size < 0:
724+
raise ValueError(
725+
f"sink_size={sink_size}, window_size={window_size} must be "
726+
"non-negative"
727+
)
728+
if config_name == "oracle_full_attention":
729+
accessible = seq_len
730+
constraint = "causal"
731+
elif config_name == "v03_sink_window":
732+
accessible = min(sink_size + window_size, seq_len)
733+
constraint = f"sink={sink_size}+window={window_size}"
734+
elif config_name == "v04_dlm_restored":
735+
accessible = seq_len
736+
constraint = (
737+
f"causal_with_dlm_reconstruction (local_cache="
738+
f"sink={sink_size}+window={window_size})"
739+
)
740+
else:
741+
raise ValueError(
742+
f"unknown config_name {config_name!r}; expected one of "
743+
"'oracle_full_attention', 'v03_sink_window', 'v04_dlm_restored'"
744+
)
745+
fraction = (accessible / seq_len) if seq_len > 0 else 0.0
746+
return {
747+
"config": config_name,
748+
"seq_len": seq_len,
749+
"effective_keys_at_last_query": int(accessible),
750+
"effective_attention_fraction": float(fraction),
751+
"structural_constraint": constraint,
752+
}
753+
754+
755+
def aggregate_attention_window_metrics(
756+
config_name: str,
757+
*,
758+
prompt_token_lens: Sequence[int],
759+
sink_size: int,
760+
window_size: int,
761+
) -> Dict[str, Any]:
762+
"""Aggregate per-sample :func:`compute_effective_attention_window`
763+
output across an evaluation set.
764+
765+
Returns the mean / min / max / median of
766+
``effective_keys_at_last_query`` and
767+
``effective_attention_fraction`` plus the constraint label and
768+
the per-sample list (kept for full transparency of the JSON
769+
evidence). Empty ``prompt_token_lens`` raises ``ValueError``.
770+
"""
771+
if not prompt_token_lens:
772+
raise ValueError("prompt_token_lens must be non-empty")
773+
per_sample = [
774+
compute_effective_attention_window(
775+
config_name,
776+
seq_len=int(t),
777+
sink_size=sink_size,
778+
window_size=window_size,
779+
)
780+
for t in prompt_token_lens
781+
]
782+
keys = [s["effective_keys_at_last_query"] for s in per_sample]
783+
fracs = [s["effective_attention_fraction"] for s in per_sample]
784+
785+
def _median(xs: List[float]) -> float:
786+
srt = sorted(xs)
787+
n = len(srt)
788+
return srt[n // 2] if n % 2 == 1 else (srt[n // 2 - 1] + srt[n // 2]) / 2
789+
790+
constraint = per_sample[0]["structural_constraint"]
791+
return {
792+
"config": config_name,
793+
"structural_constraint": constraint,
794+
"samples_total": len(per_sample),
795+
"effective_keys_at_last_query_mean": sum(keys) / len(keys),
796+
"effective_keys_at_last_query_min": min(keys),
797+
"effective_keys_at_last_query_max": max(keys),
798+
"effective_keys_at_last_query_median": _median([float(k) for k in keys]),
799+
"effective_attention_fraction_mean": sum(fracs) / len(fracs),
800+
"effective_attention_fraction_min": min(fracs),
801+
"effective_attention_fraction_max": max(fracs),
802+
"effective_attention_fraction_median": _median(fracs),
803+
"per_sample": per_sample,
804+
}
805+
806+
807+
def format_attention_window_summary(metrics: Dict[str, Any]) -> str:
808+
"""Return a one-line human-readable summary of the aggregated
809+
attention-window metrics.
810+
811+
Mirrors :func:`format_memory_summary` so runners can print
812+
per-config attention coverage at the same density as latency
813+
and recall.
814+
"""
815+
keys_mean = metrics.get("effective_keys_at_last_query_mean")
816+
frac_mean = metrics.get("effective_attention_fraction_mean")
817+
constraint = metrics.get("structural_constraint", "?")
818+
if keys_mean is None or frac_mean is None:
819+
return f"attn_window: n/a (constraint={constraint})"
820+
return (
821+
f"attn_window: mean_keys={keys_mean:.0f} "
822+
f"({frac_mean * 100:.2f}% of context) constraint={constraint}"
823+
)
824+
825+
635826
def format_memory_summary(snapshot: Dict[str, Any]) -> str:
636827
"""Return a one-line human-readable summary of a memory snapshot.
637828

0 commit comments

Comments
 (0)