Skip to content

Commit 61ad37f

Browse files
authored
Merge pull request #62 from FluffyAIcode/AgentMemory/kakeyalattice-gpu-b41e
v1.6.1 — Gemma-4 drop-in (per-layer head_dim + mask sizing) + bit-packing as unified comparison standard
2 parents 198b4f9 + f01fc81 commit 61ad37f

14 files changed

Lines changed: 626 additions & 80 deletions

File tree

CHANGELOG.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,37 @@
11
# Changelog
22

3+
## v1.6.1 — 2026-06-15
4+
5+
**Drop-in support for heterogeneous per-layer head_dim (Gemma-4) + bit-packing
6+
adopted as the unified compression-ratio standard.**
7+
8+
### Fixed
9+
- **Per-layer head_dim.** Models whose layers expose different K/V head dims now
10+
work drop-in. Gemma-4-26B mixes `sliding_attention` (head_dim=256) and
11+
`full_attention` (head_dim=512) layers, which raised
12+
`AssertionError: expected last dim 256, got 512`. Each layer's codec is now
13+
built lazily from the head_dim actually observed at that layer
14+
(`KakeyaLatticeQuantizedCache`, `KakeyaLatticeCache`, `TurboQuantPackedCache`).
15+
- **Attention-mask sizes.** The int-storage caches keep their compressed state
16+
outside `self.layers`, so transformers-5's `DynamicCache.get_mask_sizes` fell
17+
through to `(query_length, 0)` and corrupted Gemma-4's sliding-window /
18+
multimodal blockwise mask during multi-step decode (CUDA device-side assert).
19+
`get_mask_sizes` is now overridden to report the true cache length.
20+
- Verified on H200: **Gemma-4-26B generates end-to-end** with
21+
`KakeyaLatticePackedCache` (E8 Q=38), real CR **2.44×**, lossless; per-layer
22+
codecs 256 (sliding) / 512 (full). Qwen3-4B regression unchanged.
23+
24+
### Changed
25+
- **Bit-packing + iso-quality is now the unified comparison standard.** All
26+
codec-vs-codec comparisons (KakeyaLattice and the TurboQuant baseline) use the
27+
bit-packed caches (`KakeyaLatticePackedCache`, `TurboQuantPackedCache`) **and**
28+
match quality (each codec taken at the operating point meeting a fixed |Δppl|
29+
threshold, then real bytes compared). Raw CR at unmatched bit budgets is never
30+
used to rank codecs. Iso-ppl result on Qwen3-4B (|Δppl| ≤ 2 %): **E8 +7.7 %,
31+
D4 +5.0 %** real-byte advantage over TurboQuant. The int8
32+
`KakeyaLatticeQuantizedCache` (1.94×) remains as the simpler, dependency-free
33+
storage option. README and reports updated accordingly.
34+
335
## v1.6.0 — 2026-06-15
436

537
**fix codec.roundtrip bug — contiguous, directly-SDPA-feedable K/V decode.**

