Skip to content

Commit c0fb18a

Browse files
Update README
1 parent 03be222 commit c0fb18a

4 files changed

Lines changed: 504 additions & 0 deletions

File tree

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ CLI args for integrations/hf/model_prep.py:
6262
benchmarking/bit_1_58/reports/best_k_{device}.json
6363
```
6464

65+
> [!NOTE]
66+
> `k` is hardware-dependent, so run the `best_k` benchmark on the same machine
67+
> and device you plan to use for inference, then reuse the generated JSON. If
68+
> no `best_k_{device}.json` is found, `model_prep.py` falls back to `--k`.
69+
6570
### Run model inference 🤖
6671
Use `integrations/hf/model_infer.py` to run generation from a preprocessed
6772
model directory. The default backend is `rsr`.
@@ -201,6 +206,10 @@ frontend on `http://localhost:5173`. Press `Ctrl+C` to stop both.
201206

202207
## Benchmark Results 📊
203208

209+
> [!NOTE]
210+
> The results below were measured on a machine with Python 3.12.7, PyTorch
211+
> 2.10.0+cu128, an NVIDIA GeForce RTX 5090, and a 64-logical-thread CPU.
212+
204213
### Matrix-Vector Multiplication 🧮
205214

206215
#### CPU 🖥️

ui/backend/main.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,229 @@ def run():
493493
return {"job_id": job_id}
494494

495495

