11import itertools
2+ import os
23
34import pandas as pd
45import torch
56import triton
67from sgl_kernel import fused_qk_norm_rope
78
9+ # Supported dtypes and their properties
10+ DTYPE_MAP = {
11+ "fp16" : torch .float16 ,
12+ "bf16" : torch .bfloat16 ,
13+ "fp8_e4m3fn" : torch .float8_e4m3fn ,
14+ }
15+ DTYPE_BYTES = {
16+ "fp16" : 2 ,
17+ "bf16" : 2 ,
18+ "fp8_e4m3fn" : 1 ,
19+ }
20+
821
922def llama_rms_norm (x , w , eps = 1e-6 ):
1023 """PyTorch reference implementation of RMS normalization."""
@@ -175,15 +188,46 @@ def fused_qk_norm_rope_reference(
175188 (128 , 128 , 128 , 128 ), # DeepSeek-V3 style
176189]
177190is_neox_range = [True , False ]
191+ dtype_range = ["fp16" , "bf16" , "fp8_e4m3fn" ]
178192
179193configs = []
180- for batch_size , seq_len , (nq , nk , nv , hd ), is_neox in itertools .product (
181- batch_size_range , seq_len_range , head_config_range , is_neox_range
194+ for batch_size , seq_len , (nq , nk , nv , hd ), is_neox , dtype in itertools .product (
195+ batch_size_range , seq_len_range , head_config_range , is_neox_range , dtype_range
182196):
183- configs .append ((batch_size , seq_len , nq , nk , nv , hd , is_neox ))
197+ configs .append ((batch_size , seq_len , nq , nk , nv , hd , is_neox , dtype ))
184198
185199all_results = []
186200
201+ # Support running the benchmark in "chunked" mode by setting environment variables:
202+ # - NUM_CHUNKS: total number of chunks
203+ # - CHUNK_IDX: index of this chunk (0-based)
204+ num_chunks = int (os .environ .get ("NUM_CHUNKS" , "1" ))
205+ chunk_idx = int (os .environ .get ("CHUNK_IDX" , "0" ))
206+ # Support selecting a single config via env var SINGLE_CONFIG with CSV:
207+ # "batch_size,seq_len,nq,nk,nv,head_dim,is_neox,dtype"
208+ single_cfg = os .environ .get ("SINGLE_CONFIG" , "" )
209+ if single_cfg :
210+ parts = [p .strip () for p in single_cfg .split ("," )]
211+ if len (parts ) != 8 :
212+ raise RuntimeError ("SINGLE_CONFIG must have 8 comma-separated fields" )
213+ bsz = int (parts [0 ])
214+ sl = int (parts [1 ])
215+ nq = int (parts [2 ])
216+ nk = int (parts [3 ])
217+ nv = int (parts [4 ])
218+ hd = int (parts [5 ])
219+ is_neox_val = parts [6 ].lower () in ("1" , "true" , "yes" )
220+ dtype_val = parts [7 ]
221+ configs = [(bsz , sl , nq , nk , nv , hd , is_neox_val , dtype_val )]
222+ num_chunks = 1
223+ chunk_idx = 0
224+ if num_chunks > 1 :
225+ total = len (configs )
226+ chunk_size = (total + num_chunks - 1 ) // num_chunks
227+ start = chunk_idx * chunk_size
228+ end = min (start + chunk_size , total )
229+ configs = configs [start :end ]
230+
187231
188232def calculate_flops (
189233 num_tokens : int ,
@@ -220,6 +264,7 @@ def calculate_effective_bandwidth(
220264 head_dim : int ,
221265 rotary_dim : int ,
222266 time_ms : float ,
267+ bytes_per_elem : int = 2 ,
223268) -> dict :
224269 """
225270 Calculate effective bandwidth and FLOPs for fused QK norm + RoPE kernel.
@@ -229,11 +274,11 @@ def calculate_effective_bandwidth(
229274 num_tokens = batch_size * seq_len
230275 num_heads = num_heads_q + num_heads_k + num_heads_v
231276
232- # Input/output QKV tensor (bf16)
233- qkv_bytes = num_tokens * num_heads * head_dim * 2
277+ # Input/output QKV tensor
278+ qkv_bytes = num_tokens * num_heads * head_dim * bytes_per_elem
234279
235- # Weight tensors (bf16)
236- weight_bytes = 2 * head_dim * 2 # q_weight + k_weight
280+ # Weight tensors
281+ weight_bytes = 2 * head_dim * bytes_per_elem # q_weight + k_weight
237282
238283 # Total bytes (read QKV + write QKV + read weights)
239284 total_bytes = 2 * qkv_bytes + weight_bytes
@@ -265,6 +310,7 @@ def calculate_effective_bandwidth(
265310 "num_heads_v" ,
266311 "head_dim" ,
267312 "is_neox" ,
313+ "dtype" ,
268314 ],
269315 x_vals = configs ,
270316 line_arg = "provider" ,
@@ -284,17 +330,41 @@ def benchmark(
284330 num_heads_v ,
285331 head_dim ,
286332 is_neox ,
333+ dtype ,
287334 provider ,
288335):
289336 device = torch .device ("xpu" )
290337 num_tokens = batch_size * seq_len
291338 num_heads = num_heads_q + num_heads_k + num_heads_v
339+ torch_dtype = DTYPE_MAP [dtype ]
340+ is_fp8 = dtype == "fp8_e4m3fn"
341+
342+ if is_fp8 :
343+ # FP8 tensors: create in float32, clamp to representable range, convert
344+ qkv = (
345+ torch .randn (
346+ num_tokens , num_heads * head_dim , device = device , dtype = torch .float32
347+ )
348+ .clamp (- 448.0 , 448.0 )
349+ .to (torch_dtype )
350+ )
351+ q_weight = (
352+ torch .randn (head_dim , device = device , dtype = torch .float32 )
353+ .clamp (- 448.0 , 448.0 )
354+ .to (torch_dtype )
355+ )
356+ k_weight = (
357+ torch .randn (head_dim , device = device , dtype = torch .float32 )
358+ .clamp (- 448.0 , 448.0 )
359+ .to (torch_dtype )
360+ )
361+ else :
362+ qkv = torch .randn (
363+ num_tokens , num_heads * head_dim , device = device , dtype = torch_dtype
364+ )
365+ q_weight = torch .randn (head_dim , device = device , dtype = torch_dtype )
366+ k_weight = torch .randn (head_dim , device = device , dtype = torch_dtype )
292367
293- qkv = torch .randn (
294- num_tokens , num_heads * head_dim , device = device , dtype = torch .bfloat16
295- )
296- q_weight = torch .randn (head_dim , device = device , dtype = torch .bfloat16 )
297- k_weight = torch .randn (head_dim , device = device , dtype = torch .bfloat16 )
298368 position_ids = torch .arange (num_tokens , device = device , dtype = torch .int32 )
299369
300370 eps = 1e-6
@@ -304,24 +374,48 @@ def benchmark(
304374 quantiles = [0.5 , 0.2 , 0.8 ]
305375
306376 if provider == "torch" :
307- fn = lambda : fused_qk_norm_rope_reference (
308- qkv .clone (),
309- num_heads_q ,
310- num_heads_k ,
311- num_heads_v ,
312- head_dim ,
313- eps ,
314- q_weight ,
315- k_weight ,
316- base ,
317- is_neox ,
318- position_ids ,
319- factor = 1.0 ,
320- low = 1.0 ,
321- high = 1.0 ,
322- attention_factor = 1.0 ,
323- rotary_dim = rotary_dim ,
324- )
377+ if is_fp8 :
378+ # PyTorch has no native FP8 compute; upcast to float32 for reference
379+ qkv_ref = qkv .to (torch .float32 )
380+ q_weight_ref = q_weight .to (torch .float32 )
381+ k_weight_ref = k_weight .to (torch .float32 )
382+ fn = lambda : fused_qk_norm_rope_reference (
383+ qkv_ref .clone (),
384+ num_heads_q ,
385+ num_heads_k ,
386+ num_heads_v ,
387+ head_dim ,
388+ eps ,
389+ q_weight_ref ,
390+ k_weight_ref ,
391+ base ,
392+ is_neox ,
393+ position_ids ,
394+ factor = 1.0 ,
395+ low = 1.0 ,
396+ high = 1.0 ,
397+ attention_factor = 1.0 ,
398+ rotary_dim = rotary_dim ,
399+ ).to (torch_dtype )
400+ else :
401+ fn = lambda : fused_qk_norm_rope_reference (
402+ qkv .clone (),
403+ num_heads_q ,
404+ num_heads_k ,
405+ num_heads_v ,
406+ head_dim ,
407+ eps ,
408+ q_weight ,
409+ k_weight ,
410+ base ,
411+ is_neox ,
412+ position_ids ,
413+ factor = 1.0 ,
414+ low = 1.0 ,
415+ high = 1.0 ,
416+ attention_factor = 1.0 ,
417+ rotary_dim = rotary_dim ,
418+ )
325419 elif provider == "sglang" :
326420 fn = lambda : fused_qk_norm_rope (
327421 qkv .clone (),
@@ -344,7 +438,7 @@ def benchmark(
344438
345439 ms , min_ms , max_ms = triton .testing .do_bench (fn , quantiles = quantiles )
346440
347- # Calculate effective bandwidth
441+ # Calculate effective bandwidth using the correct bytes-per-element for this dtype
348442 bw_metrics = calculate_effective_bandwidth (
349443 batch_size ,
350444 seq_len ,
@@ -354,6 +448,7 @@ def benchmark(
354448 head_dim ,
355449 rotary_dim ,
356450 ms ,
451+ bytes_per_elem = DTYPE_BYTES [dtype ],
357452 )
358453
359454 all_results .append (
@@ -366,6 +461,7 @@ def benchmark(
366461 "num_heads_v" : num_heads_v ,
367462 "head_dim" : head_dim ,
368463 "is_neox" : is_neox ,
464+ "dtype" : dtype ,
369465 "provider" : provider ,
370466 "time_us" : 1000 * ms ,
371467 "bandwidth_gbs" : bw_metrics ["bandwidth_gbs" ],
@@ -382,12 +478,20 @@ def benchmark(
382478 print ("Running benchmarks..." )
383479 benchmark .run (print_data = True )
384480
481+ # Ensure results dir exists and write per-chunk CSV
482+ os .makedirs ("benchmark/results" , exist_ok = True )
483+ df = pd .DataFrame (all_results )
484+ chunk_label = (
485+ f"chunk_{ os .environ .get ('CHUNK_IDX' ,'0' )} _of_{ os .environ .get ('NUM_CHUNKS' ,'1' )} "
486+ )
487+ out_csv = os .path .join ("benchmark/results" , f"results_{ chunk_label } .csv" )
488+ df .to_csv (out_csv , index = False )
489+ print (f"Wrote results CSV: { out_csv } " )
490+
385491 # Print bandwidth results
386492 print ("\n " + "=" * 80 )
387493 print ("Effective Bandwidth Results" )
388494 print ("=" * 80 )
389-
390- df = pd .DataFrame (all_results )
391495 df ["bandwidth_gbs" ] = df ["bandwidth_gbs" ].round (2 )
392496 df ["total_bytes_mb" ] = df ["total_bytes_mb" ].round (2 )
393497 df ["time_us" ] = df ["time_us" ].round (2 )
@@ -400,7 +504,7 @@ def benchmark(
400504 print ("\n " + "=" * 80 )
401505 print ("Summary Statistics by Provider" )
402506 print ("=" * 80 )
403- summary = df .groupby (" provider" ).agg (
507+ summary = df .groupby ([ "dtype" , " provider"] ).agg (
404508 {
405509 "bandwidth_gbs" : ["mean" , "min" , "max" ],
406510 "time_us" : ["mean" , "min" , "max" ],
@@ -424,13 +528,25 @@ def benchmark(
424528 "num_heads_v" ,
425529 "head_dim" ,
426530 "is_neox" ,
531+ "dtype" ,
427532 ],
428533 columns = "provider" ,
429534 values = "time_us" ,
430535 )
431536
432537 if "torch" in pivot .columns and "sglang" in pivot .columns :
433538 pivot ["speedup" ] = pivot ["torch" ] / pivot ["sglang" ]
434- print (f"\n Average speedup: { pivot ['speedup' ].mean ():.2f} x" )
435- print (f"Max speedup: { pivot ['speedup' ].max ():.2f} x" )
436- print (f"Min speedup: { pivot ['speedup' ].min ():.2f} x" )
539+ print (f"\n Overall average speedup: { pivot ['speedup' ].mean ():.2f} x" )
540+ print (f"Overall max speedup: { pivot ['speedup' ].max ():.2f} x" )
541+ print (f"Overall min speedup: { pivot ['speedup' ].min ():.2f} x" )
542+
543+ # Per-dtype speedup breakdown
544+ print ("\n Speedup by dtype:" )
545+ for dt in df ["dtype" ].unique ():
546+ mask = pivot .index .get_level_values ("dtype" ) == dt
547+ sp = pivot .loc [mask , "speedup" ]
548+ if not sp .empty :
549+ print (
550+ f" { dt :>12s} : avg={ sp .mean ():.2f} x max={ sp .max ():.2f} x min={ sp .min ():.2f} x"
551+ )
552+ print (f"Wrote results CSV: { out_csv } " )
0 commit comments