README.md

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,33 @@
33
> **A D4 / E8 nested-lattice codec that realises a discrete *Kakeya
44
> cover* over the direction sphere of transformer KV activations.**
55
>
6-
> Two `transformers.DynamicCache` subclasses ship together:
6+
> Three `transformers.DynamicCache` subclasses ship together:
77
>
8+
> - **`KakeyaLatticePackedCache`** — stores **bit-packed lattice codes**.
9+
> **Real ~2.4× HBM compression** (D4 Q=38 ≈ 2.46×, E8 Q=38 ≈ 2.37×;
10+
> measured end-to-end on Qwen3-4B / H200). **This is the unified
11+
> comparison standard** for all reported compression ratios as of v1.6.
812
> - **`KakeyaLatticeQuantizedCache`** — stores **int8 lattice indices**.
9-
> **Real ~1.94× HBM compression** (measured at the tensor-byte
10-
> level; see [`reports/v1_5_release/hbm_savings/REAL_HBM_PROOF.md`](reports/v1_5_release/hbm_savings/REAL_HBM_PROOF.md)).
13+
> Simpler, dependency-free storage; **real ~1.94× HBM compression**
14+
> (the int8-vs-6.3-bit overhead). Bit-identical reconstruction to the
15+
> packed cache — use it when you prefer the simplest storage type.
1116
> - **`KakeyaLatticeCache`** — stores reconstructed bf16. **Zero HBM
1217
> savings**; use as a reconstruction-quality probe.
1318
>
1419
> At the **codec bit-rate level** (a Q=38 lattice vector needs ~6.3
1520
> bits per coordinate, vs 16 bits for bf16), the achievable ceiling
1621
> is **2.4×–2.8× compression at <1 % perplexity loss** on Qwen3,
17-
> Llama-3, DeepSeek, GLM-4, and Gemma. The current int8
18-
> implementation hits **1.94×** of that ceiling; the gap to 2.4× is
19-
> bit-packed int storage, the v1.6 work item.
22+
> Llama-3, DeepSeek, GLM-4, and Gemma. **`KakeyaLatticePackedCache`
23+
> realises that ceiling as real bytes** (v1.6); the int8 cache trades
24+
> ~25 % of it for a plain storage type. **All compression-ratio
25+
> comparisons in this repo use (1) the bit-packed caches** (both for
26+
> KakeyaLattice and the TurboQuant baseline) **and (2) iso-quality
27+
> matching** — each codec is taken at the operating point meeting a fixed
28+
> |Δppl| threshold, then real bytes are compared. (Raw CR at unmatched bit
29+
> budgets is never used to rank codecs — a lower-bit point trivially shows a
30+
> higher CR at worse quality.) Iso-ppl result on Qwen3-4B (|Δppl| ≤ 2 %):
31+
> **E8 +7.7 %, D4 +5.0 %** real-byte advantage over TurboQuant; see
32+
> [`reports/v1_5_release/bitpack_vs_tq_2026-06-15/`](reports/v1_5_release/bitpack_vs_tq_2026-06-15/).
2033
>
2134
> `pip install kakeyalattice`.
2235
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Gemma-4-26B heterogeneous-head_dim check / repro / fix verification.
2+
3+
Gemma-4 uses head_dim=256 (sliding_attention layers) and global_head_dim=512
4+
(full_attention layers). This script:
5+
1. loads the model (text-only generate),
6+
2. inspects the per-layer K head_dim from a bf16 DynamicCache,
7+
3. tries KakeyaLatticePackedCache (E8 Q=38) and reports success/CR/coherence
8+
or the assertion (pre-fix repro).
9+
"""
10+
from __future__ import annotations
11+
import argparse, json, os, traceback
12+
import torch
13+
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
14+
from kakeyalattice.hf import KakeyaLatticePackedCache
15+
16+
17+
def layer_kv_dims(cache):
18+
dims = []
19+
if hasattr(cache, "layers"):
20+
for layer in cache.layers:
21+
k = getattr(layer, "keys", None)
22+
dims.append(None if k is None else int(k.shape[-1]))
23+
return dims
24+
25+
26+
def main():
27+
ap = argparse.ArgumentParser()
28+
ap.add_argument("--model", default="google/gemma-4-26B-A4B-it")
29+
ap.add_argument("--max-new", type=int, default=24)
30+
ap.add_argument("--out", default="/root/kakeyalattice-test/reports/v1_5_release/gemma4_hetero_headdim_2026-06-15/gemma4_check.json")
31+
args = ap.parse_args()
32+
dev = "cuda"
33+
tok = AutoTokenizer.from_pretrained(args.model)
34+
model = AutoModelForCausalLM.from_pretrained(args.model, dtype=torch.bfloat16, device_map=dev).eval()
35+
cfg = model.config
36+
tcfg = getattr(cfg, "text_config", cfg)
37+
L = tcfg.num_hidden_layers
38+
hd = getattr(tcfg, "head_dim", None)
39+
ghd = getattr(tcfg, "global_head_dim", None)
40+
print(f"[cfg] layers={L} head_dim={hd} global_head_dim={ghd}", flush=True)
41+
print(f"[cfg] layer_types={getattr(tcfg,'layer_types',None)}", flush=True)
42+
43+
msgs = [{"role": "user", "content": "In one sentence, what is lattice quantization?"}]
44+
enc = tok.apply_chat_template(msgs, add_generation_prompt=True, return_tensors="pt", return_dict=True)
45+
ids = enc["input_ids"].to(dev)
46+
in_len = ids.shape[1]
47+
gen = dict(max_new_tokens=args.max_new, do_sample=False, use_cache=True)
48+
49+
report = {"model": args.model, "layers": L, "head_dim": hd, "global_head_dim": ghd}
50+
51+
# 1) bf16 baseline + per-layer K dims
52+
cacheA = DynamicCache()
53+
with torch.inference_mode():
54+
outA = model.generate(ids, past_key_values=cacheA, **gen)
55+
dimsA = layer_kv_dims(cacheA)
56+
base_bytes = sum(
57+
(layer.keys.element_size()*layer.keys.numel() + layer.values.element_size()*layer.values.numel())
58+
for layer in cacheA.layers if getattr(layer, "keys", None) is not None)
59+
textA = tok.decode(outA[0][in_len:], skip_special_tokens=True)
60+
print(f"[bf16] per-layer K head_dim = {dimsA}", flush=True)
61+
print(f"[bf16] distinct dims = {sorted(set(d for d in dimsA if d))}", flush=True)
62+
print(f"[bf16] text: {textA[:160]}", flush=True)
63+
report["per_layer_kv_dim"] = dimsA
64+
report["distinct_dims"] = sorted(set(d for d in dimsA if d))
65+
report["bf16_text"] = textA
66+
report["bf16_kv_bytes"] = base_bytes
67+
seqA = int(outA.shape[1]); del cacheA, outA; torch.cuda.empty_cache()
68+
69+
# 2) packed cache (E8 Q=38)
70+
try:
71+
cacheB = KakeyaLatticePackedCache(variant="e8", q_range=38,
72+
num_hidden_layers=L, head_dim=hd or 256, device=dev)
73+
with torch.inference_mode():
74+
outB = model.generate(ids, past_key_values=cacheB, **gen)
75+
textB = tok.decode(outB[0][in_len:], skip_special_tokens=True)
76+
kb = cacheB.kv_storage_bytes()
77+
cr = base_bytes / kb if kb else None
78+
codec_dims = {li: (c.D_shape if c is not None else None)
79+
for li, c in enumerate(cacheB._codecs)}
80+
print(f"[packed] OK seq={int(outB.shape[1])} kv={kb/2**20:.2f}MiB realCR={cr:.3f}x lossless={cacheB.packed_pack_unpack_ok()}", flush=True)
81+
print(f"[packed] per-layer codec D_shape = {codec_dims}", flush=True)
82+
print(f"[packed] text: {textB[:160]}", flush=True)
83+
report.update({"packed_ok": True, "packed_kv_bytes": kb, "packed_real_cr": cr,
84+
"packed_text": textB, "codec_dims": codec_dims,
85+
"lossless": cacheB.packed_pack_unpack_ok()})
86+
except Exception as e:
87+
print(f"[packed] FAILED: {type(e).__name__}: {e}", flush=True)
88+
traceback.print_exc()
89+
report.update({"packed_ok": False, "error": f"{type(e).__name__}: {e}"})
90+
91+
os.makedirs(os.path.dirname(args.out), exist_ok=True)
92+
with open(args.out, "w") as f:
93+
json.dump(report, f, indent=2, default=str)
94+
print(f"[out] {args.out}", flush=True)
95+
96+
97+
if __name__ == "__main__":
98+
main()

benchmarks/bitpack_vs_tq/verify_packed_e2e.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
1-
"""End-to-end REAL bit-packed storage on a live model.
1+
"""End-to-end REAL bit-packed storage on a live model (per-operating-point CR).
22
3-
Generates with the bit-packed caches on Qwen3-4B and reports the real packed
4-
KV footprint vs bf16 DynamicCache:
5-
* KakeyaLatticePackedCache (D4 Q=38) -> ~2.46x
6-
* KakeyaLatticePackedCache (E8 Q=38) -> ~2.42x
7-
* TurboQuantPackedCache (b=4) -> ~3.76x (lower quality; see iso-ppl)
8-
Also verifies the pack->unpack cycle is lossless (so quality == unpacked cache).
3+
Generates with the bit-packed caches on Qwen3-4B and reports the real packed KV
4+
footprint vs bf16 DynamicCache, and verifies pack->unpack is lossless.
5+
6+
!!! NOT A FAIR HEAD-TO-HEAD !!!
7+
The points below (D4/E8 @ Q=38, TurboQuant @ b=4) are at DIFFERENT bit budgets /
8+
quality, so their raw CRs are NOT comparable: TurboQuant b=4 shows a higher CR
9+
ONLY because it is a much more aggressive, much lower-quality point
10+
(|Δppl| ~4.8% vs ~0.2% for KakeyaLattice Q=38). Comparing CR across unmatched
11+
quality is meaningless.
12+
13+
>>> The canonical KakeyaLattice-vs-TurboQuant comparison is ISO-QUALITY (matched
14+
|Δppl|) and lives in `compare_real_cr.py`. At |Δppl| <= 2% on Qwen3-4B the
15+
real-byte winners are E8 +7.7% / D4 +5.0% over TurboQuant. <<<
16+
17+
This script is only a sanity check that each packed cache works end-to-end and
18+
hits its expected real CR at its own operating point.
919
"""
1020
from __future__ import annotations
1121
import argparse, json, time
@@ -80,6 +90,8 @@ def run(make_cache, name):
8090

