Skip to content

Commit add46de

Browse files
Server reads KV bytes from engine.kv_state(), not the slab pool
The 2026-05-30 short test #2 confirmed: - bench in-flight metrics poller works (7313 samples / 58 turns, median 110 / turn, max 429) - orphan-session fix works (idle pool_in_use settles to 0) - slab IS acquired during turns (in-flight pool_in_use peak = 1) But scheduler_kv_live_bytes still read 0.0 in 58/58 turns. Root cause: SlabPool.live_kv_bytes (added in PR #24) sums slabs' live_kv_bytes_override, which is only ever set by PooledVerifier — and PooledVerifier is never wired into scripts/serve.py. Wrapping the verifier in PooledVerifier requires plumbing the slab through Scheduler -> Engine -> SpeculativeDecoder -> Verifier, which is a non-trivial structural change. Cheaper fix ----------- The verifier already holds the real KV cache tensors and is the canonical source of truth for live KV bytes. Expose it directly: - kv_cache_proposer/verifier.py SinkWindowVerifier.live_kv_bytes() -> int Sums layer.keys.numel() * element_size() + same for values across the cache. Returns 0 when cache is None (between reset() and prefill()). _record_peak_kv now reads through it. - inference_engine/backends/mlx/verifier.py MLXSinkWindowVerifier.live_kv_bytes() -> int Same surface as the CPU verifier; reads from cache_ops.total_kv_bytes(self.cache). _record_peak_kv now reads through it too. - inference_engine/server/engine.py Engine protocol: new kv_state() -> int method. SpeculativeEngine: returns decoder.verifier.live_kv_bytes() if exists else 0. Defensive on the verifier surface so legacy verifiers that don't expose the optional method don't break the engine. - inference_engine/server/app.py /metrics handler: replace kv_live_bytes=pool.live_kv_bytes with kv_live_bytes=int(engine.kv_state()) The pool-side gauge is preserved as infrastructure; once PooledVerifier is wired (post-v0.3.0), the slab will report correctly via override and aggregate matches engine. For v0.3, the engine is the source of truth. Thread safety ------------- Both verifiers' live_kv_bytes() reads are int-attribute walks over tensor shape descriptors. CPython torch.Tensor.numel() / mlx array.size are atomic reads — a concurrent worker writing the cache produces some valid intermediate value, never garbage. Documented inline. Tests (no mock; all real concrete classes) ------------------------------------------ tests/core/test_verifier.py + test_live_kv_bytes_zero_before_prefill + test_live_kv_bytes_nonzero_after_prefill + test_live_kv_bytes_returns_zero_when_layer_kv_is_null tests/backends/mlx/test_verifier.py + test_live_kv_bytes_zero_before_prefill + test_live_kv_bytes_nonzero_after_prefill tests/inference_engine/server/test_engine.py + test_kv_state_reads_from_verifier_live_kv_bytes (with concrete _VerifierDouble exposing live_kv_bytes) + test_kv_state_returns_zero_when_verifier_has_no_method + test_kv_state_called_each_invocation (asserts /metrics scrape contract — no caching) tests/inference_engine/server/test_app_metrics_and_auth.py + test_metrics_kv_live_bytes_reflects_engine_kv_state (regression test pinning the fix; uses _KVAwareEngine subclass returning a deterministic non-zero value and asserts it appears in the prometheus text exposition) Test doubles updated (DeterministicEngine in conftest.py, _RaisingEngine in two test files): all return kv_state() == 0 as a no-real-cache default. Verified locally: pytest tests/inference_engine/server/test_engine.py tests/inference_engine/server/test_app_metrics_and_auth.py tests/core/test_verifier.py -> 65 passed pytest tests/inference_engine/ -> 389 passed (no regression) Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
1 parent bc3fbcd commit add46de

11 files changed

Lines changed: 256 additions & 10 deletions

File tree

inference_engine/backends/mlx/verifier.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,26 @@ def _cache_buffer_size(self) -> int:
187187
return 0
188188
return cache_ops.cache_seq_length(self.cache)
189189

