Skip to content

Commit b272765

Browse files
authored
Merge pull request #78 from FluffyAIcode/AgentMemory/v04-pr-k1g-memory-tracking-8e7f
PR-K1.G: memory usage tracking for K1.E NIAH validation (stacked on #77)
2 parents 9531f89 + 7f9c905 commit b272765

4 files changed

Lines changed: 362 additions & 3 deletions

File tree

inference_engine/v04/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,15 @@
4949
NIAHSample,
5050
aggregate_recall,
5151
evaluate,
52+
format_memory_summary,
5253
greedy_decode_oracle,
5354
greedy_decode_sink_window,
5455
greedy_decode_v04,
5556
make_niah_dataset,
5657
make_sink_window_4d_mask,
5758
recall_predicate,
59+
record_memory,
60+
reset_memory_peak,
5861
)
5962

6063
__all__ = [
@@ -83,4 +86,8 @@
8386
"make_niah_dataset",
8487
"make_sink_window_4d_mask",
8588
"recall_predicate",
89+
# K1.G — memory tracking
90+
"format_memory_summary",
91+
"record_memory",
92+
"reset_memory_peak",
8693
]

inference_engine/v04/niah_eval.py

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import math
4343
import random
4444
import time
45-
from typing import Callable, List, Optional, Sequence, Tuple
45+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
4646