496+
class MatvecBenchmarkRequest(BaseModel):
497+
shapes: list[str] # e.g. ["4096x4096", "2560x6912"]
498+
k_values: list[int] = [2, 4, 6, 8, 10]
499+
device: str = "cpu" # "cpu" or "cuda"
500+
bit_width: str = "1.58" # "1" or "1.58"
501+
warmup: int = 5
502+
repeats: int = 20
503+
504+
505+
@app.post("/api/benchmarks/run-matvec")
506+
async def run_matvec_benchmark(req: MatvecBenchmarkRequest):
507+
"""Start a kernel-level matvec benchmark (background thread)."""
508+
job_id = f"bench_matvec_{req.device}_{int(time.time())}"
509+
510+
def run():
511+
import numpy as np
512+
import torch
513+
514+
with _job_lock:
515+
_jobs[job_id] = {
516+
"status": "running",
517+
"progress": "Discovering multipliers...",
518+
"current": 0,
519+
"total": len(req.shapes),
520+
}
521+
522+
try:
523+
# Parse shapes
524+
shapes = []
525+
for s in req.shapes:
526+
for sep in ("x", "X", ","):
527+
if sep in s:
528+
parts = s.split(sep)
529+
shapes.append((int(parts[0].strip()), int(parts[1].strip())))
530+
break
531+
532+
bit_dir = "bit_1_58" if req.bit_width == "1.58" else "bit_1"
533+
is_cuda = req.device == "cuda"
534+
535+
# --- Discover multipliers ---
536+
# Baselines (no k)
537+
pt_mod = importlib.import_module(f"multiplier.{bit_dir}.pytorch")
538+
baselines = []
539+
for name, obj in inspect.getmembers(pt_mod, inspect.isclass):
540+
if obj.__module__ == pt_mod.__name__ and name.endswith("Multiplier"):
541+
label = name.replace("Multiplier", "").replace("Pytorch", "pytorch_").strip("_")
542+
if not label:
543+
label = "pytorch"
544+
baselines.append((label, obj))
545+
# Keep only fp32 and bf16 for brevity
546+
baselines = [
547+
(l, c) for l, c in baselines
548+
if any(tag in l.lower() for tag in ("pytorch", "fp32", "bf16"))
549+
]
550+
if not baselines:
551+
baselines = [(name, obj) for name, obj in baselines[:2]]
552+
553+
# RSR multipliers (need k)
554+
rsr_versions = []
555+
platform = "cuda" if is_cuda else "cpu"
556+
pkg_dir = _PROJECT_ROOT / "multiplier" / bit_dir / platform
557+
if pkg_dir.exists():
558+
for py_file in sorted(pkg_dir.glob("*.py")):
559+
if py_file.stem.startswith("_") or py_file.stem in ("__init__", "base"):
560+
continue
561+
module_path = f"multiplier.{bit_dir}.{platform}.{py_file.stem}"
562+
try:
563+
mod = importlib.import_module(module_path)
564+
cls = next(
565+
(obj for _, obj in inspect.getmembers(mod, inspect.isclass)
566+
if obj.__module__ == module_path and obj.__name__.endswith("Multiplier")),
567+
None,
568+
)
569+
if cls is None:
570+
continue
571+
needs_k = "k" in inspect.signature(cls.__init__).parameters
572+
if needs_k:
573+
rsr_versions.append((py_file.stem, cls))
574+
except Exception:
575+
continue
576+
577+
# Pick primary RSR version (prefer "nonsquare" or last available)
578+
primary_rsr = None
579+
for stem, cls in rsr_versions:
580+
if "nonsquare" in stem or "v2_0" in stem:
581+
primary_rsr = ("RSR", cls)
582+
break
583+
if primary_rsr is None and rsr_versions:
584+
primary_rsr = ("RSR", rsr_versions[-1][1])
585+
586+
# --- Bench helpers ---
587+
def bench_cpu(multiplier, v, warmup, repeats):
588+
for _ in range(warmup):
589+
multiplier(v)
590+
times = []
591+
for _ in range(repeats):
592+
t0 = time.perf_counter()
593+
multiplier(v)
594+
t1 = time.perf_counter()
595+
times.append(t1 - t0)
596+
return float(np.median(times))
597+
598+
def bench_cuda(multiplier, v, warmup, repeats):
599+
for _ in range(warmup):
600+
multiplier(v)
601+
torch.cuda.synchronize()
602+
times = []
603+
for _ in range(repeats):
604+
start_ev = torch.cuda.Event(enable_timing=True)
605+
end_ev = torch.cuda.Event(enable_timing=True)
606+
start_ev.record()
607+
multiplier(v)
608+
end_ev.record()
609+
torch.cuda.synchronize()
610+
times.append(start_ev.elapsed_time(end_ev) / 1000.0)
611+
return float(np.median(times))
612+
613+
bench_fn = bench_cuda if is_cuda else bench_cpu
614+
615+
# --- Run benchmarks ---
616+
results = []
617+
618+
for idx, (n_rows, n_cols) in enumerate(shapes):
619+
with _job_lock:
620+
_jobs[job_id]["progress"] = f"Benchmarking {n_rows}x{n_cols}..."
621+
_jobs[job_id]["current"] = idx
622+
623+
# Create matrix and vector
624+
if req.bit_width == "1.58":
625+
M = torch.randint(-1, 2, (n_rows, n_cols), dtype=torch.float32)
626+
else:
627+
M = torch.randint(0, 2, (n_rows, n_cols), dtype=torch.float32)
628+
629+
v_device = "cuda" if is_cuda else "cpu"
630+
v = torch.randn(n_cols, dtype=torch.float32, device=v_device)
631+
632+
# Baseline timings
633+
baseline_results = {}
634+
for label, cls in baselines:
635+
try:
636+
m_input = M.cuda() if is_cuda else M
637+
mul = cls(m_input)
638+
t = bench_fn(mul, v, req.warmup, req.repeats)
639+
baseline_results[label] = round(t * 1e3, 4)
640+
except Exception:
641+
baseline_results[label] = None
642+
643+
# RSR per k
644+
if primary_rsr:
645+
rsr_label, rsr_cls = primary_rsr
646+
for k in req.k_values:
647+
if n_rows % k != 0:
648+
continue
649+
try:
650+
mul = rsr_cls(M, k)
651+
t = bench_fn(mul, v, req.warmup, req.repeats)
652+
rsr_ms = round(t * 1e3, 4)
653+
# Pick a reference baseline for speedup
654+
ref_key = next(
655+
(key for key in ("pytorch_BF16", "pytorch_bf16", "pytorch")
656+
if key in baseline_results and baseline_results[key] is not None),
657+
None,
658+
)
659+
fp32_key = next(
660+
(key for key in ("pytorch", "pytorch_FP32", "pytorch_fp32")
661+
if key in baseline_results and baseline_results[key] is not None),
662+
None,
663+
)
664+
row = {
665+
"shape": f"{n_rows}x{n_cols}",
666+
"n_rows": n_rows,
667+
"n_cols": n_cols,
668+
"k": k,
669+
"rsr_ms": rsr_ms,
670+
}
671+
# Attach all baselines
672+
for bl, val in baseline_results.items():
673+
row[f"{bl}_ms"] = val
674+
# Compute speedups
675+
if fp32_key and baseline_results[fp32_key]:
676+
row["fp32_ms"] = baseline_results[fp32_key]
677+
row["speedup_vs_fp32"] = round(baseline_results[fp32_key] / rsr_ms, 3)
678+
if ref_key and baseline_results[ref_key]:
679+
row["bf16_ms"] = baseline_results[ref_key]
680+
row["speedup_vs_bf16"] = round(baseline_results[ref_key] / rsr_ms, 3)
681+
results.append(row)
682+
except Exception as e:
683+
results.append({
684+
"shape": f"{n_rows}x{n_cols}",
685+
"n_rows": n_rows,
686+
"n_cols": n_cols,
687+
"k": k,
688+
"error": str(e),
689+
})
690+
691+
# Clean up
692+
del M
693+
if is_cuda:
694+
torch.cuda.empty_cache()
695+
gc.collect()
696+
697+
with _job_lock:
698+
_jobs[job_id] = {
699+
"status": "completed",
700+
"results": results,
701+
"current": len(shapes),
702+
"total": len(shapes),
703+
}
704+
705+
except Exception as e:
706+
import traceback
707+
with _job_lock:
708+
_jobs[job_id] = {
709+
"status": "error",
710+
"progress": str(e),
711+
"traceback": traceback.format_exc(),
712+
}
713+
714+
thread = threading.Thread(target=run, daemon=True)
715+
thread.start()
716+
return {"job_id": job_id}
717+
718+
496719
@app.get("/api/benchmarks/job/{job_id}")
497720
async def get_benchmark_status(job_id: str):
498721
"""Check benchmark job status."""

ui/frontend/src/api.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ export const getShapesResults = (category, device) =>
5858
export const runBenchmark = (body) =>
5959
request("/benchmarks/run", { method: "POST", body: JSON.stringify(body) });
6060
export const getBenchmarkJob = (jobId) => request(`/benchmarks/job/${jobId}`);
61+
export const runMatvecBenchmark = (body) =>
62+
request("/benchmarks/run-matvec", { method: "POST", body: JSON.stringify(body) });
6163

6264
// System
6365
export const getSystemInfo = () => request("/system");

0 commit comments

Comments
 (0)