|
42 | 42 | import math |
43 | 43 | import random |
44 | 44 | import time |
45 | | -from typing import Callable, List, Optional, Sequence, Tuple |
| 45 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple |
46 | 46 |
|
47 | 47 | import torch |
48 | 48 | import torch.nn as nn |
@@ -497,3 +497,173 @@ def evaluate( |
497 | 497 | decoded_texts.append(text) |
498 | 498 | latencies_s.append(latency) |
499 | 499 | return aggregate_recall(name, samples, decoded_texts, latencies_s) |
| 500 | + |
| 501 | + |
| 502 | +# --------------------------------------------------------------------------- |
| 503 | +# Memory measurement helpers |
| 504 | +# --------------------------------------------------------------------------- |
| 505 | +# |
| 506 | +# ADR 0008 §11.5 §"Five properties" item 1 — "constant memory in |
| 507 | +# context length" — is a measurable claim, not a presumption. The |
| 508 | +# helpers below let runners record per-config peak / current memory |
| 509 | +# on the active device and emit it into the run's JSON evidence so |
| 510 | +# the constant-memory claim becomes empirically verifiable rather |
| 511 | +# than rhetorical. |
| 512 | +# |
| 513 | +# CUDA: torch.cuda.max_memory_allocated tracks the high-water mark |
| 514 | +# since the last reset. Reset before each config evaluation, sample |
| 515 | +# after, and the peak is the config's memory cost. |
| 516 | +# |
| 517 | +# MPS: torch.mps does not expose a peak counter as of torch 2.x, so |
| 518 | +# we record current_allocated and driver_allocated as point-in-time |
| 519 | +# samples. Mac runs cannot demonstrate the sustained-memory claim |
| 520 | +# with the same precision as CUDA runs but they can still show |
| 521 | +# rough magnitudes. |
| 522 | +# |
| 523 | +# CPU: optional dependency on psutil. If present, RSS is recorded; |
| 524 | +# if absent, memory fields are None and the run continues. Tests |
| 525 | +# pass psutil-less to verify graceful degradation. |
| 526 | + |
| 527 | + |
| 528 | +def reset_memory_peak(device: torch.device) -> None: |
| 529 | + """Reset the device's peak-memory counter so a subsequent |
| 530 | + :func:`record_memory` capture reflects only the period after |
| 531 | + this call. |
| 532 | +
|
| 533 | + Idempotent. Safe to call on devices that don't track peaks |
| 534 | + (MPS, CPU); the call is a no-op there. |
| 535 | + """ |
| 536 | + if device.type == "cuda": |
| 537 | + torch.cuda.synchronize(device) |
| 538 | + torch.cuda.empty_cache() |
| 539 | + torch.cuda.reset_peak_memory_stats(device) |
| 540 | + elif device.type == "mps": |
| 541 | + # No-op: torch.mps does not expose reset_peak_memory_stats |
| 542 | + # in the current torch line. Documented limitation; the |
| 543 | + # MPS branch reports point-in-time allocations only. |
| 544 | + pass |
| 545 | + # CPU path: nothing to reset; RSS is process-level and we |
| 546 | + # baseline against a "before" snapshot in record_memory if |
| 547 | + # the caller wants per-config delta. |
| 548 | + |
| 549 | + |
| 550 | +def record_memory(device: torch.device) -> Dict[str, Any]: |
| 551 | + """Capture a memory snapshot on the given device. |
| 552 | +
|
| 553 | + Returns a dict whose shape depends on the device kind: |
| 554 | +
|
| 555 | + * **cuda**: ``{ |
| 556 | + "device_kind": "cuda", |
| 557 | + "current_allocated_bytes": int, |
| 558 | + "current_reserved_bytes": int, |
| 559 | + "peak_allocated_bytes": int, # since last reset |
| 560 | + "peak_reserved_bytes": int, # since last reset |
| 561 | + "device_total_bytes": int, |
| 562 | + }`` |
| 563 | + * **mps**: ``{ |
| 564 | + "device_kind": "mps", |
| 565 | + "current_allocated_bytes": int, |
| 566 | + "driver_allocated_bytes": int, |
| 567 | + "peak_allocated_bytes": None, # not exposed on MPS |
| 568 | + "peak_reserved_bytes": None, |
| 569 | + "device_total_bytes": None, |
| 570 | + }`` |
| 571 | + * **cpu**: ``{ |
| 572 | + "device_kind": "cpu", |
| 573 | + "current_allocated_bytes": int|None, # process RSS via psutil |
| 574 | + "peak_allocated_bytes": None, |
| 575 | + ... |
| 576 | + }`` |
| 577 | +
|
| 578 | + All bytes fields are ``int`` when measurable, ``None`` when the |
| 579 | + device kind doesn't expose that metric. JSON-serialisable. |
| 580 | +
|
| 581 | + Synchronizes the CUDA stream before sampling so async kernels |
| 582 | + have committed; MPS path doesn't currently expose a sync API for |
| 583 | + memory accounting (kernels are typically already complete when |
| 584 | + the eval loop is between samples). |
| 585 | + """ |
| 586 | + if device.type == "cuda": |
| 587 | + torch.cuda.synchronize(device) |
| 588 | + props = torch.cuda.get_device_properties(device) |
| 589 | + return { |
| 590 | + "device_kind": "cuda", |
| 591 | + "device_name": props.name, |
| 592 | + "device_total_bytes": int(props.total_memory), |
| 593 | + "current_allocated_bytes": int(torch.cuda.memory_allocated(device)), |
| 594 | + "current_reserved_bytes": int(torch.cuda.memory_reserved(device)), |
| 595 | + "peak_allocated_bytes": int(torch.cuda.max_memory_allocated(device)), |
| 596 | + "peak_reserved_bytes": int(torch.cuda.max_memory_reserved(device)), |
| 597 | + } |
| 598 | + if device.type == "mps": |
| 599 | + # torch.mps.current_allocated_memory and |
| 600 | + # torch.mps.driver_allocated_memory are stable since torch 2.0. |
| 601 | + try: |
| 602 | + current = int(torch.mps.current_allocated_memory()) |
| 603 | + except Exception: |
| 604 | + current = None |
| 605 | + try: |
| 606 | + driver = int(torch.mps.driver_allocated_memory()) |
| 607 | + except Exception: |
| 608 | + driver = None |
| 609 | + return { |
| 610 | + "device_kind": "mps", |
| 611 | + "device_name": "Apple MPS", |
| 612 | + "device_total_bytes": None, |
| 613 | + "current_allocated_bytes": current, |
| 614 | + "driver_allocated_bytes": driver, |
| 615 | + "peak_allocated_bytes": None, |
| 616 | + "peak_reserved_bytes": None, |
| 617 | + } |
| 618 | + # CPU or other: try psutil for process RSS. |
| 619 | + rss: Optional[int] = None |
| 620 | + try: |
| 621 | + import psutil # type: ignore |
| 622 | + rss = int(psutil.Process().memory_info().rss) |
| 623 | + except Exception: |
| 624 | + rss = None |
| 625 | + return { |
| 626 | + "device_kind": device.type, |
| 627 | + "device_name": str(device), |
| 628 | + "device_total_bytes": None, |
| 629 | + "current_allocated_bytes": rss, |
| 630 | + "peak_allocated_bytes": None, |
| 631 | + "peak_reserved_bytes": None, |
| 632 | + } |
| 633 | + |
| 634 | + |
| 635 | +def format_memory_summary(snapshot: Dict[str, Any]) -> str: |
| 636 | + """Return a one-line human-readable summary of a memory snapshot. |
| 637 | +
|
| 638 | + Used by runners to print per-config memory at the same density |
| 639 | + as the latency / recall summary lines. Returns a string suitable |
| 640 | + for direct ``print()``-ing; callers prepend their own prefix. |
| 641 | + """ |
| 642 | + kind = snapshot.get("device_kind", "?") |
| 643 | + if kind == "cuda": |
| 644 | + peak = snapshot.get("peak_allocated_bytes") |
| 645 | + cur = snapshot.get("current_allocated_bytes") |
| 646 | + total = snapshot.get("device_total_bytes") |
| 647 | + if peak is not None and total is not None and total > 0: |
| 648 | + pct = peak / total * 100 |
| 649 | + return ( |
| 650 | + f"cuda peak={peak / 1e9:.2f}GB ({pct:.0f}% of " |
| 651 | + f"{total / 1e9:.0f}GB) current={cur / 1e9:.2f}GB" |
| 652 | + ) |
| 653 | + return f"cuda peak={peak} current={cur}" |
| 654 | + if kind == "mps": |
| 655 | + cur = snapshot.get("current_allocated_bytes") |
| 656 | + drv = snapshot.get("driver_allocated_bytes") |
| 657 | + if cur is not None: |
| 658 | + cur_str = f"{cur / 1e9:.2f}GB" |
| 659 | + else: |
| 660 | + cur_str = "n/a" |
| 661 | + if drv is not None: |
| 662 | + drv_str = f"{drv / 1e9:.2f}GB" |
| 663 | + else: |
| 664 | + drv_str = "n/a" |
| 665 | + return f"mps current={cur_str} driver={drv_str} (no peak counter)" |
| 666 | + cur = snapshot.get("current_allocated_bytes") |
| 667 | + if cur is not None: |
| 668 | + return f"cpu rss={cur / 1e9:.2f}GB" |
| 669 | + return f"{kind} (no memory accounting available)" |
0 commit comments