4747
import torch
4848
import torch.nn as nn
@@ -497,3 +497,173 @@ def evaluate(
497497
decoded_texts.append(text)
498498
latencies_s.append(latency)
499499
return aggregate_recall(name, samples, decoded_texts, latencies_s)
500+
501+
502+
# ---------------------------------------------------------------------------
503+
# Memory measurement helpers
504+
# ---------------------------------------------------------------------------
505+
#
506+
# ADR 0008 §11.5 §"Five properties" item 1 — "constant memory in
507+
# context length" — is a measurable claim, not a presumption. The
508+
# helpers below let runners record per-config peak / current memory
509+
# on the active device and emit it into the run's JSON evidence so
510+
# the constant-memory claim becomes empirically verifiable rather
511+
# than rhetorical.
512+
#
513+
# CUDA: torch.cuda.max_memory_allocated tracks the high-water mark
514+
# since the last reset. Reset before each config evaluation, sample
515+
# after, and the peak is the config's memory cost.
516+
#
517+
# MPS: torch.mps does not expose a peak counter as of torch 2.x, so
518+
# we record current_allocated and driver_allocated as point-in-time
519+
# samples. Mac runs cannot demonstrate the sustained-memory claim
520+
# with the same precision as CUDA runs but they can still show
521+
# rough magnitudes.
522+
#
523+
# CPU: optional dependency on psutil. If present, RSS is recorded;
524+
# if absent, memory fields are None and the run continues. Tests
525+
# pass psutil-less to verify graceful degradation.
526+
527+
528+
def reset_memory_peak(device: torch.device) -> None:
529+
"""Reset the device's peak-memory counter so a subsequent
530+
:func:`record_memory` capture reflects only the period after
531+
this call.
532+
533+
Idempotent. Safe to call on devices that don't track peaks
534+
(MPS, CPU); the call is a no-op there.
535+
"""
536+
if device.type == "cuda":
537+
torch.cuda.synchronize(device)
538+
torch.cuda.empty_cache()
539+
torch.cuda.reset_peak_memory_stats(device)
540+
elif device.type == "mps":
541+
# No-op: torch.mps does not expose reset_peak_memory_stats
542+
# in the current torch line. Documented limitation; the
543+
# MPS branch reports point-in-time allocations only.
544+
pass
545+
# CPU path: nothing to reset; RSS is process-level and we
546+
# baseline against a "before" snapshot in record_memory if
547+
# the caller wants per-config delta.
548+
549+
550+
def record_memory(device: torch.device) -> Dict[str, Any]:
551+
"""Capture a memory snapshot on the given device.
552+
553+
Returns a dict whose shape depends on the device kind:
554+
555+
* **cuda**: ``{
556+
"device_kind": "cuda",
557+
"current_allocated_bytes": int,
558+
"current_reserved_bytes": int,
559+
"peak_allocated_bytes": int, # since last reset
560+
"peak_reserved_bytes": int, # since last reset
561+
"device_total_bytes": int,
562+
}``
563+
* **mps**: ``{
564+
"device_kind": "mps",
565+
"current_allocated_bytes": int,
566+
"driver_allocated_bytes": int,
567+
"peak_allocated_bytes": None, # not exposed on MPS
568+
"peak_reserved_bytes": None,
569+
"device_total_bytes": None,
570+
}``
571+
* **cpu**: ``{
572+
"device_kind": "cpu",
573+
"current_allocated_bytes": int|None, # process RSS via psutil
574+
"peak_allocated_bytes": None,
575+
...
576+
}``
577+
578+
All bytes fields are ``int`` when measurable, ``None`` when the
579+
device kind doesn't expose that metric. JSON-serialisable.
580+
581+
Synchronizes the CUDA stream before sampling so async kernels
582+
have committed; MPS path doesn't currently expose a sync API for
583+
memory accounting (kernels are typically already complete when
584+
the eval loop is between samples).
585+
"""
586+
if device.type == "cuda":
587+
torch.cuda.synchronize(device)
588+
props = torch.cuda.get_device_properties(device)
589+
return {
590+
"device_kind": "cuda",
591+
"device_name": props.name,
592+
"device_total_bytes": int(props.total_memory),
593+
"current_allocated_bytes": int(torch.cuda.memory_allocated(device)),
594+
"current_reserved_bytes": int(torch.cuda.memory_reserved(device)),
595+
"peak_allocated_bytes": int(torch.cuda.max_memory_allocated(device)),
596+
"peak_reserved_bytes": int(torch.cuda.max_memory_reserved(device)),
597+
}
598+
if device.type == "mps":
599+
# torch.mps.current_allocated_memory and
600+
# torch.mps.driver_allocated_memory are stable since torch 2.0.
601+
try:
602+
current = int(torch.mps.current_allocated_memory())
603+
except Exception:
604+
current = None
605+
try:
606+
driver = int(torch.mps.driver_allocated_memory())
607+
except Exception:
608+
driver = None
609+
return {
610+
"device_kind": "mps",
611+
"device_name": "Apple MPS",
612+
"device_total_bytes": None,
613+
"current_allocated_bytes": current,
614+
"driver_allocated_bytes": driver,
615+
"peak_allocated_bytes": None,
616+
"peak_reserved_bytes": None,
617+
}
618+
# CPU or other: try psutil for process RSS.
619+
rss: Optional[int] = None
620+
try:
621+
import psutil # type: ignore
622+
rss = int(psutil.Process().memory_info().rss)
623+
except Exception:
624+
rss = None
625+
return {
626+
"device_kind": device.type,
627+
"device_name": str(device),
628+
"device_total_bytes": None,
629+
"current_allocated_bytes": rss,
630+
"peak_allocated_bytes": None,
631+
"peak_reserved_bytes": None,
632+
}
633+
634+
635+
def format_memory_summary(snapshot: Dict[str, Any]) -> str:
636+
"""Return a one-line human-readable summary of a memory snapshot.
637+
638+
Used by runners to print per-config memory at the same density
639+
as the latency / recall summary lines. Returns a string suitable
640+
for direct ``print()``-ing; callers prepend their own prefix.
641+
"""
642+
kind = snapshot.get("device_kind", "?")
643+
if kind == "cuda":
644+
peak = snapshot.get("peak_allocated_bytes")
645+
cur = snapshot.get("current_allocated_bytes")
646+
total = snapshot.get("device_total_bytes")
647+
if peak is not None and total is not None and total > 0:
648+
pct = peak / total * 100
649+
return (
650+
f"cuda peak={peak / 1e9:.2f}GB ({pct:.0f}% of "
651+
f"{total / 1e9:.0f}GB) current={cur / 1e9:.2f}GB"
652+
)
653+
return f"cuda peak={peak} current={cur}"
654+
if kind == "mps":
655+
cur = snapshot.get("current_allocated_bytes")
656+
drv = snapshot.get("driver_allocated_bytes")
657+
if cur is not None:
658+
cur_str = f"{cur / 1e9:.2f}GB"
659+
else:
660+
cur_str = "n/a"
661+
if drv is not None:
662+
drv_str = f"{drv / 1e9:.2f}GB"
663+
else:
664+
drv_str = "n/a"
665+
return f"mps current={cur_str} driver={drv_str} (no peak counter)"
666+
cur = snapshot.get("current_allocated_bytes")
667+
if cur is not None:
668+
return f"cpu rss={cur / 1e9:.2f}GB"
669+
return f"{kind} (no memory accounting available)"

scripts/research/k1e_niah_validation.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,13 @@ def main() -> int:
137137
DLMRestoredVerifier,
138138
NIAHEvalResult,
139139
evaluate,
140+
format_memory_summary,
140141
greedy_decode_oracle,
141142
greedy_decode_sink_window,
142143
greedy_decode_v04,
143144
make_niah_dataset,
145+
record_memory,
146+
reset_memory_peak,
144147
)
145148

146149
samples = make_niah_dataset(
@@ -173,7 +176,22 @@ def encode_chat(prompt_text: str) -> torch.Tensor:
173176
file=sys.stderr,
174177
)
175178

179+
# K1.G: baseline memory snapshot. Captured BEFORE any config
180+
# runs, after model + tokenizer + dataset are loaded — represents
181+
# the minimum sustained working set for this run. Per-config
182+
# peak is reported relative to this baseline so the
183+
# constant-memory claim of ADR 0008 §11.5 §"Five properties"
184+
# item 1 is empirically verifiable from the JSON evidence.
185+
reset_memory_peak(device)
186+
baseline_memory = record_memory(device)
187+
print(
188+
f"[k1e] baseline memory after model+dataset load: "
189+
f"{format_memory_summary(baseline_memory)}",
190+
file=sys.stderr,
191+
)
192+
176193
results = {}
194+
memory_per_config = {}
177195