8191
print(f"model={args.model} layers={L} head_dim={hd} seq={seqA}")
8292
print(f"bf16 DynamicCache KV bytes = {base_bytes:,} ({base_bytes/2**20:.2f} MiB)")
93+
print("NOTE: per-operating-point raw CR — NOT quality-matched. Do not rank "
94+
"codecs by these numbers. Iso-quality comparison: compare_real_cr.py.")
8395
print(f"{'cache':<26} {'KV MiB':>9} {'real CR':>9} {'lossless':>9} {'time(s)':>8}")
8496
rows = []
8597
for r in runs:
@@ -96,7 +108,12 @@ def run(make_cache, name):
96108
with open(args.out, "w") as f:
97109
json.dump({"model": args.model, "gpu": torch.cuda.get_device_name(0),
98110
"head_dim": hd, "layers": L, "seq": seqA,
99-
"bf16_kv_bytes": base_bytes, "runs": rows}, f, indent=2)
111+
"bf16_kv_bytes": base_bytes, "runs": rows,
112+
"note": ("per-operating-point raw CR, NOT quality-matched; "
113+
"ranking codecs by these is meaningless. Iso-quality "
114+
"(matched |Dppl|) comparison is in compare_real_cr.py."),
115+
"iso_quality_comparison": "benchmarks/bitpack_vs_tq/compare_real_cr.py"},
116+
f, indent=2)
100117
print(f"[out] {args.out}")
101118

