2121# Optional imports for visualization
2222try :
2323 import matplotlib .pyplot as plt
24+
2425 HAS_MATPLOTLIB = True
2526except ImportError :
2627 HAS_MATPLOTLIB = False
2728
2829try :
2930 import numpy as np
31+
3032 HAS_NUMPY = True
3133except ImportError :
3234 HAS_NUMPY = False
3537@dataclass
3638class BenchmarkResult :
3739 """Container for benchmark results."""
40+
3841 kernel : str
3942 hpc_ms : float
4043 baseline_ms : float
@@ -48,6 +51,7 @@ class BenchmarkResult:
4851@dataclass
4952class DeviceInfo :
5053 """GPU device information."""
54+
5155 name : str
5256 compute_capability : Tuple [int , int ]
5357 total_memory_gb : float
@@ -99,7 +103,7 @@ def benchmark_kernel(
99103 min_run_time : float = 1.0 ,
100104 flops : Optional [int ] = None ,
101105 bytes_accessed : Optional [int ] = None ,
102- ** kwargs
106+ ** kwargs ,
103107) -> BenchmarkResult :
104108 """
105109 Compare HPC kernel with baseline implementation.
@@ -128,14 +132,14 @@ def benchmark_kernel(
128132 # Benchmark HPC kernel
129133 hpc_timer = Timer (
130134 stmt = "hpc_fn(*args, **kwargs)" ,
131- globals = {"hpc_fn" : hpc_fn , "args" : args , "kwargs" : kwargs }
135+ globals = {"hpc_fn" : hpc_fn , "args" : args , "kwargs" : kwargs },
132136 )
133137 hpc_result = hpc_timer .blocked_autorange (min_run_time = min_run_time )
134138
135139 # Benchmark baseline
136140 baseline_timer = Timer (
137141 stmt = "baseline_fn(*args, **kwargs)" ,
138- globals = {"baseline_fn" : baseline_fn , "args" : args , "kwargs" : kwargs }
142+ globals = {"baseline_fn" : baseline_fn , "args" : args , "kwargs" : kwargs },
139143 )
140144 baseline_result = baseline_timer .blocked_autorange (min_run_time = min_run_time )
141145
@@ -201,7 +205,9 @@ def analyze(self, result: BenchmarkResult) -> Dict[str, Any]:
201205 achieved_tflops = result .tflops
202206
203207 # Ridge point: where compute and memory rooflines meet
204- ridge_point = self .device_info .peak_fp32_tflops / self .device_info .peak_bandwidth_gb_s
208+ ridge_point = (
209+ self .device_info .peak_fp32_tflops / self .device_info .peak_bandwidth_gb_s
210+ )
205211
206212 # Determine bottleneck
207213 if ai < ridge_point :
@@ -228,7 +234,7 @@ def plot_roofline(
228234 self ,
229235 results : List [BenchmarkResult ],
230236 output_path : str = "roofline.png" ,
231- title : str = "Roofline Analysis"
237+ title : str = "Roofline Analysis" ,
232238 ):
233239 """Generate roofline plot for multiple kernels."""
234240 if not HAS_MATPLOTLIB or not HAS_NUMPY :
@@ -252,9 +258,11 @@ def plot_roofline(
252258 roofline = np .minimum (memory_roof , compute_roof )
253259
254260 # Plot roofline
255- ax .loglog (ai_range , roofline , 'b-' , linewidth = 2 , label = 'Roofline' )
256- ax .loglog (ai_range , memory_roof , 'b--' , alpha = 0.5 , label = 'Memory Bound' )
257- ax .axhline (y = peak_compute , color = 'b' , linestyle = ':' , alpha = 0.5 , label = 'Compute Bound' )
261+ ax .loglog (ai_range , roofline , "b-" , linewidth = 2 , label = "Roofline" )
262+ ax .loglog (ai_range , memory_roof , "b--" , alpha = 0.5 , label = "Memory Bound" )
263+ ax .axhline (
264+ y = peak_compute , color = "b" , linestyle = ":" , alpha = 0.5 , label = "Compute Bound"
265+ )
258266
259267 # Plot kernel results
260268 colors = plt .cm .tab10 (np .linspace (0 , 1 , len (results )))
@@ -265,25 +273,25 @@ def plot_roofline(
265273 result .tflops ,
266274 s = 200 ,
267275 c = [color ],
268- marker = 'o' ,
276+ marker = "o" ,
269277 label = result .kernel ,
270- zorder = 5
278+ zorder = 5 ,
271279 )
272280
273281 # Ridge point
274282 ridge_point = peak_compute / peak_bandwidth
275- ax .axvline (x = ridge_point , color = ' gray' , linestyle = '--' , alpha = 0.5 )
283+ ax .axvline (x = ridge_point , color = " gray" , linestyle = "--" , alpha = 0.5 )
276284 ax .annotate (
277- f' Ridge Point\n ({ ridge_point :.1f} FLOP/B)' ,
285+ f" Ridge Point\n ({ ridge_point :.1f} FLOP/B)" ,
278286 xy = (ridge_point , peak_compute * 0.5 ),
279287 fontsize = 9 ,
280- ha = ' center'
288+ ha = " center" ,
281289 )
282290
283- ax .set_xlabel (' Arithmetic Intensity (FLOP/Byte)' , fontsize = 12 )
284- ax .set_ylabel (' Performance (TFLOPS)' , fontsize = 12 )
285- ax .set_title (f' { title } \n { self .device_info .name } ' , fontsize = 14 )
286- ax .legend (loc = ' lower right' )
291+ ax .set_xlabel (" Arithmetic Intensity (FLOP/Byte)" , fontsize = 12 )
292+ ax .set_ylabel (" Performance (TFLOPS)" , fontsize = 12 )
293+ ax .set_title (f" { title } \n { self .device_info .name } " , fontsize = 14 )
294+ ax .legend (loc = " lower right" )
287295 ax .grid (True , alpha = 0.3 )
288296 ax .set_xlim (0.01 , 10000 )
289297 ax .set_ylim (0.01 , peak_compute * 2 )
@@ -294,13 +302,17 @@ def plot_roofline(
294302 print (f"Roofline plot saved to { output_path } " )
295303
296304
297- def print_results (results : List [BenchmarkResult ], device_info : Optional [DeviceInfo ] = None ):
305+ def print_results (
306+ results : List [BenchmarkResult ], device_info : Optional [DeviceInfo ] = None
307+ ):
298308 """Print benchmark results in a formatted table."""
299309 print ("\n " + "=" * 90 )
300310 if device_info :
301311 print (f"Device: { device_info .name } " )
302- print (f"Peak FP32: { device_info .peak_fp32_tflops :.1f} TFLOPS | "
303- f"Peak Bandwidth: { device_info .peak_bandwidth_gb_s :.0f} GB/s" )
312+ print (
313+ f"Peak FP32: { device_info .peak_fp32_tflops :.1f} TFLOPS | "
314+ f"Peak Bandwidth: { device_info .peak_bandwidth_gb_s :.0f} GB/s"
315+ )
304316 print ("=" * 90 )
305317
306318 header = f"{ 'Kernel' :<25} { 'HPC (ms)' :<10} { 'Base (ms)' :<10} { 'Speedup' :<10} "
@@ -325,7 +337,7 @@ def generate_html_report(
325337 results : List [BenchmarkResult ],
326338 device_info : DeviceInfo ,
327339 output_path : str = "benchmark_report.html" ,
328- roofline_image : Optional [str ] = None
340+ roofline_image : Optional [str ] = None ,
329341):
330342 """Generate HTML benchmark report."""
331343 timestamp = datetime .now ().strftime ("%Y-%m-%d %H:%M:%S" )
@@ -415,15 +427,15 @@ def generate_html_report(
415427</html>
416428"""
417429
418- with open (output_path , 'w' ) as f :
430+ with open (output_path , "w" ) as f :
419431 f .write (html )
420432 print (f"HTML report saved to { output_path } " )
421433
422434
423435def plot_speedup_chart (
424436 results : List [BenchmarkResult ],
425437 output_path : str = "speedup_chart.png" ,
426- title : str = "Kernel Speedup vs Baseline"
438+ title : str = "Kernel Speedup vs Baseline" ,
427439):
428440 """Generate speedup bar chart."""
429441 if not HAS_MATPLOTLIB :
@@ -434,30 +446,30 @@ def plot_speedup_chart(
434446
435447 kernels = [r .kernel for r in results ]
436448 speedups = [r .speedup for r in results ]
437- colors = [' #4CAF50' if s >= 1.0 else ' #f44336' for s in speedups ]
449+ colors = [" #4CAF50" if s >= 1.0 else " #f44336" for s in speedups ]
438450
439451 bars = ax .bar (kernels , speedups , color = colors )
440452
441453 # Add value labels
442454 for bar , speedup in zip (bars , speedups ):
443455 height = bar .get_height ()
444456 ax .annotate (
445- f' { speedup :.2f} x' ,
457+ f" { speedup :.2f} x" ,
446458 xy = (bar .get_x () + bar .get_width () / 2 , height ),
447459 xytext = (0 , 3 ),
448460 textcoords = "offset points" ,
449- ha = ' center' ,
450- va = ' bottom' ,
451- fontsize = 10
461+ ha = " center" ,
462+ va = " bottom" ,
463+ fontsize = 10 ,
452464 )
453465
454- ax .axhline (y = 1.0 , color = ' gray' , linestyle = '--' , alpha = 0.7 , label = ' Baseline' )
455- ax .set_xlabel (' Kernel' , fontsize = 12 )
456- ax .set_ylabel (' Speedup' , fontsize = 12 )
466+ ax .axhline (y = 1.0 , color = " gray" , linestyle = "--" , alpha = 0.7 , label = " Baseline" )
467+ ax .set_xlabel (" Kernel" , fontsize = 12 )
468+ ax .set_ylabel (" Speedup" , fontsize = 12 )
457469 ax .set_title (title , fontsize = 14 )
458470 ax .legend ()
459471
460- plt .xticks (rotation = 45 , ha = ' right' )
472+ plt .xticks (rotation = 45 , ha = " right" )
461473 plt .tight_layout ()
462474 plt .savefig (output_path , dpi = 150 )
463475 plt .close ()
@@ -474,16 +486,26 @@ def main():
474486 python benchmark.py --suite gemm --sizes 1024,2048,4096
475487 python benchmark.py --suite all --output results.json --html report.html
476488 python benchmark.py --roofline --output roofline.png
477- """
489+ """ ,
490+ )
491+ parser .add_argument (
492+ "--suite" ,
493+ type = str ,
494+ default = "all" ,
495+ choices = ["all" , "gemm" , "elementwise" , "reduction" , "attention" ],
496+ help = "Benchmark suite to run" ,
497+ )
498+ parser .add_argument (
499+ "--sizes" ,
500+ type = str ,
501+ default = "1024,2048,4096" ,
502+ help = "Comma-separated list of sizes to benchmark" ,
478503 )
479- parser .add_argument ("--suite" , type = str , default = "all" ,
480- choices = ["all" , "gemm" , "elementwise" , "reduction" , "attention" ],
481- help = "Benchmark suite to run" )
482- parser .add_argument ("--sizes" , type = str , default = "1024,2048,4096" ,
483- help = "Comma-separated list of sizes to benchmark" )
484504 parser .add_argument ("--output" , type = str , help = "Output JSON file for results" )
485505 parser .add_argument ("--html" , type = str , help = "Output HTML report file" )
486- parser .add_argument ("--roofline" , action = "store_true" , help = "Generate roofline plot" )
506+ parser .add_argument (
507+ "--roofline" , action = "store_true" , help = "Generate roofline plot"
508+ )
487509 parser .add_argument ("--chart" , action = "store_true" , help = "Generate speedup chart" )
488510 args = parser .parse_args ()
489511
@@ -494,7 +516,9 @@ def main():
494516 # Get device info
495517 device_info = get_device_info ()
496518 print (f"\n Device: { device_info .name } " )
497- print (f"Compute Capability: { device_info .compute_capability [0 ]} .{ device_info .compute_capability [1 ]} " )
519+ print (
520+ f"Compute Capability: { device_info .compute_capability [0 ]} .{ device_info .compute_capability [1 ]} "
521+ )
498522 print (f"Peak FP32: { device_info .peak_fp32_tflops :.1f} TFLOPS" )
499523 print (f"Peak Bandwidth: { device_info .peak_bandwidth_gb_s :.0f} GB/s" )
500524
@@ -515,12 +539,16 @@ def main():
515539 print_results (results , device_info )
516540
517541 if args .output :
518- with open (args .output , 'w' ) as f :
519- json .dump ({
520- "device" : asdict (device_info ),
521- "results" : [asdict (r ) for r in results ],
522- "timestamp" : datetime .now ().isoformat ()
523- }, f , indent = 2 )
542+ with open (args .output , "w" ) as f :
543+ json .dump (
544+ {
545+ "device" : asdict (device_info ),
546+ "results" : [asdict (r ) for r in results ],
547+ "timestamp" : datetime .now ().isoformat (),
548+ },
549+ f ,
550+ indent = 2 ,
551+ )
524552 print (f"Results saved to { args .output } " )
525553
526554 if args .html :
0 commit comments