@@ -493,6 +493,229 @@ def run():
493493 return {"job_id" : job_id }
494494
495495
496+ class MatvecBenchmarkRequest (BaseModel ):
497+ shapes : list [str ] # e.g. ["4096x4096", "2560x6912"]
498+ k_values : list [int ] = [2 , 4 , 6 , 8 , 10 ]
499+ device : str = "cpu" # "cpu" or "cuda"
500+ bit_width : str = "1.58" # "1" or "1.58"
501+ warmup : int = 5
502+ repeats : int = 20
503+
504+
505+ @app .post ("/api/benchmarks/run-matvec" )
506+ async def run_matvec_benchmark (req : MatvecBenchmarkRequest ):
507+ """Start a kernel-level matvec benchmark (background thread)."""
508+ job_id = f"bench_matvec_{ req .device } _{ int (time .time ())} "
509+
510+ def run ():
511+ import numpy as np
512+ import torch
513+
514+ with _job_lock :
515+ _jobs [job_id ] = {
516+ "status" : "running" ,
517+ "progress" : "Discovering multipliers..." ,
518+ "current" : 0 ,
519+ "total" : len (req .shapes ),
520+ }
521+
522+ try :
523+ # Parse shapes
524+ shapes = []
525+ for s in req .shapes :
526+ for sep in ("x" , "X" , "," ):
527+ if sep in s :
528+ parts = s .split (sep )
529+ shapes .append ((int (parts [0 ].strip ()), int (parts [1 ].strip ())))
530+ break
531+
532+ bit_dir = "bit_1_58" if req .bit_width == "1.58" else "bit_1"
533+ is_cuda = req .device == "cuda"
534+
535+ # --- Discover multipliers ---
536+ # Baselines (no k)
537+ pt_mod = importlib .import_module (f"multiplier.{ bit_dir } .pytorch" )
538+ baselines = []
539+ for name , obj in inspect .getmembers (pt_mod , inspect .isclass ):
540+ if obj .__module__ == pt_mod .__name__ and name .endswith ("Multiplier" ):
541+ label = name .replace ("Multiplier" , "" ).replace ("Pytorch" , "pytorch_" ).strip ("_" )
542+ if not label :
543+ label = "pytorch"
544+ baselines .append ((label , obj ))
545+ # Keep only fp32 and bf16 for brevity
546+ baselines = [
547+ (l , c ) for l , c in baselines
548+ if any (tag in l .lower () for tag in ("pytorch" , "fp32" , "bf16" ))
549+ ]
550+ if not baselines :
551+ baselines = [(name , obj ) for name , obj in baselines [:2 ]]
552+
553+ # RSR multipliers (need k)
554+ rsr_versions = []
555+ platform = "cuda" if is_cuda else "cpu"
556+ pkg_dir = _PROJECT_ROOT / "multiplier" / bit_dir / platform
557+ if pkg_dir .exists ():
558+ for py_file in sorted (pkg_dir .glob ("*.py" )):
559+ if py_file .stem .startswith ("_" ) or py_file .stem in ("__init__" , "base" ):
560+ continue
561+ module_path = f"multiplier.{ bit_dir } .{ platform } .{ py_file .stem } "
562+ try :
563+ mod = importlib .import_module (module_path )
564+ cls = next (
565+ (obj for _ , obj in inspect .getmembers (mod , inspect .isclass )
566+ if obj .__module__ == module_path and obj .__name__ .endswith ("Multiplier" )),
567+ None ,
568+ )
569+ if cls is None :
570+ continue
571+ needs_k = "k" in inspect .signature (cls .__init__ ).parameters
572+ if needs_k :
573+ rsr_versions .append ((py_file .stem , cls ))
574+ except Exception :
575+ continue
576+
577+ # Pick primary RSR version (prefer "nonsquare" or last available)
578+ primary_rsr = None
579+ for stem , cls in rsr_versions :
580+ if "nonsquare" in stem or "v2_0" in stem :
581+ primary_rsr = ("RSR" , cls )
582+ break
583+ if primary_rsr is None and rsr_versions :
584+ primary_rsr = ("RSR" , rsr_versions [- 1 ][1 ])
585+
586+ # --- Bench helpers ---
587+ def bench_cpu (multiplier , v , warmup , repeats ):
588+ for _ in range (warmup ):
589+ multiplier (v )
590+ times = []
591+ for _ in range (repeats ):
592+ t0 = time .perf_counter ()
593+ multiplier (v )
594+ t1 = time .perf_counter ()
595+ times .append (t1 - t0 )
596+ return float (np .median (times ))
597+
598+ def bench_cuda (multiplier , v , warmup , repeats ):
599+ for _ in range (warmup ):
600+ multiplier (v )
601+ torch .cuda .synchronize ()
602+ times = []
603+ for _ in range (repeats ):
604+ start_ev = torch .cuda .Event (enable_timing = True )
605+ end_ev = torch .cuda .Event (enable_timing = True )
606+ start_ev .record ()
607+ multiplier (v )
608+ end_ev .record ()
609+ torch .cuda .synchronize ()
610+ times .append (start_ev .elapsed_time (end_ev ) / 1000.0 )
611+ return float (np .median (times ))
612+
613+ bench_fn = bench_cuda if is_cuda else bench_cpu
614+
615+ # --- Run benchmarks ---
616+ results = []
617+
618+ for idx , (n_rows , n_cols ) in enumerate (shapes ):
619+ with _job_lock :
620+ _jobs [job_id ]["progress" ] = f"Benchmarking { n_rows } x{ n_cols } ..."
621+ _jobs [job_id ]["current" ] = idx
622+
623+ # Create matrix and vector
624+ if req .bit_width == "1.58" :
625+ M = torch .randint (- 1 , 2 , (n_rows , n_cols ), dtype = torch .float32 )
626+ else :
627+ M = torch .randint (0 , 2 , (n_rows , n_cols ), dtype = torch .float32 )
628+
629+ v_device = "cuda" if is_cuda else "cpu"
630+ v = torch .randn (n_cols , dtype = torch .float32 , device = v_device )
631+
632+ # Baseline timings
633+ baseline_results = {}
634+ for label , cls in baselines :
635+ try :
636+ m_input = M .cuda () if is_cuda else M
637+ mul = cls (m_input )
638+ t = bench_fn (mul , v , req .warmup , req .repeats )
639+ baseline_results [label ] = round (t * 1e3 , 4 )
640+ except Exception :
641+ baseline_results [label ] = None
642+
643+ # RSR per k
644+ if primary_rsr :
645+ rsr_label , rsr_cls = primary_rsr
646+ for k in req .k_values :
647+ if n_rows % k != 0 :
648+ continue
649+ try :
650+ mul = rsr_cls (M , k )
651+ t = bench_fn (mul , v , req .warmup , req .repeats )
652+ rsr_ms = round (t * 1e3 , 4 )
653+ # Pick a reference baseline for speedup
654+ ref_key = next (
655+ (key for key in ("pytorch_BF16" , "pytorch_bf16" , "pytorch" )
656+ if key in baseline_results and baseline_results [key ] is not None ),
657+ None ,
658+ )
659+ fp32_key = next (
660+ (key for key in ("pytorch" , "pytorch_FP32" , "pytorch_fp32" )
661+ if key in baseline_results and baseline_results [key ] is not None ),
662+ None ,
663+ )
664+ row = {
665+ "shape" : f"{ n_rows } x{ n_cols } " ,
666+ "n_rows" : n_rows ,
667+ "n_cols" : n_cols ,
668+ "k" : k ,
669+ "rsr_ms" : rsr_ms ,
670+ }
671+ # Attach all baselines
672+ for bl , val in baseline_results .items ():
673+ row [f"{ bl } _ms" ] = val
674+ # Compute speedups
675+ if fp32_key and baseline_results [fp32_key ]:
676+ row ["fp32_ms" ] = baseline_results [fp32_key ]
677+ row ["speedup_vs_fp32" ] = round (baseline_results [fp32_key ] / rsr_ms , 3 )
678+ if ref_key and baseline_results [ref_key ]:
679+ row ["bf16_ms" ] = baseline_results [ref_key ]
680+ row ["speedup_vs_bf16" ] = round (baseline_results [ref_key ] / rsr_ms , 3 )
681+ results .append (row )
682+ except Exception as e :
683+ results .append ({
684+ "shape" : f"{ n_rows } x{ n_cols } " ,
685+ "n_rows" : n_rows ,
686+ "n_cols" : n_cols ,
687+ "k" : k ,
688+ "error" : str (e ),
689+ })
690+
691+ # Clean up
692+ del M
693+ if is_cuda :
694+ torch .cuda .empty_cache ()
695+ gc .collect ()
696+
697+ with _job_lock :
698+ _jobs [job_id ] = {
699+ "status" : "completed" ,
700+ "results" : results ,
701+ "current" : len (shapes ),
702+ "total" : len (shapes ),
703+ }
704+
705+ except Exception as e :
706+ import traceback
707+ with _job_lock :
708+ _jobs [job_id ] = {
709+ "status" : "error" ,
710+ "progress" : str (e ),
711+ "traceback" : traceback .format_exc (),
712+ }
713+
714+ thread = threading .Thread (target = run , daemon = True )
715+ thread .start ()
716+ return {"job_id" : job_id }
717+
718+
496719@app .get ("/api/benchmarks/job/{job_id}" )
497720async def get_benchmark_status (job_id : str ):
498721 """Check benchmark job status."""
0 commit comments