Skip to content

Commit ffccd5c

Browse files
adubey-aiCopilot
andauthored
FP8 support for FusedQKNormRope (#133)
* FP8 support for FusedQKNormRope Signed-off-by: Adarsh Dubey <adarsh.dubey@intel.com> * Attached Benchmark File for FP16, BF16 and FP8 Signed-off-by: adubey <adarsh.dubey@intel.com> * Run CI Signed-off-by: adubey <adarsh.dubey@intel.com> * src/sycl/FusedQKNormRope.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Used SYCL_DISPATCH_FLOATING_TYPES-style Signed-off-by: Adarsh Dubey <adarsh.dubey@intel.com> --------- Signed-off-by: Adarsh Dubey <adarsh.dubey@intel.com> Signed-off-by: adubey <adarsh.dubey@intel.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent f7975d2 commit ffccd5c

4 files changed

Lines changed: 336 additions & 68 deletions

File tree

benchmark/bench_fused_qk_norm_rope.py

Lines changed: 153 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
11
import itertools
2+
import os
23

34
import pandas as pd
45
import torch
56
import triton
67
from sgl_kernel import fused_qk_norm_rope
78

9+
# Supported dtypes and their properties
10+
DTYPE_MAP = {
11+
"fp16": torch.float16,
12+
"bf16": torch.bfloat16,
13+
"fp8_e4m3fn": torch.float8_e4m3fn,
14+
}
15+
DTYPE_BYTES = {
16+
"fp16": 2,
17+
"bf16": 2,
18+
"fp8_e4m3fn": 1,
19+
}
20+
821

922
def llama_rms_norm(x, w, eps=1e-6):
1023
"""PyTorch reference implementation of RMS normalization."""
@@ -175,15 +188,46 @@ def fused_qk_norm_rope_reference(
175188
(128, 128, 128, 128), # DeepSeek-V3 style
176189
]
177190
is_neox_range = [True, False]
191+
dtype_range = ["fp16", "bf16", "fp8_e4m3fn"]
178192

179193
configs = []
180-
for batch_size, seq_len, (nq, nk, nv, hd), is_neox in itertools.product(
181-
batch_size_range, seq_len_range, head_config_range, is_neox_range
194+
for batch_size, seq_len, (nq, nk, nv, hd), is_neox, dtype in itertools.product(
195+
batch_size_range, seq_len_range, head_config_range, is_neox_range, dtype_range
182196
):
183-
configs.append((batch_size, seq_len, nq, nk, nv, hd, is_neox))
197+
configs.append((batch_size, seq_len, nq, nk, nv, hd, is_neox, dtype))
184198

185199
all_results = []
186200

201+
# Support running the benchmark in "chunked" mode by setting environment variables:
202+
# - NUM_CHUNKS: total number of chunks
203+
# - CHUNK_IDX: index of this chunk (0-based)
204+
num_chunks = int(os.environ.get("NUM_CHUNKS", "1"))
205+
chunk_idx = int(os.environ.get("CHUNK_IDX", "0"))
206+
# Support selecting a single config via env var SINGLE_CONFIG with CSV:
207+
# "batch_size,seq_len,nq,nk,nv,head_dim,is_neox,dtype"
208+
single_cfg = os.environ.get("SINGLE_CONFIG", "")
209+
if single_cfg:
210+
parts = [p.strip() for p in single_cfg.split(",")]
211+
if len(parts) != 8:
212+
raise RuntimeError("SINGLE_CONFIG must have 8 comma-separated fields")
213+
bsz = int(parts[0])
214+
sl = int(parts[1])
215+
nq = int(parts[2])
216+
nk = int(parts[3])
217+
nv = int(parts[4])
218+
hd = int(parts[5])
219+
is_neox_val = parts[6].lower() in ("1", "true", "yes")
220+
dtype_val = parts[7]
221+
configs = [(bsz, sl, nq, nk, nv, hd, is_neox_val, dtype_val)]
222+
num_chunks = 1
223+
chunk_idx = 0
224+
if num_chunks > 1:
225+
total = len(configs)
226+
chunk_size = (total + num_chunks - 1) // num_chunks
227+
start = chunk_idx * chunk_size
228+
end = min(start + chunk_size, total)
229+
configs = configs[start:end]
230+
187231

188232
def calculate_flops(
189233
num_tokens: int,
@@ -220,6 +264,7 @@ def calculate_effective_bandwidth(
220264
head_dim: int,
221265
rotary_dim: int,
222266
time_ms: float,
267+
bytes_per_elem: int = 2,
223268
) -> dict:
224269
"""
225270
Calculate effective bandwidth and FLOPs for fused QK norm + RoPE kernel.
@@ -229,11 +274,11 @@ def calculate_effective_bandwidth(
229274
num_tokens = batch_size * seq_len
230275
num_heads = num_heads_q + num_heads_k + num_heads_v
231276

232-
# Input/output QKV tensor (bf16)
233-
qkv_bytes = num_tokens * num_heads * head_dim * 2
277+
# Input/output QKV tensor
278+
qkv_bytes = num_tokens * num_heads * head_dim * bytes_per_elem
234279

235-
# Weight tensors (bf16)
236-
weight_bytes = 2 * head_dim * 2 # q_weight + k_weight
280+
# Weight tensors
281+
weight_bytes = 2 * head_dim * bytes_per_elem # q_weight + k_weight
237282

238283
# Total bytes (read QKV + write QKV + read weights)
239284
total_bytes = 2 * qkv_bytes + weight_bytes
@@ -265,6 +310,7 @@ def calculate_effective_bandwidth(
265310
"num_heads_v",
266311
"head_dim",
267312
"is_neox",
313+
"dtype",
268314
],
269315
x_vals=configs,
270316
line_arg="provider",
@@ -284,17 +330,41 @@ def benchmark(
284330
num_heads_v,
285331
head_dim,
286332
is_neox,
333+
dtype,
287334
provider,
288335
):
289336
device = torch.device("xpu")
290337
num_tokens = batch_size * seq_len
291338
num_heads = num_heads_q + num_heads_k + num_heads_v
339+
torch_dtype = DTYPE_MAP[dtype]
340+
is_fp8 = dtype == "fp8_e4m3fn"
341+
342+
if is_fp8:
343+
# FP8 tensors: create in float32, clamp to representable range, convert
344+
qkv = (
345+
torch.randn(
346+
num_tokens, num_heads * head_dim, device=device, dtype=torch.float32
347+
)
348+
.clamp(-448.0, 448.0)
349+
.to(torch_dtype)
350+
)
351+
q_weight = (
352+
torch.randn(head_dim, device=device, dtype=torch.float32)
353+
.clamp(-448.0, 448.0)
354+
.to(torch_dtype)
355+
)
356+
k_weight = (
357+
torch.randn(head_dim, device=device, dtype=torch.float32)
358+
.clamp(-448.0, 448.0)
359+
.to(torch_dtype)
360+
)
361+
else:
362+
qkv = torch.randn(
363+
num_tokens, num_heads * head_dim, device=device, dtype=torch_dtype
364+
)
365+
q_weight = torch.randn(head_dim, device=device, dtype=torch_dtype)
366+
k_weight = torch.randn(head_dim, device=device, dtype=torch_dtype)
292367

293-
qkv = torch.randn(
294-
num_tokens, num_heads * head_dim, device=device, dtype=torch.bfloat16
295-
)
296-
q_weight = torch.randn(head_dim, device=device, dtype=torch.bfloat16)
297-
k_weight = torch.randn(head_dim, device=device, dtype=torch.bfloat16)
298368
position_ids = torch.arange(num_tokens, device=device, dtype=torch.int32)
299369

300370
eps = 1e-6
@@ -304,24 +374,48 @@ def benchmark(
304374
quantiles = [0.5, 0.2, 0.8]
305375

306376
if provider == "torch":
307-
fn = lambda: fused_qk_norm_rope_reference(
308-
qkv.clone(),
309-
num_heads_q,
310-
num_heads_k,
311-
num_heads_v,
312-
head_dim,
313-
eps,
314-
q_weight,
315-
k_weight,
316-
base,
317-
is_neox,
318-
position_ids,
319-
factor=1.0,
320-
low=1.0,
321-
high=1.0,
322-
attention_factor=1.0,
323-
rotary_dim=rotary_dim,
324-
)
377+
if is_fp8:
378+
# PyTorch has no native FP8 compute; upcast to float32 for reference
379+
qkv_ref = qkv.to(torch.float32)
380+
q_weight_ref = q_weight.to(torch.float32)
381+
k_weight_ref = k_weight.to(torch.float32)
382+
fn = lambda: fused_qk_norm_rope_reference(
383+
qkv_ref.clone(),
384+
num_heads_q,
385+
num_heads_k,
386+
num_heads_v,
387+
head_dim,
388+
eps,
389+
q_weight_ref,
390+
k_weight_ref,
391+
base,
392+
is_neox,
393+
position_ids,
394+
factor=1.0,
395+
low=1.0,
396+
high=1.0,
397+
attention_factor=1.0,
398+
rotary_dim=rotary_dim,
399+
).to(torch_dtype)
400+
else:
401+
fn = lambda: fused_qk_norm_rope_reference(
402+
qkv.clone(),
403+
num_heads_q,
404+
num_heads_k,
405+
num_heads_v,
406+
head_dim,
407+
eps,
408+
q_weight,
409+
k_weight,
410+
base,
411+
is_neox,
412+
position_ids,
413+
factor=1.0,
414+
low=1.0,
415+
high=1.0,
416+
attention_factor=1.0,
417+
rotary_dim=rotary_dim,
418+
)
325419
elif provider == "sglang":
326420
fn = lambda: fused_qk_norm_rope(
327421
qkv.clone(),
@@ -344,7 +438,7 @@ def benchmark(
344438

345439
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
346440

347-
# Calculate effective bandwidth
441+
# Calculate effective bandwidth using the correct bytes-per-element for this dtype
348442
bw_metrics = calculate_effective_bandwidth(
349443
batch_size,
350444
seq_len,
@@ -354,6 +448,7 @@ def benchmark(
354448
head_dim,
355449
rotary_dim,
356450
ms,
451+
bytes_per_elem=DTYPE_BYTES[dtype],
357452
)
358453

359454
all_results.append(
@@ -366,6 +461,7 @@ def benchmark(
366461
"num_heads_v": num_heads_v,
367462
"head_dim": head_dim,
368463
"is_neox": is_neox,
464+
"dtype": dtype,
369465
"provider": provider,
370466
"time_us": 1000 * ms,
371467
"bandwidth_gbs": bw_metrics["bandwidth_gbs"],
@@ -382,12 +478,20 @@ def benchmark(
382478
print("Running benchmarks...")
383479
benchmark.run(print_data=True)
384480

481+
# Ensure results dir exists and write per-chunk CSV
482+
os.makedirs("benchmark/results", exist_ok=True)
483+
df = pd.DataFrame(all_results)
484+
chunk_label = (
485+
f"chunk_{os.environ.get('CHUNK_IDX','0')}_of_{os.environ.get('NUM_CHUNKS','1')}"
486+
)
487+
out_csv = os.path.join("benchmark/results", f"results_{chunk_label}.csv")
488+
df.to_csv(out_csv, index=False)
489+
print(f"Wrote results CSV: {out_csv}")
490+
385491
# Print bandwidth results
386492
print("\n" + "=" * 80)
387493
print("Effective Bandwidth Results")
388494
print("=" * 80)
389-
390-
df = pd.DataFrame(all_results)
391495
df["bandwidth_gbs"] = df["bandwidth_gbs"].round(2)
392496
df["total_bytes_mb"] = df["total_bytes_mb"].round(2)
393497
df["time_us"] = df["time_us"].round(2)
@@ -400,7 +504,7 @@ def benchmark(
400504
print("\n" + "=" * 80)
401505
print("Summary Statistics by Provider")
402506
print("=" * 80)
403-
summary = df.groupby("provider").agg(
507+
summary = df.groupby(["dtype", "provider"]).agg(
404508
{
405509
"bandwidth_gbs": ["mean", "min", "max"],
406510
"time_us": ["mean", "min", "max"],
@@ -424,13 +528,25 @@ def benchmark(
424528
"num_heads_v",
425529
"head_dim",
426530
"is_neox",
531+
"dtype",
427532
],
428533
columns="provider",
429534
values="time_us",
430535
)
431536

432537
if "torch" in pivot.columns and "sglang" in pivot.columns:
433538
pivot["speedup"] = pivot["torch"] / pivot["sglang"]
434-
print(f"\nAverage speedup: {pivot['speedup'].mean():.2f}x")
435-
print(f"Max speedup: {pivot['speedup'].max():.2f}x")
436-
print(f"Min speedup: {pivot['speedup'].min():.2f}x")
539+
print(f"\nOverall average speedup: {pivot['speedup'].mean():.2f}x")
540+
print(f"Overall max speedup: {pivot['speedup'].max():.2f}x")
541+
print(f"Overall min speedup: {pivot['speedup'].min():.2f}x")
542+
543+
# Per-dtype speedup breakdown
544+
print("\nSpeedup by dtype:")
545+
for dt in df["dtype"].unique():
546+
mask = pivot.index.get_level_values("dtype") == dt
547+
sp = pivot.loc[mask, "speedup"]
548+
if not sp.empty:
549+
print(
550+
f" {dt:>12s}: avg={sp.mean():.2f}x max={sp.max():.2f}x min={sp.min():.2f}x"
551+
)
552+
print(f"Wrote results CSV: {out_csv}")

cmake/Modules/FindSYCL.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,16 @@ macro(SYCL_LINK_DEVICE_OBJECTS output_file sycl_target sycl_offline_compiler_fla
401401
endif()
402402

403403
# Build the generated file and dependency file ##########################
404+
# Only pass -Xs when there are offline compiler flags (AOT targets).
405+
# For JIT (spir64) targets SYCL_OFFLINE_COMPILER_FLAGS is empty and
406+
# passing a bare -Xs causes icx to consume the following -o flag as
407+
# its argument, producing "no such file or directory" errors.
408+
if(SYCL_OFFLINE_COMPILER_FLAGS)
409+
set(_sycl_xs_flags -Xs ${SYCL_OFFLINE_COMPILER_FLAGS})
410+
else()
411+
set(_sycl_xs_flags)
412+
endif()
413+
404414
add_custom_command(
405415
OUTPUT ${output_file}
406416
DEPENDS ${object_files}

0 commit comments

Comments
 (0)