102119

kakeyalattice/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"
88

99
[project]
1010
name = "kakeyalattice"
11-
version = "1.6.0"
11+
version = "1.6.1"
1212
description = "Nested-lattice KV-cache compression for LLM inference: Zamir-Feder D4 and E8 variants with shaping gain over scalar quantisation."
1313
readme = "README.md"
1414
# NOTE: we intentionally declare the license only via classifier

kakeyalattice/python/kakeyalattice/hf/cache.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -173,36 +173,58 @@ def __init__(
173173
warnings.warn(msg, UserWarning, stacklevel=2)
174174
logger.warning(msg)
175175

176-
# One codec instance per layer. Codec has no cross-layer state
177-
# but per-layer instantiation allows future per-layer Q sweeps
178-
# without re-architecting.
179-
self._codecs: list[Any | None] = []
180-
if self._supports_lattice:
181-
self._init_codecs()
176+
# Per-layer codecs are built LAZILY on first update(), keyed by the
177+
# head_dim actually observed at each layer — so models with
178+
# heterogeneous per-layer head_dim (e.g. Gemma-4 sliding=256 / full=512)
179+
# work drop-in. ``head_dim`` is the declared default for back-compat.
180+
self._codecs: list[Any | None] = [None] * self.num_hidden_layers
181+
self._raw_layers: set[int] = set()
182182

183183
# Fire counters for sanity / audit.
184184
self.codec_fired_per_layer: dict[int, int] = {}
185185
self.skip_fired_per_layer: dict[int, int] = {}
186186

187187
# ----- codec management -----
188188

189-
def _init_codecs(self) -> None:
189+
def _codec_cls(self):
190190
if self.variant == "d4":
191191
from kakeyalattice import V14KakeyaZamirLatticeGPU as CodecCls
192192
else:
193193
from kakeyalattice import V15KakeyaZamirE8GPU as CodecCls
194+
return CodecCls
194195

195-
self._codecs = []
196-
for layer_idx in range(self.num_hidden_layers):
197-
if self._is_boundary_layer(layer_idx):
198-
self._codecs.append(None)
199-
else:
200-
codec = CodecCls(
201-
D=self.head_dim,
202-
q_range=self.q_range,
203-
device=str(self.device),
196+
def _get_codec(self, layer_idx: int, observed_dim: int):
197+
"""Lazily build/return the per-layer codec from the observed head_dim.
198+
Returns None for raw bf16 (boundary / incompatible-with-strict=False)."""
199+
if self._is_boundary_layer(layer_idx) or layer_idx in self._raw_layers:
200+
return None
201+
codec = self._codecs[layer_idx]
202+
if codec is not None:
203+
if codec.D_shape != observed_dim:
204+
raise ValueError(
205+
f"layer {layer_idx} head_dim changed "
206+
f"{codec.D_shape} -> {observed_dim} between updates"
204207
)
205-
self._codecs.append(codec)
208+
return codec
209+
bd = self._block_dim
210+
is_pow2 = observed_dim > 0 and (observed_dim & (observed_dim - 1)) == 0
211+
if (observed_dim % bd != 0) or not is_pow2:
212+
msg = (
213+
f"KakeyaLatticeCache(variant={self.variant!r}): layer "
214+
f"{layer_idx} head_dim={observed_dim} is incompatible "
215+
f"(need a power of 2 divisible by {bd})."
216+
)
217+
if self.strict:
218+
raise ValueError(msg + " Pass strict=False to keep raw bf16.")
219+
warnings.warn(msg + " strict=False: layer kept as raw bf16.",
220+
UserWarning, stacklevel=2)
221+
self._raw_layers.add(layer_idx)
222+
return None
223+
codec = self._codec_cls()(
224+
D=observed_dim, q_range=self.q_range, device=str(self.device),
225+
)
226+
self._codecs[layer_idx] = codec
227+
return codec
206228

207229
def _is_boundary_layer(self, layer_idx: int) -> bool:
208230
if self.boundary <= 0:
@@ -239,21 +261,17 @@ def update(
239261
"""Roundtrip K and V through the per-layer codec, then delegate
240262
to ``DynamicCache.update`` to concat with existing cache state.
241263
"""
242-
# Fast path: codec disabled (strict=False on incompatible model,
243-
# or boundary layer).
244-
if (
245-
not self._supports_lattice
246-
or layer_idx >= len(self._codecs)
247-
or self._codecs[layer_idx] is None
248-
):
264+
# Lazily resolve the per-layer codec from the observed head_dim so
265+
# heterogeneous-head_dim models work drop-in. None => raw bf16.
266+
codec = self._get_codec(layer_idx, key_states.shape[-1])
267+
if codec is None:
249268
self.skip_fired_per_layer[layer_idx] = (
250269
self.skip_fired_per_layer.get(layer_idx, 0) + 1
251270
)
252271
return super().update(
253272
key_states, value_states, layer_idx, *args, **kwargs
254273
)
255274

256-
codec = self._codecs[layer_idx]
257275
k_rt = self._roundtrip(key_states, codec)
258276
v_rt = self._roundtrip(value_states, codec)
259277
self.codec_fired_per_layer[layer_idx] = (

0 commit comments

Comments
 (0)