178196
# ----------------------------------------------------------------
179197
# (a) full-attention oracle
@@ -191,14 +209,21 @@ def oracle_decode(sample) -> Tuple[str, float]:
191209
)
192210
return text, time.perf_counter() - t0
193211

212+
reset_memory_peak(device)
194213
oracle = evaluate("oracle_full_attention", samples, oracle_decode)
214+
oracle_memory = record_memory(device)
195215
results["oracle_full_attention"] = _result_to_dict(oracle)
216+
memory_per_config["oracle_full_attention"] = oracle_memory
196217
print(
197218
f"[k1e] oracle recall={oracle.recall:.3f} "
198219
f"({oracle.samples_correct}/{oracle.samples_total}) "
199220
f"mean_latency={oracle.mean_latency_s:.2f}s",
200221
file=sys.stderr,
201222
)
223+
print(
224+
f"[k1e] oracle memory: {format_memory_summary(oracle_memory)}",
225+
file=sys.stderr,
226+
)
202227

203228
# ----------------------------------------------------------------
204229
# (b) v0.3 sink+window baseline
@@ -221,14 +246,21 @@ def v03_decode(sample) -> Tuple[str, float]:
221246
)
222247
return text, time.perf_counter() - t0
223248

249+
reset_memory_peak(device)
224250
v03 = evaluate("v03_sink_window", samples, v03_decode)
251+
v03_memory = record_memory(device)
225252
results["v03_sink_window"] = _result_to_dict(v03)
253+
memory_per_config["v03_sink_window"] = v03_memory
226254
print(
227255
f"[k1e] v0.3 recall={v03.recall:.3f} "
228256
f"({v03.samples_correct}/{v03.samples_total}) "
229257
f"mean_latency={v03.mean_latency_s:.2f}s",
230258
file=sys.stderr,
231259
)
260+
print(
261+
f"[k1e] v0.3 memory: {format_memory_summary(v03_memory)}",
262+
file=sys.stderr,
263+
)
232264

233265
# ----------------------------------------------------------------
234266
# (c) v0.4 DLMRestoredVerifier
@@ -255,14 +287,21 @@ def v04_decode(sample) -> Tuple[str, float]:
255287
)
256288
return text, time.perf_counter() - t0
257289

290+
reset_memory_peak(device)
258291
v04 = evaluate("v04_dlm_restored", samples, v04_decode)
292+
v04_memory = record_memory(device)
259293
results["v04_dlm_restored"] = _result_to_dict(v04)
294+
memory_per_config["v04_dlm_restored"] = v04_memory
260295
print(
261296
f"[k1e] v0.4 recall={v04.recall:.3f} "
262297
f"({v04.samples_correct}/{v04.samples_total}) "
263298
f"mean_latency={v04.mean_latency_s:.2f}s",
264299
file=sys.stderr,
265300
)
301+
print(
302+
f"[k1e] v0.4 memory: {format_memory_summary(v04_memory)}",
303+
file=sys.stderr,
304+
)
266305

267306
# ----------------------------------------------------------------
268307
# Gate evaluation (only meaningful if both oracle and v04 ran)
@@ -283,7 +322,9 @@ def v04_decode(sample) -> Tuple[str, float]:
283322
gate["v04_dominates_v03"] = v04_recall > v03_recall
284323

285324
report = {
286-
"schema_version": 1,
325+
# schema v2: K1.G adds 'baseline_memory' and 'memory_per_config'.
326+
# v1 consumers must default the memory blocks to {} on read.
327+
"schema_version": 2,
287328
"kind": "k1e_niah_validation",
288329
"config": {
289330
"model": args.model,
@@ -302,6 +343,10 @@ def v04_decode(sample) -> Tuple[str, float]:
302343
"prompt_token_len_mean": sum(seq_lens) // len(seq_lens),
303344
},
304345
"results": results,
346+
"memory": {
347+
"baseline": baseline_memory,
348+
"per_config": memory_per_config,
349+
},
305350
"gate": gate,
306351
}
307352

@@ -316,9 +361,15 @@ def v04_decode(sample) -> Tuple[str, float]:
316361
# Top-line summary
317362
print("[k1e] ─── SUMMARY ──────────────────────────────────────", file=sys.stderr)
318363
for name, r in results.items():
364+
mem = memory_per_config.get(name, {})
365+
mem_str = ""
366+
if mem.get("device_kind") == "cuda" and mem.get("peak_allocated_bytes") is not None:
367+
mem_str = f" peak_mem={mem['peak_allocated_bytes'] / 1e9:.2f}GB"
368+
elif mem.get("device_kind") == "mps" and mem.get("current_allocated_bytes") is not None:
369+
mem_str = f" current_mem={mem['current_allocated_bytes'] / 1e9:.2f}GB"
319370
print(
320371
f"[k1e] {name:<24s} recall={r['recall']:.3f} "
321-
f"mean_latency={r['mean_latency_s']:.2f}s",
372+
f"mean_latency={r['mean_latency_s']:.2f}s{mem_str}",
322373
file=sys.stderr,
323374
)
324375
if gate:

0 commit comments

Comments
 (0)