diff --git a/.github/workflows/pr-test-xpu.yml b/.github/workflows/pr-test-xpu.yml index 015766d2..c777a1ec 100644 --- a/.github/workflows/pr-test-xpu.yml +++ b/.github/workflows/pr-test-xpu.yml @@ -67,6 +67,8 @@ jobs: python3 bench_merge_states_v2.py 2>&1 | tee merge_states.py.log \ python3 bench_swiglu_alpha_limit.py 2>&1 | tee swiglu_alpha_limit.py.log \ python3 bench_fused_qk_norm_rope.py 2>&1 | tee fused_qk_norm_rope.py.log \ + python3 bench_per_token_group_quant_8bit.py 2>&1 | tee per_token_group_quant_8bit.py.log \ + python3 bench_per_token_group_quant_mxfp4.py 2>&1 | tee per_token_group_quant_mxfp4.py.log \ " - name: Copy logs from container diff --git a/CMakeLists.txt b/CMakeLists.txt index 1ba0e2ef..8147a33e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,7 +42,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla FetchContent_Declare( repo-cutlass-sycl GIT_REPOSITORY https://github.com/intel/sycl-tla.git - GIT_TAG 482b40e8bed0e9204311d1569c876b4573dfb952 + GIT_TAG 64584484b4279b1b4184b508af445698a4a1b603 GIT_SHALLOW OFF ) FetchContent_MakeAvailable(repo-cutlass-sycl) diff --git a/Dockerfile.xpu_kernel b/Dockerfile.xpu_kernel index 3c34ba22..8e2a25b7 100644 --- a/Dockerfile.xpu_kernel +++ b/Dockerfile.xpu_kernel @@ -22,10 +22,10 @@ ARG SG_LANG_KERNEL_BRANCH=main # Install the latest UMD driver for SYCL-TLA RUN apt-get install -y software-properties-common && \ add-apt-repository -y ppa:kobuk-team/intel-graphics && \ + apt-get update && \ apt-get install -y libze-intel-gpu1 libze1 intel-metrics-discovery intel-opencl-icd clinfo intel-gsc && \ apt-get install -y intel-media-va-driver-non-free libmfx-gen1 libvpl2 libvpl-tools libva-glx2 va-driver-all vainfo && \ - apt-get install -y libze-dev intel-ocloc && \ - apt-get update + apt-get install -y libze-dev intel-ocloc # Install Miniforge & PyTorch/Triton RUN curl -fsSL -v -o miniforge.sh -O https://github.com/conda-forge/miniforge/releases/download/25.1.1-0/Miniforge3-Linux-x86_64.sh && \ @@ -66,3 +66,4 @@ RUN --mount=type=secret,id=github_token \ # Set the default shell to bash SHELL ["bash", "-c"] CMD ["bash", "-c", "source /root/.bashrc && exec bash"] +USER root diff --git a/benchmark/bench_per_token_group_quant_mxfp4.py b/benchmark/bench_per_token_group_quant_mxfp4.py new file mode 100644 index 00000000..ae1e4671 --- /dev/null +++ b/benchmark/bench_per_token_group_quant_mxfp4.py @@ -0,0 +1,547 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark script for MXFP4 (E2M1) per-token group quantization on Intel XPU. + +Rounding: Per OCP MX spec (section 5.3.3), FP4 conversion uses +roundTiesToEven — at midpoints between representable values, the +value with even mantissa (mantissa bit = 0) is chosen. +""" + +import itertools +import os + +import pandas as pd +import torch +import triton + +MXFP4_BLOCK_SIZE = 32 +FLOAT4_E2M1_MAX = 6.0 + +# E2M1 format parameters (from Microsoft microxcaling formats.py) +E2M1_EBITS = 2 +E2M1_MBITS = 3 # includes sign bit and implicit one +E2M1_EMAX = 2 ** (E2M1_EBITS - 1) # = 2 +E2M1_MAX_NORM = ( + 2**E2M1_EMAX * float(2 ** (E2M1_MBITS - 1) - 1) / 2 ** (E2M1_MBITS - 2) +) # = 6.0 + +FP32_EXPONENT_BIAS = 127 +FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1) # 2^(-126) + +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + + +def is_xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _round_mantissa_even(A: torch.Tensor) -> torch.Tensor: + """Round mantissa using roundTiesToEven (from Microsoft microxcaling). + + At exact 0.5 midpoints (i.e., values like 0.5, 2.5, 4.5, ...), + round to the nearest even integer (the one whose LSB is 0). + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + absA = torch.abs(A) + # Identify exact midpoints: 0.5, 2.5, 4.5, ... i.e. (absA - 0.5) % 2 == 0 + maskA = ((absA - 0.5) % 2 == torch.zeros_like(A)).type(A.dtype) + # round half up, then subtract 1 at midpoints to get even + return torch.sign(A) * (torch.floor(absA + 0.5) - maskA) + + +def _quantize_elemwise_core_e2m1( + A: torch.Tensor, saturate_normals: bool = True +) -> torch.Tensor: + """Element-wise quantization to E2M1 using Microsoft microxcaling's + _quantize_elemwise_core algorithm with round='even'. + + E2M1 format: ebits=2, mbits=3, emax=2, max_norm=6.0 + min_exp = -(2^(ebits-1)) + 2 = 0 + + Algorithm (from Microsoft microxcaling elemwise_ops.py): + 1. Compute per-element private exponent = floor(log2(|A|)), + clamped to min_exp. + 2. Left-shift: out = A / 2^private_exp * 2^(mbits-2) + 3. Round mantissa with roundTiesToEven + 4. Right-shift: out = out / 2^(mbits-2) * 2^private_exp + 5. Clamp to [-max_norm, max_norm] if saturate_normals + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + ebits = E2M1_EBITS # 2 + mbits = E2M1_MBITS # 3 + max_norm = E2M1_MAX_NORM # 6.0 + + # min representable exponent: -(2^(ebits-1)) + 2 = 0 + min_exp = -(2 ** (ebits - 1)) + 2 # 0 + + out = A.clone() + + # Per-element private exponent: floor(log2(|A|)) + # Add guard for zeros: log2(0) is -inf, we use (A==0) to avoid that + private_exp = torch.floor(torch.log2(torch.abs(A) + (A == 0).type(A.dtype))) + private_exp = private_exp.clip(min=min_exp) + + # Left-shift: scale up so mantissa bits land in integer portion + # out = A / 2^private_exp * 2^(mbits-2) + shift = mbits - 2 # = 1 + out = out / (2**private_exp) * (2**shift) + + # Round mantissa with roundTiesToEven + out = _round_mantissa_even(out) + + # Right-shift: undo scaling + # out = out / 2^(mbits-2) * 2^private_exp + out = out / (2**shift) * (2**private_exp) + + # Saturate to [-max_norm, max_norm] + if saturate_normals: + out = torch.clamp(out, min=-max_norm, max=max_norm) + + return out + + +def _float_to_e2m1_code(val: torch.Tensor) -> torch.Tensor: + """Convert quantized float values back to E2M1 4-bit codes. + + After _quantize_elemwise_core_e2m1, values are one of the 8 representable + E2M1 magnitudes: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0}. + This maps them to 4-bit codes (sign in bit 3, magnitude in bits 0-2). + """ + sign = (val < 0).to(torch.uint8) + abs_val = val.abs() + + # Map representable magnitudes to 3-bit indices via the kE2M1ToFloat LUT. + indices = torch.zeros_like(abs_val, dtype=torch.uint8) + lut = kE2M1ToFloat.to(device=val.device) + for i in range(8): + indices = torch.where( + torch.isclose(abs_val, lut[i].expand_as(abs_val), atol=1e-6, rtol=0), + torch.tensor(i, dtype=torch.uint8, device=val.device), + indices, + ) + + return (sign << 3) | indices + + +def quantize_to_e2m1(tensor: torch.Tensor) -> torch.Tensor: + """Quantize tensor values to E2M1 format (4-bit indices). + + Uses the Microsoft microxcaling _quantize_elemwise_core algorithm + with roundTiesToEven, then maps the resulting float values to 4-bit codes. + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + quantized_float = _quantize_elemwise_core_e2m1( + tensor.float(), saturate_normals=True + ) + return _float_to_e2m1_code(quantized_float) + + +def pack_fp4(tensor: torch.Tensor) -> torch.Tensor: + """Pack two 4-bit values into one uint8.""" + assert tensor.shape[-1] % 2 == 0 + shape = tensor.shape[:-1] + (tensor.shape[-1] // 2, 2) + paired = tensor.reshape(shape) + packed = (paired[..., 0] & 0x0F) | ((paired[..., 1] & 0x0F) << 4) + return packed.to(torch.uint8) + + +def _normalize_packed_fp4_signed_zero(packed: torch.Tensor) -> torch.Tensor: + """Canonicalize signed zeros in packed FP4 bytes. + + In E2M1, code 0x0 is +0.0 and code 0x8 is -0.0. Both represent + the same value, but different implementations may emit either form. + This helper rewrites every -0.0 nibble (0x8) to +0.0 (0x0) so that + byte-level comparisons are not tripped up by this harmless difference. + """ + lo = packed & 0x0F + hi = (packed >> 4) & 0x0F + lo = torch.where(lo == 0x08, torch.zeros_like(lo), lo) + hi = torch.where(hi == 0x08, torch.zeros_like(hi), hi) + return (lo | (hi << 4)).to(torch.uint8) + + +def unpack_fp4(packed: torch.Tensor) -> torch.Tensor: + """Unpack uint8 into two 4-bit values.""" + low = packed & 0x0F + high = (packed >> 4) & 0x0F + return torch.stack([low, high], dim=-1).reshape(*packed.shape[:-1], -1) + + +def dequantize_e2m1( + quantized: torch.Tensor, dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + """Dequantize E2M1 values back to float.""" + sign = ((quantized >> 3) & 1).to(torch.bool) + magnitude_idx = (quantized & 0x07).to(torch.long) + kE2M1 = kE2M1ToFloat.to(device=quantized.device) + magnitude = kE2M1[magnitude_idx] + result = torch.where(sign, -magnitude, magnitude) + return result.to(dtype) + + +def _shared_exponents(A: torch.Tensor, axis: int) -> torch.Tensor: + """Compute shared exponents per block using Microsoft microxcaling's + _shared_exponents algorithm with method="max". + + Algorithm: + 1. shared_exp = max(|A|) along axis (per block) + 2. shared_exp = floor(log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0))) + The FP32_MIN_NORMAL guard ensures log2(0) doesn't produce -inf. + 3. Offset by emax: shared_exp = shared_exp - emax + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/mx_ops.py + """ + shared_exp = torch.max(torch.abs(A), dim=axis, keepdim=True).values + + # floor(log2(...)) with zero-guard from microxcaling + shared_exp = torch.floor( + torch.log2( + shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype) + ) + ) + + # Offset by the largest representable exponent in E2M1 + shared_exp = shared_exp - E2M1_EMAX + + return shared_exp + + +def quantize_to_mxfp4_ref( + tensor: torch.Tensor, block_size: int = MXFP4_BLOCK_SIZE, eps: float = 1e-10 +) -> tuple: + """Reference implementation for MXFP4 quantization using Microsoft + microxcaling's _quantize_mx algorithm. + + Algorithm (from mx_ops.py _quantize_mx): + 1. Reshape into blocks + 2. Compute shared exponent per block via _shared_exponents + 3. Clamp shared_exp to scale_emax range [-127, 127] + 4. Scale elements: A = A / 2^shared_exp + 5. Quantize element-wise with _quantize_elemwise_core (saturate_normals=True) + 6. Rescale: A = A * 2^shared_exp (implicitly stored in UE8M0 scale) + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/mx_ops.py + """ + assert tensor.dim() == 2 + m, k = tensor.shape + assert k % block_size == 0 + assert k % 2 == 0 + + tensor_fp32 = tensor.float() + num_blocks = k // block_size + tensor_blocks = tensor_fp32.reshape(m, num_blocks, block_size) + + # Compute shared exponents (microxcaling _shared_exponents + offset by emax) + shared_exp = _shared_exponents(tensor_blocks, axis=-1) + + # Clamp to UE8M0 scale range: scale_bits=8, scale_emax = 2^(8-1)-1 = 127 + scale_emax = 127 + shared_exp = shared_exp.clamp(min=-scale_emax, max=scale_emax) + + # Encode as UE8M0: stored_scale = shared_exp + 127 + scales_ue8m0 = (shared_exp.to(torch.int32) + 127).to(torch.uint8).squeeze(-1) + + # Scale elements by shared exponent: A = A / 2^shared_exp + scaled_tensor = tensor_blocks / (2.0**shared_exp) + + # Quantize element-wise with microxcaling core (roundTiesToEven, saturate) + quantized_float = _quantize_elemwise_core_e2m1(scaled_tensor, saturate_normals=True) + + # Convert quantized float values to 4-bit E2M1 codes + quantized_blocks = _float_to_e2m1_code(quantized_float) + + quantized = quantized_blocks.reshape(m, k) + packed = pack_fp4(quantized) + + return packed, scales_ue8m0 + + +def dequantize_mxfp4( + packed: torch.Tensor, + scales: torch.Tensor, + dtype: torch.dtype = torch.float32, + block_size: int = MXFP4_BLOCK_SIZE, +) -> torch.Tensor: + """Dequantize MXFP4 packed values back to float.""" + m, packed_k = packed.shape + k = packed_k * 2 + + unpacked = unpack_fp4(packed) + dequantized = dequantize_e2m1(unpacked, dtype) + + num_blocks = k // block_size + dequantized_blocks = dequantized.reshape(m, num_blocks, block_size) + + scale_exp = scales.to(torch.int32) - 127 + scale_values = torch.pow(2.0, scale_exp.float()).unsqueeze(-1) + scaled = dequantized_blocks * scale_values + + return scaled.reshape(m, k).to(dtype) + + +def reference_per_token_group_quant_mxfp4( + x: torch.Tensor, group_size: int, eps: float = 1e-10 +) -> tuple: + """Reference implementation using PyTorch operations.""" + assert x.shape[-1] % group_size == 0 + assert x.is_contiguous() + + x_cpu = x.cpu().float() + x_q, x_s = quantize_to_mxfp4_ref(x_cpu, group_size, eps) + return x_q.to(x.device), x_s.to(x.device) + + +def sglang_per_token_group_quant_mxfp4( + x: torch.Tensor, group_size: int, eps: float = 1e-10 +) -> tuple: + """SGL kernel wrapper for MXFP4 quantization.""" + from sgl_kernel import sgl_per_token_group_quant_fp4 + + assert x.shape[-1] % group_size == 0 + assert x.is_contiguous() + + x_q, x_s = sgl_per_token_group_quant_fp4(x=x, group_size=group_size, eps=eps) + return x_q, x_s + + +def calculate_diff( + batch_size: int, + seq_len: int, + hidden_dim: int, + group_size: int, + src_dtype: torch.dtype, + eps: float = 1e-10, +): + """Verify correctness by comparing reference and kernel implementations.""" + device = torch.device("xpu") + + x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=src_dtype) + + x_q_ref, x_s_ref = reference_per_token_group_quant_mxfp4(x.clone(), group_size, eps) + x_q_sgl, x_s_sgl = sglang_per_token_group_quant_mxfp4(x.clone(), group_size, eps) + + # Compare quantized outputs directly (packed uint8 and scales). + # Normalise signed zeros first: in E2M1 code 0x0 (+0.0) and 0x8 + # (-0.0) are semantically identical. The kernel may preserve the + # sign of the original float while the reference always emits +0.0, + # so we canonicalise before comparing. + x_q_ref_norm = _normalize_packed_fp4_signed_zero(x_q_ref.cpu()) + x_q_sgl_norm = _normalize_packed_fp4_signed_zero(x_q_sgl.cpu()) + q_match = torch.equal(x_q_ref_norm, x_q_sgl_norm) + s_match = torch.equal(x_s_ref.cpu(), x_s_sgl.cpu()) + + if q_match and s_match: + print( + f" \u2705 Quantized values match (batch={batch_size}, seq={seq_len}, hidden={hidden_dim}, group={group_size}, dtype={src_dtype})" + ) + else: + q_mismatches = (x_q_ref_norm != x_q_sgl_norm).sum().item() if not q_match else 0 + s_mismatches = ( + (x_s_ref.cpu() != x_s_sgl.cpu()).sum().item() if not s_match else 0 + ) + print( + f" \u274c Quantized values differ: " + f"packed_q({q_mismatches} mismatches) " + f"scales({s_mismatches} mismatches)" + ) + + # Compare dequantized outputs + x_dq_ref = dequantize_mxfp4(x_q_ref.cpu(), x_s_ref.cpu(), torch.float32, group_size) + x_dq_sgl = dequantize_mxfp4(x_q_sgl.cpu(), x_s_sgl.cpu(), torch.float32, group_size) + + if torch.allclose(x_dq_ref, x_dq_sgl, rtol=0.2, atol=0.5): + print( + f" \u2705 Dequantized values match (batch={batch_size}, seq={seq_len}, hidden={hidden_dim}, group={group_size}, dtype={src_dtype})" + ) + else: + max_diff = (x_dq_ref - x_dq_sgl).abs().max().item() + print(f" \u274c Dequantized values differ (max_diff={max_diff:.4f})") + + +def calculate_flops(num_elements: int, num_groups: int) -> int: + """Calculate FLOPs for MXFP4 per-token-group quantization.""" + flops_per_element = 5 + flops_per_group = 8 + return (num_elements * flops_per_element) + (num_groups * flops_per_group) + + +def calculate_effective_bandwidth( + batch_size: int, + seq_len: int, + hidden_dim: int, + group_size: int, + src_dtype: torch.dtype, + time_ms: float, +) -> dict: + """Calculate effective bandwidth and FLOPs for MXFP4 quantization kernel.""" + num_tokens = batch_size * seq_len + num_elements = num_tokens * hidden_dim + num_groups = num_elements // group_size + + dtype_size = 2 if src_dtype in (torch.float16, torch.bfloat16) else 4 + input_bytes = num_elements * dtype_size + output_bytes = num_elements // 2 + scale_bytes = num_groups + total_bytes = input_bytes + output_bytes + scale_bytes + + time_s = time_ms / 1000.0 + bandwidth_gbs = (total_bytes / 1e9) / time_s + + total_flops = calculate_flops(num_elements, num_groups) + gflops = (total_flops / 1e9) / time_s + + return { + "num_tokens": num_tokens, + "num_elements": num_elements, + "num_groups": num_groups, + "total_bytes": total_bytes, + "bandwidth_gbs": bandwidth_gbs, + "total_flops": total_flops, + "gflops": gflops, + } + + +batch_size_range = [1, 2, 4, 8, 16, 32, 64] if not IS_CI else [1, 4, 16] +seq_len_range = [64, 128, 256, 512, 1024, 2048] if not IS_CI else [64, 256] +# Only group_size=32 is supported for MXFP4 (per OCP MX spec block size) +group_size_range = [32] +src_dtype_range = [torch.bfloat16] + +configs = list( + itertools.product( + batch_size_range, seq_len_range, group_size_range, src_dtype_range + ) +) + +all_results = [] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "group_size", "src_dtype"], + x_vals=configs, + line_arg="provider", + line_vals=["sglang"], + line_names=["SGL Kernel"], + styles=[("green", "-")], + ylabel="us", + plot_name="per-token-group-quant-mxfp4-performance", + args={}, + ) +) +def benchmark(batch_size, seq_len, group_size, src_dtype, provider): + device = torch.device("xpu") + hidden_dim = 7168 + + x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=src_dtype) + + quantiles = [0.5, 0.2, 0.8] + + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sglang_per_token_group_quant_mxfp4(x, group_size), + quantiles=quantiles, + ) + + bw_metrics = calculate_effective_bandwidth( + batch_size, seq_len, hidden_dim, group_size, src_dtype, ms + ) + + all_results.append( + { + "batch_size": batch_size, + "seq_len": seq_len, + "num_tokens": bw_metrics["num_tokens"], + "hidden_dim": hidden_dim, + "group_size": group_size, + "src_dtype": str(src_dtype), + "provider": provider, + "time_us": 1000 * ms, + "bandwidth_gbs": bw_metrics["bandwidth_gbs"], + "total_bytes_mb": bw_metrics["total_bytes"] / 1e6, + "total_flops_m": bw_metrics["total_flops"] / 1e6, + "gflops": bw_metrics["gflops"], + } + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +def print_summary(results: list): + """Print summary statistics from benchmark results.""" + print("\n" + "=" * 100) + print("MXFP4 Per-Token Group Quantization Benchmark Results") + print("=" * 100) + + df = pd.DataFrame(results) + df["bandwidth_gbs"] = df["bandwidth_gbs"].round(2) + df["total_bytes_mb"] = df["total_bytes_mb"].round(2) + df["time_us"] = df["time_us"].round(2) + df["total_flops_m"] = df["total_flops_m"].round(2) + df["gflops"] = df["gflops"].round(2) + + print("\nDetailed Results:") + print(df.to_markdown(index=False)) + + print("\n" + "=" * 100) + print("Summary Statistics by Provider") + print("=" * 100) + summary = df.groupby("provider").agg( + { + "bandwidth_gbs": ["mean", "min", "max"], + "time_us": ["mean", "min", "max"], + "gflops": ["mean", "min", "max"], + } + ) + print(summary.to_markdown()) + + +def main(): + if not is_xpu_available(): + print("Error: Intel XPU not available") + return + + try: + from sgl_kernel import sgl_per_token_group_quant_fp4 + + assert callable(sgl_per_token_group_quant_fp4) + except ImportError: + print("Error: sgl_per_token_group_quant_fp4 kernel not available") + return + + print("Running MXFP4 Per-Token Group Quantization Benchmark") + print(" Device: Intel XPU") + print(f" MXFP4 block size: {MXFP4_BLOCK_SIZE}") + + print("\n" + "=" * 80) + print("Correctness Verification") + print("=" * 80) + calculate_diff( + batch_size=2, + seq_len=64, + hidden_dim=128, + group_size=32, + src_dtype=torch.bfloat16, + ) + calculate_diff( + batch_size=1, seq_len=32, hidden_dim=128, group_size=32, src_dtype=torch.float32 + ) + + print("\n" + "=" * 80) + print("Performance Benchmark") + print("=" * 80) + benchmark.run(print_data=True) + + print_summary(all_results) + + +if __name__ == "__main__": + main() diff --git a/include/sgl_kernel_ops.h b/include/sgl_kernel_ops.h index 6dec3842..212518fb 100644 --- a/include/sgl_kernel_ops.h +++ b/include/sgl_kernel_ops.h @@ -157,6 +157,8 @@ void fused_qk_norm_rope( double high, double attention_factor, int64_t rotary_dim); +void sgl_per_token_group_quant_fp4( + at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, double eps); } // namespace at::native::xpu void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/python/sgl_kernel/__init__.py b/python/sgl_kernel/__init__.py index a243f4df..9f78df88 100755 --- a/python/sgl_kernel/__init__.py +++ b/python/sgl_kernel/__init__.py @@ -43,6 +43,7 @@ scaled_fp4_quant, sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_8bit, + sgl_per_token_group_quant_fp4, sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, diff --git a/python/sgl_kernel/gemm.py b/python/sgl_kernel/gemm.py index 30316544..284d9eac 100644 --- a/python/sgl_kernel/gemm.py +++ b/python/sgl_kernel/gemm.py @@ -124,6 +124,67 @@ def sgl_per_tensor_quant_fp8( ) +def sgl_per_token_group_quant_fp4( + x: torch.Tensor, + group_size: int = 32, + eps: float = 1e-10, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 (E2M1) format with per-token group scaling. + + MXFP4 follows the OpenCompute MX (Microscaling) format specification: + - Data type: E2M1 (4-bit float with 2-bit exponent, 1-bit mantissa) + - Block size: 32 elements per scale factor (default) + - Scale format: UE8M0 (unsigned 8-bit exponent-only, no mantissa) + + Args: + x: Input tensor with shape (..., K) where K is divisible by group_size. + Must be contiguous and dtype float16, bfloat16, or float32. + group_size: Number of elements per quantization group. Must be 32 for MXFP4. + eps: Small epsilon to avoid division by zero. Default is 1e-10. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - output_q: Packed FP4 tensor with shape (..., K // 2) and dtype uint8. + Two E2M1 values are packed into each byte. + - output_s: Scale tensor with shape (..., K // group_size) and dtype uint8. + Scales are stored in UE8M0 format (exponent + 127 bias). + """ + assert ( + x.shape[-1] % group_size == 0 + ), f"the last dimension of `x` ({x.shape[-1]}) must be divisible by `group_size` ({group_size})" + assert x.is_contiguous(), "`x` is not contiguous" + assert group_size == 32, f"group_size must be 32 for MXFP4, got {group_size}" + + # Ensure input is 2D for the kernel + original_shape = x.shape + if x.dim() == 1: + x = x.unsqueeze(0) + elif x.dim() > 2: + x = x.view(-1, x.shape[-1]) + + m, k = x.shape + num_groups_per_row = k // group_size + + # Output is packed FP4 (2 values per byte) + output_q = torch.empty((m, k // 2), device=x.device, dtype=torch.uint8) + + # Scales in row-major layout: (m, num_groups_per_row) + # Each row has the scales for that token's groups + output_s = torch.empty((m, num_groups_per_row), device=x.device, dtype=torch.uint8) + + if x.shape[0] > 0: + torch.ops.sgl_kernel.sgl_per_token_group_quant_fp4.default( + x, output_q, output_s, group_size, eps + ) + + # Reshape output to match input shape + output_shape_q = original_shape[:-1] + (original_shape[-1] // 2,) + output_shape_s = original_shape[:-1] + (original_shape[-1] // group_size,) + + return output_q.view(output_shape_q), output_s.view(output_shape_s) + + def sgl_per_token_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 115dcbd4..8f59214c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,6 +19,8 @@ foreach(file ${device_cpp}) endforeach() +include(FMHADecodeXe20.cmake) + message(STATUS "BMG files: ${device_cpp_xe20}") message(STATUS "Common files: ${device_cpp_common}") @@ -26,6 +28,8 @@ list(APPEND ATen_XPU_CPP_SRCS ${host_cpp}) list(APPEND ATen_XPU_SYCL_COMMON ${device_cpp_common}) list(APPEND ATen_XPU_SYCL_XE20 ${device_cpp_xe20}) +include(${CMAKE_CURRENT_SOURCE_DIR}/GroupGemmXe20.cmake) + set(ATen_XPU_CPP_SRCS ${ATen_XPU_CPP_SRCS} PARENT_SCOPE) set(ATen_XPU_SYCL_COMMON ${ATen_XPU_SYCL_COMMON} PARENT_SCOPE) set(ATen_XPU_SYCL_XE20 ${ATen_XPU_SYCL_XE20} PARENT_SCOPE) diff --git a/src/FMHADecodeXe20.cmake b/src/FMHADecodeXe20.cmake new file mode 100644 index 00000000..f8e0dd8f --- /dev/null +++ b/src/FMHADecodeXe20.cmake @@ -0,0 +1,29 @@ +# Generate FMHA decode kernel instantiation files. +# Each (QG_SZ, HEAD_DIM, PAGE_SIZE) combination is compiled as a separate +# library to parallelize and speed up compilation. + +set(FMHA_DECODE_QG_SIZES 1 2 4 8 16) +set(FMHA_DECODE_HEAD_DIMS 64 96 128 192 256) +set(FMHA_DECODE_PAGE_SIZES 64 128) + +set(FMHA_DECODE_TEMPLATE + "${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_decode_kernel.cpp.in") + +set(FMHA_SPLIT_DECODE_TEMPLATE + "${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in") + +foreach(QG_SZ ${FMHA_DECODE_QG_SIZES}) + foreach(HEAD_DIM ${FMHA_DECODE_HEAD_DIMS}) + foreach(PAGE_SIZE ${FMHA_DECODE_PAGE_SIZES}) + set(GENERATED_FILE + "${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_decode_kernel_${QG_SZ}_${HEAD_DIM}_${PAGE_SIZE}.cpp") + configure_file(${FMHA_DECODE_TEMPLATE} ${GENERATED_FILE} @ONLY) + list(APPEND device_cpp_common ${GENERATED_FILE}) + + set(GENERATED_SPLIT_FILE + "${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_split_decode_kernel_${QG_SZ}_${HEAD_DIM}_${PAGE_SIZE}.cpp") + configure_file(${FMHA_SPLIT_DECODE_TEMPLATE} ${GENERATED_SPLIT_FILE} @ONLY) + list(APPEND device_cpp_common ${GENERATED_SPLIT_FILE}) + endforeach() + endforeach() +endforeach() diff --git a/src/GroupGemmXe20.cmake b/src/GroupGemmXe20.cmake new file mode 100644 index 00000000..267d7233 --- /dev/null +++ b/src/GroupGemmXe20.cmake @@ -0,0 +1,32 @@ +set(GROUP_GEMM_XE20_TEMPLATE "${CMAKE_CURRENT_SOURCE_DIR}/sycl/GroupGemmXe20LauncherInstance.cpp.in") +set(GROUP_GEMM_XE20_GEN_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated/group_gemm_xe20") +set(GROUP_GEMM_XE20_INST_SRCS) +file(MAKE_DIRECTORY ${GROUP_GEMM_XE20_GEN_DIR}) + +function(add_group_gemm_xe20_inst TILE_M TILE_N TILE_K SG_SHAPE SG_STRIDE ACT_TYPE FUSE_ACT WITH_BIAS) + set(TILE "Shape<${TILE_M}, ${TILE_N}, ${TILE_K}>") + set(SGLAYOUT "Layout, Stride<${SG_STRIDE}>>") + set(GEN_SRC + "${GROUP_GEMM_XE20_GEN_DIR}/GroupGemmXe20_inst_${TILE_M}_${TILE_N}_${TILE_K}_a${ACT_TYPE}_f${FUSE_ACT}_b${WITH_BIAS}.cpp") + + configure_file(${GROUP_GEMM_XE20_TEMPLATE} ${GEN_SRC} @ONLY) + list(APPEND GROUP_GEMM_XE20_INST_SRCS ${GEN_SRC}) + set(GROUP_GEMM_XE20_INST_SRCS ${GROUP_GEMM_XE20_INST_SRCS} PARENT_SCOPE) +endfunction() + +foreach(act_type 0 1) + foreach(with_bias true false) + foreach(fuse_act true false) + add_group_gemm_xe20_inst("_8" "_64" "_32" "_1, _4, _1" "_4, _1, _0" ${act_type} ${fuse_act} ${with_bias}) + add_group_gemm_xe20_inst("_16" "_64" "_32" "_1, _4, _1" "_4, _1, _0" ${act_type} ${fuse_act} ${with_bias}) + add_group_gemm_xe20_inst("_32" "_64" "_32" "_1, _4, _1" "_4, _1, _0" ${act_type} ${fuse_act} ${with_bias}) + endforeach() + + add_group_gemm_xe20_inst("_128" "_64" "_32" "_4, _2, _1" "_2, _1, _0" ${act_type} true ${with_bias}) + add_group_gemm_xe20_inst("_128" "_128" "_32" "_4, _2, _1" "_2, _1, _0" ${act_type} false ${with_bias}) + add_group_gemm_xe20_inst("_256" "_64" "_32" "_8, _2, _1" "_2, _1, _0" ${act_type} true ${with_bias}) + add_group_gemm_xe20_inst("_256" "_256" "_32" "_8, _4, _1" "_4, _1, _0" ${act_type} false ${with_bias}) + endforeach() +endforeach() + +list(APPEND ATen_XPU_SYCL_XE20 ${GROUP_GEMM_XE20_INST_SRCS}) diff --git a/src/sycl/GroupGemmXe20.cpp b/src/sycl/GroupGemmXe20.cpp index f2e8ab2d..e0c7a364 100644 --- a/src/sycl/GroupGemmXe20.cpp +++ b/src/sycl/GroupGemmXe20.cpp @@ -15,10 +15,6 @@ using namespace cute; using ElementAccumulator = float; // <- data type of accumulator -template -class GemmXe20Name; - -// ActType: 0=silu, 1=gelu template void Xe20MoEGEMMLauncher( sycl::queue q, @@ -31,69 +27,62 @@ void Xe20MoEGEMMLauncher( const int gemm_k, const int* num_rows_per_expert_device, const int num_experts, - int* workspace) { - using Element = cutlass::bfloat16_t; - - auto make_dummy_tensor = [&](auto val, auto stride) { - return make_tensor(make_gmem_ptr(&val), make_layout(repeat>(1), stride)); - }; - auto make_dummy_bias = [&](auto val) { - return make_tensor(make_gmem_ptr(&val), make_layout(Shape{}, Stride<_1>{})); - }; - using StrideA = Stride; - using StrideB = Stride; - using StrideD = Stride; - using TensorA = decltype(make_dummy_tensor(Element{}, StrideA{})); - using TensorB = decltype(make_dummy_tensor(Element{}, StrideB{})); - using TensorD = decltype(make_dummy_tensor(Element{}, StrideD{})); - using TensorBias = decltype(make_dummy_bias(Element{})); - - using ElementA_non_CV = cutlass::platform::remove_cv_t; - using MMA = - typename TiledMMAHelper>, Layout, SGLayout>::TiledMMA; - auto mma = MMA{}; - - int sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); - auto MaxThreadsPerWorkgroup = size(mma); - - static constexpr int MaxThreadsPerSM = 512; - - TORCH_CHECK( - MaxThreadsPerSM % MaxThreadsPerWorkgroup == 0, "MaxThreadsPerSM must be divisible by MaxThreadsPerWorkgroup") - - sycl::range<3> local(1, 1, MaxThreadsPerWorkgroup); - sycl::range<3> global(1, sm_count * MaxThreadsPerSM / MaxThreadsPerWorkgroup, 1); - - namespace syclex = sycl::ext::oneapi::experimental; - namespace intelex = sycl::ext::intel::experimental; - - syclex::properties kernel_props{syclex::sub_group_size<16>, intelex::grf_size<256>}; - - using Kernel = - MoE::MoEGEMM; - typename Kernel::Params params{ - static_cast(activations), - static_cast(weights), - static_cast(bias), - static_cast(outputs), - num_rows_per_expert_device, - gemm_n, - gemm_k, - num_experts, - workspace, - mma, - }; - - auto event = q.submit([&](sycl::handler& h) { - sycl::local_accessor local_mem(sycl::range<1>(1), h); - h.parallel_for>( - sycl::nd_range<3>(global * local, local), kernel_props, [=](sycl::nd_item<3> item) { - int32_t* slm_mem = - static_cast(local_mem.template get_multi_ptr().get()); - Kernel{}(params, item, slm_mem); - }); - }); -} + int* workspace); + +using Tile_8_64_32 = Shape<_8, _64, _32>; +using Tile_16_64_32 = Shape<_16, _64, _32>; +using Tile_32_64_32 = Shape<_32, _64, _32>; +using Tile_128_64_32 = Shape<_128, _64, _32>; +using Tile_128_128_32 = Shape<_128, _128, _32>; +using Tile_256_64_32 = Shape<_256, _64, _32>; +using Tile_256_256_32 = Shape<_256, _256, _32>; + +using SG_1_4_1 = Layout, Stride<_4, _1, _0>>; +using SG_4_2_1 = Layout, Stride<_2, _1, _0>>; +using SG_8_2_1 = Layout, Stride<_2, _1, _0>>; +using SG_8_4_1 = Layout, Stride<_4, _1, _0>>; + +#define DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, ActType, FuseAct, WithBias) \ + extern template void Xe20MoEGEMMLauncher( \ + sycl::queue, \ + const void*, \ + const void*, \ + const void*, \ + const void*, \ + void*, \ + const int, \ + const int, \ + const int*, \ + const int, \ + int*); + +#define DECLARE_XE20_MOE_TILE_ALL_FUSES(Tile, SGLayout) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, true, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, true, false) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, false, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, false, false) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, true, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, true, false) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, false, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, false, false) + +#define DECLARE_XE20_MOE_TILE_FUSE(Tile, SGLayout, FuseAct) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, FuseAct, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, FuseAct, false) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, FuseAct, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, FuseAct, false) + +DECLARE_XE20_MOE_TILE_ALL_FUSES(Tile_8_64_32, SG_1_4_1) +DECLARE_XE20_MOE_TILE_ALL_FUSES(Tile_16_64_32, SG_1_4_1) +DECLARE_XE20_MOE_TILE_ALL_FUSES(Tile_32_64_32, SG_1_4_1) +DECLARE_XE20_MOE_TILE_FUSE(Tile_128_64_32, SG_4_2_1, true) +DECLARE_XE20_MOE_TILE_FUSE(Tile_128_128_32, SG_4_2_1, false) +DECLARE_XE20_MOE_TILE_FUSE(Tile_256_64_32, SG_8_2_1, true) +DECLARE_XE20_MOE_TILE_FUSE(Tile_256_256_32, SG_8_4_1, false) + +#undef DECLARE_XE20_MOE_TILE_FUSE +#undef DECLARE_XE20_MOE_TILE_ALL_FUSES +#undef DECLARE_XE20_MOE_EXTERN #define LAUNCH_MOE(...) \ Xe20MoEGEMMLauncher<__VA_ARGS__>( \ diff --git a/src/sycl/GroupGemmXe20LauncherInstance.cpp.in b/src/sycl/GroupGemmXe20LauncherInstance.cpp.in new file mode 100644 index 00000000..e9a0b0f1 --- /dev/null +++ b/src/sycl/GroupGemmXe20LauncherInstance.cpp.in @@ -0,0 +1,109 @@ +#define SYCL_INTEL_TARGET 20 + +#include +#include +#include + +#include + +#include "sycl/Utils.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "sycl/kernels/moe/xe20/moe_kernel.hpp" + +using namespace cute; + +template +class GemmXe20Name; + +template +void Xe20MoEGEMMLauncher( + sycl::queue q, + const void* activations, + const void* weights, + const void* scales, + const void* bias, + void* outputs, + const int gemm_n, + const int gemm_k, + const int* num_rows_per_expert_device, + const int num_experts, + int* workspace) { + (void)scales; + using Element = cutlass::bfloat16_t; + + auto make_dummy_tensor = [&](auto val, auto stride) { + return make_tensor(make_gmem_ptr(&val), make_layout(repeat>(1), stride)); + }; + auto make_dummy_bias = [&](auto val) { + return make_tensor(make_gmem_ptr(&val), make_layout(Shape{}, Stride<_1>{})); + }; + using StrideA = Stride; + using StrideB = Stride; + using StrideD = Stride; + using TensorA = decltype(make_dummy_tensor(Element{}, StrideA{})); + using TensorB = decltype(make_dummy_tensor(Element{}, StrideB{})); + using TensorD = decltype(make_dummy_tensor(Element{}, StrideD{})); + using TensorBias = decltype(make_dummy_bias(Element{})); + + using ElementA_non_CV = cutlass::platform::remove_cv_t; + using MMA = + typename TiledMMAHelper>, Layout, SGLayout>::TiledMMA; + auto mma = MMA{}; + + int sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + auto MaxThreadsPerWorkgroup = size(mma); + + static constexpr int MaxThreadsPerSM = 512; + + TORCH_CHECK( + MaxThreadsPerSM % MaxThreadsPerWorkgroup == 0, "MaxThreadsPerSM must be divisible by MaxThreadsPerWorkgroup"); + + sycl::range<3> local(1, 1, MaxThreadsPerWorkgroup); + sycl::range<3> global(1, sm_count * MaxThreadsPerSM / MaxThreadsPerWorkgroup, 1); + + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + syclex::properties kernel_props{syclex::sub_group_size<16>, intelex::grf_size<256>}; + + using Kernel = + MoE::MoEGEMM; + typename Kernel::Params params{ + static_cast(activations), + static_cast(weights), + static_cast(bias), + static_cast(outputs), + num_rows_per_expert_device, + gemm_n, + gemm_k, + num_experts, + workspace, + mma, + }; + + q.submit([&](sycl::handler& h) { + sycl::local_accessor local_mem(sycl::range<1>(1), h); + h.parallel_for>( + sycl::nd_range<3>(global * local, local), kernel_props, [=](sycl::nd_item<3> item) { + int32_t* slm_mem = + static_cast(local_mem.template get_multi_ptr().get()); + Kernel{}(params, item, slm_mem); + }); + }); +} + +template void Xe20MoEGEMMLauncher<@TILE@, @SGLAYOUT@, @ACT_TYPE@, @FUSE_ACT@, @WITH_BIAS@>( + sycl::queue, + const void*, + const void*, + const void*, + const void*, + void*, + const int, + const int, + const int*, + const int, + int*); + +#undef SYCL_INTEL_TARGET diff --git a/src/sycl/Norm.h b/src/sycl/Norm.h index b8e39a15..e432f5db 100644 --- a/src/sycl/Norm.h +++ b/src/sycl/Norm.h @@ -13,7 +13,7 @@ constexpr int NUM_REDUCE_STAGES = 16; #define DECLARE_SYCL_GLOBAL_FENCE sycl::access::fence_space::global_space #define DECLARE_SYCL_GLOBAL_AND_LOCAL_FENCE dpcpp_global_and_local_fence = sycl::access::fence_space::global_and_local -inline std::pair _check_layer_norm_inputs( +inline std::tuple _check_layer_norm_inputs( const torch::Tensor& input, IntArrayRef normalized_shape, std::optional& weight /* optional */, @@ -34,8 +34,9 @@ inline std::pair _check_layer_norm_inputs( unsigned int batch_size = input.size(0); unsigned int hidden_size = input.size(1); + unsigned int batch_stride = input.stride(0); - return std::make_pair(batch_size, hidden_size); + return std::make_tuple(batch_size, hidden_size, batch_stride); } template @@ -98,8 +99,14 @@ static inline void norm_group_reduce( class NormConfig { public: - NormConfig(int Batch, int Plane, int problem_dim, int element_size_bytes) - : Batch(Batch), Plane(Plane), problem_dim(problem_dim), element_size_bytes(element_size_bytes) { + NormConfig( + int Batch, int Plane, int problem_dim, int element_size_bytes, int input_batch_stride, int output_batch_stride) + : Batch(Batch), + Plane(Plane), + problem_dim(problem_dim), + element_size_bytes(element_size_bytes), + input_batch_stride(input_batch_stride), + output_batch_stride(output_batch_stride) { semaphores_ptr = nullptr; scratchpad_ptr = nullptr; sub_group_num_global = 1; @@ -126,6 +133,8 @@ class NormConfig { int workgroup_size; int sub_group_num; + int input_batch_stride; + int output_batch_stride; int* semaphores_ptr; void* scratchpad_ptr; int sub_group_num_global; diff --git a/src/sycl/RMSNorm.cpp b/src/sycl/RMSNorm.cpp index 4b0e06ee..32b5ce5f 100644 --- a/src/sycl/RMSNorm.cpp +++ b/src/sycl/RMSNorm.cpp @@ -36,7 +36,7 @@ class RMSNormForward : public NormForward { auto group_id = item_id.get_group(0); auto group_id_foreach = item_id.get_group(1); auto local_id = item_id.get_local_id(2); - index_t group_offset = group_id * cfg.Plane; + index_t group_offset = group_id * cfg.input_batch_stride; for (index_t j = local_id * vec_size; j < cfg.WGPlane; j += cfg.workgroup_size * vec_size) { index_t plane_offset = group_id_foreach * cfg.WGPlane + j; @@ -63,7 +63,8 @@ class RMSNormForward : public NormForward { auto group_id_foreach = item_id.get_group(1); auto local_id = item_id.get_local_id(2); - index_t group_offset = group_id * cfg.Plane; + index_t x_group_offset = group_id * cfg.input_batch_stride; + index_t y_group_offset = group_id * cfg.output_batch_stride; if (cfg.workgroup_num_foreach == 1) { if (local_id == 0) { reduce_project(item_id, sum_value, sum_tmp, cfg); @@ -75,14 +76,14 @@ class RMSNormForward : public NormForward { for (index_t j = local_id * vec_size; j < cfg.WGPlane; j += cfg.workgroup_size * vec_size) { index_t plane_offset = group_id_foreach * cfg.WGPlane + j; if (plane_offset < cfg.Plane) { - vec_t X_val = *(reinterpret_cast(NF::X_data + group_offset + plane_offset)); + vec_t X_val = *(reinterpret_cast(NF::X_data + x_group_offset + plane_offset)); vec_t Y_val; weight_vec_t gamma_val = *(reinterpret_cast(NF::gamma_data + plane_offset)); for (int v = 0; v < vec_size; ++v) { Y_val[v] = static_cast(gamma_val[v] * var_val * X_val[v]); } - *(reinterpret_cast(NF::Y_data + group_offset + plane_offset)) = Y_val; + *(reinterpret_cast(NF::Y_data + y_group_offset + plane_offset)) = Y_val; } } } @@ -113,7 +114,7 @@ class AddRMSNormForward : public RMSNormForward { auto group_id = item_id.get_group(0); auto group_id_foreach = item_id.get_group(1); auto local_id = item_id.get_local_id(2); - index_t group_offset = group_id * cfg.Plane; + index_t group_offset = group_id * cfg.input_batch_stride; for (index_t j = local_id * vec_size; j < cfg.WGPlane; j += cfg.workgroup_size * vec_size) { index_t plane_offset = group_id_foreach * cfg.WGPlane + j; @@ -148,7 +149,8 @@ class GemmaRMSNormForward : public RMSNormForward { auto group_id_foreach = item_id.get_group(1); auto local_id = item_id.get_local_id(2); - index_t group_offset = group_id * cfg.Plane; + index_t x_group_offset = group_id * cfg.input_batch_stride; + index_t y_group_offset = group_id * cfg.output_batch_stride; if (cfg.workgroup_num_foreach == 1) { if (local_id == 0) { RNF::reduce_project(item_id, sum_value, sum_tmp, cfg); @@ -160,14 +162,14 @@ class GemmaRMSNormForward : public RMSNormForward { for (index_t j = local_id * vec_size; j < cfg.WGPlane; j += cfg.workgroup_size * vec_size) { index_t plane_offset = group_id_foreach * cfg.WGPlane + j; if (plane_offset < cfg.Plane) { - vec_t X_val = *(reinterpret_cast(NF::X_data + group_offset + plane_offset)); + vec_t X_val = *(reinterpret_cast(NF::X_data + x_group_offset + plane_offset)); vec_t Y_val; weight_vec_t gamma_val = *(reinterpret_cast(NF::gamma_data + plane_offset)); for (int v = 0; v < vec_size; ++v) { Y_val[v] = static_cast((accscalar_t(1.0) + gamma_val[v]) * var_val * X_val[v]); } - *(reinterpret_cast(NF::Y_data + group_offset + plane_offset)) = Y_val; + *(reinterpret_cast(NF::Y_data + y_group_offset + plane_offset)) = Y_val; } } } @@ -310,13 +312,21 @@ void launch_vectorized_fused_norm_kernel(Norm& norm, template void RMSNormKernelImplInternal( - const Tensor& X, const Tensor& gemma, int64_t M, int64_t N, acc_type eps, Tensor& Y, Tensor& rstd) { + const Tensor& X, + const Tensor& gemma, + int64_t M, + int64_t N, + acc_type eps, + Tensor& Y, + Tensor& rstd, + int64_t input_batch_stride, + int64_t output_batch_stride) { scalar_t* X_data = X.data_ptr(); scalar_t* Y_data = Y.data_ptr(); mean_t* var_data = rstd.data_ptr(); weight_t* gemma_data = gemma.defined() ? gemma.data_ptr() : nullptr; - auto config = NormConfig(M, N, 1, sizeof(scalar_t)); + auto config = NormConfig(M, N, 1, sizeof(scalar_t), input_batch_stride, output_batch_stride); RMSNormForward rms_norm_forward(X_data, Y_data, var_data, gemma_data, eps, M, N); config.workgroup_num_foreach = 1; config.WGPlane = config.Plane; @@ -338,7 +348,7 @@ void FusedAddRMSNormKernelImplInternal( weight_t* gemma_data = gemma.defined() ? gemma.data_ptr() : nullptr; scalar_t* residual_data = residual.data_ptr(); - auto config = NormConfig(M, N, 1, sizeof(scalar_t)); + auto config = NormConfig(M, N, 1, sizeof(scalar_t), N, N); AddRMSNormForward add_rms_norm_forward( X_data, X_data, var_data, gemma_data, eps, residual_data, M, N); config.workgroup_num_foreach = 1; @@ -349,13 +359,21 @@ void FusedAddRMSNormKernelImplInternal( template void GemmaRMSNormKernelImplInternal( - const Tensor& X, const Tensor& gemma, int64_t M, int64_t N, acc_type eps, Tensor& Y, Tensor& rstd) { + const Tensor& X, + const Tensor& gemma, + int64_t M, + int64_t N, + acc_type eps, + Tensor& Y, + Tensor& rstd, + int64_t input_batch_stride, + int64_t output_batch_stride) { scalar_t* X_data = X.data_ptr(); scalar_t* Y_data = Y.data_ptr(); mean_t* var_data = rstd.data_ptr(); weight_t* gemma_data = gemma.defined() ? gemma.data_ptr() : nullptr; - auto config = NormConfig(M, N, 1, sizeof(scalar_t)); + auto config = NormConfig(M, N, 1, sizeof(scalar_t), input_batch_stride, output_batch_stride); GemmaRMSNormForward gemma_rms_norm_forward(X_data, Y_data, var_data, gemma_data, eps, M, N); config.workgroup_num_foreach = 1; config.WGPlane = config.Plane; @@ -377,7 +395,7 @@ void GemmaFusedAddRMSNormKernelImplInternal( weight_t* gemma_data = gemma.defined() ? gemma.data_ptr() : nullptr; scalar_t* residual_data = residual.data_ptr(); - auto config = NormConfig(M, N, 1, sizeof(scalar_t)); + auto config = NormConfig(M, N, 1, sizeof(scalar_t), N, N); GemmaAddRMSNormForward gemma_add_rms_norm_forward( X_data, X_data, var_data, gemma_data, eps, residual_data, M, N); config.workgroup_num_foreach = 1; @@ -390,28 +408,40 @@ void GemmaFusedAddRMSNormKernelImplInternal( void rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight, double eps) { std::optional opt_weight = weight; std::optional opt_bias; - auto M_N = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); - auto M = M_N.first; - auto N = M_N.second; + auto M_N_S = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); + auto M = std::get<0>(M_N_S); + auto N = std::get<1>(M_N_S); + auto input_batch_stride = std::get<2>(M_N_S); Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input; Tensor output_ = (output.dim() == 1) ? output.reshape({M, N}) : output; Tensor weight_ = (weight.dim() == 1) ? weight.reshape({N}) : weight; Tensor rstd = at::empty({M}, input_.options().dtype(kFloat)); + int64_t output_batch_stride = (output.dim() >= 2) ? output.stride(0) : N; SYCL_DISPATCH_FLOATING_TYPES( at::ScalarType::Half, at::ScalarType::BFloat16, input_.scalar_type(), "RMSNormKernelImpl", [&]() { RMSNormKernelImplInternal( - input_, weight_, M, N, static_cast>(eps), output_, rstd); + input_, + weight_, + M, + N, + static_cast>(eps), + output_, + rstd, + input_batch_stride, + output_batch_stride); }); } void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { + TORCH_CHECK(input.is_contiguous(), "fused_add_rmsnorm: input must be contiguous"); + TORCH_CHECK(residual.is_contiguous(), "fused_add_rmsnorm: residual must be contiguous"); std::optional opt_weight = weight; std::optional opt_bias; - auto M_N = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); - auto M = M_N.first; - auto N = M_N.second; + auto M_N_S = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); + auto M = std::get<0>(M_N_S); + auto N = std::get<1>(M_N_S); Tensor rstd = at::empty({M}, input.options().dtype(kFloat)); @@ -425,28 +455,40 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso void gemma_rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight, double eps) { std::optional opt_weight = weight; std::optional opt_bias; - auto M_N = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); - auto M = M_N.first; - auto N = M_N.second; + auto M_N_S = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); + auto M = std::get<0>(M_N_S); + auto N = std::get<1>(M_N_S); + auto input_batch_stride = std::get<2>(M_N_S); Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input; Tensor output_ = (output.dim() == 1) ? output.reshape({M, N}) : output; Tensor weight_ = (weight.dim() == 1) ? weight.reshape({N}) : weight; Tensor rstd = at::empty({M}, input_.options().dtype(kFloat)); + int64_t output_batch_stride = (output.dim() >= 2) ? output.stride(0) : N; SYCL_DISPATCH_FLOATING_TYPES( at::ScalarType::Half, at::ScalarType::BFloat16, input_.scalar_type(), "GemmaRMSNormKernelImpl", [&]() { GemmaRMSNormKernelImplInternal( - input_, weight_, M, N, static_cast>(eps), output_, rstd); + input_, + weight_, + M, + N, + static_cast>(eps), + output_, + rstd, + input_batch_stride, + output_batch_stride); }); } void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double eps) { + TORCH_CHECK(input.is_contiguous(), "gemma_fused_add_rmsnorm: input must be contiguous"); + TORCH_CHECK(residual.is_contiguous(), "gemma_fused_add_rmsnorm: residual must be contiguous"); std::optional opt_weight = weight; std::optional opt_bias; - auto M_N = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); - auto M = M_N.first; - auto N = M_N.second; + auto M_N_S = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); + auto M = std::get<0>(M_N_S); + auto N = std::get<1>(M_N_S); Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input; Tensor residual_ = (residual.dim() == 1) ? residual.reshape({M, N}) : residual; diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index 1f467fc7..0fc8639c 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -35,10 +35,392 @@ #include #include -#include - #include "kernels/chunk_prefill/chunk_prefill_runner.hpp" -#include "kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" +#include "kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp" + +namespace decode { + +// Dispatch macros following the GroupGemmXe20.cpp pattern. +// Directly call struct operator() - no function pointers. + +#define DISPATCH_DECODE_KERNEL(QG, HD, PS) \ + do { \ + if (params.use_split_kv_decode) { \ + FmhaSplitDecodeRunner{}(params); \ + } else { \ + FmhaDecodeRunner{}(params); \ + } \ + } while (0) + +#define DISPATCH_DECODE_PAGE_SIZE(QG, HD) \ + do { \ + switch (params.page_size) { \ + case 64: \ + DISPATCH_DECODE_KERNEL(QG, HD, 64); \ + break; \ + case 128: \ + DISPATCH_DECODE_KERNEL(QG, HD, 128); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported page_size for decode attention: ", params.page_size); \ + } \ + } while (0) + +#define DISPATCH_DECODE_HEAD_DIM(QG) \ + do { \ + switch (params.d) { \ + case 64: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 64); \ + break; \ + case 96: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 96); \ + break; \ + case 128: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 128); \ + break; \ + case 192: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 192); \ + break; \ + case 256: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 256); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size for decode attention: ", params.d); \ + } \ + } while (0) + +#define DISPATCH_DECODE(qg_sz) \ + do { \ + switch (qg_sz) { \ + case 1: \ + DISPATCH_DECODE_HEAD_DIM(1); \ + break; \ + case 2: \ + DISPATCH_DECODE_HEAD_DIM(2); \ + break; \ + case 4: \ + DISPATCH_DECODE_HEAD_DIM(4); \ + break; \ + case 8: \ + DISPATCH_DECODE_HEAD_DIM(8); \ + break; \ + case 16: \ + DISPATCH_DECODE_HEAD_DIM(16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported q_group_size for decode attention: ", params.q_group_size); \ + } \ + } while (0) + +std::vector mha_fwd( + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, + // h_k, d) if there is page_table. + const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, + // page_size, h_k, dv) if there is page_table. + std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + int max_seqlen_q, + int max_seqlen_k, + std::optional& page_table, // (b_k, max_num_pages_per_seq) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) + const float softmax_scale_, + std::optional& sinks_, + bool is_causal, + int window_size_left, + int window_size_right, + float const softcap, + bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional& scheduler_metadata_, // (b + 1) + // int num_kv_splits, + std::optional pack_gqa_, + int const sm_margin) { + auto q_type = q.scalar_type(); + TORCH_CHECK( + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "mha_fwd only supports Half and BFloat16, got", + q_type); + + TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + CHECK_INPUT(q); + CHECK_INPUT(k); + CHECK_INPUT(v); + TORCH_CHECK( + q.stride(-1) == 1 && k.stride(-1) == 1 && v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + TORCH_CHECK(page_table.value().dtype() == torch::kInt32, "page_table must have dtype torch.int32"); + TORCH_CHECK(page_table.value().stride(-1) == 1, "page_table must have contiguous last dimension"); + + TORCH_CHECK(q.dim() == 3, "query must be in ragged format"); + CHECK_INPUT(cu_seqlens_q); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); + + CHECK_INPUT(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); + + auto const sizes = q.sizes(); + const int batch_size = cu_seqlens_q.size(0) - 1; + int seqlen_q = max_seqlen_q; + int total_q = q.size(0); + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const max_num_pages_per_seq = page_table.value().size(1); + int const num_pages = k.size(0); + int const page_size = k.size(1); + int const seqlen_k = page_table.has_value() ? max_num_pages_per_seq * page_size : max_seqlen_k; + int const total_k = num_pages * page_size; + int const num_heads_k = k.size(-2); + + int const batch_size_k = page_table.value().size(0); + float softmax_scale = softmax_scale_; + + if (!kv_batch_idx_.has_value()) { + TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); + } + + // Currently only support head dims <= 256 + static constexpr int max_headdim = 256; + TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most ", max_headdim); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + // TODO: check this + + if (window_size_left >= seqlen_k - 1) { + window_size_left = -1; + } + window_size_right = min(window_size_right, seqlen_q); + // causal=true is the same as causal=false in this case + if (is_causal) { + window_size_right = 0; + } + + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table.value(), batch_size_k, max_num_pages_per_seq); + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_INPUT(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + static constexpr int alignment = 8; + TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + + auto opts = q.options(); + at::Tensor out; + at::Tensor temp_out; // [batch, num_kv_splits, num_head_q, seq_q, head_size] + at::Tensor exp_sums; // [batch, num_head_q, seq_q, num_kv_splits] + at::Tensor max_logits; // [batch, num_head_q, seq_q, num_kv_splits] + int num_kv_splits = 1; + out = torch::empty({total_q, num_heads, head_size_v}, opts); + Arguments params; + params.use_split_kv_decode = true; + if (params.use_split_kv_decode) { + auto get_num_splits = [](int batch_size, int num_heads_kv, int max_seqlen_k, int block_size) { + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + auto device = queue.get_device(); + int num_xe_cores = device.get_info() * + device.get_info(); + int parallel_ = num_xe_cores; + int parallel_2 = num_xe_cores * 2; + int cur_parallel_d = batch_size * num_heads_kv; + int num_splits = (parallel_ + cur_parallel_d - 1) / cur_parallel_d; + if (cur_parallel_d * num_splits > parallel_ && num_splits > 1) { + num_splits = std::ceil(parallel_2 / static_cast(cur_parallel_d)) - 1; + } + + int max_splits = (max_seqlen_k + block_size - 1) / block_size; + max_splits = std::min(max_splits, parallel_); + return std::min(num_splits, max_splits); + }; + num_kv_splits = get_num_splits(batch_size, num_heads_k, seqlen_k, page_size); + temp_out = num_kv_splits == 1 + ? out + : torch::empty({total_q, num_kv_splits * num_heads, head_size_v}, q.options().device(q.device())); + + exp_sums = torch::empty({total_q, num_heads, num_kv_splits}, q.options().dtype(at::kFloat).device(q.device())); + max_logits = torch::empty({total_q, num_heads, num_kv_splits}, q.options().dtype(at::kFloat).device(q.device())); + params.temp_out_ptr = temp_out.data_ptr(); + params.exp_sums_ptr = exp_sums.data_ptr(); + params.max_logits_ptr = max_logits.data_ptr(); + } + int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + c10::DeviceGuard device_guard(q.device()); + + at::Tensor softmax_lse; + softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + + // align with FA3 + + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_dim_stride = v.stride(-1); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + params.cu_seqlens_q = cu_seqlens_q.data_ptr(); + params.cu_seqlens_k = cu_seqlens_k.data_ptr(); + params.num_kv_splits = num_kv_splits; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse.data_ptr(); + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.q_group_size = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.softmax_scale = softmax_scale; + bool use_sink = sinks_.has_value(); + params.softmax_sink_ptr = use_sink ? sinks_.value().data_ptr() : nullptr; + params.use_sink = use_sink; + + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + params.use_causal_mask = params.is_causal; + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0) { + window_size_left = seqlen_k - 1; + } + if (window_size_right < 0) { + window_size_right = seqlen_q - 1; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.total_q = total_q; + params.total_k = total_k; + params.b_k = batch_size_k; + params.dv = head_size_v; + params.page_table = page_table.value().data_ptr(); + params.page_table_batch_stride = page_table.value().stride(0); + params.max_num_pages_per_seq = max_num_pages_per_seq; + params.page_size = page_size; + params.num_pages = num_pages; + + if (q_v_.has_value()) { + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK( + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(false, "q_v is not supported yet"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + } + + if (rotary_cos_.has_value()) { + auto rotary_cos = rotary_cos_.value(); + CHECK_INPUT(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_INPUT(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + at::Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_INPUT(seqlens_rotary); + TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = seqlens_rotary.data_ptr(); + } + } else { + params.rotary_dim = 0; + } + + if (kv_batch_idx_.has_value()) { + auto kv_batch_idx = kv_batch_idx_.value(); + CHECK_INPUT(kv_batch_idx); + TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); + params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); + } + + params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device()); + + at::Tensor out_accum, softmax_lse_accum; + + int qg_sz = nextPowerOf2(params.q_group_size); + TORCH_CHECK(qg_sz >= 1 && qg_sz <= 16, "Unsupported q_group_size for decode attention: ", params.q_group_size); + TORCH_CHECK( + params.d == 64 || params.d == 96 || params.d == 128 || params.d == 192 || params.d == 256, + "Unsupported head size for decode attention: ", + params.d); + TORCH_CHECK( + params.page_size == 64 || params.page_size == 128, + "Unsupported page size for decode attention: ", + params.page_size); + + DISPATCH_DECODE(qg_sz); + + return {out, softmax_lse, out_accum, softmax_lse_accum}; +} + +#undef DISPATCH_DECODE_KERNEL +#undef DISPATCH_DECODE_PAGE_SIZE +#undef DISPATCH_DECODE_HEAD_DIM +#undef DISPATCH_DECODE + +} // namespace decode std::vector mha_fwd( const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp new file mode 100644 index 00000000..19d7f3cc --- /dev/null +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "xe_fmha_fwd_decode_runner.hpp" + +namespace decode { + +// Struct functor declarations for FMHA decode kernel launchers. +// Each template specialization is explicitly instantiated in a separate +// generated .cpp file (from xe_fmha_fwd_decode_kernel.cpp.in / +// xe_fmha_fwd_split_decode_kernel.cpp.in). +// +// QG_SZ in {1, 2, 4, 8, 16} +// HEAD_DIM in {64, 96, 128, 192, 256} +// PAGE_SIZE in {64, 128} + +// Explicit instantiation declarations — tell the compiler these are compiled +// in separate translation units (generated from the .cpp.in templates). + +#define EXTERN_FMHA_DECODE_RUNNER(QG, HD, PS) \ + extern template struct FmhaDecodeRunner; + +#define EXTERN_FMHA_SPLIT_DECODE_RUNNER(QG, HD, PS) \ + extern template struct FmhaSplitDecodeRunner; + +#define EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(QG, HD) \ + EXTERN_FMHA_DECODE_RUNNER(QG, HD, 64) \ + EXTERN_FMHA_DECODE_RUNNER(QG, HD, 128) + +#define EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(QG, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER(QG, HD, 64) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER(QG, HD, 128) + +#define EXTERN_FMHA_DECODE_RUNNER_ALL_QG(HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(1, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(2, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(4, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(8, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(16, HD) + +#define EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(1, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(2, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(4, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(8, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(16, HD) + +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(64) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(96) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(128) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(192) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(256) + +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(64) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(96) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(128) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(192) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(256) + +#undef EXTERN_FMHA_DECODE_RUNNER +#undef EXTERN_FMHA_SPLIT_DECODE_RUNNER +#undef EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES +#undef EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES +#undef EXTERN_FMHA_DECODE_RUNNER_ALL_QG +#undef EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG + +} // namespace decode diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp index d4d53931..f889e26a 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp @@ -182,6 +182,9 @@ struct Arguments { bool is_causal; bool is_local; + bool use_sink = false; + bool use_causal_mask = false; + bool is_rotary_interleaved; torch::TensorOptions tensor_opts; @@ -717,7 +720,7 @@ template < typename GmemTiledCopyK = void, typename GmemTiledCopyV = void, typename GmemTiledCopyO = void> -struct SplitDeodeConfig { +struct SplitDecodeConfig { static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))(); using MMAOperation = cute::conditional_t, XE_DPAS_TT, MMAOperation_>; @@ -790,379 +793,20 @@ struct SplitDeodeConfig { } }; -std::vector mha_fwd( - const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, - // h_k, d) if there is page_table. - const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, - // page_size, h_k, dv) if there is page_table. - std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - const at::Tensor& cu_seqlens_q, // b+1 - const at::Tensor& cu_seqlens_k, // b+1 - int max_seqlen_q, - int max_seqlen_k, - std::optional& page_table, // (b_k, max_num_pages_per_seq) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) - const float softmax_scale_, - std::optional& sinks_, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional& scheduler_metadata_, // (b + 1) - // int num_kv_splits, - std::optional pack_gqa_, - int const sm_margin) { - auto q_type = q.scalar_type(); - TORCH_CHECK( - q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "mha_fwd only supports Half and BFloat16, got", - q_type); - - TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); - TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); - CHECK_INPUT(q); - CHECK_INPUT(k); - CHECK_INPUT(v); - TORCH_CHECK( - q.stride(-1) == 1 && k.stride(-1) == 1 && v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - TORCH_CHECK(page_table.value().dtype() == torch::kInt32, "page_table must have dtype torch.int32"); - TORCH_CHECK(page_table.value().stride(-1) == 1, "page_table must have contiguous last dimension"); - - TORCH_CHECK(q.dim() == 3, "query must be in ragged format"); - CHECK_INPUT(cu_seqlens_q); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); - - CHECK_INPUT(cu_seqlens_k); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); - - auto const sizes = q.sizes(); - const int batch_size = cu_seqlens_q.size(0) - 1; - int seqlen_q = max_seqlen_q; - int total_q = q.size(0); - int num_heads = q.size(-2); - int const head_size = q.size(-1); - int const head_size_v = v.size(-1); - int const max_num_pages_per_seq = page_table.value().size(1); - int const num_pages = k.size(0); - int const page_size = k.size(1); - const bool has_page_table = page_table.has_value(); - int const seqlen_k = has_page_table ? max_num_pages_per_seq * page_size : max_seqlen_k; - int const total_k = num_pages * page_size; - int const num_heads_k = k.size(-2); - - int const batch_size_k = page_table.value().size(0); - float softmax_scale = softmax_scale_; - - if (!kv_batch_idx_.has_value()) { - TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); - } - - // Currently only support head dims <= 256 - static constexpr int max_headdim = 256; - TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most ", max_headdim); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); +// Struct functors for decode kernel dispatch. +// operator() is declared here; each specialization's body is defined in a +// generated .cpp file (from xe_fmha_fwd_decode_kernel.cpp.in / +// xe_fmha_fwd_split_decode_kernel.cpp.in) so the compiler only emits code +// for the combinations that are actually needed. - // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM - // TODO: check this - - if (window_size_left >= seqlen_k - 1) { - window_size_left = -1; - } - window_size_right = min(window_size_right, seqlen_q); - // causal=true is the same as causal=false in this case - if (is_causal) { - window_size_right = 0; - } - - CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); - CHECK_SHAPE(page_table.value(), batch_size_k, max_num_pages_per_seq); - - if (leftpad_k_.has_value()) { - auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); - CHECK_INPUT(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - } - - static constexpr int alignment = 8; - TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); - TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); - - auto opts = q.options(); - auto device_opts = opts.device(q.device()); - at::Tensor out; - at::Tensor temp_out; // [batch, num_kv_splits, num_head_q, seq_q, head_size] - at::Tensor exp_sums; // [batch, num_head_q, seq_q, num_kv_splits] - at::Tensor max_logits; // [batch, num_head_q, seq_q, num_kv_splits] - int num_kv_splits = 1; - out = torch::empty({total_q, num_heads, head_size_v}, opts); - Arguments params; - params.use_split_kv_decode = true; - if (params.use_split_kv_decode) { - // lambda to calculate num_splits based on batch_size, num_heads_kv, max_seqlen_k and block_size - auto get_num_splits = [](int batch_size, int num_heads_kv, int max_seqlen_k, int block_size) { - auto stream = at::xpu::getCurrentXPUStream(); - auto queue = stream.queue(); - auto device = queue.get_device(); - int num_xe_cores = device.get_info() * - device.get_info(); - int parallel_ = num_xe_cores; - int parallel_2 = num_xe_cores * 2; - int cur_parallel_d = batch_size * num_heads_kv; - int num_splits = (parallel_ + cur_parallel_d - 1) / cur_parallel_d; - if (cur_parallel_d * num_splits > parallel_ && num_splits > 1) { - num_splits = std::ceil(parallel_2 / static_cast(cur_parallel_d)) - 1; - } - - int max_splits = (max_seqlen_k + block_size - 1) / block_size; - max_splits = std::min(max_splits, parallel_); - return std::min(num_splits, max_splits); - }; - // lambda end - // For split-kv, we split the kv sequence into num_kv_splits splits and run the kernel for each split, then do a - // reduction to get the final output. - num_kv_splits = get_num_splits(batch_size, num_heads_k, seqlen_k, page_size); - temp_out = num_kv_splits == 1 - ? out - : torch::empty({total_q, num_kv_splits * num_heads, head_size_v}, q.options().device(q.device())); - // auto float_options = opts.dtype(at::kFloat).device(q.device()); - - exp_sums = torch::empty({total_q, num_heads, num_kv_splits}, q.options().dtype(at::kFloat).device(q.device())); - max_logits = torch::empty({total_q, num_heads, num_kv_splits}, q.options().dtype(at::kFloat).device(q.device())); - params.temp_out_ptr = temp_out.data_ptr(); - params.exp_sums_ptr = exp_sums.data_ptr(); - params.max_logits_ptr = max_logits.data_ptr(); - } - int const head_size_rounded = round_up_headdim(head_size); - int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - c10::DeviceGuard device_guard(q.device()); - - at::Tensor softmax_lse; - softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); - - // align with FA3 - - params.is_bf16 = q.dtype() == torch::kBFloat16; - - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = k.data_ptr(); - params.v_ptr = v.data_ptr(); - // All stride are in elements, not bytes. - params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.q_head_stride = q.stride(-2); - params.k_head_stride = k.stride(-2); - params.v_head_stride = v.stride(-2); - params.v_dim_stride = v.stride(-1); - params.o_ptr = out.data_ptr(); - params.o_row_stride = out.stride(-3); - params.o_head_stride = out.stride(-2); - - params.cu_seqlens_q = cu_seqlens_q.data_ptr(); - params.cu_seqlens_k = cu_seqlens_k.data_ptr(); - params.num_kv_splits = num_kv_splits; - // Softmax sum - params.softmax_lse_ptr = softmax_lse.data_ptr(); - - // Set the dimensions. - params.b = batch_size; - params.h = num_heads; - params.h_k = num_heads_k; - params.q_group_size = num_heads / num_heads_k; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.d = head_size; - params.d_rounded = head_size_rounded; - - // Set the different scale values. - params.softmax_scale = softmax_scale; - bool use_sink = sinks_.has_value(); - params.softmax_sink_ptr = use_sink ? sinks_.value().data_ptr() : nullptr; - - params.softcap = softcap; - - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f; - - // Causal is the special case where window_size_right == 0 and window_size_left < 0. - // LocalMask is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0; - params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; - - // TODO: check this - if (window_size_left < 0) { - window_size_left = seqlen_k - 1; - } - if (window_size_right < 0) { - window_size_right = seqlen_q - 1; - } - params.window_size_left = window_size_left; - params.window_size_right = window_size_right; - params.total_q = total_q; - params.total_k = total_k; - params.b_k = batch_size_k; - params.dv = head_size_v; - params.page_table = page_table.value().data_ptr(); - params.page_table_batch_stride = page_table.value().stride(0); - params.max_num_pages_per_seq = max_num_pages_per_seq; - params.page_size = page_size; - params.num_pages = num_pages; - - if (q_v_.has_value()) { - TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); - TORCH_CHECK( - q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "q_v is only supported for fp16 and bf16 data type"); - TORCH_CHECK(false, "q_v is not supported yet"); - at::Tensor q_v = q_v_.value(); - TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); - TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); - CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); - params.qv_ptr = q_v.data_ptr(); - // All stride are in elements, not bytes. - params.qv_row_stride = q_v.stride(-3); - params.qv_head_stride = q_v.stride(-2); - } +template +struct FmhaDecodeRunner { + void operator()(const Arguments& params) const; +}; - if (rotary_cos_.has_value()) { - auto rotary_cos = rotary_cos_.value(); - CHECK_INPUT(rotary_cos); - params.rotary_dim = rotary_cos.size(1) * 2; - TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); - TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); - const int seqlen_ro = rotary_cos.size(0); - TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); - CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - - TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); - auto rotary_sin = rotary_sin_.value(); - CHECK_INPUT(rotary_sin); - CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - params.rotary_cos_ptr = rotary_cos.data_ptr(); - params.rotary_sin_ptr = rotary_sin.data_ptr(); - params.is_rotary_interleaved = is_rotary_interleaved; - if (seqlens_rotary_.has_value()) { - at::Tensor seqlens_rotary = seqlens_rotary_.value(); - CHECK_INPUT(seqlens_rotary); - TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); - CHECK_SHAPE(seqlens_rotary, batch_size); - params.seqlens_rotary = seqlens_rotary.data_ptr(); - } - } else { - params.rotary_dim = 0; - } +template +struct FmhaSplitDecodeRunner { + void operator()(const Arguments& params) const; +}; - params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device()); - - at::Tensor out_accum, softmax_lse_accum; - auto outaccum_type = at::ScalarType::Float; - - constexpr bool Causal = false; // The decode kernel does not support causal mode. It must be set to false. - // using ShapeQK = Shape<_8, _64, _64>; - // using ShapePV = Shape<_8, _32, _64>; - // using ShapeOut = Shape<_8, _128>; - // using SubgroupLayoutQK = Layout>; - // // SplitDeodeConfig::run(params); - // SplitDeodeConfig::run(params); - - auto launch_kernel = [&](auto _QG_SZ, auto _HEAD_DIM, auto _PAGE_SIZE, auto _NUM_SG) { - using TileShapeQK = cute::Shape; - using TileShapePV = cute::Shape; - using TileShapeOutput = cute::Shape; - using SubgroupLayoutQK = cute::Layout>; - - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - if (params.use_split_kv_decode) { - SplitDeodeConfig::run( - params); - } else { - DecodeConfig::run( - params); - } - }); - }); - }; - - auto dispatch_page_size = [&](auto _QG_SZ, auto _HEAD_DIM) { - switch (params.page_size) { - // case 32: - // launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{}); - // break; - case 64: - launch_kernel(_QG_SZ, _HEAD_DIM, _64{}, _4{}); - break; - case 128: - launch_kernel(_QG_SZ, _HEAD_DIM, _128{}, _8{}); - break; - default: - TORCH_CHECK(false, "Unsupported page size for decode attention: ", params.page_size); - } - }; - - auto dispatch_q_group = [&](auto _HEAD_DIM) { - switch (nextPowerOf2(params.q_group_size)) { - case 1: - dispatch_page_size(_1{}, _HEAD_DIM); - break; - case 2: - dispatch_page_size(_2{}, _HEAD_DIM); - break; - case 4: - dispatch_page_size(_4{}, _HEAD_DIM); - break; - case 8: - dispatch_page_size(_8{}, _HEAD_DIM); - break; - case 16: - dispatch_page_size(_16{}, _HEAD_DIM); - break; - // case 32: - // dispatch_page_size(_32{}, _HEAD_DIM); - break; - default: - TORCH_CHECK(false, "Unsupported qgroup_size for decode attention: ", max_seqlen_q); - } - }; - - switch (params.d) { - case 64: - dispatch_q_group(_64{}); - break; - case 96: - dispatch_q_group(_96{}); - break; - case 128: - dispatch_q_group(_128{}); - break; - case 192: - dispatch_q_group(_192{}); - break; - case 256: - dispatch_q_group(_256{}); - break; - default: - TORCH_CHECK(false, "Unsupported head size for decode attention: ", params.d); - } - return {out, softmax_lse, out_accum, softmax_lse_accum}; -} } // namespace decode diff --git a/src/sycl/per_token_group_quant_fp4.cpp b/src/sycl/per_token_group_quant_fp4.cpp new file mode 100644 index 00000000..4c276f96 --- /dev/null +++ b/src/sycl/per_token_group_quant_fp4.cpp @@ -0,0 +1,334 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * SYCL kernel for per-token group quantization to MXFP4 (E2M1) format. + * + * MXFP4 follows the OpenCompute MX (Microscaling) format specification: + * - Data type: E2M1 (4-bit float with 2-bit exponent, 1-bit mantissa) + * - Block size: 32 elements per scale factor + * - Scale format: UE8M0 (unsigned 8-bit exponent-only, no mantissa) + * + * E2M1 representable values (magnitude): 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0 + * With sign bit, we have 16 total values. + * + * Bit layout of E2M1: + * Bit 3: Sign (0 = positive, 1 = negative) + * Bits 0-2: Magnitude index (0-7) + * + * Two FP4 values are packed into a single uint8_t: + * - Lower nibble (bits 0-3): First value + * - Upper nibble (bits 4-7): Second value + * + * Rounding: Per OCP MX spec (section 5.3.3), FP4 conversion uses + * roundTiesToEven — at midpoints between representable values, the + * value with even mantissa (mantissa bit = 0) is chosen. + */ + +#include +#include +#include + +#include +#include + +#include "SYCLHelpers.h" +#include "Utils.h" + +namespace at::native::xpu { + +constexpr float FLOAT4_E2M1_MAX = 6.0f; + +template +inline T QuantGroupReduceMaxFP4(T val, sycl::nd_item<1> item) { + auto sg = item.get_sub_group(); + + val = sycl::fmax(val, sycl::permute_group_by_xor(sg, val, 8)); + val = sycl::fmax(val, sycl::permute_group_by_xor(sg, val, 4)); + val = sycl::fmax(val, sycl::permute_group_by_xor(sg, val, 2)); + val = sycl::fmax(val, sycl::permute_group_by_xor(sg, val, 1)); + + return val; +} + +// E2M1 format (4-bit float): 1 sign bit, 2 exponent bits, 1 mantissa bit +// Encoding: exp=00 (subnormal), exp=01/10/11 (normal with bias=1) +// Result: bits[3]=sign, bits[2:1]=exponent, bits[0]=mantissa +// +// Representable values and their codes: +// 0.0 -> 0b000 (subnormal, m=0, even) +// 0.5 -> 0b001 (subnormal, m=1, odd) +// 1.0 -> 0b010 (e=01, m=0, even) +// 1.5 -> 0b011 (e=01, m=1, odd) +// 2.0 -> 0b100 (e=10, m=0, even) +// 3.0 -> 0b101 (e=10, m=1, odd) +// 4.0 -> 0b110 (e=11, m=0, even) +// 6.0 -> 0b111 (e=11, m=1, odd) +// +// RoundTiesToEven: At exact midpoints between two representable values, +// we round to the one whose mantissa bit is 0 (even). +// +// Midpoints and their rounding targets: +// 0.25 -> midpoint of (0.0, 0.5) -> round to 0.0 (m=0, even) +// 0.75 -> midpoint of (0.5, 1.0) -> round to 1.0 (m=0, even) +// 1.25 -> midpoint of (1.0, 1.5) -> round to 1.0 (m=0, even) +// 1.75 -> midpoint of (1.5, 2.0) -> round to 2.0 (m=0, even) +// 2.5 -> midpoint of (2.0, 3.0) -> round to 2.0 (m=0, even) +// 3.5 -> midpoint of (3.0, 4.0) -> round to 4.0 (m=0, even) +// 5.0 -> midpoint of (4.0, 6.0) -> round to 4.0 (m=0, even) +inline uint8_t quantize_to_e2m1(float val) { + uint8_t sign = (val < 0.0f) ? 1 : 0; + float abs_val = sycl::fabs(val); + + uint8_t code; + // RoundTiesToEven: at midpoints, round to the value with even mantissa (m=0). + // Midpoints use strict < for the upper bound so ties go to the even value. + // TODO(sspintel): Optimize this logic under a LUT to avoid branch divergence. + if (abs_val <= 0.25f) { + code = 0b000; // 0.0 (subnormal: exp=00, m=0) + } else if (abs_val < 0.75f) { + code = 0b001; // 0.5 (subnormal: exp=00, m=1) + } else if (abs_val <= 1.25f) { + code = 0b010; // 1.0 (exp=01, m=0) + } else if (abs_val < 1.75f) { + code = 0b011; // 1.5 (exp=01, m=1) + } else if (abs_val <= 2.5f) { + code = 0b100; // 2.0 (exp=10, m=0) + } else if (abs_val < 3.5f) { + code = 0b101; // 3.0 (exp=10, m=1) + } else if (abs_val <= 5.0f) { + code = 0b110; // 4.0 (exp=11, m=0) + } else { + code = 0b111; // 6.0 (exp=11, m=1) + } + + return (sign << 3) | code; +} + +// Use SYCL native vector type for efficient loading +template +using vec_t = sycl::vec; + +// Compile-time constants for group sizes +template +struct FP4GroupSizeTraits { + static constexpr int THREADS_PER_GROUP = 16; + static constexpr int SUB_GROUP_SIZE = 32; +}; + +template +struct PerTokenGroupQuantFP4Kernel : public __SYCL_KER_CONFIG_CONVENTION__ { + static constexpr uint32_t VEC_SIZE = 16 / sizeof(T); + static constexpr int32_t NUM_VEC_ELEMS = GROUP_SIZE / VEC_SIZE; + static constexpr int32_t THREADS_PER_GROUP = FP4GroupSizeTraits::THREADS_PER_GROUP; + static constexpr int32_t VECS_PER_THREAD = (NUM_VEC_ELEMS + THREADS_PER_GROUP - 1) / THREADS_PER_GROUP; + + PerTokenGroupQuantFP4Kernel( + const T* input, uint8_t* output_q, uint8_t* output_s, int num_groups, int groups_per_block, float eps) + : input(input), + output_q(output_q), + output_s(output_s), + num_groups(num_groups), + groups_per_block(groups_per_block), + eps(eps) {} + + void sycl_ker_config_convention(sycl::handler& cgh) {} + + [[sycl::reqd_sub_group_size(32)]] void operator()(sycl::nd_item<1> item) const { + const int64_t local_group_id = item.get_local_id(0) / THREADS_PER_GROUP; + const int lane_id = item.get_local_id(0) % THREADS_PER_GROUP; + + const int64_t block_group_id = item.get_group(0) * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + + if (global_group_id >= num_groups) return; + + const int64_t block_group_offset = global_group_id * GROUP_SIZE; + + float local_absmax = eps; + + const T* group_input = input + block_group_offset; + // Output is packed FP4 (2 values per byte), so offset is halved + uint8_t* group_output = output_q + (block_group_offset / 2); + + // Calculate scale output position (row-major layout) + // Each row has num_groups_per_row scales, stored contiguously + uint8_t* scale_output = output_s + global_group_id; + + using vec_type = vec_t; + using float_vec_type = vec_t; + + vec_type input_vecs[VECS_PER_THREAD]; + float_vec_type input_vals[VECS_PER_THREAD]; + +#pragma unroll + for (int32_t v = 0; v < VECS_PER_THREAD; ++v) { + const int32_t i = lane_id + v * THREADS_PER_GROUP; + if (i < NUM_VEC_ELEMS) { + input_vecs[v].load( + 0, sycl::multi_ptr(group_input + i * VEC_SIZE)); + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + float val = static_cast(input_vecs[v][j]); + input_vals[v][j] = val; + local_absmax = sycl::fmax(local_absmax, sycl::fabs(val)); + } + } + } + + // Reduce across the threads in the quantization group to find the maximum + local_absmax = QuantGroupReduceMaxFP4(local_absmax, item); + + // Shared exponent per OCP MX spec / Microsoft micro-scaling: + // shared_exp = floor(log2(absmax)) - E2M1_EMAX + // where E2M1_EMAX = 2. eps already lower-limits local_absmax so + // log2 is well-defined. + float log2_scale = sycl::floor(sycl::log2(local_absmax)) - 2.0f; + int clamped_exponent = sycl::clamp(static_cast(log2_scale), -127, 127); + float scale_value = sycl::exp2(static_cast(clamped_exponent)); + + if (lane_id == 0) { + // Store scale as UE8M0: exponent + 127 bias + uint8_t scale_ue8m0 = static_cast(clamped_exponent + 127); + *scale_output = scale_ue8m0; + } + + const float inv_scale = 1.0f / scale_value; + + // Second pass: quantize and pack values + // Each thread processes VEC_SIZE elements at a time + // Two FP4 values are packed into one byte +#pragma unroll + for (int32_t v = 0; v < VECS_PER_THREAD; ++v) { + const int32_t i = lane_id + v * THREADS_PER_GROUP; + if (i < NUM_VEC_ELEMS) { + // Process VEC_SIZE elements, packing pairs into bytes + uint8_t packed_output[VEC_SIZE / 2]; + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j += 2) { + float val0 = input_vals[v][j] * inv_scale; + float val1 = input_vals[v][j + 1] * inv_scale; + + uint8_t q0 = quantize_to_e2m1(val0); + uint8_t q1 = quantize_to_e2m1(val1); + + // Pack: first value in lower nibble, second in upper nibble + // No masking needed — quantize_to_e2m1 returns values in [0, 15] + packed_output[j / 2] = q0 | (q1 << 4); + } + + // Store packed output + // Each vec of VEC_SIZE elements becomes VEC_SIZE/2 packed bytes + uint8_t* out_ptr = group_output + i * (VEC_SIZE / 2); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE / 2; ++j) { + out_ptr[j] = packed_output[j]; + } + } + } + } + + private: + const T* input; + uint8_t* output_q; + uint8_t* output_s; + int num_groups; + int groups_per_block; + float eps; +}; + +void sgl_per_token_group_quant_fp4( + torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, int64_t group_size, double eps) { + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output_q); + + TORCH_CHECK(group_size == 32, "sgl_per_token_group_quant_fp4: group_size must be 32 for MXFP4, got ", group_size); + + TORCH_CHECK( + input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 || + input.scalar_type() == at::ScalarType::Float, + "sgl_per_token_group_quant_fp4: input dtype must be Float16, BFloat16, or Float32, got ", + input.scalar_type()); + + TORCH_CHECK( + output_q.scalar_type() == at::ScalarType::Byte, + "output_q must be uint8 (packed FP4), got ", + output_q.scalar_type()); + TORCH_CHECK( + output_s.scalar_type() == at::ScalarType::Byte, + "output_s must be uint8 (UE8M0 scales), got ", + output_s.scalar_type()); + + TORCH_CHECK(input.dim() >= 1, "input must have at least 1 dimension"); + TORCH_CHECK( + input.size(-1) % group_size == 0, + "sgl_per_token_group_quant_fp4: last dimension of input (", + input.size(-1), + ") must be divisible by group_size (", + group_size, + ")"); + + const int num_groups = input.numel() / group_size; + + // Output should be half the size (2 FP4 values per byte) + CHECK_EQ(output_q.numel(), input.numel() / 2); + + // Ensure eps is positive to prevent NaN from log2(0) + float eps_f = static_cast(eps); + if (eps_f <= 0.0f) { + eps_f = 1e-10f; + } + + auto queue = dpcppGetCurrentQueue(); + + constexpr int THREADS_PER_GROUP = 16; + + int groups_per_block = 1; + + if (num_groups % 16 == 0) { + groups_per_block = 16; + } else if (num_groups % 8 == 0) { + groups_per_block = 8; + } else if (num_groups % 4 == 0) { + groups_per_block = 4; + } else if (num_groups % 2 == 0) { + groups_per_block = 2; + } + + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * THREADS_PER_GROUP; + + sycl::range<1> global_range(num_blocks * num_threads); + sycl::range<1> local_range(num_threads); + +#define LAUNCH_FP4_KERNEL_WITH_GROUP_SIZE(T, GS) \ + do { \ + auto kernel = PerTokenGroupQuantFP4Kernel( \ + static_cast(input.data_ptr()), \ + static_cast(output_q.data_ptr()), \ + static_cast(output_s.data_ptr()), \ + num_groups, \ + groups_per_block, \ + eps_f); \ + sycl_kernel_submit(global_range, local_range, queue, kernel); \ + } while (0) + +#define LAUNCH_FP4_KERNEL(T) \ + do { \ + LAUNCH_FP4_KERNEL_WITH_GROUP_SIZE(T, 32); \ + } while (0) + + // Dispatch based on input type + if (input.scalar_type() == at::ScalarType::Half) { + LAUNCH_FP4_KERNEL(sycl::half); + } else if (input.scalar_type() == at::ScalarType::BFloat16) { + LAUNCH_FP4_KERNEL(sycl::ext::oneapi::bfloat16); + } else if (input.scalar_type() == at::ScalarType::Float) { + LAUNCH_FP4_KERNEL(float); + } + +#undef LAUNCH_FP4_KERNEL +#undef LAUNCH_FP4_KERNEL_WITH_GROUP_SIZE +} + +} // namespace at::native::xpu diff --git a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in new file mode 100644 index 00000000..d8414d2b --- /dev/null +++ b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// Auto-generated from xe_fmha_fwd_decode_kernel.cpp.in +// Template parameters: QG_SZ=@QG_SZ@, HEAD_DIM=@HEAD_DIM@, PAGE_SIZE=@PAGE_SIZE@ +#define SYCL_INTEL_TARGET 20 + +#include "sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" + +namespace decode { + +template <> +void FmhaDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>::operator()(const Arguments& params) const { + using TileShapeQK = cute::Shape, cute::Int<@PAGE_SIZE@>, cute::_64>; + using TileShapePV = cute::Shape, cute::_32, cute::Int<@PAGE_SIZE@>>; + using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; + using SubgroupLayoutQK = cute::Layout, cute::_1>>; + + AT_DISPATCH_BOOL_NO_RETURN(params.use_causal_mask, Causal, { + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + DecodeConfig::run(params); + }); + }); + }); +} + +template struct FmhaDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>; + +} // namespace decode diff --git a/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in new file mode 100644 index 00000000..9bfc40f3 --- /dev/null +++ b/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// Auto-generated from xe_fmha_fwd_split_decode_kernel.cpp.in +// Template parameters: QG_SZ=@QG_SZ@, HEAD_DIM=@HEAD_DIM@, PAGE_SIZE=@PAGE_SIZE@ +#define SYCL_INTEL_TARGET 20 + +#include "sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" + +namespace decode { + +template <> +void FmhaSplitDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>::operator()(const Arguments& params) const { + using TileShapeQK = cute::Shape, cute::Int<@PAGE_SIZE@>, cute::_64>; + using TileShapePV = cute::Shape, cute::_32, cute::Int<@PAGE_SIZE@>>; + using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; + using SubgroupLayoutQK = cute::Layout, cute::_1>>; + + AT_DISPATCH_BOOL_NO_RETURN(params.use_causal_mask, Causal, { + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + SplitDecodeConfig::run( + params); + }); + }); + }); +} + +template struct FmhaSplitDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>; + +} // namespace decode diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 523e5e2e..7d68caf2 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -129,6 +129,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size," " float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"); m.impl("sgl_per_token_group_quant_8bit", torch::kXPU, &at::native::xpu::sgl_per_token_group_quant_8bit); + m.def( + "sgl_per_token_group_quant_fp4(Tensor input, Tensor output_q, Tensor output_s, int group_size," + " float eps) -> ()"); + m.impl("sgl_per_token_group_quant_fp4", torch::kXPU, &at::native::xpu::sgl_per_token_group_quant_fp4); m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.impl("sgl_per_tensor_quant_fp8", torch::kXPU, &sgl_per_tensor_quant_fp8); diff --git a/tests/run_suite.py b/tests/run_suite.py index 0b3923b4..58539deb 100644 --- a/tests/run_suite.py +++ b/tests/run_suite.py @@ -23,6 +23,7 @@ class TestFile: TestFile("test_moe_prepare_input.py"), TestFile("test_swiglu_with_alpha_limit.py"), TestFile("test_per_token_group_quant_8bit.py"), + TestFile("test_per_token_group_quant_mxfp4.py"), TestFile("test_moe_fused_gate.py"), TestFile("test_per_tensor_quant_fp8.py"), TestFile("test_fused_qk_norm_rope.py"), diff --git a/tests/test_norm.py b/tests/test_norm.py index 6b0623db..a49ed003 100644 --- a/tests/test_norm.py +++ b/tests/test_norm.py @@ -134,5 +134,46 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) +############################################################################### +# Non-contiguous input tests (DeepSeek split pattern: stride[0] != hidden_size) +############################################################################### + + +def _make_non_contiguous(batch_size, hidden_size, dtype, extra=64): + """Create a non-contiguous tensor by slicing a larger tensor, + mimicking latent_cache.split([hidden_size, extra], dim=-1)[0].""" + full = torch.randn(batch_size, hidden_size + extra, device=device, dtype=dtype) + view = full[:, :hidden_size] # stride = (hidden_size + extra, 1) + # assert not view.is_contiguous() + assert view.stride(0) == hidden_size + extra + return view + + +@pytest.mark.parametrize("batch_size", [1, 19, 99]) +@pytest.mark.parametrize("hidden_size", [512, 1024, 3072]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_norm_non_contiguous(batch_size, hidden_size, dtype): + x_nc = _make_non_contiguous(batch_size, hidden_size, dtype) + w = torch.randn(hidden_size, device=device, dtype=dtype) + + y_ref = llama_rms_norm(x_nc.clone(), w) + y = sgl_kernel.rmsnorm(x_nc, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99]) +@pytest.mark.parametrize("hidden_size", [512, 1024, 3072]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_gemma_norm_non_contiguous(batch_size, hidden_size, dtype): + x_nc = _make_non_contiguous(batch_size, hidden_size, dtype) + w = torch.randn(hidden_size, device=device, dtype=dtype) + + y_ref = gemma_rms_norm(x_nc.clone(), w) + y = sgl_kernel.gemma_rmsnorm(x_nc, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": sys.exit(pytest.main([__file__])) diff --git a/tests/test_per_token_group_quant_mxfp4.py b/tests/test_per_token_group_quant_mxfp4.py new file mode 100644 index 00000000..2e129719 --- /dev/null +++ b/tests/test_per_token_group_quant_mxfp4.py @@ -0,0 +1,546 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for MXFP4 (E2M1) Per-Token Group Quantization on Intel XPU + +MXFP4 follows the OpenCompute MX (Microscaling) format specification: +- Data type: E2M1 (4-bit float with 2-bit exponent, 1-bit mantissa) +- Block size: 32 elements per scale factor +- Scale format: UE8M0 (unsigned 8-bit exponent-only, no mantissa) + +Rounding: Per OCP MX spec (section 5.3.3), FP4 conversion uses +roundTiesToEven — at midpoints between representable values, the +value with even mantissa (mantissa bit = 0) is chosen. + +Usage: + pytest test_per_token_group_quant_mxfp4.py -v +""" + +import pytest +import torch + +MXFP4_BLOCK_SIZE = 32 +FLOAT4_E2M1_MAX = 6.0 + +# E2M1 format parameters (from Microsoft microxcaling formats.py) +E2M1_EBITS = 2 +E2M1_MBITS = 3 # includes sign bit and implicit one +E2M1_EMAX = 2 ** (E2M1_EBITS - 1) # = 2 +E2M1_MAX_NORM = ( + 2**E2M1_EMAX * float(2 ** (E2M1_MBITS - 1) - 1) / 2 ** (E2M1_MBITS - 2) +) # = 6.0 + +FP32_EXPONENT_BIAS = 127 +FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1) # 2^(-126) + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + + +def is_xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _round_mantissa_even(A: torch.Tensor) -> torch.Tensor: + """Round mantissa using roundTiesToEven (from Microsoft microxcaling). + + At exact 0.5 midpoints (i.e., values like 0.5, 2.5, 4.5, ...), + round to the nearest even integer (the one whose LSB is 0). + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + absA = torch.abs(A) + # Identify exact midpoints: 0.5, 2.5, 4.5, ... i.e. (absA - 0.5) % 2 == 0 + maskA = ((absA - 0.5) % 2 == torch.zeros_like(A)).type(A.dtype) + # round half up, then subtract 1 at midpoints to get even + return torch.sign(A) * (torch.floor(absA + 0.5) - maskA) + + +def _quantize_elemwise_core_e2m1( + A: torch.Tensor, saturate_normals: bool = True +) -> torch.Tensor: + """Element-wise quantization to E2M1 using Microsoft microxcaling's + _quantize_elemwise_core algorithm with round='even'. + + E2M1 format: ebits=2, mbits=3, emax=2, max_norm=6.0 + min_exp = -(2^(ebits-1)) + 2 = 0 + + Algorithm (from Microsoft microxcaling elemwise_ops.py): + 1. Compute per-element private exponent = floor(log2(|A|)), + clamped to min_exp. + 2. Left-shift: out = A / 2^private_exp * 2^(mbits-2) + 3. Round mantissa with roundTiesToEven + 4. Right-shift: out = out / 2^(mbits-2) * 2^private_exp + 5. Clamp to [-max_norm, max_norm] if saturate_normals + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + ebits = E2M1_EBITS # 2 + mbits = E2M1_MBITS # 3 + max_norm = E2M1_MAX_NORM # 6.0 + + # min representable exponent: -(2^(ebits-1)) + 2 = 0 + min_exp = -(2 ** (ebits - 1)) + 2 # 0 + + out = A.clone() + + # Per-element private exponent: floor(log2(|A|)) + # Add guard for zeros: log2(0) is -inf, we use (A==0) to avoid that + private_exp = torch.floor(torch.log2(torch.abs(A) + (A == 0).type(A.dtype))) + private_exp = private_exp.clip(min=min_exp) + + # Left-shift: scale up so mantissa bits land in integer portion + # out = A / 2^private_exp * 2^(mbits-2) + shift = mbits - 2 # = 1 + out = out / (2**private_exp) * (2**shift) + + # Round mantissa with roundTiesToEven + out = _round_mantissa_even(out) + + # Right-shift: undo scaling + # out = out / 2^(mbits-2) * 2^private_exp + out = out / (2**shift) * (2**private_exp) + + # Saturate to [-max_norm, max_norm] + if saturate_normals: + out = torch.clamp(out, min=-max_norm, max=max_norm) + + return out + + +def _float_to_e2m1_code(val: torch.Tensor) -> torch.Tensor: + """Convert quantized float values back to E2M1 4-bit codes. + + After _quantize_elemwise_core_e2m1, values are one of the 8 representable + E2M1 magnitudes: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0}. + This maps them to 4-bit codes (sign in bit 3, magnitude in bits 0-2). + """ + sign = (val < 0).to(torch.uint8) + abs_val = val.abs() + + # Map representable magnitudes to 3-bit indices via the kE2M1ToFloat LUT. + # Use a tolerance-based comparison since values are exact after quantization. + indices = torch.zeros_like(abs_val, dtype=torch.uint8) + lut = kE2M1ToFloat.to(device=val.device) + for i in range(8): + indices = torch.where( + torch.isclose(abs_val, lut[i].expand_as(abs_val), atol=1e-6, rtol=0), + torch.tensor(i, dtype=torch.uint8, device=val.device), + indices, + ) + + return (sign << 3) | indices + + +def quantize_to_e2m1(tensor: torch.Tensor) -> torch.Tensor: + """Quantize tensor values to E2M1 format (4-bit indices). + + Uses the Microsoft microxcaling _quantize_elemwise_core algorithm + with roundTiesToEven, then maps the resulting float values to 4-bit codes. + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + quantized_float = _quantize_elemwise_core_e2m1( + tensor.float(), saturate_normals=True + ) + return _float_to_e2m1_code(quantized_float) + + +def pack_fp4(tensor: torch.Tensor) -> torch.Tensor: + assert tensor.shape[-1] % 2 == 0 + shape = tensor.shape[:-1] + (tensor.shape[-1] // 2, 2) + paired = tensor.reshape(shape) + packed = (paired[..., 0] & 0x0F) | ((paired[..., 1] & 0x0F) << 4) + return packed.to(torch.uint8) + + +def _normalize_packed_fp4_signed_zero(packed: torch.Tensor) -> torch.Tensor: + """Canonicalize signed zeros in packed FP4 bytes. + + In E2M1, code 0x0 is +0.0 and code 0x8 is -0.0. Both represent + the same value, but different implementations may emit either form. + This helper rewrites every -0.0 nibble (0x8) to +0.0 (0x0) so that + byte-level comparisons are not tripped up by this harmless difference. + """ + # For each nibble, 0x8 is the only code that equals -0.0 + # (sign=1, exponent=0, mantissa=0). Clear bit 3 whenever the + # lower 3 bits (magnitude) are zero — i.e. the nibble is 0x0 or 0x8. + lo = packed & 0x0F + hi = (packed >> 4) & 0x0F + lo = torch.where(lo == 0x08, torch.zeros_like(lo), lo) + hi = torch.where(hi == 0x08, torch.zeros_like(hi), hi) + return (lo | (hi << 4)).to(torch.uint8) + + +def unpack_fp4(packed: torch.Tensor) -> torch.Tensor: + low = packed & 0x0F + high = (packed >> 4) & 0x0F + unpacked = torch.stack([low, high], dim=-1).reshape(*packed.shape[:-1], -1) + return unpacked + + +def dequantize_e2m1( + quantized: torch.Tensor, dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + sign = ((quantized >> 3) & 1).to(torch.bool) + magnitude_idx = (quantized & 0x07).to(torch.long) + kE2M1 = kE2M1ToFloat.to(device=quantized.device) + magnitude = kE2M1[magnitude_idx] + result = torch.where(sign, -magnitude, magnitude) + return result.to(dtype) + + +def _shared_exponents(A: torch.Tensor, axis: int) -> torch.Tensor: + """Compute shared exponents per block using Microsoft microxcaling's + _shared_exponents algorithm with method="max". + + Algorithm: + 1. shared_exp = max(|A|) along axis (per block) + 2. shared_exp = floor(log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0))) + The FP32_MIN_NORMAL guard ensures log2(0) doesn't produce -inf. + 3. Offset by emax: shared_exp = shared_exp - emax + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/mx_ops.py + """ + shared_exp = torch.max(torch.abs(A), dim=axis, keepdim=True).values + + # floor(log2(...)) with zero-guard from microxcaling + shared_exp = torch.floor( + torch.log2( + shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype) + ) + ) + + # Offset by the largest representable exponent in E2M1 + shared_exp = shared_exp - E2M1_EMAX + + return shared_exp + + +def quantize_to_mxfp4( + tensor: torch.Tensor, block_size: int = MXFP4_BLOCK_SIZE, eps: float = 1e-10 +) -> tuple: + """Quantize to MXFP4 using Microsoft microxcaling's _quantize_mx algorithm. + + Algorithm (from mx_ops.py _quantize_mx): + 1. Reshape into blocks + 2. Compute shared exponent per block via _shared_exponents + 3. Clamp shared_exp to scale_emax range [-127, 127] + 4. Scale elements: A = A / 2^shared_exp + 5. Quantize element-wise with _quantize_elemwise_core (saturate_normals=True) + 6. Rescale: A = A * 2^shared_exp (implicitly stored in UE8M0 scale) + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/mx_ops.py + """ + assert tensor.dim() == 2 + m, k = tensor.shape + assert k % block_size == 0 + assert k % 2 == 0 + + tensor_fp32 = tensor.float() + num_blocks = k // block_size + tensor_blocks = tensor_fp32.reshape(m, num_blocks, block_size) + + # Compute shared exponents (microxcaling _shared_exponents + offset by emax) + shared_exp = _shared_exponents(tensor_blocks, axis=-1) + + # Clamp to UE8M0 scale range: scale_bits=8, scale_emax = 2^(8-1)-1 = 127 + scale_emax = 127 + shared_exp = shared_exp.clamp(min=-scale_emax, max=scale_emax) + + # Encode as UE8M0: stored_scale = shared_exp + 127 + scales_ue8m0 = (shared_exp.to(torch.int32) + 127).to(torch.uint8).squeeze(-1) + + # Scale elements by shared exponent: A = A / 2^shared_exp + scaled_tensor = tensor_blocks / (2.0**shared_exp) + + # Quantize element-wise with microxcaling core (roundTiesToEven, saturate) + quantized_float = _quantize_elemwise_core_e2m1(scaled_tensor, saturate_normals=True) + + # Convert quantized float values to 4-bit E2M1 codes + quantized_blocks = _float_to_e2m1_code(quantized_float) + + quantized = quantized_blocks.reshape(m, k) + packed = pack_fp4(quantized) + + return packed, scales_ue8m0 + + +def dequantize_mxfp4( + packed: torch.Tensor, + scales: torch.Tensor, + dtype: torch.dtype = torch.float32, + block_size: int = MXFP4_BLOCK_SIZE, +) -> torch.Tensor: + m, packed_k = packed.shape + k = packed_k * 2 + + unpacked = unpack_fp4(packed) + dequantized = dequantize_e2m1(unpacked, dtype) + + num_blocks = k // block_size + dequantized_blocks = dequantized.reshape(m, num_blocks, block_size) + + scale_exp = scales.to(torch.int32) - 127 + scale_values = torch.pow(2.0, scale_exp.float()).unsqueeze(-1) + scaled = dequantized_blocks * scale_values + + return scaled.reshape(m, k).to(dtype) + + +class TestMXFP4ReferenceQuantization: + def test_e2m1_roundtrip(self): + device = torch.device("cpu") + test_values = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, + device=device, + ) + quantized = quantize_to_e2m1(test_values) + dequantized = dequantize_e2m1(quantized) + torch.testing.assert_close(dequantized, test_values, atol=0.0, rtol=0.0) + + def test_e2m1_round_ties_to_even(self): + """Test that midpoints between representable values round to even (m=0). + + Per OCP MX spec section 5.3.3, FP4 must use roundTiesToEven. + At midpoints, the value with even mantissa (m=0) is chosen. + """ + device = torch.device("cpu") + # Midpoint values and their expected quantized results + # (midpoint_value, expected_dequantized_value) + midpoint_tests = [ + (0.25, 0.0), # midpoint of (0.0, 0.5) -> 0.0 (m=0, even) + (0.75, 1.0), # midpoint of (0.5, 1.0) -> 1.0 (m=0, even) + (1.25, 1.0), # midpoint of (1.0, 1.5) -> 1.0 (m=0, even) + (1.75, 2.0), # midpoint of (1.5, 2.0) -> 2.0 (m=0, even) + (2.5, 2.0), # midpoint of (2.0, 3.0) -> 2.0 (m=0, even) + (3.5, 4.0), # midpoint of (3.0, 4.0) -> 4.0 (m=0, even) + (5.0, 4.0), # midpoint of (4.0, 6.0) -> 4.0 (m=0, even) + # Negative midpoints + (-0.25, 0.0), # -> -0.0 = 0.0 + (-0.75, -1.0), + (-1.25, -1.0), + (-1.75, -2.0), + (-2.5, -2.0), + (-3.5, -4.0), + (-5.0, -4.0), + ] + for midpoint, expected in midpoint_tests: + tensor = torch.tensor([midpoint], dtype=torch.float32, device=device) + quantized = quantize_to_e2m1(tensor) + dequantized = dequantize_e2m1(quantized) + # For -0.25, dequantized is -0.0 which equals 0.0 + assert dequantized.item() == expected or ( + expected == 0.0 and abs(dequantized.item()) == 0.0 + ), f"Midpoint {midpoint}: expected {expected}, got {dequantized.item()}" + + def test_pack_unpack_roundtrip(self): + device = torch.device("cpu") + m, k = 16, 64 + original = torch.randint(0, 16, (m, k), dtype=torch.uint8, device=device) + packed = pack_fp4(original) + unpacked = unpack_fp4(packed) + torch.testing.assert_close(unpacked, original) + + def test_mxfp4_quantization_shape(self): + device = torch.device("cpu") + m, k = 32, 128 + original = torch.randn(m, k, dtype=torch.float32, device=device) + packed, scales = quantize_to_mxfp4(original) + assert packed.shape == (m, k // 2) + assert scales.shape == (m, k // MXFP4_BLOCK_SIZE) + assert packed.dtype == torch.uint8 + assert scales.dtype == torch.uint8 + + def test_mxfp4_dequantization_accuracy(self): + device = torch.device("cpu") + m, k = 32, 128 + original = torch.randn(m, k, dtype=torch.float32, device=device) * 3.0 + packed, scales = quantize_to_mxfp4(original) + dequantized = dequantize_mxfp4(packed, scales, torch.float32) + assert dequantized.shape == original.shape + relative_error = (dequantized - original).abs() / (original.abs() + 1e-6) + mean_error = relative_error.mean().item() + assert mean_error < 0.5 + + +@pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") +class TestPerTokenGroupQuantFP4XPU: + @pytest.fixture(autouse=True) + def setup(self): + import utils + + self.device = utils.get_device() + self.eps = 1e-10 + + def _import_kernel(self): + try: + from sgl_kernel import sgl_per_token_group_quant_fp4 + + return sgl_per_token_group_quant_fp4 + except ImportError: + pytest.skip("sgl_per_token_group_quant_fp4 kernel not available") + + def _test_against_reference( + self, + num_tokens: int, + hidden_dim: int, + src_dtype: torch.dtype = torch.bfloat16, + seed: int = 42, + ): + sgl_per_token_group_quant_fp4 = self._import_kernel() + group_size = MXFP4_BLOCK_SIZE + + torch.manual_seed(seed) + + x_cpu = torch.randn(num_tokens, hidden_dim, dtype=src_dtype, device="cpu") + x_q_ref, scales_ref = quantize_to_mxfp4(x_cpu.float(), group_size, eps=self.eps) + + x_xpu = x_cpu.to(self.device) + x_q_xpu, scales_xpu = sgl_per_token_group_quant_fp4( + x=x_xpu, + group_size=group_size, + eps=self.eps, + ) + + x_q_xpu_cpu = x_q_xpu.cpu() + scales_xpu_cpu = scales_xpu.cpu() + + assert ( + x_q_xpu_cpu.shape == x_q_ref.shape + ), f"Quantized shape mismatch: {x_q_xpu_cpu.shape} vs {x_q_ref.shape}" + assert ( + scales_xpu_cpu.shape == scales_ref.shape + ), f"Scales shape mismatch: {scales_xpu_cpu.shape} vs {scales_ref.shape}" + assert x_q_xpu_cpu.dtype == torch.uint8 + assert scales_xpu_cpu.dtype == torch.uint8 + + # Compare quantized values directly (packed uint8). + # Normalise signed zeros first: in E2M1 code 0x0 (+0.0) and 0x8 + # (-0.0) are semantically identical. The kernel may preserve the + # sign of the original float while the reference always emits +0.0, + # so we canonicalise before comparing. + x_q_xpu_norm = _normalize_packed_fp4_signed_zero(x_q_xpu_cpu) + x_q_ref_norm = _normalize_packed_fp4_signed_zero(x_q_ref) + q_match = torch.equal(x_q_xpu_norm, x_q_ref_norm) + if not q_match: + q_mismatches = (x_q_xpu_norm != x_q_ref_norm).sum().item() + total = x_q_ref_norm.numel() + assert ( + q_mismatches / total < 0.05 + ), f"Too many quantized value mismatches: {q_mismatches}/{total}" + + # Compare scale exponents (allow ±1 difference due to rounding) + scale_exp_ref = scales_ref.to(torch.int32) - 127 + scale_exp_xpu = scales_xpu_cpu.to(torch.int32) - 127 + exp_diff = (scale_exp_ref - scale_exp_xpu).abs() + assert exp_diff.max() == 0, f"Scale exponent difference: {exp_diff.max()}" + + # Compare dequantized outputs + x_dq_ref = dequantize_mxfp4(x_q_ref, scales_ref, torch.float32, group_size) + x_dq_xpu = dequantize_mxfp4( + x_q_xpu_cpu, scales_xpu_cpu, torch.float32, group_size + ) + torch.testing.assert_close(x_dq_xpu, x_dq_ref, rtol=0.0, atol=0.0) + + @pytest.mark.parametrize( + "num_tokens,hidden_dim,src_dtype", + [ + (128, 256, torch.bfloat16), + (64, 128, torch.float16), + (64, 128, torch.float32), + (256, 2048, torch.bfloat16), + ], + ) + def test_quantization_vs_reference(self, num_tokens, hidden_dim, src_dtype): + self._test_against_reference(num_tokens, hidden_dim, src_dtype) + + def test_quantize_dequantize_roundtrip(self): + sgl_per_token_group_quant_fp4 = self._import_kernel() + + torch.manual_seed(42) + m, k = 128, 256 + + x_cpu = torch.randn(m, k, dtype=torch.bfloat16, device="cpu") + x_xpu = x_cpu.to(self.device) + + x_q, scales = sgl_per_token_group_quant_fp4( + x=x_xpu, group_size=MXFP4_BLOCK_SIZE + ) + + x_dq = dequantize_mxfp4( + x_q.cpu(), scales.cpu(), torch.float32, MXFP4_BLOCK_SIZE + ) + + correlation = torch.corrcoef( + torch.stack([x_dq.flatten(), x_cpu.float().flatten()]) + )[0, 1] + assert correlation > 0.9, f"Correlation too low: {correlation}" + + def test_round_ties_to_even_on_xpu(self): + """Test that the kernel implements roundTiesToEven at midpoints.""" + sgl_per_token_group_quant_fp4 = self._import_kernel() + + # Create a tensor of exactly 32 elements (one group) containing + # midpoint values. Scale will be 1.0 (exponent=0) since max abs is 5.0 + # which maps to scale = 2^(floor(log2(5.0)) - 2) = 2^(2 - 2) = 2^0 = 1.0 + midpoints = [ + 0.25, + 0.75, + 1.25, + 1.75, + 2.5, + 3.5, + 5.0, + -0.25, + -0.75, + -1.25, + -1.75, + -2.5, + -3.5, + -5.0, + ] + # Pad to 32 elements with zeros + padded = midpoints + [0.0] * (32 - len(midpoints)) + x = torch.tensor([padded], dtype=torch.float32, device=self.device) + + x_q, scales = sgl_per_token_group_quant_fp4( + x=x, group_size=MXFP4_BLOCK_SIZE, eps=self.eps + ) + + # Reference + x_q_ref, scales_ref = quantize_to_mxfp4( + x.cpu().float(), MXFP4_BLOCK_SIZE, eps=self.eps + ) + + x_dq_xpu = dequantize_mxfp4( + x_q.cpu(), scales.cpu(), torch.float32, MXFP4_BLOCK_SIZE + ) + x_dq_ref = dequantize_mxfp4( + x_q_ref, scales_ref, torch.float32, MXFP4_BLOCK_SIZE + ) + + torch.testing.assert_close(x_dq_xpu, x_dq_ref, atol=0.0, rtol=0.0) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__, "-v"]))