190-
def _record_peak_kv(self) -> None:
190+
def live_kv_bytes(self) -> int:
191+
"""Return the current size of the verifier's live KV cache in bytes.
192+
193+
This is the *now* size, not a peak. Reads from any thread:
194+
``cache_ops.total_kv_bytes`` walks the per-layer
195+
:class:`SinkWindowKVCache` instances and sums
196+
``keys.size * keys.dtype.size`` + same for values, all of
197+
which are integer attributes that don't tear under a
198+
concurrent reader. The HTTP ``/metrics`` handler relies on
199+
this property to scrape KV usage during in-flight generation.
200+
201+
Returns 0 when the cache has not been allocated yet (between
202+
``reset()`` and the next ``prefill()``).
203+
"""
191204
if self.cache is None:
192-
return
193-
total = cache_ops.total_kv_bytes(self.cache)
205+
return 0
206+
return cache_ops.total_kv_bytes(self.cache)
207+
208+
def _record_peak_kv(self) -> None:
209+
total = self.live_kv_bytes()
194210
if total > self.stats.peak_kv_bytes:
195211
self.stats.peak_kv_bytes = total
196212

inference_engine/server/app.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,21 @@ async def metrics_endpoint() -> Response:
291291
# Refresh scheduler-state gauges on every scrape so the
292292
# exposition reflects "now" rather than the last
293293
# admission/completion event.
294+
engine_for_kv: Engine = app.state.engine
295+
# Read KV bytes directly from the engine's verifier rather
296+
# than from pool.live_kv_bytes. Rationale: in v0.3 the slab
297+
# is a session ticket (acquired/released per request) — the
298+
# verifier holds the real KV cache tensors and is the
299+
# canonical source of truth. Pool-side accounting only
300+
# populates once PooledVerifier is wired (a post-v0.3.0
301+
# change) and otherwise reads 0 even while the verifier
302+
# cache is several MiB.
294303
metrics.snapshot_scheduler(
295304
active=scheduler.active_count,
296305
pool_in_use=pool.in_use_count,
297306
pool_total=pool.total_count,
298307
pending=scheduler.pending_count,
299-
kv_live_bytes=pool.live_kv_bytes,
308+
kv_live_bytes=int(engine_for_kv.kv_state()),
300309
)
301310
return PlainTextResponse(
302311
content=metrics.render(),

inference_engine/server/engine.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ class Engine(Protocol):
6464
returns ``True``, generation stops at that token boundary.
6565
The callback is the only way streaming routes inject
6666
cancellation signals (e.g. client disconnect).
67+
kv_state
68+
Return the engine's current verifier KV-cache size in
69+
bytes, or 0 if the engine has no real KV cache (test
70+
doubles). Read on every ``/metrics`` scrape to populate
71+
the ``scheduler_kv_live_bytes`` gauge so the ADR 0006
72+
§2.3 long-session memory-stability claim is verifiable
73+
in production.
6774
"""
6875

6976
@property
@@ -83,6 +90,9 @@ def generate(
8390
) -> EngineResult:
8491
... # pragma: no cover - Protocol body, never executed
8592

93+
def kv_state(self) -> int:
94+
... # pragma: no cover - Protocol body, never executed
95+
8696

8797
class SpeculativeEngine:
8898
"""Concrete :class:`Engine` backed by a real SpeculativeDecoder.
@@ -175,3 +185,20 @@ def generate(
175185
verifier_forward_calls=int(result.verifier_forward_calls),
176186
stopped_on_eos=stopped_on_eos,
177187
)
188+
189+
def kv_state(self) -> int:
190+
"""Live KV cache bytes from the underlying verifier.
191+
192+
Reads ``self._decoder.verifier.live_kv_bytes()`` if the
193+
verifier exposes that method (both the CPU and MLX
194+
verifiers in this repository do). Returns 0 if the verifier
195+
is older / a stub that does not. Called from the
196+
``/metrics`` handler on every scrape and must be safe to
197+
call concurrently with the worker thread that is mutating
198+
the verifier's cache (see verifier docstrings for the
199+
thread-safety argument).
200+
"""
201+
live = getattr(self._decoder.verifier, "live_kv_bytes", None)
202+
if live is None:
203+
return 0
204+
return int(live())

kv_cache_proposer/verifier.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,32 @@ def _truncate_tail_in_place(self, drop: int) -> None:
246246
layer.keys = keys[:, :, :keep, :].contiguous()
247247
layer.values = values[:, :, :keep, :].contiguous()
248248

249-
def _record_peak_kv(self) -> None:
249+
def live_kv_bytes(self) -> int:
250+
"""Return the current size of the verifier's live KV cache in bytes.
251+
252+
This is the *now* size, not a peak. Reads cleanly from any
253+
thread (no locks): in CPython, walking ``self.cache.layers``
254+
and reading ``Tensor.numel()`` / ``element_size()`` on each
255+
is safe even while the worker thread is mutating the cache —
256+
a concurrent write produces a value somewhere between the
257+
two adjacent stable states, never garbage. The HTTP
258+
``/metrics`` handler relies on this property.
259+
260+
Returns 0 when the cache has not been allocated yet (between
261+
``reset()`` and the next ``prefill()``).
262+
"""
250263
if self.cache is None:
251-
return
264+
return 0
252265
total = 0
253266
for layer in self.cache.layers:
254267
if layer.keys is not None:
255268
total += layer.keys.numel() * layer.keys.element_size()
256269
if layer.values is not None:
257270
total += layer.values.numel() * layer.values.element_size()
271+
return total
272+
273+
def _record_peak_kv(self) -> None:
274+
total = self.live_kv_bytes()
258275
self.stats.peak_kv_bytes = max(self.stats.peak_kv_bytes, total)
259276

260277
def _record_peak_activation(self, logits: torch.Tensor) -> None:

tests/backends/mlx/test_verifier.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,24 @@ def test_record_peak_kv_handles_null_cache() -> None:
255255
assert v.stats.peak_kv_bytes == pre
256256

257257

258+
def test_live_kv_bytes_zero_before_prefill() -> None:
259+
"""The /metrics gauge must read 0 before any prefill."""
260+
v = _build_mlx_verifier()
261+
assert v.live_kv_bytes() == 0
262+
263+
264+
def test_live_kv_bytes_nonzero_after_prefill() -> None:
265+
"""During in-flight generation the gauge must read the actual
266+
bytes — this is what bench_long_session.py polls on each turn
267+
to verify the ADR 0006 §2.3 KV-bounded claim."""
268+
v = _build_mlx_verifier()
269+
v.prefill(list(range(16)))
270+
n = v.live_kv_bytes()
271+
assert n > 0
272+
# Right after prefill, peak == live.
273+
assert v.stats.peak_kv_bytes == n
274+
275+
258276
def test_record_peak_activation_grows_only() -> None:
259277
v = _build_mlx_verifier()
260278
a = mx.zeros((1, 4, 32), dtype=mx.bfloat16)

tests/core/test_verifier.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,48 @@ def test_record_peak_kv_handles_layers_with_null_kv(fresh_verifier_factory) -> N
317317
layer0.values = saved_v
318318

319319

320+
def test_live_kv_bytes_zero_before_prefill(fresh_verifier_factory) -> None:
321+
"""Before any prefill, ``live_kv_bytes()`` must read 0. Required
322+
by the /metrics scrape contract: the gauge has a stable value at
323+
process startup."""
324+
verifier = fresh_verifier_factory()
325+
assert verifier.live_kv_bytes() == 0
326+
327+
328+
def test_live_kv_bytes_nonzero_after_prefill(fresh_verifier_factory) -> None:
329+
"""After prefill the cache holds tensors; live_kv_bytes returns
330+
the sum of bytes across all layers' keys + values. This is the
331+
gauge value the bench scrapes during in-flight generation."""
332+
verifier = fresh_verifier_factory()
333+
verifier.prefill(list(range(16)))
334+
n = verifier.live_kv_bytes()
335+
assert n > 0
336+
# Round-trip: peak_kv_bytes is set from the same source so they
337+
# must agree right after prefill.
338+
assert verifier.stats.peak_kv_bytes == n
339+
340+
341+
def test_live_kv_bytes_returns_zero_when_layer_kv_is_null(
342+
fresh_verifier_factory,
343+
) -> None:
344+
"""The keys-None branch is taken on cleared layers and must not
345+
raise — live_kv_bytes simply skips them in the sum."""
346+
verifier = fresh_verifier_factory()
347+
verifier.prefill(list(range(4)))
348+
layer0 = verifier.cache.layers[0]
349+
saved_k, saved_v = layer0.keys, layer0.values
350+
layer0.keys = None
351+
layer0.values = None
352+
try:
353+
# Must not raise. Returns the sum across the *remaining*
354+
# non-null layers (potentially less than the full prefill total).
355+
n = verifier.live_kv_bytes()
356+
assert n >= 0
357+
finally:
358+
layer0.keys = saved_k
359+
layer0.values = saved_v
360+
361+
320362
def test_record_peak_activation_grows_only(fresh_verifier_factory) -> None:
321363
verifier = fresh_verifier_factory()
322364
a = torch.zeros((1, 4, 32), dtype=torch.bfloat16)

tests/inference_engine/server/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ def tokenizer(self) -> DeterministicTokenizer:
189189
def model_id_label(self) -> str:
190190
return self._model_id_label
191191

192+
def kv_state(self) -> int:
193+
"""Test double has no real KV cache — 0 by default. Tests that
194+
want to drive a non-zero gauge value override this."""
195+
return 0
196+
192197
def generate(
193198
self,
194199
prompt_ids: List[int],

tests/inference_engine/server/test_app_metrics_and_auth.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ async def test_metrics_kv_live_bytes_gauge_present_and_zero_at_idle(
9292
short_engine,
9393
):
9494
"""The KV-live-bytes gauge must be exposed and read 0 on an idle
95-
pool (every slab has logical_size == 0). This is the gauge that
96-
bench_long_session.py scrapes to verify the ADR 0006 §2.3
97-
KV-bounded claim, so its presence is part of the public contract.
95+
engine. This is the gauge that bench_long_session.py scrapes to
96+
verify the ADR 0006 §2.3 KV-bounded claim, so its presence is
97+
part of the public contract.
9898
"""
9999
app = create_app(short_engine, ServerConfig(max_concurrent=2))
100100
async with AsyncClient(transport=ASGITransport(app=app),
@@ -105,6 +105,47 @@ async def test_metrics_kv_live_bytes_gauge_present_and_zero_at_idle(
105105
assert "scheduler_kv_live_bytes 0.0" in text
106106

107107

108+
async def test_metrics_kv_live_bytes_reflects_engine_kv_state(tokenizer):
109+
"""The /metrics handler must read KV bytes from the engine on
110+
every scrape (not from the pool). This is the v0.3 wiring that
111+
makes bench_long_session.py's in-flight scrape produce a
112+
non-zero number on real hardware — without it the gauge
113+
unconditionally reads 0 because no production code path sets
114+
the slab's live_kv_bytes_override.
115+
116+
The 2026-05-30 short test #2 (results/.../bench_long_session_mac_short2_
117+
1780196477.json) recorded 7313 in-flight samples across 58 turns
118+
with pool_in_use=1 throughout, yet kv_live_bytes was 0.0 in every
119+
sample. This regression test pins the fix.
120+
"""
121+
from tests.inference_engine.server.conftest import DeterministicEngine
122+
123+
class _KVAwareEngine(DeterministicEngine):
124+
def __init__(self, *args, kv_value: int, **kwargs):
125+
super().__init__(*args, **kwargs)
126+
self._kv_value = kv_value
127+
128+
def kv_state(self) -> int:
129+
return self._kv_value
130+
131+
eos = tokenizer.eos_token_id
132+
assert eos is not None
133+
hello = tokenizer._intern("hi")
134+
eng = _KVAwareEngine(
135+
fixed_tokens=[hello, eos],
136+
tokenizer=tokenizer,
137+
model_id_label="kv-aware",
138+
kv_value=12345678,
139+
)
140+
app = create_app(eng, ServerConfig(max_concurrent=1))
141+
async with AsyncClient(transport=ASGITransport(app=app),
142+
base_url="http://t") as c:
143+
r = await c.get("/metrics")
144+
assert r.status_code == 200
145+
assert "scheduler_kv_live_bytes 1.2345678e+07" in r.text or \
146+
"scheduler_kv_live_bytes 12345678" in r.text
147+
148+
108149
# ---------------------------------------------------------------------------
109150
# OpenAI error envelope
110151
# ---------------------------------------------------------------------------

tests/inference_engine/server/test_app_streaming.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ def tokenizer(self):
350350
def model_id_label(self):
351351
return "raising"
352352

353+
def kv_state(self) -> int:
354+
return 0
355+
353356
def generate(self, prompt_ids, max_new_tokens, eos_token_ids, on_token=None):
354357
raise RuntimeError("synthetic engine failure")
355358

tests/inference_engine/server/test_app_with_scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ def tokenizer(self):
275275
def model_id_label(self):
276276
return "raising"
277277

278+
def kv_state(self) -> int:
279+
return 0
280+
278281
def generate(self, prompt_ids, max_new_tokens, eos_token_ids, on_token=None):
279282
raise RuntimeError("synthetic engine failure")
280283

0 commit comments

Comments
 (0)