From 94a2786bd9562aafe51d13f876965d3c90a82862 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 23 Apr 2026 15:33:34 +0200 Subject: [PATCH 01/87] init --- src/transformers/integrations/deepgemm.py | 343 ++++++++++++++++++ .../integrations/finegrained_fp8.py | 230 +----------- src/transformers/integrations/moe.py | 2 + 3 files changed, 348 insertions(+), 227 deletions(-) create mode 100644 src/transformers/integrations/deepgemm.py diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py new file mode 100644 index 000000000000..7a8fb0786446 --- /dev/null +++ b/src/transformers/integrations/deepgemm.py @@ -0,0 +1,343 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepGEMM integration: fused grouped GEMM kernels from `kernels-community/deep-gemm`. + +Provides: +- `fp8_deepgemm_matmul`: FP8 dense matmul used as a fast path inside the finegrained-fp8 Linear. +- `fp8_deepgemm_experts_forward`: FP8 M-grouped experts forward, registered as "deepgemm" in the FP8 ExpertsInterface. +- `bf16_deepgemm_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. + +Requirements: CUDA, Hopper (SM90+), CUDA runtime >= 12.3, `kernels`. +""" + +import functools + +import torch + +from ..utils import logging +from ..utils.import_utils import get_cuda_runtime_version, resolve_internal_import +from .hub_kernels import lazy_load_kernel + + +logger = logging.get_logger(__name__) + +# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. +# TMA is an H100 hardware addition that allows applications to asynchronously and +# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. +_DEEPGEMM_M_ALIGNMENT = 128 + + +@functools.cache +def _load_deepgemm_kernel(): + """ + Load deep-gemm once and return its required symbols. + + Raises: + ImportError if CUDA/hardware requirements are not met, or the kernel or + required symbols are not found. + + Returns: + Tuple of (fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, + m_grouped_bf16_gemm_nt_contiguous, per_token_cast_to_fp8) from the deep-gemm kernel. + """ + if not torch.cuda.is_available(): + raise ImportError( + "deep-gemm kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." + ) + + # deep-gemm requires Hopper (SM90) or newer for FP8 WGMMA instructions + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError( + f"deep-gemm requires a Hopper (SM90+) or newer GPU, but the current device " + f"has compute capability {major}.x. Use a different `experts_implementation`." + ) + + # deep-gemm requires CUDA runtime >= 12.3 + cuda_major, cuda_minor = get_cuda_runtime_version() + if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): + raise ImportError( + f"deep-gemm requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " + "Please upgrade your CUDA toolkit or use a different `experts_implementation`." + ) + + kernel = lazy_load_kernel("deep-gemm") + if kernel is None: + raise ImportError( + "deep-gemm kernel not found. Make sure you have the `kernels` package installed (`pip install -U kernels`)." + ) + + fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) + m_grouped_fp8_gemm_nt_contiguous = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) + m_grouped_bf16_gemm_nt_contiguous = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) + per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + + missing = [ + name + for name, attr in [ + ("fp8_gemm_nt", fp8_gemm_nt), + ("m_grouped_fp8_gemm_nt_contiguous", m_grouped_fp8_gemm_nt_contiguous), + ("m_grouped_bf16_gemm_nt_contiguous", m_grouped_bf16_gemm_nt_contiguous), + ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), + ] + if attr is None + ] + if missing: + raise ImportError( + f"deep-gemm kernel is missing required symbols: {', '.join(missing)}. " + "Please update the `kernels` package (`pip install -U kernels`)." + ) + + return fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, m_grouped_bf16_gemm_nt_contiguous, per_token_cast_to_fp8 + + +def fp8_deepgemm_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + FP8 dense matmul via deep-gemm's `fp8_gemm_nt`. Block-wise 128x128 scales expected. + + Args: + A: (M, K) float8_e4m3fn — quantized activations + B: (N, K) float8_e4m3fn — quantized weights + As: (M, K//128) float32 — per-block activation scales + Bs: (N//128, K//128) float32 — per-block weight scales + output_dtype: desired output dtype. + """ + fp8_gemm_nt, _, _, _ = _load_deepgemm_kernel() + A_2d = A.view(-1, A.shape[-1]) + As_2d = As.view(-1, As.shape[-1]) + output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) + fp8_gemm_nt((A_2d, As_2d.float()), (B, Bs.float()), output) + return output.view(A.shape[:-1] + (B.shape[0],)) + + +def _build_deepgemm_contiguous_layout(expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int) -> tuple: + """Build a TMA-aligned contiguous layout for deep-gemm's grouped GEMM. + + deep-gemm requires M-dimension alignment per expert for TMA. This computes + the mapping from sorted token positions to padded row positions, and the + layout tensor that deep-gemm uses to identify expert boundaries. + + Returns: + sorted_to_padded: (num_tokens,) index map from sorted position to padded row + grouped_layout: expert layout tensor (format depends on GPU architecture) + total_padded_rows: total number of rows including alignment padding + """ + device = expert_ids_sorted.device + num_tokens = expert_ids_sorted.size(0) + tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() + aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment + # Upper bound avoids GPU->CPU sync; padding rows are skipped by deep-gemm. + total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + + padding_per_expert = aligned_tokens_per_expert - tokens_per_expert + cumulative_padding = padding_per_expert.cumsum(0) - padding_per_expert + sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] + + if torch.cuda.get_device_capability(device)[0] >= 10: # Blackwell (SM100+) + grouped_layout = tokens_per_expert.cumsum(0).int() + else: + # Hopper: per-row expert id, -1 for padding rows + grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) + grouped_layout[sorted_to_padded] = expert_ids_sorted.int() + + return sorted_to_padded, grouped_layout, total_padded_rows + + +def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: + """Pad a sorted tensor into the TMA-aligned contiguous layout.""" + padded = torch.zeros(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + padded[sorted_to_padded] = x + return padded + + +def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: + """Remove padding rows from the TMA-aligned contiguous layout.""" + return x_padded[sorted_to_padded] + + +def fp8_deepgemm_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if self.activation_scheme == "static": + raise NotImplementedError( + "deepgemm experts dispatch does not support activation_scheme='static'. " + "Use the default eager dispatch or switch to activation_scheme='dynamic'." + ) + if self.block_size is None: + raise ValueError( + "deep-gemm requires block-wise quantization (block_size=[128, 128]), " + "but got per-tensor quantization (block_size=None)." + ) + if self.block_size[0] != 128 or self.block_size[1] != 128: + raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") + + _, m_grouped_fp8_gemm_nt_contiguous, _, per_token_cast_to_fp8 = _load_deepgemm_kernel() + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # Sort by expert for grouped processing + perm = torch.argsort(expert_ids) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) + + expert_ids_g = expert_ids[perm] + sample_weights_g = sample_weights[perm] + selected_hidden_states_g = hidden_states[token_idx[perm]] + + sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT + ) + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + + # --- Up projection per expert (deep-gemm grouped contiguous) --- + w_up = self.gate_up_proj if self.has_gate else self.up_proj + ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv + act_fp8, act_scales = per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) + act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) + act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) + proj_out = torch.zeros(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) + m_grouped_fp8_gemm_nt_contiguous( + (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout + ) + + # Apply gating or activation + if self.has_gate: + proj_out = self._apply_gate(proj_out) + else: + proj_out = self.act_fn(proj_out) + + # --- Down projection per expert (deep-gemm grouped contiguous) --- + proj_fp8, proj_scales = per_token_cast_to_fp8(proj_out, use_ue8m0=False) + proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + m_grouped_fp8_gemm_nt_contiguous( + (proj_fp8, proj_scales), + (self.down_proj, self.down_proj_scale_inv.float()), + proj_out, + grouped_layout, + use_psum_layout=use_psum_layout, + ) + + # Remove padding rows + proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) + + # Apply routing weights + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + + # Restore original order + weighted_out = weighted_out[inv_perm] + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) + + +def bf16_deepgemm_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if self.is_transposed: + raise ValueError("deepgemm bf16 path requires non-transposed weights (is_transposed=False)") + if not self.has_gate: + raise ValueError("deepgemm bf16 path requires gated experts (has_gate=True)") + if self.has_bias: + raise ValueError("deepgemm bf16 path does not support bias (m_grouped_bf16_gemm_nt_contiguous has no bias input)") + if hidden_states.device.type != "cuda": + raise ValueError("deepgemm bf16 path requires CUDA device") + + _, _, m_grouped_bf16_gemm_nt_contiguous, _ = _load_deepgemm_kernel() + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # Handle invalid expert IDs from Expert Parallelism (EP) + invalid_mask = expert_ids >= self.num_experts + expert_ids = expert_ids.clamp(0, self.num_experts - 1) + + # Sort by expert for grouped processing + perm = torch.argsort(expert_ids) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) + + expert_ids_g = expert_ids[perm] + sample_weights_g = sample_weights[perm] + invalid_mask_g = invalid_mask[perm] + selected_hidden_states_g = hidden_states[token_idx[perm]] + + sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT + ) + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + + # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- + act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) + proj_out = torch.zeros( + total_padded_rows, self.gate_up_proj.shape[1], device=device, dtype=hidden_states.dtype + ) + m_grouped_bf16_gemm_nt_contiguous( + act, self.gate_up_proj, proj_out, grouped_layout, use_psum_layout=use_psum_layout + ) + + # Apply gating + proj_out = self._apply_gate(proj_out) + + # --- Down projection per expert (deep-gemm grouped contiguous, bf16) --- + out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) + m_grouped_bf16_gemm_nt_contiguous( + proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout + ) + + # Remove padding rows + out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) + + # Apply routing weights and zero out invalid expert contributions + weighted_out = out * sample_weights_g.to(out.dtype).unsqueeze(-1) # (S, hidden_dim) + weighted_out.masked_fill_(invalid_mask_g.unsqueeze(-1), 0.0) + + # Restore original order + weighted_out = weighted_out[inv_perm] + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index a6b9a517b20d..5f583533792e 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -19,7 +19,7 @@ from ..core_model_loading import ConversionOps, _IdentityOp from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging -from ..utils.import_utils import get_cuda_runtime_version, resolve_internal_import +from .deepgemm import fp8_deepgemm_experts_forward, fp8_deepgemm_matmul from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation @@ -31,11 +31,6 @@ _FP8_MIN = torch.finfo(_FP8_DTYPE).min _FP8_MAX = torch.finfo(_FP8_DTYPE).max -# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM -# TMA is an H100 hardware addition that allows applications to asynchronously and -# bi-directionally transfer 1D-5D tensors between GPU global and shared memory -_DEEPGEMM_M_ALIGNMENT = 128 - # Lazily-loaded finegrained-fp8 Triton kernel functions (populated by _load_triton_kernel) triton_fp8_matmul = None triton_fp8_act_quant = None @@ -44,13 +39,6 @@ # _triton_available: None = not yet attempted, True = loaded, False = failed (won't retry) _triton_available = None -# Lazily-loaded DeepGEMM kernel functions (populated by _load_deepgemm_kernel) -deepgemm_fp8_matmul = None -deepgemm_grouped_fp8_matmul = None -deepgemm_per_token_cast_to_fp8 = None -# _deepgemm_available: None = not yet attempted, True = loaded, False = failed (won't retry) -_deepgemm_available = None - def _load_triton_kernel(): """Lazily load the finegrained-fp8 Triton kernel and extract functions. @@ -97,67 +85,6 @@ def _load_triton_kernel(): _triton_available = True -def _load_deepgemm_kernel(): - """Lazily load the DeepGEMM kernel and extract functions with proper names. - - Uses the hub kernels lazy loading pattern. Raises an error if the kernel - cannot be loaded, required functions are missing, or the hardware is insufficient. - Only attempts loading once. - """ - global _deepgemm_available, deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 - - if _deepgemm_available is not None: - if not _deepgemm_available: - raise ImportError("DeepGEMM kernel is not available (previous load attempt failed).") - return - - _deepgemm_available = False # mark attempted before any early exit - - # DeepGEMM requires CUDA and a compatible GPU - if not torch.cuda.is_available(): - raise ImportError( - "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." - ) - - # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions - major = torch.cuda.get_device_capability()[0] - if major < 9: - raise ImportError( - f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " - f"has compute capability {major}.x. Use a different `experts_implementation`." - ) - - # DeepGEMM requires CUDA runtime ≥ 12.3. - cuda_major, cuda_minor = get_cuda_runtime_version() - if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): - raise ImportError( - f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " - "Please upgrade your CUDA toolkit or use a different `experts_implementation`." - ) - - kernel = lazy_load_kernel("deep-gemm") - deepgemm_fp8_matmul = getattr(kernel, "fp8_gemm_nt") - deepgemm_grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous") - deepgemm_per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") - - missing = [ - name - for name, attr in [ - ("fp8_gemm_nt", deepgemm_fp8_matmul), - ("m_grouped_fp8_gemm_nt_contiguous", deepgemm_grouped_fp8_matmul), - ("utils.per_token_cast_to_fp8", deepgemm_per_token_cast_to_fp8), - ] - if attr is None - ] - if missing: - raise ImportError( - f"DeepGEMM kernel is missing required functions: {', '.join(missing)}. " - "Please update the `kernels` package (`pip install -U kernels`)." - ) - - _deepgemm_available = True - - def _cdiv(a: int, b: int) -> int: """Ceiling division.""" return (a + b - 1) // b @@ -191,21 +118,14 @@ def w8a8_fp8_matmul( """ if block_size is not None and block_size[0] == block_size[1] == 128: try: - _load_deepgemm_kernel() - global deepgemm_fp8_matmul + # 3-6x faster than Triton + return fp8_deepgemm_matmul(A, B, As, Bs, output_dtype=output_dtype) except ImportError: logger.warning_once( "DeepGEMM kernel is not available or compatible, falling back to Triton finegrained-fp8 kernel. " "To use DeepGEMM FP8 matmul, ensure you have a Hopper (SM90+) or newer GPU with CUDA runtime 12.3+, " "and that the `kernels` package is installed and up to date (`pip install -U kernels`)." ) - else: - # 3-6x faster than Triton - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) _load_triton_kernel() global triton_fp8_matmul @@ -434,150 +354,6 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) -def _build_deepgemm_contiguous_layout(expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int) -> tuple: - """Build a TMA-aligned contiguous layout for DeepGEMM grouped GEMM. - - DeepGEMM requires M-dimension alignment per expert for TMA. This computes - the mapping from sorted token positions to padded row positions, and the - layout tensor that DeepGEMM uses to identify expert boundaries. - - Returns: - sorted_to_padded: (num_tokens,) index map from sorted position to padded row - grouped_layout: expert layout tensor (format depends on GPU architecture) - total_padded_rows: total number of rows including alignment padding - """ - device = expert_ids_sorted.device - num_tokens = expert_ids_sorted.size(0) - tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() - aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment - # Upper bound avoids GPU→CPU sync; padding rows are skipped by DeepGEMM. - total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) - - padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = padding_per_expert.cumsum(0) - padding_per_expert - sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - - if torch.cuda.get_device_capability(device)[0] >= 10: # Blackwell (SM100+) - grouped_layout = tokens_per_expert.cumsum(0).int() - else: - # Hopper: per-row expert id, -1 for padding rows - grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = expert_ids_sorted.int() - - return sorted_to_padded, grouped_layout, total_padded_rows - - -def _pad_to_deepgemm_contiguous_layout( - hidden_states: torch.Tensor, - scales: torch.Tensor, - sorted_to_padded: torch.Tensor, - total_padded_rows: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """Pad sorted hidden states and scales into the TMA-aligned contiguous layout.""" - hidden_padded = torch.zeros( - total_padded_rows, hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype - ) - hidden_padded[sorted_to_padded] = hidden_states - scales_padded = torch.zeros(total_padded_rows, scales.shape[1], device=hidden_states.device, dtype=torch.float32) - scales_padded[sorted_to_padded] = scales - return hidden_padded, scales_padded - - -def _unpad_from_deepgemm_contiguous_layout( - hidden_states_padded: torch.Tensor, sorted_to_padded: torch.Tensor -) -> torch.Tensor: - """Remove padding rows from the TMA-aligned contiguous layout.""" - return hidden_states_padded[sorted_to_padded] - - -def fp8_deepgemm_experts_forward( - self: torch.nn.Module, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - if self.activation_scheme == "static": - raise NotImplementedError( - "deepgemm experts dispatch does not support activation_scheme='static'. " - "Use the default eager dispatch or switch to activation_scheme='dynamic'." - ) - if self.block_size is None: - raise ValueError( - "DeepGEMM requires block-wise quantization (block_size=[128, 128]), " - "but got per-tensor quantization (block_size=None)." - ) - if self.block_size[0] != 128 or self.block_size[1] != 128: - raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") - - _load_deepgemm_kernel() - global deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 - - device = hidden_states.device - num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) - - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) - sample_weights = top_k_weights.reshape(-1) # (S,) - expert_ids = top_k_index.reshape(-1) # (S,) - - # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] - sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] - - # Build TMA-aligned contiguous layout for DeepGEMM - sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT - ) - - # --- Up projection per expert (DeepGEMM grouped contiguous) --- - w_up = self.gate_up_proj if self.has_gate else self.up_proj - ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv - act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) - act_fp8, act_scales = _pad_to_deepgemm_contiguous_layout(act_fp8, act_scales, sorted_to_padded, total_padded_rows) - proj_out = torch.zeros(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - deepgemm_grouped_fp8_matmul( - (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) - - # Apply gating or activation - if self.has_gate: - proj_out = self._apply_gate(proj_out) - else: - proj_out = self.act_fn(proj_out) - - # --- Down projection per expert (DeepGEMM grouped contiguous) --- - w_down = self.down_proj - ws_down = self.down_proj_scale_inv - proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) - proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) - deepgemm_grouped_fp8_matmul( - (proj_fp8, proj_scales), (w_down, ws_down.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) - - # Remove padding rows - proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) - - # Apply routing weights - weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) - - # Restore original order - weighted_out = weighted_out[inv_perm] - - # Accumulate results using deterministic reshape+sum instead of index_add_ - # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) - final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - - return final_hidden_states.to(hidden_states.dtype) - - class FP8Experts(nn.Module): def __init__( self, diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index d17522d26daa..622b0ceb2fa6 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -23,6 +23,7 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) +from .deepgemm import bf16_deepgemm_experts_forward if is_torch_available(): @@ -460,6 +461,7 @@ class ExpertsInterface(GeneralInterface): _global_mapping = { "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, + "deepgemm": bf16_deepgemm_experts_forward, } def get_interface(self, experts_implementation: str, default: Callable) -> Callable: From 357a0355c9f6f6a9df20c85a163d5711b5635a76 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 23 Apr 2026 15:34:29 +0200 Subject: [PATCH 02/87] style --- src/transformers/integrations/deepgemm.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 7a8fb0786446..f2951deda99c 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -273,7 +273,9 @@ def bf16_deepgemm_experts_forward( if not self.has_gate: raise ValueError("deepgemm bf16 path requires gated experts (has_gate=True)") if self.has_bias: - raise ValueError("deepgemm bf16 path does not support bias (m_grouped_bf16_gemm_nt_contiguous has no bias input)") + raise ValueError( + "deepgemm bf16 path does not support bias (m_grouped_bf16_gemm_nt_contiguous has no bias input)" + ) if hidden_states.device.type != "cuda": raise ValueError("deepgemm bf16 path requires CUDA device") @@ -310,9 +312,7 @@ def bf16_deepgemm_experts_forward( # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) - proj_out = torch.zeros( - total_padded_rows, self.gate_up_proj.shape[1], device=device, dtype=hidden_states.dtype - ) + proj_out = torch.zeros(total_padded_rows, self.gate_up_proj.shape[1], device=device, dtype=hidden_states.dtype) m_grouped_bf16_gemm_nt_contiguous( act, self.gate_up_proj, proj_out, grouped_layout, use_psum_layout=use_psum_layout ) @@ -322,9 +322,7 @@ def bf16_deepgemm_experts_forward( # --- Down projection per expert (deep-gemm grouped contiguous, bf16) --- out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) - m_grouped_bf16_gemm_nt_contiguous( - proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout - ) + m_grouped_bf16_gemm_nt_contiguous(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) # Remove padding rows out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) From 741b5eb717829ba7cfba22bc823fc48e73381b40 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 23 Apr 2026 15:44:38 +0200 Subject: [PATCH 03/87] full support --- src/transformers/integrations/deepgemm.py | 71 +++++++++++++++-------- src/transformers/integrations/sonicmoe.py | 19 +++++- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index f2951deda99c..98c2b83032e2 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -27,7 +27,7 @@ import torch from ..utils import logging -from ..utils.import_utils import get_cuda_runtime_version, resolve_internal_import +from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import from .hub_kernels import lazy_load_kernel @@ -50,8 +50,12 @@ def _load_deepgemm_kernel(): Returns: Tuple of (fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, - m_grouped_bf16_gemm_nt_contiguous, per_token_cast_to_fp8) from the deep-gemm kernel. + m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nn_contiguous, + per_token_cast_to_fp8) from the deep-gemm kernel. """ + if not is_kernels_available(): + raise ImportError("deep-gemm kernel requires the `kernels` package. Install it with `pip install -U kernels`.") + if not torch.cuda.is_available(): raise ImportError( "deep-gemm kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." @@ -82,6 +86,7 @@ def _load_deepgemm_kernel(): fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) m_grouped_fp8_gemm_nt_contiguous = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) m_grouped_bf16_gemm_nt_contiguous = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) + m_grouped_bf16_gemm_nn_contiguous = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") missing = [ @@ -90,6 +95,7 @@ def _load_deepgemm_kernel(): ("fp8_gemm_nt", fp8_gemm_nt), ("m_grouped_fp8_gemm_nt_contiguous", m_grouped_fp8_gemm_nt_contiguous), ("m_grouped_bf16_gemm_nt_contiguous", m_grouped_bf16_gemm_nt_contiguous), + ("m_grouped_bf16_gemm_nn_contiguous", m_grouped_bf16_gemm_nn_contiguous), ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), ] if attr is None @@ -100,7 +106,13 @@ def _load_deepgemm_kernel(): "Please update the `kernels` package (`pip install -U kernels`)." ) - return fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, m_grouped_bf16_gemm_nt_contiguous, per_token_cast_to_fp8 + return ( + fp8_gemm_nt, + m_grouped_fp8_gemm_nt_contiguous, + m_grouped_bf16_gemm_nt_contiguous, + m_grouped_bf16_gemm_nn_contiguous, + per_token_cast_to_fp8, + ) def fp8_deepgemm_matmul( @@ -120,7 +132,7 @@ def fp8_deepgemm_matmul( Bs: (N//128, K//128) float32 — per-block weight scales output_dtype: desired output dtype. """ - fp8_gemm_nt, _, _, _ = _load_deepgemm_kernel() + fp8_gemm_nt, _, _, _, _ = _load_deepgemm_kernel() A_2d = A.view(-1, A.shape[-1]) As_2d = As.view(-1, As.shape[-1]) output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) @@ -192,7 +204,7 @@ def fp8_deepgemm_experts_forward( if self.block_size[0] != 128 or self.block_size[1] != 128: raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") - _, m_grouped_fp8_gemm_nt_contiguous, _, per_token_cast_to_fp8 = _load_deepgemm_kernel() + _, m_grouped_fp8_gemm_nt_contiguous, _, _, per_token_cast_to_fp8 = _load_deepgemm_kernel() device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -268,18 +280,15 @@ def bf16_deepgemm_experts_forward( top_k_index: torch.Tensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: - if self.is_transposed: - raise ValueError("deepgemm bf16 path requires non-transposed weights (is_transposed=False)") - if not self.has_gate: - raise ValueError("deepgemm bf16 path requires gated experts (has_gate=True)") - if self.has_bias: - raise ValueError( - "deepgemm bf16 path does not support bias (m_grouped_bf16_gemm_nt_contiguous has no bias input)" - ) - if hidden_states.device.type != "cuda": - raise ValueError("deepgemm bf16 path requires CUDA device") - - _, _, m_grouped_bf16_gemm_nt_contiguous, _ = _load_deepgemm_kernel() + if hidden_states.dtype != torch.bfloat16: + raise ValueError(f"deepgemm bf16 path requires bfloat16 hidden states, got {hidden_states.dtype}") + + _, _, m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nn_contiguous, _ = _load_deepgemm_kernel() + # Non-transposed HF experts have weight layout (E, N, K) -> NT kernel. + # Transposed HF experts have weight layout (E, K, N) -> NN kernel. + m_grouped_bf16_gemm = ( + m_grouped_bf16_gemm_nn_contiguous if self.is_transposed else m_grouped_bf16_gemm_nt_contiguous + ) device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -301,8 +310,8 @@ def bf16_deepgemm_experts_forward( inv_perm[perm] = torch.arange(perm.size(0), device=device) expert_ids_g = expert_ids[perm] - sample_weights_g = sample_weights[perm] invalid_mask_g = invalid_mask[perm] + sample_weights_g = sample_weights[perm] selected_hidden_states_g = hidden_states[token_idx[perm]] sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( @@ -311,18 +320,30 @@ def bf16_deepgemm_experts_forward( use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- + w_up = self.gate_up_proj if self.has_gate else self.up_proj + # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). + up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) - proj_out = torch.zeros(total_padded_rows, self.gate_up_proj.shape[1], device=device, dtype=hidden_states.dtype) - m_grouped_bf16_gemm_nt_contiguous( - act, self.gate_up_proj, proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) + proj_out = torch.zeros(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) + m_grouped_bf16_gemm(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) + + # The kernel has no bias input -> add per-expert bias post-GEMM; padding rows get discarded at unpad time. + if self.has_bias: + up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias + proj_out = proj_out + _pad_for_deepgemm(up_bias[expert_ids_g], sorted_to_padded, total_padded_rows) - # Apply gating - proj_out = self._apply_gate(proj_out) + # Apply gating or activation + if self.has_gate: + proj_out = self._apply_gate(proj_out) + else: + proj_out = self.act_fn(proj_out) # --- Down projection per expert (deep-gemm grouped contiguous, bf16) --- out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) - m_grouped_bf16_gemm_nt_contiguous(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) + m_grouped_bf16_gemm(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) + + if self.has_bias: + out = out + _pad_for_deepgemm(self.down_proj_bias[expert_ids_g], sorted_to_padded, total_padded_rows) # Remove padding rows out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index e322bb4bc061..df6bfbbd8f1a 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -23,6 +23,7 @@ import torch from ..utils import logging +from ..utils.import_utils import is_kernels_available from .hub_kernels import lazy_load_kernel @@ -38,11 +39,27 @@ def _load_sonic_kernel(): Load sonic-moe once and return its required symbols. Raises: - ImportError if the kernel or required symbols are not found. + ImportError if CUDA/hardware requirements are not met, or if the kernel or + required symbols are not found. Returns: Tuple of (ActivationType, moe_general_routing_inputs function) from the sonic-moe kernel. """ + if not is_kernels_available(): + raise ImportError("sonic-moe kernel requires the `kernels` package. Install it with `pip install -U kernels`.") + + if not torch.cuda.is_available(): + raise ImportError( + "sonic-moe kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." + ) + + # sonic-moe requires Hopper (SM90) or newer + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError( + f"sonic-moe requires a Hopper (SM90+) or newer GPU, but the current device " + f"has compute capability {major}.x. Use a different `experts_implementation`." + ) kernel = lazy_load_kernel("sonic-moe") if kernel is None: From 9fc3662d1f9ec22ff94615513b1d5c189772a5ef Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 09:45:11 +0200 Subject: [PATCH 04/87] support EP better using offsets ! --- src/transformers/integrations/deepgemm.py | 118 ++++++++++-------- .../integrations/finegrained_fp8.py | 104 ++++++++------- src/transformers/integrations/moe.py | 51 +++----- src/transformers/integrations/sonicmoe.py | 5 +- .../integrations/tensor_parallel.py | 14 +++ 5 files changed, 154 insertions(+), 138 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 98c2b83032e2..4d5fcf8095b4 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -17,11 +17,13 @@ Provides: - `fp8_deepgemm_matmul`: FP8 dense matmul used as a fast path inside the finegrained-fp8 Linear. - `fp8_deepgemm_experts_forward`: FP8 M-grouped experts forward, registered as "deepgemm" in the FP8 ExpertsInterface. -- `bf16_deepgemm_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. +- `deepgemm_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. Requirements: CUDA, Hopper (SM90+), CUDA runtime >= 12.3, `kernels`. """ +from __future__ import annotations + import functools import torch @@ -80,7 +82,8 @@ def _load_deepgemm_kernel(): kernel = lazy_load_kernel("deep-gemm") if kernel is None: raise ImportError( - "deep-gemm kernel not found. Make sure you have the `kernels` package installed (`pip install -U kernels`)." + "Failed to load the deep-gemm kernel — check that `kernels-community/deep-gemm` " + "has a build matching the current torch/CUDA." ) fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) @@ -140,42 +143,56 @@ def fp8_deepgemm_matmul( return output.view(A.shape[:-1] + (B.shape[0],)) -def _build_deepgemm_contiguous_layout(expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int) -> tuple: - """Build a TMA-aligned contiguous layout for deep-gemm's grouped GEMM. +def _build_deepgemm_contiguous_layout( + expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool +) -> tuple: + """Build the TMA-aligned layout deep-gemm's grouped GEMM expects. - deep-gemm requires M-dimension alignment per expert for TMA. This computes - the mapping from sorted token positions to padded row positions, and the - layout tensor that deep-gemm uses to identify expert boundaries. + Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes + expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or + per-row expert ids with -1 for padding on Hopper. - Returns: - sorted_to_padded: (num_tokens,) index map from sorted position to padded row - grouped_layout: expert layout tensor (format depends on GPU architecture) - total_padded_rows: total number of rows including alignment padding + Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) + are routed past the last aligned expert block and marked `-1` in the Hopper layout (and + excluded from the Blackwell cumsum), so deep-gemm skips them. """ device = expert_ids_sorted.device num_tokens = expert_ids_sorted.size(0) + # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment # Upper bound avoids GPU->CPU sync; padding rows are skipped by deep-gemm. total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the + # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, + # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the + # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = padding_per_expert.cumsum(0) - padding_per_expert + cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - if torch.cuda.get_device_capability(device)[0] >= 10: # Blackwell (SM100+) - grouped_layout = tokens_per_expert.cumsum(0).int() + if use_psum_layout: # Blackwell (SM100+) + # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= + # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler + # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` + # between experts only matches the padded tensor when the stored cumsum is over aligned counts. + grouped_layout = aligned_tokens_per_expert.cumsum(0).int() else: - # Hopper: per-row expert id, -1 for padding rows + # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = expert_ids_sorted.int() + grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) return sorted_to_padded, grouped_layout, total_padded_rows def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: - """Pad a sorted tensor into the TMA-aligned contiguous layout.""" - padded = torch.zeros(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + """Pad a sorted tensor into the TMA-aligned contiguous layout. + + Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) + or via the psum offsets (Blackwell), so their values never enter the computation. + """ + padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) padded[sorted_to_padded] = x return padded @@ -212,23 +229,18 @@ def fp8_deepgemm_experts_forward( hidden_dim = hidden_states.size(-1) # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 # --- Up projection per expert (deep-gemm grouped contiguous) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj @@ -236,7 +248,7 @@ def fp8_deepgemm_experts_forward( act_fp8, act_scales = per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) - proj_out = torch.zeros(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) + proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) m_grouped_fp8_gemm_nt_contiguous( (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout ) @@ -249,7 +261,7 @@ def fp8_deepgemm_experts_forward( # --- Down projection per expert (deep-gemm grouped contiguous) --- proj_fp8, proj_scales = per_token_cast_to_fp8(proj_out, use_ue8m0=False) - proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) m_grouped_fp8_gemm_nt_contiguous( (proj_fp8, proj_scales), (self.down_proj, self.down_proj_scale_inv.float()), @@ -262,9 +274,11 @@ def fp8_deepgemm_experts_forward( proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) # Apply routing weights - weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -274,14 +288,14 @@ def fp8_deepgemm_experts_forward( return final_hidden_states.to(hidden_states.dtype) -def bf16_deepgemm_experts_forward( +def deepgemm_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: if hidden_states.dtype != torch.bfloat16: - raise ValueError(f"deepgemm bf16 path requires bfloat16 hidden states, got {hidden_states.dtype}") + raise ValueError(f"deepgemm path requires bfloat16 hidden states, got {hidden_states.dtype}") _, _, m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nn_contiguous, _ = _load_deepgemm_kernel() # Non-transposed HF experts have weight layout (E, N, K) -> NT kernel. @@ -296,41 +310,40 @@ def bf16_deepgemm_experts_forward( hidden_dim = hidden_states.size(-1) # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Handle invalid expert IDs from Expert Parallelism (EP) - invalid_mask = expert_ids >= self.num_experts - expert_ids = expert_ids.clamp(0, self.num_experts - 1) - + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail + # and `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond + # the cumsum on Blackwell) — so deep-gemm performs no real GEMM work for them. # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] - invalid_mask_g = invalid_mask[perm] + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + + # Clamp now that the layout has been built — needed for the per-row bias gather below to stay + # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. + expert_ids_g.clamp_(0, self.num_experts - 1) # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) + # `torch.zeros` so sentinel rows read back as 0 at unpad time (kernel leaves them untouched). proj_out = torch.zeros(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) m_grouped_bf16_gemm(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) - # The kernel has no bias input -> add per-expert bias post-GEMM; padding rows get discarded at unpad time. + # The kernel has no bias input -> add per-expert bias in-place on the unpadded slice; + # padding rows get discarded at unpad time. if self.has_bias: up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias - proj_out = proj_out + _pad_for_deepgemm(up_bias[expert_ids_g], sorted_to_padded, total_padded_rows) + proj_out.index_add_(0, sorted_to_padded, up_bias[expert_ids_g]) # Apply gating or activation if self.has_gate: @@ -343,16 +356,17 @@ def bf16_deepgemm_experts_forward( m_grouped_bf16_gemm(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) if self.has_bias: - out = out + _pad_for_deepgemm(self.down_proj_bias[expert_ids_g], sorted_to_padded, total_padded_rows) + out.index_add_(0, sorted_to_padded, self.down_proj_bias[expert_ids_g]) # Remove padding rows out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) - # Apply routing weights and zero out invalid expert contributions - weighted_out = out * sample_weights_g.to(out.dtype).unsqueeze(-1) # (S, hidden_dim) - weighted_out.masked_fill_(invalid_mask_g.unsqueeze(-1), 0.0) + # Apply routing weights + weighted_out = out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # Accumulate results using deterministic reshape+sum instead of index_add_ diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 5f583533792e..9579d50c5fd7 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import functools + import torch import torch.nn as nn from torch.nn import functional as F @@ -19,9 +23,11 @@ from ..core_model_loading import ConversionOps, _IdentityOp from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging +from ..utils.import_utils import is_kernels_available from .deepgemm import fp8_deepgemm_experts_forward, fp8_deepgemm_matmul from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation +from .tensor_parallel import neutralize_ep_sentinels logger = logging.get_logger(__name__) @@ -31,40 +37,36 @@ _FP8_MIN = torch.finfo(_FP8_DTYPE).min _FP8_MAX = torch.finfo(_FP8_DTYPE).max -# Lazily-loaded finegrained-fp8 Triton kernel functions (populated by _load_triton_kernel) -triton_fp8_matmul = None -triton_fp8_act_quant = None -triton_batched_fp8_matmul = None -triton_grouped_fp8_matmul = None -# _triton_available: None = not yet attempted, True = loaded, False = failed (won't retry) -_triton_available = None - +@functools.cache def _load_triton_kernel(): - """Lazily load the finegrained-fp8 Triton kernel and extract functions. - - Uses the hub kernels lazy loading pattern. Raises an error if the kernel - cannot be loaded or required functions are missing. Only attempts loading once. """ - global \ - _triton_available, \ - triton_fp8_act_quant, \ - triton_fp8_matmul, \ - triton_batched_fp8_matmul, \ - triton_grouped_fp8_matmul + Load the finegrained-fp8 Triton kernel once and return its required symbols. - if _triton_available is not None: - if not _triton_available: - raise ImportError("finegrained-fp8 kernel is not available (previous load attempt failed).") - return + Raises: + ImportError if the `kernels` package is missing, or the kernel or required + symbols cannot be found. - _triton_available = False # mark attempted before any early exit + Returns: + Tuple of (w8a8_fp8_matmul, fp8_act_quant, w8a8_fp8_matmul_batched, + w8a8_fp8_matmul_grouped) from the finegrained-fp8 kernel. + """ + if not is_kernels_available(): + raise ImportError( + "finegrained-fp8 kernel requires the `kernels` package. Install it with `pip install -U kernels`." + ) kernel = lazy_load_kernel("finegrained-fp8") - triton_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul") - triton_fp8_act_quant = getattr(kernel, "fp8_act_quant") - triton_batched_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_batched") - triton_grouped_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_grouped") + if kernel is None: + raise ImportError( + "Failed to load the finegrained-fp8 kernel — check that `kernels-community/finegrained-fp8` " + "has a build matching the current torch/CUDA." + ) + + triton_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul", None) + triton_fp8_act_quant = getattr(kernel, "fp8_act_quant", None) + triton_batched_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_batched", None) + triton_grouped_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_grouped", None) missing = [ name @@ -78,11 +80,11 @@ def _load_triton_kernel(): ] if missing: raise ImportError( - f"finegrained-fp8 kernel is missing required functions: {', '.join(missing)}. " + f"finegrained-fp8 kernel is missing required symbols: {', '.join(missing)}. " "Please update the `kernels` package (`pip install -U kernels`)." ) - _triton_available = True + return triton_fp8_matmul, triton_fp8_act_quant, triton_batched_fp8_matmul, triton_grouped_fp8_matmul def _cdiv(a: int, b: int) -> int: @@ -127,8 +129,7 @@ def w8a8_fp8_matmul( "and that the `kernels` package is installed and up to date (`pip install -U kernels`)." ) - _load_triton_kernel() - global triton_fp8_matmul + triton_fp8_matmul, _, _, _ = _load_triton_kernel() return triton_fp8_matmul(A, B, As, Bs, block_size, output_dtype) @@ -182,8 +183,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: scale_inv = self.weight_scale_inv.contiguous() if self.activation_scheme == "dynamic": - _load_triton_kernel() - global triton_fp8_act_quant + _, triton_fp8_act_quant, _, _ = _load_triton_kernel() qinput, scale = triton_fp8_act_quant( input, self.block_size[1] if self.block_size is not None else input.shape[-1] ) @@ -203,7 +203,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) if self.bias is not None: - output = output + self.bias + output.add_(self.bias) return output.to(dtype=input.dtype) @@ -220,21 +220,20 @@ def fp8_batched_mm_experts_forward( "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) - _load_triton_kernel() - global triton_batched_fp8_matmul + _, _, triton_batched_fp8_matmul, _ = _load_triton_kernel() - device = hidden_states.device num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) hidden_dim = hidden_states.size(-1) # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + # Replicate each token num_top_k times to align with the flattened (S,) routing tensors. + selected_hidden_states = hidden_states.repeat_interleave(num_top_k, dim=0) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Get current hidden states for selected samples - selected_hidden_states = hidden_states[token_idx] + # Handle invalid expert IDs from Expert Parallelism (EP) + neutralize_ep_sentinels(expert_ids, sample_weights, self.num_experts) # --- Up projection per expert (FP8 batched) --- proj_out = triton_batched_fp8_matmul( @@ -263,7 +262,8 @@ def fp8_batched_mm_experts_forward( ) # (S, hidden_dim) # Apply routing weights - weighted_out = proj_out * sample_weights.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + # Let torch promote bf16 `proj_out` × fp32 `sample_weights` to fp32 for the reduction below. + weighted_out = proj_out * sample_weights.unsqueeze(-1) # (S, hidden_dim) # Accumulate results using deterministic reshape+sum instead of index_add_ # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) @@ -284,8 +284,7 @@ def fp8_grouped_mm_experts_forward( "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) - _load_triton_kernel() - global triton_grouped_fp8_matmul + _, _, _, triton_grouped_fp8_matmul = _load_triton_kernel() device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -293,22 +292,18 @@ def fp8_grouped_mm_experts_forward( hidden_dim = hidden_states.size(-1) # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] # Compute offsets for grouped processing. # histc instead of bincount avoids cuda-graph issues; # CPU requires float input, CUDA requires int input (deterministic mode). + # histc drops values > max, so sentinels (== num_experts) are excluded from the per-expert count. histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) @@ -342,9 +337,11 @@ def fp8_grouped_mm_experts_forward( ) # (S, hidden_dim) # Apply routing weights - weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -472,8 +469,7 @@ def linear( scale = activation_scale.to(torch.float32) qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) else: - _load_triton_kernel() - global triton_fp8_act_quant + _, triton_fp8_act_quant, _, _ = _load_triton_kernel() qinput, scale = triton_fp8_act_quant( input, self.block_size[1] if self.block_size is not None else input.shape[-1] ) @@ -685,5 +681,5 @@ def convert( } @property - def reverse_op(self) -> "ConversionOps": + def reverse_op(self) -> ConversionOps: return _IdentityOp() diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index b8015a0505b4..2c3ea91eafb6 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from collections.abc import Callable from functools import wraps @@ -23,8 +24,9 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) -from .deepgemm import bf16_deepgemm_experts_forward +from .deepgemm import deepgemm_experts_forward from .sonicmoe import sonicmoe_experts_forward +from .tensor_parallel import neutralize_ep_sentinels if is_torch_available(): @@ -103,7 +105,7 @@ def _batched_linear( out = torch.bmm(weight, input.unsqueeze(-1)).squeeze(-1) if bias is not None: - out = out + bias + out.add_(bias) return out @@ -114,24 +116,18 @@ def batched_mm_experts_forward( top_k_index: torch.Tensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: - device = hidden_states.device num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) hidden_dim = hidden_states.size(-1) - # Reshape for easier indexing # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + # Replicate each token num_top_k times to align with the flattened (S,) routing tensors. + selected_hidden_states = hidden_states.repeat_interleave(num_top_k, dim=0) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) # Handle invalid expert IDs from Expert Parallelism (EP) - # When EP is enabled, tokens assigned to experts on other devices are marked with sentinel value >= num_experts - invalid_mask = expert_ids >= self.num_experts - expert_ids = expert_ids.clamp(0, self.num_experts - 1) - - # Get current hidden states for selected samples - selected_hidden_states = hidden_states[token_idx] + neutralize_ep_sentinels(expert_ids, sample_weights, self.num_experts) # Select gate_up or just up projection weights and biases if self.has_gate: @@ -163,9 +159,8 @@ def batched_mm_experts_forward( proj_out, selected_weights, bias=selected_biases, is_transposed=self.is_transposed ) # (S, hidden_dim) - # Apply routing weights and zero out invalid expert contributions + # Apply routing weights weighted_out = proj_out * sample_weights.unsqueeze(-1) # (S, hidden_dim) - weighted_out.masked_fill_(invalid_mask.unsqueeze(-1), 0.0) # Zero out invalid expert contributions # Accumulate results using deterministic reshape+sum instead of index_add_ # index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd @@ -364,7 +359,7 @@ def _grouped_linear( if bias is not None: # We should be able to pass bias to the grouped_mm call, but it's not yet supported. - out = out + bias + out.add_(bias) return out @@ -380,32 +375,26 @@ def grouped_mm_experts_forward( num_tokens = hidden_states.size(0) hidden_dim = hidden_states.size(-1) - # Reshape for easier indexing # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Handle invalid expert IDs from Expert Parallelism (EP) - invalid_mask = expert_ids >= self.num_experts - expert_ids = expert_ids.clamp(0, self.num_experts - 1) - # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] # Compute offsets for grouped_mm # using histc instead of bincount to avoid cuda graph issues # With deterministic algorithms, CPU only supports float input, CUDA only supports int input. + # `max=num_experts-1` drops unclamped sentinels (value == num_experts) from the per-expert count. histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + # Clamp now that offsets are built. We only need this for the per-row bias gather below to stay in-bounds. + expert_ids_g.clamp_(0, self.num_experts - 1) + # Select expert weights and biases # NOTE: We keep all experts here and rely on offsets to target the active ones. # I have already implemented a version that only passes the active experts, but @@ -440,12 +429,12 @@ def grouped_mm_experts_forward( proj_out, selected_weights, offsets, bias=selected_biases, is_transposed=self.is_transposed ) # (S, hidden_dim) - # Apply routing weights and zero out invalid expert contributions from EP + # Apply routing weights weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) - invalid_mask_g = invalid_mask[perm] - weighted_out.masked_fill_(invalid_mask_g.unsqueeze(-1), 0.0) # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # (S, hidden_dim) # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -461,10 +450,10 @@ class ExpertsInterface(GeneralInterface): """Interface for registering custom experts forward functions.""" _global_mapping = { - "sonicmoe": sonicmoe_experts_forward, "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, - "deepgemm": bf16_deepgemm_experts_forward, + "deepgemm": deepgemm_experts_forward, + "sonicmoe": sonicmoe_experts_forward, } def get_interface(self, experts_implementation: str, default: Callable) -> Callable: diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index df6bfbbd8f1a..d6eee485fea7 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -18,6 +18,8 @@ Requirements: CUDA, `kernels`, `nvidia-cutlass-dsl`, has_gate=True. """ +from __future__ import annotations + import functools import torch @@ -64,7 +66,8 @@ def _load_sonic_kernel(): kernel = lazy_load_kernel("sonic-moe") if kernel is None: raise ImportError( - "sonic-moe kernel not found. Make sure you have the `kernels` and `nvidia-cutlass-dsl` packages installed." + "Failed to load the sonic-moe kernel — check that `kernels-community/sonic-moe` " + "has a build matching the current torch/CUDA." ) ActivationType = getattr(getattr(kernel, "enums", None), "ActivationType", None) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 82d6d284f052..0c4557e4d3d7 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1079,6 +1079,20 @@ def update_module_attributes(self, module: nn.Module): module.num_experts = self.get_expected_sharded_shape((self.empty_param.shape[0],))[0] +def neutralize_ep_sentinels(expert_ids, sample_weights, num_experts) -> None: + """Make EP sentinel slots (`expert_ids >= num_experts`) no-ops for indexing backends. + + Mutates in place: clamps `expert_ids` in-range (so weight indexing stays valid) and zeros + `sample_weights` at sentinel slots (so their expert GEMM output contributes nothing). + + Sentinel tokens still go through the expert GEMMs; filtering them beforehand needs a host sync + or dynamic-shape kernels, both of which break CUDA graphs — so we keep the shape-preserving path. + Grouped-GEMM backends can skip sentinels via offsets instead — see `grouped_mm_experts_forward`. + """ + sample_weights.masked_fill_(expert_ids >= num_experts, 0.0) + expert_ids.clamp_(0, num_experts - 1) + + class RouterParallel(TensorParallelLayer): """ Allows to reshape the router scores to support running expert parallel. From 84552ae98465ad2ed13bbbc67ff08b79b9bcb1bd Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 10:48:41 +0200 Subject: [PATCH 05/87] comments --- src/transformers/integrations/deepgemm.py | 19 +++++++++++-------- .../integrations/finegrained_fp8.py | 5 ++++- src/transformers/integrations/moe.py | 12 +++++++++--- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 4d5fcf8095b4..10fb7adcda8e 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -313,9 +313,11 @@ def deepgemm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail - # and `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond - # the cumsum on Blackwell) — so deep-gemm performs no real GEMM work for them. + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the + # cumsum on Blackwell), and deep-gemm skips them — so sentinels cost no real GEMM compute. Their + # routing weights are already zero (RouterParallel masks them at dispatch) so the weighted mul + # contributes nothing. # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -326,17 +328,17 @@ def deepgemm_experts_forward( expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) - # Clamp now that the layout has been built — needed for the per-row bias gather below to stay - # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. - expert_ids_g.clamp_(0, self.num_experts - 1) + if self.has_bias: + # Clamp now that the layout has been built — needed for the per-row bias gather below to stay + # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. + expert_ids_g.clamp_(0, self.num_experts - 1) # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) - # `torch.zeros` so sentinel rows read back as 0 at unpad time (kernel leaves them untouched). - proj_out = torch.zeros(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) + proj_out = torch.empty(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) m_grouped_bf16_gemm(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) # The kernel has no bias input -> add per-expert bias in-place on the unpadded slice; @@ -352,6 +354,7 @@ def deepgemm_experts_forward( proj_out = self.act_fn(proj_out) # --- Down projection per expert (deep-gemm grouped contiguous, bf16) --- + # Zero-init: unpad later reads sentinel-row positions the kernel never writes. out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) m_grouped_bf16_gemm(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 9579d50c5fd7..e8d3f25c3edc 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -295,6 +295,10 @@ def fp8_grouped_mm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and the grouped matmul skips + # rows beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Their routing weights are + # already zero (RouterParallel masks them at dispatch) so the weighted mul contributes nothing. # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -303,7 +307,6 @@ def fp8_grouped_mm_experts_forward( # Compute offsets for grouped processing. # histc instead of bincount avoids cuda-graph issues; # CPU requires float input, CUDA requires int input (deterministic mode). - # histc drops values > max, so sentinels (== num_experts) are excluded from the per-expert count. histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 2c3ea91eafb6..705e07763bd4 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -379,6 +379,10 @@ def grouped_mm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and grouped_mm skips rows + # beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Their routing weights are + # already zero (RouterParallel masks them at dispatch) so the weighted mul contributes nothing. # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -387,19 +391,21 @@ def grouped_mm_experts_forward( # Compute offsets for grouped_mm # using histc instead of bincount to avoid cuda graph issues # With deterministic algorithms, CPU only supports float input, CUDA only supports int input. - # `max=num_experts-1` drops unclamped sentinels (value == num_experts) from the per-expert count. histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) - # Clamp now that offsets are built. We only need this for the per-row bias gather below to stay in-bounds. - expert_ids_g.clamp_(0, self.num_experts - 1) + if self.has_bias: + # Clamp now that the layout has been built — needed for the per-row bias gather below to stay + # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. + expert_ids_g.clamp_(0, self.num_experts - 1) # Select expert weights and biases # NOTE: We keep all experts here and rely on offsets to target the active ones. # I have already implemented a version that only passes the active experts, but # to do so I had to use torch.unique which breaks the graph capture (data-dependent). # Also there were no speedup gains from it in my experiments, even in eager mode. + # NOTE: The grouped_mm kernel only targets the active experts / tokens via the offsets if self.has_gate: selected_weights = self.gate_up_proj selected_biases = self.gate_up_proj_bias[expert_ids_g] if self.has_bias else None From 1d9f319b9623d414ca8e8b7b931c3081efc27100 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 11:00:48 +0200 Subject: [PATCH 06/87] get rid of neutralize_ep_sentinels --- src/transformers/integrations/finegrained_fp8.py | 7 ++++--- src/transformers/integrations/moe.py | 7 ++++--- src/transformers/integrations/tensor_parallel.py | 14 -------------- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index e8d3f25c3edc..f08329003df4 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -27,7 +27,6 @@ from .deepgemm import fp8_deepgemm_experts_forward, fp8_deepgemm_matmul from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation -from .tensor_parallel import neutralize_ep_sentinels logger = logging.get_logger(__name__) @@ -232,8 +231,10 @@ def fp8_batched_mm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Handle invalid expert IDs from Expert Parallelism (EP) - neutralize_ep_sentinels(expert_ids, sample_weights, self.num_experts) + # Clamp EP sentinels so per-token weight indexing stays in-bounds. Routing weights are already + # zero at sentinel slots (RouterParallel masks them at dispatch), so the weighted mul drops + # those contributions — we pay the wasted GEMM compute because batched_mm has no offset to skip. + expert_ids.clamp_(0, self.num_experts - 1) # --- Up projection per expert (FP8 batched) --- proj_out = triton_batched_fp8_matmul( diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 705e07763bd4..4f1f9c315959 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -26,7 +26,6 @@ ) from .deepgemm import deepgemm_experts_forward from .sonicmoe import sonicmoe_experts_forward -from .tensor_parallel import neutralize_ep_sentinels if is_torch_available(): @@ -126,8 +125,10 @@ def batched_mm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Handle invalid expert IDs from Expert Parallelism (EP) - neutralize_ep_sentinels(expert_ids, sample_weights, self.num_experts) + # Clamp EP sentinels so `gate_up_proj[expert_ids]` stays in-bounds. Routing weights are already + # zero at sentinel slots (RouterParallel masks them at dispatch), so the weighted mul drops + # those contributions — we pay the wasted GEMM compute because batched_mm has no offset to skip. + expert_ids.clamp_(0, self.num_experts - 1) # Select gate_up or just up projection weights and biases if self.has_gate: diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 0c4557e4d3d7..82d6d284f052 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1079,20 +1079,6 @@ def update_module_attributes(self, module: nn.Module): module.num_experts = self.get_expected_sharded_shape((self.empty_param.shape[0],))[0] -def neutralize_ep_sentinels(expert_ids, sample_weights, num_experts) -> None: - """Make EP sentinel slots (`expert_ids >= num_experts`) no-ops for indexing backends. - - Mutates in place: clamps `expert_ids` in-range (so weight indexing stays valid) and zeros - `sample_weights` at sentinel slots (so their expert GEMM output contributes nothing). - - Sentinel tokens still go through the expert GEMMs; filtering them beforehand needs a host sync - or dynamic-shape kernels, both of which break CUDA graphs — so we keep the shape-preserving path. - Grouped-GEMM backends can skip sentinels via offsets instead — see `grouped_mm_experts_forward`. - """ - sample_weights.masked_fill_(expert_ids >= num_experts, 0.0) - expert_ids.clamp_(0, num_experts - 1) - - class RouterParallel(TensorParallelLayer): """ Allows to reshape the router scores to support running expert parallel. From 9b8604341198535dd11016bbf35100139dd9a2bd Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 11:24:05 +0200 Subject: [PATCH 07/87] remove deepgemm stuff --- src/transformers/integrations/deepgemm.py | 379 ------------------ .../integrations/finegrained_fp8.py | 251 +++++++++++- src/transformers/integrations/moe.py | 2 - 3 files changed, 249 insertions(+), 383 deletions(-) delete mode 100644 src/transformers/integrations/deepgemm.py diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py deleted file mode 100644 index 10fb7adcda8e..000000000000 --- a/src/transformers/integrations/deepgemm.py +++ /dev/null @@ -1,379 +0,0 @@ -# Copyright 2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DeepGEMM integration: fused grouped GEMM kernels from `kernels-community/deep-gemm`. - -Provides: -- `fp8_deepgemm_matmul`: FP8 dense matmul used as a fast path inside the finegrained-fp8 Linear. -- `fp8_deepgemm_experts_forward`: FP8 M-grouped experts forward, registered as "deepgemm" in the FP8 ExpertsInterface. -- `deepgemm_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. - -Requirements: CUDA, Hopper (SM90+), CUDA runtime >= 12.3, `kernels`. -""" - -from __future__ import annotations - -import functools - -import torch - -from ..utils import logging -from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import -from .hub_kernels import lazy_load_kernel - - -logger = logging.get_logger(__name__) - -# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. -# TMA is an H100 hardware addition that allows applications to asynchronously and -# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. -_DEEPGEMM_M_ALIGNMENT = 128 - - -@functools.cache -def _load_deepgemm_kernel(): - """ - Load deep-gemm once and return its required symbols. - - Raises: - ImportError if CUDA/hardware requirements are not met, or the kernel or - required symbols are not found. - - Returns: - Tuple of (fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, - m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nn_contiguous, - per_token_cast_to_fp8) from the deep-gemm kernel. - """ - if not is_kernels_available(): - raise ImportError("deep-gemm kernel requires the `kernels` package. Install it with `pip install -U kernels`.") - - if not torch.cuda.is_available(): - raise ImportError( - "deep-gemm kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." - ) - - # deep-gemm requires Hopper (SM90) or newer for FP8 WGMMA instructions - major = torch.cuda.get_device_capability()[0] - if major < 9: - raise ImportError( - f"deep-gemm requires a Hopper (SM90+) or newer GPU, but the current device " - f"has compute capability {major}.x. Use a different `experts_implementation`." - ) - - # deep-gemm requires CUDA runtime >= 12.3 - cuda_major, cuda_minor = get_cuda_runtime_version() - if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): - raise ImportError( - f"deep-gemm requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " - "Please upgrade your CUDA toolkit or use a different `experts_implementation`." - ) - - kernel = lazy_load_kernel("deep-gemm") - if kernel is None: - raise ImportError( - "Failed to load the deep-gemm kernel — check that `kernels-community/deep-gemm` " - "has a build matching the current torch/CUDA." - ) - - fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) - m_grouped_fp8_gemm_nt_contiguous = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) - m_grouped_bf16_gemm_nt_contiguous = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) - m_grouped_bf16_gemm_nn_contiguous = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) - per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") - - missing = [ - name - for name, attr in [ - ("fp8_gemm_nt", fp8_gemm_nt), - ("m_grouped_fp8_gemm_nt_contiguous", m_grouped_fp8_gemm_nt_contiguous), - ("m_grouped_bf16_gemm_nt_contiguous", m_grouped_bf16_gemm_nt_contiguous), - ("m_grouped_bf16_gemm_nn_contiguous", m_grouped_bf16_gemm_nn_contiguous), - ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), - ] - if attr is None - ] - if missing: - raise ImportError( - f"deep-gemm kernel is missing required symbols: {', '.join(missing)}. " - "Please update the `kernels` package (`pip install -U kernels`)." - ) - - return ( - fp8_gemm_nt, - m_grouped_fp8_gemm_nt_contiguous, - m_grouped_bf16_gemm_nt_contiguous, - m_grouped_bf16_gemm_nn_contiguous, - per_token_cast_to_fp8, - ) - - -def fp8_deepgemm_matmul( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - output_dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - """ - FP8 dense matmul via deep-gemm's `fp8_gemm_nt`. Block-wise 128x128 scales expected. - - Args: - A: (M, K) float8_e4m3fn — quantized activations - B: (N, K) float8_e4m3fn — quantized weights - As: (M, K//128) float32 — per-block activation scales - Bs: (N//128, K//128) float32 — per-block weight scales - output_dtype: desired output dtype. - """ - fp8_gemm_nt, _, _, _, _ = _load_deepgemm_kernel() - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - fp8_gemm_nt((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) - - -def _build_deepgemm_contiguous_layout( - expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool -) -> tuple: - """Build the TMA-aligned layout deep-gemm's grouped GEMM expects. - - Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes - expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or - per-row expert ids with -1 for padding on Hopper. - - Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) - are routed past the last aligned expert block and marked `-1` in the Hopper layout (and - excluded from the Blackwell cumsum), so deep-gemm skips them. - """ - device = expert_ids_sorted.device - num_tokens = expert_ids_sorted.size(0) - # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. - tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() - aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment - # Upper bound avoids GPU->CPU sync; padding rows are skipped by deep-gemm. - total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) - - # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the - # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, - # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the - # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. - padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) - sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - - if use_psum_layout: # Blackwell (SM100+) - # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= - # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler - # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` - # between experts only matches the padded tensor when the stored cumsum is over aligned counts. - grouped_layout = aligned_tokens_per_expert.cumsum(0).int() - else: - # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). - grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) - - return sorted_to_padded, grouped_layout, total_padded_rows - - -def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: - """Pad a sorted tensor into the TMA-aligned contiguous layout. - - Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) - or via the psum offsets (Blackwell), so their values never enter the computation. - """ - padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) - padded[sorted_to_padded] = x - return padded - - -def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: - """Remove padding rows from the TMA-aligned contiguous layout.""" - return x_padded[sorted_to_padded] - - -def fp8_deepgemm_experts_forward( - self: torch.nn.Module, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - if self.activation_scheme == "static": - raise NotImplementedError( - "deepgemm experts dispatch does not support activation_scheme='static'. " - "Use the default eager dispatch or switch to activation_scheme='dynamic'." - ) - if self.block_size is None: - raise ValueError( - "deep-gemm requires block-wise quantization (block_size=[128, 128]), " - "but got per-tensor quantization (block_size=None)." - ) - if self.block_size[0] != 128 or self.block_size[1] != 128: - raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") - - _, m_grouped_fp8_gemm_nt_contiguous, _, _, per_token_cast_to_fp8 = _load_deepgemm_kernel() - - device = hidden_states.device - num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) - - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - sample_weights = top_k_weights.reshape(-1) # (S,) - expert_ids = top_k_index.reshape(-1) # (S,) - - # Sort by expert for grouped processing - expert_ids_g, perm = torch.sort(expert_ids) - selected_hidden_states_g = hidden_states[perm // num_top_k] - sample_weights_g = sample_weights[perm] - - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout - ) - - # --- Up projection per expert (deep-gemm grouped contiguous) --- - w_up = self.gate_up_proj if self.has_gate else self.up_proj - ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv - act_fp8, act_scales = per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) - act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) - act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) - proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) - m_grouped_fp8_gemm_nt_contiguous( - (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) - - # Apply gating or activation - if self.has_gate: - proj_out = self._apply_gate(proj_out) - else: - proj_out = self.act_fn(proj_out) - - # --- Down projection per expert (deep-gemm grouped contiguous) --- - proj_fp8, proj_scales = per_token_cast_to_fp8(proj_out, use_ue8m0=False) - proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) - m_grouped_fp8_gemm_nt_contiguous( - (proj_fp8, proj_scales), - (self.down_proj, self.down_proj_scale_inv.float()), - proj_out, - grouped_layout, - use_psum_layout=use_psum_layout, - ) - - # Remove padding rows - proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) - - # Apply routing weights - weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) - - # Restore original order - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - weighted_out = weighted_out[inv_perm] - - # Accumulate results using deterministic reshape+sum instead of index_add_ - # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) - final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - - return final_hidden_states.to(hidden_states.dtype) - - -def deepgemm_experts_forward( - self: torch.nn.Module, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - if hidden_states.dtype != torch.bfloat16: - raise ValueError(f"deepgemm path requires bfloat16 hidden states, got {hidden_states.dtype}") - - _, _, m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nn_contiguous, _ = _load_deepgemm_kernel() - # Non-transposed HF experts have weight layout (E, N, K) -> NT kernel. - # Transposed HF experts have weight layout (E, K, N) -> NN kernel. - m_grouped_bf16_gemm = ( - m_grouped_bf16_gemm_nn_contiguous if self.is_transposed else m_grouped_bf16_gemm_nt_contiguous - ) - - device = hidden_states.device - num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) - - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - sample_weights = top_k_weights.reshape(-1) # (S,) - expert_ids = top_k_index.reshape(-1) # (S,) - - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, - # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the - # cumsum on Blackwell), and deep-gemm skips them — so sentinels cost no real GEMM compute. Their - # routing weights are already zero (RouterParallel masks them at dispatch) so the weighted mul - # contributes nothing. - # Sort by expert for grouped processing - expert_ids_g, perm = torch.sort(expert_ids) - selected_hidden_states_g = hidden_states[perm // num_top_k] - sample_weights_g = sample_weights[perm] - - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout - ) - - if self.has_bias: - # Clamp now that the layout has been built — needed for the per-row bias gather below to stay - # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. - expert_ids_g.clamp_(0, self.num_experts - 1) - - # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- - w_up = self.gate_up_proj if self.has_gate else self.up_proj - # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). - up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] - act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) - proj_out = torch.empty(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) - m_grouped_bf16_gemm(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) - - # The kernel has no bias input -> add per-expert bias in-place on the unpadded slice; - # padding rows get discarded at unpad time. - if self.has_bias: - up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias - proj_out.index_add_(0, sorted_to_padded, up_bias[expert_ids_g]) - - # Apply gating or activation - if self.has_gate: - proj_out = self._apply_gate(proj_out) - else: - proj_out = self.act_fn(proj_out) - - # --- Down projection per expert (deep-gemm grouped contiguous, bf16) --- - # Zero-init: unpad later reads sentinel-row positions the kernel never writes. - out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) - m_grouped_bf16_gemm(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) - - if self.has_bias: - out.index_add_(0, sorted_to_padded, self.down_proj_bias[expert_ids_g]) - - # Remove padding rows - out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) - - # Apply routing weights - weighted_out = out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) - - # Restore original order - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - weighted_out = weighted_out[inv_perm] - - # Accumulate results using deterministic reshape+sum instead of index_add_ - # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) - final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - - return final_hidden_states.to(hidden_states.dtype) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index f08329003df4..c51d2322fe36 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -23,8 +23,7 @@ from ..core_model_loading import ConversionOps, _IdentityOp from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging -from ..utils.import_utils import is_kernels_available -from .deepgemm import fp8_deepgemm_experts_forward, fp8_deepgemm_matmul +from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation @@ -86,6 +85,162 @@ def _load_triton_kernel(): return triton_fp8_matmul, triton_fp8_act_quant, triton_batched_fp8_matmul, triton_grouped_fp8_matmul +# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. +# TMA is an H100 hardware addition that allows applications to asynchronously and +# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. +_DEEPGEMM_M_ALIGNMENT = 128 + + +@functools.cache +def _load_deepgemm_kernel(): + """ + Load deep-gemm once and return its required symbols. + + Raises: + ImportError if CUDA/hardware requirements are not met, or the kernel or + required symbols are not found. + + Returns: + Tuple of (fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8) + from the deep-gemm kernel. + """ + if not is_kernels_available(): + raise ImportError("deep-gemm kernel requires the `kernels` package. Install it with `pip install -U kernels`.") + + if not torch.cuda.is_available(): + raise ImportError( + "deep-gemm kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." + ) + + # deep-gemm requires Hopper (SM90) or newer for FP8 WGMMA instructions + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError( + f"deep-gemm requires a Hopper (SM90+) or newer GPU, but the current device " + f"has compute capability {major}.x. Use a different `experts_implementation`." + ) + + # deep-gemm requires CUDA runtime >= 12.3 + cuda_major, cuda_minor = get_cuda_runtime_version() + if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): + raise ImportError( + f"deep-gemm requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " + "Please upgrade your CUDA toolkit or use a different `experts_implementation`." + ) + + kernel = lazy_load_kernel("deep-gemm") + if kernel is None: + raise ImportError( + "Failed to load the deep-gemm kernel — check that `kernels-community/deep-gemm` " + "has a build matching the current torch/CUDA." + ) + + fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) + m_grouped_fp8_gemm_nt_contiguous = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) + per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + + missing = [ + name + for name, attr in [ + ("fp8_gemm_nt", fp8_gemm_nt), + ("m_grouped_fp8_gemm_nt_contiguous", m_grouped_fp8_gemm_nt_contiguous), + ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), + ] + if attr is None + ] + if missing: + raise ImportError( + f"deep-gemm kernel is missing required symbols: {', '.join(missing)}. " + "Please update the `kernels` package (`pip install -U kernels`)." + ) + + return fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8 + + +def fp8_deepgemm_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + FP8 dense matmul via deep-gemm's `fp8_gemm_nt`. Block-wise 128x128 scales expected. + + Args: + A: (M, K) float8_e4m3fn — quantized activations + B: (N, K) float8_e4m3fn — quantized weights + As: (M, K//128) float32 — per-block activation scales + Bs: (N//128, K//128) float32 — per-block weight scales + output_dtype: desired output dtype. + """ + fp8_gemm_nt, _, _ = _load_deepgemm_kernel() + A_2d = A.view(-1, A.shape[-1]) + As_2d = As.view(-1, As.shape[-1]) + output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) + fp8_gemm_nt((A_2d, As_2d.float()), (B, Bs.float()), output) + return output.view(A.shape[:-1] + (B.shape[0],)) + + +def _build_deepgemm_contiguous_layout( + expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool +) -> tuple: + """Build the TMA-aligned layout deep-gemm's grouped GEMM expects. + + Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes + expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or + per-row expert ids with -1 for padding on Hopper. + + Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) + are routed past the last aligned expert block and marked `-1` in the Hopper layout (and + excluded from the Blackwell cumsum), so deep-gemm skips them. + """ + device = expert_ids_sorted.device + num_tokens = expert_ids_sorted.size(0) + # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. + tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() + aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment + # Upper bound avoids GPU->CPU sync; padding rows are skipped by deep-gemm. + total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + + # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the + # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, + # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the + # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. + padding_per_expert = aligned_tokens_per_expert - tokens_per_expert + cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) + sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] + + if use_psum_layout: # Blackwell (SM100+) + # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= + # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler + # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` + # between experts only matches the padded tensor when the stored cumsum is over aligned counts. + grouped_layout = aligned_tokens_per_expert.cumsum(0).int() + else: + # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). + grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) + grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) + + return sorted_to_padded, grouped_layout, total_padded_rows + + +def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: + """Pad a sorted tensor into the TMA-aligned contiguous layout. + + Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) + or via the psum offsets (Blackwell), so their values never enter the computation. + """ + padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + padded[sorted_to_padded] = x + return padded + + +def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: + """Remove padding rows from the TMA-aligned contiguous layout.""" + return x_padded[sorted_to_padded] + + def _cdiv(a: int, b: int) -> int: """Ceiling division.""" return (a + b - 1) // b @@ -355,6 +510,98 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) +def fp8_deepgemm_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if self.activation_scheme == "static": + raise NotImplementedError( + "deepgemm experts dispatch does not support activation_scheme='static'. " + "Use the default eager dispatch or switch to activation_scheme='dynamic'." + ) + if self.block_size is None: + raise ValueError( + "deep-gemm requires block-wise quantization (block_size=[128, 128]), " + "but got per-tensor quantization (block_size=None)." + ) + if self.block_size[0] != 128 or self.block_size[1] != 128: + raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") + + _, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8 = _load_deepgemm_kernel() + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the + # cumsum on Blackwell), and deep-gemm skips them — so sentinels cost no real GEMM compute. Their + # routing weights are already zero (RouterParallel masks them at dispatch) so the weighted mul + # contributes nothing. + # Sort by expert for grouped processing + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] + sample_weights_g = sample_weights[perm] # inherits zeros at invalid EP slots + + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout + ) + + # --- Up projection per expert (deep-gemm grouped contiguous) --- + w_up = self.gate_up_proj if self.has_gate else self.up_proj + ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv + act_fp8, act_scales = per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) + act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) + act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) + proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) + m_grouped_fp8_gemm_nt_contiguous( + (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout + ) + + # Apply gating or activation + if self.has_gate: + proj_out = self._apply_gate(proj_out) + else: + proj_out = self.act_fn(proj_out) + + # --- Down projection per expert (deep-gemm grouped contiguous) --- + proj_fp8, proj_scales = per_token_cast_to_fp8(proj_out, use_ue8m0=False) + # Zero-init: unpad later reads sentinel-row positions the kernel never writes. + proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + m_grouped_fp8_gemm_nt_contiguous( + (proj_fp8, proj_scales), + (self.down_proj, self.down_proj_scale_inv.float()), + proj_out, + grouped_layout, + use_psum_layout=use_psum_layout, + ) + + # Remove padding rows + proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) + + # Apply routing weights + weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + + # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) + weighted_out = weighted_out[inv_perm] + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) + + class FP8Experts(nn.Module): def __init__( self, diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 4f1f9c315959..1ceb9e167409 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -24,7 +24,6 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) -from .deepgemm import deepgemm_experts_forward from .sonicmoe import sonicmoe_experts_forward @@ -459,7 +458,6 @@ class ExpertsInterface(GeneralInterface): _global_mapping = { "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, - "deepgemm": deepgemm_experts_forward, "sonicmoe": sonicmoe_experts_forward, } From 996d67d0ce9fa46d46a82c4d552215305ee960cd Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 11:26:27 +0200 Subject: [PATCH 08/87] fix --- src/transformers/integrations/finegrained_fp8.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index b6b437761eed..c75d66087cf7 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -36,6 +36,13 @@ _FP8_MAX = torch.finfo(_FP8_DTYPE).max +def _first_attr(obj, *names): + for name in names: + if hasattr(obj, name): + return getattr(obj, name) + raise AttributeError(f"{type(obj).__name__} has none of: {names}") + + @functools.cache def _load_triton_kernel(): """ From d033a8309a538bb476298d931ec70032792000dd Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 11:38:43 +0200 Subject: [PATCH 09/87] prefix --- .../integrations/finegrained_fp8.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index c75d66087cf7..6a4ae50c8f6a 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -108,7 +108,7 @@ def _load_deepgemm_kernel(): required symbols are not found. Returns: - Tuple of (fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8) + Tuple of (deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8) from the deep-gemm kernel. """ if not is_kernels_available(): @@ -142,16 +142,16 @@ def _load_deepgemm_kernel(): "has a build matching the current torch/CUDA." ) - fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) - m_grouped_fp8_gemm_nt_contiguous = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) - per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + deepgemm_fp8_matmul = getattr(kernel, "fp8_gemm_nt", None) + deepgemm_grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) + deepgemm_per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") missing = [ name for name, attr in [ - ("fp8_gemm_nt", fp8_gemm_nt), - ("m_grouped_fp8_gemm_nt_contiguous", m_grouped_fp8_gemm_nt_contiguous), - ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), + ("fp8_gemm_nt", deepgemm_fp8_matmul), + ("m_grouped_fp8_gemm_nt_contiguous", deepgemm_grouped_fp8_matmul), + ("utils.per_token_cast_to_fp8", deepgemm_per_token_cast_to_fp8), ] if attr is None ] @@ -161,7 +161,7 @@ def _load_deepgemm_kernel(): "Please update the `kernels` package (`pip install -U kernels`)." ) - return fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8 + return deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 def fp8_deepgemm_matmul( @@ -181,11 +181,11 @@ def fp8_deepgemm_matmul( Bs: (N//128, K//128) float32 — per-block weight scales output_dtype: desired output dtype. """ - fp8_gemm_nt, _, _ = _load_deepgemm_kernel() + deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() A_2d = A.view(-1, A.shape[-1]) As_2d = As.view(-1, As.shape[-1]) output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - fp8_gemm_nt((A_2d, As_2d.float()), (B, Bs.float()), output) + deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) return output.view(A.shape[:-1] + (B.shape[0],)) @@ -536,7 +536,7 @@ def fp8_deepgemm_experts_forward( if self.block_size[0] != 128 or self.block_size[1] != 128: raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") - _, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8 = _load_deepgemm_kernel() + _, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 = _load_deepgemm_kernel() device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -565,11 +565,11 @@ def fp8_deepgemm_experts_forward( # --- Up projection per expert (deep-gemm grouped contiguous) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv - act_fp8, act_scales = per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) + act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) - m_grouped_fp8_gemm_nt_contiguous( + deepgemm_grouped_fp8_matmul( (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout ) @@ -580,10 +580,10 @@ def fp8_deepgemm_experts_forward( proj_out = self.act_fn(proj_out) # --- Down projection per expert (deep-gemm grouped contiguous) --- - proj_fp8, proj_scales = per_token_cast_to_fp8(proj_out, use_ue8m0=False) + proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) # Zero-init: unpad later reads sentinel-row positions the kernel never writes. proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) - m_grouped_fp8_gemm_nt_contiguous( + deepgemm_grouped_fp8_matmul( (proj_fp8, proj_scales), (self.down_proj, self.down_proj_scale_inv.float()), proj_out, From e15cfe6ad62f13e87cfe07353c787f0ba7fcb3d0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 12:15:40 +0200 Subject: [PATCH 10/87] move --- .../integrations/finegrained_fp8.py | 201 +++++++++--------- 1 file changed, 100 insertions(+), 101 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 6a4ae50c8f6a..f268018314ac 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -101,7 +101,7 @@ def _load_triton_kernel(): @functools.cache def _load_deepgemm_kernel(): """ - Load deep-gemm once and return its required symbols. + Load DeepGEMM once and return its required symbols. Raises: ImportError if CUDA/hardware requirements are not met, or the kernel or @@ -109,36 +109,36 @@ def _load_deepgemm_kernel(): Returns: Tuple of (deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8) - from the deep-gemm kernel. + from the DeepGEMM kernel. """ if not is_kernels_available(): - raise ImportError("deep-gemm kernel requires the `kernels` package. Install it with `pip install -U kernels`.") + raise ImportError("DeepGEMM kernel requires the `kernels` package. Install it with `pip install -U kernels`.") if not torch.cuda.is_available(): raise ImportError( - "deep-gemm kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." + "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." ) - # deep-gemm requires Hopper (SM90) or newer for FP8 WGMMA instructions + # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions major = torch.cuda.get_device_capability()[0] if major < 9: raise ImportError( - f"deep-gemm requires a Hopper (SM90+) or newer GPU, but the current device " + f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " f"has compute capability {major}.x. Use a different `experts_implementation`." ) - # deep-gemm requires CUDA runtime >= 12.3 + # DeepGEMM requires CUDA runtime >= 12.3 cuda_major, cuda_minor = get_cuda_runtime_version() if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): raise ImportError( - f"deep-gemm requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " + f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " "Please upgrade your CUDA toolkit or use a different `experts_implementation`." ) - kernel = lazy_load_kernel("deep-gemm") + kernel = lazy_load_kernel("DeepGEMM") if kernel is None: raise ImportError( - "Failed to load the deep-gemm kernel — check that `kernels-community/deep-gemm` " + "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " "has a build matching the current torch/CUDA." ) @@ -157,97 +157,13 @@ def _load_deepgemm_kernel(): ] if missing: raise ImportError( - f"deep-gemm kernel is missing required symbols: {', '.join(missing)}. " + f"DeepGEMM kernel is missing required symbols: {', '.join(missing)}. " "Please update the `kernels` package (`pip install -U kernels`)." ) return deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 -def fp8_deepgemm_matmul( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - output_dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - """ - FP8 dense matmul via deep-gemm's `fp8_gemm_nt`. Block-wise 128x128 scales expected. - - Args: - A: (M, K) float8_e4m3fn — quantized activations - B: (N, K) float8_e4m3fn — quantized weights - As: (M, K//128) float32 — per-block activation scales - Bs: (N//128, K//128) float32 — per-block weight scales - output_dtype: desired output dtype. - """ - deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) - - -def _build_deepgemm_contiguous_layout( - expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool -) -> tuple: - """Build the TMA-aligned layout deep-gemm's grouped GEMM expects. - - Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes - expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or - per-row expert ids with -1 for padding on Hopper. - - Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) - are routed past the last aligned expert block and marked `-1` in the Hopper layout (and - excluded from the Blackwell cumsum), so deep-gemm skips them. - """ - device = expert_ids_sorted.device - num_tokens = expert_ids_sorted.size(0) - # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. - tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() - aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment - # Upper bound avoids GPU->CPU sync; padding rows are skipped by deep-gemm. - total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) - - # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the - # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, - # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the - # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. - padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) - sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - - if use_psum_layout: # Blackwell (SM100+) - # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= - # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler - # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` - # between experts only matches the padded tensor when the stored cumsum is over aligned counts. - grouped_layout = aligned_tokens_per_expert.cumsum(0).int() - else: - # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). - grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) - - return sorted_to_padded, grouped_layout, total_padded_rows - - -def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: - """Pad a sorted tensor into the TMA-aligned contiguous layout. - - Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) - or via the psum offsets (Blackwell), so their values never enter the computation. - """ - padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) - padded[sorted_to_padded] = x - return padded - - -def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: - """Remove padding rows from the TMA-aligned contiguous layout.""" - return x_padded[sorted_to_padded] - - def _cdiv(a: int, b: int) -> int: """Ceiling division.""" return (a + b - 1) // b @@ -425,7 +341,6 @@ def fp8_batched_mm_experts_forward( ) # (S, hidden_dim) # Apply routing weights - # Let torch promote bf16 `proj_out` × fp32 `sample_weights` to fp32 for the reduction below. weighted_out = proj_out * sample_weights.unsqueeze(-1) # (S, hidden_dim) # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -517,6 +432,90 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) +def fp8_deepgemm_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + FP8 dense matmul via DeepGEMM's `fp8_gemm_nt`. Block-wise 128x128 scales expected. + + Args: + A: (M, K) float8_e4m3fn — quantized activations + B: (N, K) float8_e4m3fn — quantized weights + As: (M, K//128) float32 — per-block activation scales + Bs: (N//128, K//128) float32 — per-block weight scales + output_dtype: desired output dtype. + """ + deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() + A_2d = A.view(-1, A.shape[-1]) + As_2d = As.view(-1, As.shape[-1]) + output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) + deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) + return output.view(A.shape[:-1] + (B.shape[0],)) + + +def _build_deepgemm_contiguous_layout( + expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool +) -> tuple: + """Build the TMA-aligned layout DeepGEMM's grouped GEMM expects. + + Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes + expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or + per-row expert ids with -1 for padding on Hopper. + + Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) + are routed past the last aligned expert block and marked `-1` in the Hopper layout (and + excluded from the Blackwell cumsum), so DeepGEMM skips them. + """ + device = expert_ids_sorted.device + num_tokens = expert_ids_sorted.size(0) + # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. + tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() + aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment + # Upper bound avoids GPU->CPU sync; padding rows are skipped by DeepGEMM. + total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + + # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the + # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, + # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the + # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. + padding_per_expert = aligned_tokens_per_expert - tokens_per_expert + cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) + sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] + + if use_psum_layout: # Blackwell (SM100+) + # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= + # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler + # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` + # between experts only matches the padded tensor when the stored cumsum is over aligned counts. + grouped_layout = aligned_tokens_per_expert.cumsum(0).int() + else: + # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). + grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) + grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) + + return sorted_to_padded, grouped_layout, total_padded_rows + + +def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: + """Pad a sorted tensor into the TMA-aligned contiguous layout. + + Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) + or via the psum offsets (Blackwell), so their values never enter the computation. + """ + padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + padded[sorted_to_padded] = x + return padded + + +def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: + """Remove padding rows from the TMA-aligned contiguous layout.""" + return x_padded[sorted_to_padded] + + def fp8_deepgemm_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, @@ -530,11 +529,11 @@ def fp8_deepgemm_experts_forward( ) if self.block_size is None: raise ValueError( - "deep-gemm requires block-wise quantization (block_size=[128, 128]), " + "DeepGEMM requires block-wise quantization (block_size=[128, 128]), " "but got per-tensor quantization (block_size=None)." ) if self.block_size[0] != 128 or self.block_size[1] != 128: - raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") + raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") _, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 = _load_deepgemm_kernel() @@ -549,7 +548,7 @@ def fp8_deepgemm_experts_forward( # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the - # cumsum on Blackwell), and deep-gemm skips them — so sentinels cost no real GEMM compute. Their + # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. Their # routing weights are already zero (RouterParallel masks them at dispatch) so the weighted mul # contributes nothing. # Sort by expert for grouped processing @@ -562,7 +561,7 @@ def fp8_deepgemm_experts_forward( expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) - # --- Up projection per expert (deep-gemm grouped contiguous) --- + # --- Up projection per expert (DeepGEMM grouped contiguous) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) @@ -579,7 +578,7 @@ def fp8_deepgemm_experts_forward( else: proj_out = self.act_fn(proj_out) - # --- Down projection per expert (deep-gemm grouped contiguous) --- + # --- Down projection per expert (DeepGEMM grouped contiguous) --- proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) # Zero-init: unpad later reads sentinel-row positions the kernel never writes. proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) From 10b6d904105bef0850afd69a9f177e0ff9d22389 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 12:17:59 +0200 Subject: [PATCH 11/87] fix --- src/transformers/integrations/finegrained_fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index f268018314ac..bd20894c382c 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -135,7 +135,7 @@ def _load_deepgemm_kernel(): "Please upgrade your CUDA toolkit or use a different `experts_implementation`." ) - kernel = lazy_load_kernel("DeepGEMM") + kernel = lazy_load_kernel("deep-gemm") if kernel is None: raise ImportError( "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " @@ -524,7 +524,7 @@ def fp8_deepgemm_experts_forward( ) -> torch.Tensor: if self.activation_scheme == "static": raise NotImplementedError( - "deepgemm experts dispatch does not support activation_scheme='static'. " + "DeepGEMM experts dispatch does not support activation_scheme='static'. " "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) if self.block_size is None: From d4a6b3056f701dc0c307b02b73a04a890b0bfc30 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 12:30:20 +0200 Subject: [PATCH 12/87] remove comment --- src/transformers/integrations/finegrained_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index bd20894c382c..910eab7838c1 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -554,7 +554,7 @@ def fp8_deepgemm_experts_forward( # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] - sample_weights_g = sample_weights[perm] # inherits zeros at invalid EP slots + sample_weights_g = sample_weights[perm] use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( From 1d6054ff5904407bda8e47bbddd95971f85582e0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 13:00:12 +0200 Subject: [PATCH 13/87] fix unintilized outputs leaking --- .../integrations/finegrained_fp8.py | 24 +++++++++++++------ src/transformers/integrations/moe.py | 10 ++++++-- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 910eab7838c1..e5a4479f178e 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -375,8 +375,9 @@ def fp8_grouped_mm_experts_forward( # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and the grouped matmul skips - # rows beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Their routing weights are - # already zero (RouterParallel masks them at dispatch) so the weighted mul contributes nothing. + # rows beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Sentinel rows are zeroed + # post-weighted-mul (see below), since the kernel leaves them uninitialized. + # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -420,6 +421,11 @@ def fp8_grouped_mm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + # EP sentinel handling: `proj_out` rows past `offsets[-1]` are left uninitialized by the kernel, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Restore original order inv_perm = torch.empty_like(perm) inv_perm[perm] = torch.arange(perm.size(0), device=device) @@ -548,9 +554,9 @@ def fp8_deepgemm_experts_forward( # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the - # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. Their - # routing weights are already zero (RouterParallel masks them at dispatch) so the weighted mul - # contributes nothing. + # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. + # Sentinel rows are zeroed post-weighted-mul (see below), since the kernel leaves them uninitialized. + # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -580,8 +586,7 @@ def fp8_deepgemm_experts_forward( # --- Down projection per expert (DeepGEMM grouped contiguous) --- proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) - # Zero-init: unpad later reads sentinel-row positions the kernel never writes. - proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) deepgemm_grouped_fp8_matmul( (proj_fp8, proj_scales), (self.down_proj, self.down_proj_scale_inv.float()), @@ -596,6 +601,11 @@ def fp8_deepgemm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Restore original order inv_perm = torch.empty_like(perm) inv_perm[perm] = torch.arange(perm.size(0), device=device) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 1ceb9e167409..4ef11fe029b7 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -381,8 +381,9 @@ def grouped_mm_experts_forward( # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and grouped_mm skips rows - # beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Their routing weights are - # already zero (RouterParallel masks them at dispatch) so the weighted mul contributes nothing. + # beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Sentinel rows are zeroed + # post-weighted-mul (see below), since the kernel leaves them uninitialized. + # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -438,6 +439,11 @@ def grouped_mm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + # EP sentinel handling: `proj_out` rows past `offsets[-1]` are left uninitialized by grouped_mm, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Restore original order inv_perm = torch.empty_like(perm) inv_perm[perm] = torch.arange(perm.size(0), device=device) From 137393cda9bc902f7f8dce942dd68ed25be28c2a Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 13:07:37 +0200 Subject: [PATCH 14/87] revert unnecessary changes --- .../integrations/finegrained_fp8.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index e5a4479f178e..684da70f8610 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -419,7 +419,7 @@ def fp8_grouped_mm_experts_forward( ) # (S, hidden_dim) # Apply routing weights - weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) # EP sentinel handling: `proj_out` rows past `offsets[-1]` are left uninitialized by the kernel, # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here @@ -506,20 +506,27 @@ def _build_deepgemm_contiguous_layout( return sorted_to_padded, grouped_layout, total_padded_rows -def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: - """Pad a sorted tensor into the TMA-aligned contiguous layout. - - Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) - or via the psum offsets (Blackwell), so their values never enter the computation. - """ - padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) - padded[sorted_to_padded] = x - return padded +def _pad_to_deepgemm_contiguous_layout( + hidden_states: torch.Tensor, + scales: torch.Tensor, + sorted_to_padded: torch.Tensor, + total_padded_rows: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pad sorted hidden states and scales into the TMA-aligned contiguous layout.""" + hidden_padded = torch.zeros( + total_padded_rows, hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype + ) + hidden_padded[sorted_to_padded] = hidden_states + scales_padded = torch.zeros(total_padded_rows, scales.shape[1], device=hidden_states.device, dtype=torch.float32) + scales_padded[sorted_to_padded] = scales + return hidden_padded, scales_padded -def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: +def _unpad_from_deepgemm_contiguous_layout( + hidden_states_padded: torch.Tensor, sorted_to_padded: torch.Tensor +) -> torch.Tensor: """Remove padding rows from the TMA-aligned contiguous layout.""" - return x_padded[sorted_to_padded] + return hidden_states_padded[sorted_to_padded] def fp8_deepgemm_experts_forward( @@ -571,8 +578,7 @@ def fp8_deepgemm_experts_forward( w_up = self.gate_up_proj if self.has_gate else self.up_proj ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) - act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) - act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) + act_fp8, act_scales = _pad_to_deepgemm_contiguous_layout(act_fp8, act_scales, sorted_to_padded, total_padded_rows) proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) deepgemm_grouped_fp8_matmul( (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout @@ -599,7 +605,7 @@ def fp8_deepgemm_experts_forward( proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) # Apply routing weights - weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here From 774f90181dc5d7f8cea1e25b8dd46444b4ac524a Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 13:09:03 +0200 Subject: [PATCH 15/87] more unnecessary changes --- src/transformers/integrations/finegrained_fp8.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 684da70f8610..a07a0cdd37e2 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -36,6 +36,12 @@ _FP8_MAX = torch.finfo(_FP8_DTYPE).max +# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. +# TMA is an H100 hardware addition that allows applications to asynchronously and +# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. +_DEEPGEMM_M_ALIGNMENT = 128 + + def _first_attr(obj, *names): for name in names: if hasattr(obj, name): @@ -92,12 +98,6 @@ def _load_triton_kernel(): return triton_fp8_matmul, triton_fp8_act_quant, triton_batched_fp8_matmul, triton_grouped_fp8_matmul -# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. -# TMA is an H100 hardware addition that allows applications to asynchronously and -# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. -_DEEPGEMM_M_ALIGNMENT = 128 - - @functools.cache def _load_deepgemm_kernel(): """ From 81230feeaf3c9234399755186394019bd5a21ee4 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 13:30:07 +0200 Subject: [PATCH 16/87] revert downcast --- src/transformers/integrations/finegrained_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index a07a0cdd37e2..64e9c3722c28 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -341,7 +341,7 @@ def fp8_batched_mm_experts_forward( ) # (S, hidden_dim) # Apply routing weights - weighted_out = proj_out * sample_weights.unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) # Accumulate results using deterministic reshape+sum instead of index_add_ # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) From 9f2ff08915bf865791ac8ef2ddbb79ccac317b5b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 14:52:26 +0200 Subject: [PATCH 17/87] keep it simple --- .../integrations/finegrained_fp8.py | 44 +++++++------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 64e9c3722c28..61190b480be3 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -128,7 +128,14 @@ def _load_deepgemm_kernel(): ) # DeepGEMM requires CUDA runtime >= 12.3 - cuda_major, cuda_minor = get_cuda_runtime_version() + try: + cuda_major, cuda_minor = get_cuda_runtime_version() + except OSError as e: + raise ImportError( + f"DeepGEMM requires CUDA runtime 12.3+, but libcudart could not be loaded ({e}). " + "Use a different `experts_implementation`." + ) from e + if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): raise ImportError( f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " @@ -197,14 +204,20 @@ def w8a8_fp8_matmul( """ if block_size is not None and block_size[0] == block_size[1] == 128: try: - # 3-6x faster than Triton - return fp8_deepgemm_matmul(A, B, As, Bs, output_dtype=output_dtype) + deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() except ImportError: logger.warning_once( "DeepGEMM kernel is not available or compatible, falling back to Triton finegrained-fp8 kernel. " "To use DeepGEMM FP8 matmul, ensure you have a Hopper (SM90+) or newer GPU with CUDA runtime 12.3+, " "and that the `kernels` package is installed and up to date (`pip install -U kernels`)." ) + else: + # 3-6x faster than Triton + A_2d = A.view(-1, A.shape[-1]) + As_2d = As.view(-1, As.shape[-1]) + output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) + deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) + return output.view(A.shape[:-1] + (B.shape[0],)) triton_fp8_matmul, _, _, _ = _load_triton_kernel() return triton_fp8_matmul(A, B, As, Bs, block_size, output_dtype) @@ -438,31 +451,6 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) -def fp8_deepgemm_matmul( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - output_dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - """ - FP8 dense matmul via DeepGEMM's `fp8_gemm_nt`. Block-wise 128x128 scales expected. - - Args: - A: (M, K) float8_e4m3fn — quantized activations - B: (N, K) float8_e4m3fn — quantized weights - As: (M, K//128) float32 — per-block activation scales - Bs: (N//128, K//128) float32 — per-block weight scales - output_dtype: desired output dtype. - """ - deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) - - def _build_deepgemm_contiguous_layout( expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool ) -> tuple: From c55b7b7863e864224390ea79f412a2a1830dfab5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 14:59:15 +0200 Subject: [PATCH 18/87] guard deepgemm cuda version --- .../integrations/finegrained_fp8.py | 9 +------- src/transformers/utils/import_utils.py | 21 +++++++++++++++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 61190b480be3..f423f2f6b830 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -128,14 +128,7 @@ def _load_deepgemm_kernel(): ) # DeepGEMM requires CUDA runtime >= 12.3 - try: - cuda_major, cuda_minor = get_cuda_runtime_version() - except OSError as e: - raise ImportError( - f"DeepGEMM requires CUDA runtime 12.3+, but libcudart could not be loaded ({e}). " - "Use a different `experts_implementation`." - ) from e - + cuda_major, cuda_minor = get_cuda_runtime_version() if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): raise ImportError( f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index de11d23cbecf..756363ea6c52 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -222,14 +222,27 @@ def is_cuda_platform() -> bool: def get_cuda_runtime_version() -> tuple[int, int]: """Return the CUDA runtime version as (major, minor). - Unlike ``torch.version.cuda`` which reports the compile-time version, - this queries ``cudaRuntimeGetVersion`` from ``libcudart.so`` to get the - actual runtime version installed on the system. + Prefers a direct query of ``cudaRuntimeGetVersion`` via ``libcudart.so``. If that's + not on the system loader path (common with pip-installed torch that bundles its own + CUDA runtime), falls back to ``torch.version.cuda`` — which equals the bundled + runtime's version for pip wheels. Returns ``(0, 0)`` for CPU-only torch. """ import ctypes + try: + cudart = ctypes.CDLL("libcudart.so") + except OSError: + if not is_torch_available(): + return 0, 0 + import torch + + if getattr(torch.version, "cuda", None) is None: + return 0, 0 + + major, minor, *_ = torch.version.cuda.split(".") + return int(major), int(minor) + version = ctypes.c_int() - cudart = ctypes.CDLL("libcudart.so") cudart.cudaRuntimeGetVersion(ctypes.byref(version)) return version.value // 1000, (version.value % 1000) // 10 From 20858db8159171d2ca430766b21baf1a49493bd9 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 15:19:19 +0200 Subject: [PATCH 19/87] fix style --- src/transformers/utils/import_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 756363ea6c52..8654bd083ba2 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -236,10 +236,11 @@ def get_cuda_runtime_version() -> tuple[int, int]: return 0, 0 import torch - if getattr(torch.version, "cuda", None) is None: + cuda_version = getattr(torch.version, "cuda", None) + if cuda_version is None: return 0, 0 - major, minor, *_ = torch.version.cuda.split(".") + major, minor, *_ = cuda_version.split(".") return int(major), int(minor) version = ctypes.c_int() From bfea94f06dbd07b02a2ff0fd85075e5de4d7a54c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 16:37:12 +0200 Subject: [PATCH 20/87] update --- src/transformers/integrations/deepgemm.py | 389 ++++++++++++++++++ .../integrations/finegrained_fp8.py | 246 +---------- src/transformers/integrations/moe.py | 2 + 3 files changed, 395 insertions(+), 242 deletions(-) create mode 100644 src/transformers/integrations/deepgemm.py diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py new file mode 100644 index 000000000000..c3af8f61c5fb --- /dev/null +++ b/src/transformers/integrations/deepgemm.py @@ -0,0 +1,389 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepGEMM integration: fused grouped GEMM kernels from `kernels-community/deep-gemm`. + +Provides: +- `fp8_deepgemm_matmul`: FP8 dense matmul used as a fast path inside the finegrained-fp8 Linear. +- `fp8_deepgemm_experts_forward`: FP8 M-grouped experts forward, registered as "deepgemm" in the FP8 ExpertsInterface. +- `deepgemm_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. + +Requirements: CUDA, Hopper (SM90+), CUDA runtime >= 12.3, `kernels`. +""" + +from __future__ import annotations + +import functools + +import torch + +from ..utils import logging +from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import +from .hub_kernels import lazy_load_kernel + + +logger = logging.get_logger(__name__) + +# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. +# TMA is an H100 hardware addition that allows applications to asynchronously and +# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. +_DEEPGEMM_M_ALIGNMENT = 128 + + +@functools.cache +def _load_deepgemm_kernel(): + """ + Load DeepGEMM once and return its required symbols. + + Raises: + ImportError if CUDA/hardware requirements are not met, or the kernel or + required symbols are not found. + + Returns: + Tuple of (deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, + deepgemm_grouped_bf16_matmul_nt, deepgemm_grouped_bf16_matmul_nn, + deepgemm_per_token_cast_to_fp8) from the DeepGEMM kernel. + """ + if not is_kernels_available(): + raise ImportError("DeepGEMM kernel requires the `kernels` package. Install it with `pip install -U kernels`.") + + if not torch.cuda.is_available(): + raise ImportError( + "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." + ) + + # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError( + f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " + f"has compute capability {major}.x. Use a different `experts_implementation`." + ) + + # DeepGEMM requires CUDA runtime >= 12.3 + cuda_major, cuda_minor = get_cuda_runtime_version() + if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): + raise ImportError( + f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " + "Please upgrade your CUDA toolkit or use a different `experts_implementation`." + ) + + kernel = lazy_load_kernel("deep-gemm") + if kernel is None: + raise ImportError( + "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " + "has a build matching the current torch/CUDA." + ) + + deepgemm_fp8_matmul = getattr(kernel, "fp8_gemm_nt", None) + deepgemm_grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) + deepgemm_grouped_bf16_matmul_nt = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) + deepgemm_grouped_bf16_matmul_nn = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) + deepgemm_per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + + missing = [ + name + for name, attr in [ + ("fp8_gemm_nt", deepgemm_fp8_matmul), + ("m_grouped_fp8_gemm_nt_contiguous", deepgemm_grouped_fp8_matmul), + ("m_grouped_bf16_gemm_nt_contiguous", deepgemm_grouped_bf16_matmul_nt), + ("m_grouped_bf16_gemm_nn_contiguous", deepgemm_grouped_bf16_matmul_nn), + ("utils.per_token_cast_to_fp8", deepgemm_per_token_cast_to_fp8), + ] + if attr is None + ] + if missing: + raise ImportError( + f"DeepGEMM kernel is missing required symbols: {', '.join(missing)}. " + "Please update the `kernels` package (`pip install -U kernels`)." + ) + + return ( + deepgemm_fp8_matmul, + deepgemm_grouped_fp8_matmul, + deepgemm_grouped_bf16_matmul_nt, + deepgemm_grouped_bf16_matmul_nn, + deepgemm_per_token_cast_to_fp8, + ) + + +def fp8_deepgemm_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + FP8 dense matmul via DeepGEMM's `fp8_gemm_nt`. Block-wise 128x128 scales expected. + + Args: + A: (M, K) float8_e4m3fn — quantized activations + B: (N, K) float8_e4m3fn — quantized weights + As: (M, K//128) float32 — per-block activation scales + Bs: (N//128, K//128) float32 — per-block weight scales + output_dtype: desired output dtype. + """ + deepgemm_fp8_matmul, _, _, _, _ = _load_deepgemm_kernel() + A_2d = A.view(-1, A.shape[-1]) + As_2d = As.view(-1, As.shape[-1]) + output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) + deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) + return output.view(A.shape[:-1] + (B.shape[0],)) + + +def _build_deepgemm_contiguous_layout( + expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool +) -> tuple: + """Build the TMA-aligned layout DeepGEMM's grouped GEMM expects. + + Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes + expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or + per-row expert ids with -1 for padding on Hopper. + + Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) + are routed past the last aligned expert block and marked `-1` in the Hopper layout (and + excluded from the Blackwell cumsum), so DeepGEMM skips them. + """ + device = expert_ids_sorted.device + num_tokens = expert_ids_sorted.size(0) + # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. + tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() + aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment + # Upper bound avoids GPU->CPU sync; padding rows are skipped by DeepGEMM. + total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + + # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the + # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, + # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the + # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. + padding_per_expert = aligned_tokens_per_expert - tokens_per_expert + cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) + sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] + + if use_psum_layout: # Blackwell (SM100+) + # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= + # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler + # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` + # between experts only matches the padded tensor when the stored cumsum is over aligned counts. + grouped_layout = aligned_tokens_per_expert.cumsum(0).int() + else: + # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). + grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) + grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) + + return sorted_to_padded, grouped_layout, total_padded_rows + + +def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: + """Pad a sorted tensor into the TMA-aligned contiguous layout. + + Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) + or via the psum offsets (Blackwell), so their values never enter the computation. + """ + padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + padded[sorted_to_padded] = x + return padded + + +def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: + """Remove padding rows from the TMA-aligned contiguous layout.""" + return x_padded[sorted_to_padded] + + +def fp8_deepgemm_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if self.activation_scheme == "static": + raise NotImplementedError( + "DeepGEMM experts dispatch does not support activation_scheme='static'. " + "Use the default eager dispatch or switch to activation_scheme='dynamic'." + ) + if self.block_size is None: + raise ValueError( + "DeepGEMM requires block-wise quantization (block_size=[128, 128]), " + "but got per-tensor quantization (block_size=None)." + ) + if self.block_size[0] != 128 or self.block_size[1] != 128: + raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") + + _, deepgemm_grouped_fp8_matmul, _, _, deepgemm_per_token_cast_to_fp8 = _load_deepgemm_kernel() + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the + # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. + # Sentinel rows are zeroed post-weighted-mul (see below), since the kernel leaves them uninitialized. + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] + sample_weights_g = sample_weights[perm] + + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout + ) + + # --- Up projection per expert (DeepGEMM grouped contiguous) --- + w_up = self.gate_up_proj if self.has_gate else self.up_proj + ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv + act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) + act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) + act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) + proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) + deepgemm_grouped_fp8_matmul( + (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout + ) + + # Apply gating or activation + if self.has_gate: + proj_out = self._apply_gate(proj_out) + else: + proj_out = self.act_fn(proj_out) + + # --- Down projection per expert (DeepGEMM grouped contiguous) --- + proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) + proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + deepgemm_grouped_fp8_matmul( + (proj_fp8, proj_scales), + (self.down_proj, self.down_proj_scale_inv.float()), + proj_out, + grouped_layout, + use_psum_layout=use_psum_layout, + ) + + # Remove padding rows + proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) + + # Apply routing weights + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + + # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + + # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) + weighted_out = weighted_out[inv_perm] + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) + + +def deepgemm_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if hidden_states.dtype != torch.bfloat16: + raise ValueError(f"DeepGEMM experts path requires bfloat16 hidden states, got {hidden_states.dtype}") + + # Non-transposed HF experts have weight layout (E, N, K) -> NT kernel. + # Transposed HF experts have weight layout (E, K, N) -> NN kernel. + _, _, deepgemm_grouped_bf16_matmul_nt, deepgemm_grouped_bf16_matmul_nn, _ = _load_deepgemm_kernel() + deepgemm_grouped_bf16_matmul = ( + deepgemm_grouped_bf16_matmul_nn if self.is_transposed else deepgemm_grouped_bf16_matmul_nt + ) + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the + # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. + # Sentinel rows are zeroed post-weighted-mul (see below), since the kernel leaves them uninitialized. + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] + sample_weights_g = sample_weights[perm] + + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout + ) + + if self.has_bias: + # Clamp now that the layout has been built — needed for the per-row bias gather below to stay + # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. + expert_ids_g.clamp_(0, self.num_experts - 1) + + # --- Up projection per expert (DeepGEMM grouped contiguous, bf16) --- + w_up = self.gate_up_proj if self.has_gate else self.up_proj + # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). + up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] + act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) + proj_out = torch.empty(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) + deepgemm_grouped_bf16_matmul(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) + + # The kernel has no bias input -> add per-expert bias in-place on the unpadded slice; + # padding rows get discarded at unpad time. + if self.has_bias: + up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias + proj_out.index_add_(0, sorted_to_padded, up_bias[expert_ids_g]) + + # Apply gating or activation + if self.has_gate: + proj_out = self._apply_gate(proj_out) + else: + proj_out = self.act_fn(proj_out) + + # --- Down projection per expert (DeepGEMM grouped contiguous, bf16) --- + out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) + deepgemm_grouped_bf16_matmul(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) + + if self.has_bias: + out.index_add_(0, sorted_to_padded, self.down_proj_bias[expert_ids_g]) + + # Remove padding rows + out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) + + # Apply routing weights + weighted_out = out * sample_weights_g.to(out.dtype).unsqueeze(-1) # (S, hidden_dim) + + # EP sentinel handling: `out` rows past the valid expert blocks are left uninitialized by the kernel, + # so `out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + + # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) + weighted_out = weighted_out[inv_perm] + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index f423f2f6b830..dce3159a3bd7 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -23,7 +23,8 @@ from ..core_model_loading import ConversionOps, _IdentityOp from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging -from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import +from ..utils.import_utils import is_kernels_available +from .deepgemm import fp8_deepgemm_experts_forward, fp8_deepgemm_matmul from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation @@ -36,12 +37,6 @@ _FP8_MAX = torch.finfo(_FP8_DTYPE).max -# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. -# TMA is an H100 hardware addition that allows applications to asynchronously and -# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. -_DEEPGEMM_M_ALIGNMENT = 128 - - def _first_attr(obj, *names): for name in names: if hasattr(obj, name): @@ -98,72 +93,6 @@ def _load_triton_kernel(): return triton_fp8_matmul, triton_fp8_act_quant, triton_batched_fp8_matmul, triton_grouped_fp8_matmul -@functools.cache -def _load_deepgemm_kernel(): - """ - Load DeepGEMM once and return its required symbols. - - Raises: - ImportError if CUDA/hardware requirements are not met, or the kernel or - required symbols are not found. - - Returns: - Tuple of (deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8) - from the DeepGEMM kernel. - """ - if not is_kernels_available(): - raise ImportError("DeepGEMM kernel requires the `kernels` package. Install it with `pip install -U kernels`.") - - if not torch.cuda.is_available(): - raise ImportError( - "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." - ) - - # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions - major = torch.cuda.get_device_capability()[0] - if major < 9: - raise ImportError( - f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " - f"has compute capability {major}.x. Use a different `experts_implementation`." - ) - - # DeepGEMM requires CUDA runtime >= 12.3 - cuda_major, cuda_minor = get_cuda_runtime_version() - if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): - raise ImportError( - f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " - "Please upgrade your CUDA toolkit or use a different `experts_implementation`." - ) - - kernel = lazy_load_kernel("deep-gemm") - if kernel is None: - raise ImportError( - "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " - "has a build matching the current torch/CUDA." - ) - - deepgemm_fp8_matmul = getattr(kernel, "fp8_gemm_nt", None) - deepgemm_grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) - deepgemm_per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") - - missing = [ - name - for name, attr in [ - ("fp8_gemm_nt", deepgemm_fp8_matmul), - ("m_grouped_fp8_gemm_nt_contiguous", deepgemm_grouped_fp8_matmul), - ("utils.per_token_cast_to_fp8", deepgemm_per_token_cast_to_fp8), - ] - if attr is None - ] - if missing: - raise ImportError( - f"DeepGEMM kernel is missing required symbols: {', '.join(missing)}. " - "Please update the `kernels` package (`pip install -U kernels`)." - ) - - return deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 - - def _cdiv(a: int, b: int) -> int: """Ceiling division.""" return (a + b - 1) // b @@ -197,20 +126,14 @@ def w8a8_fp8_matmul( """ if block_size is not None and block_size[0] == block_size[1] == 128: try: - deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() + # 3-6x faster than Triton + return fp8_deepgemm_matmul(A, B, As, Bs, output_dtype=output_dtype) except ImportError: logger.warning_once( "DeepGEMM kernel is not available or compatible, falling back to Triton finegrained-fp8 kernel. " "To use DeepGEMM FP8 matmul, ensure you have a Hopper (SM90+) or newer GPU with CUDA runtime 12.3+, " "and that the `kernels` package is installed and up to date (`pip install -U kernels`)." ) - else: - # 3-6x faster than Triton - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) triton_fp8_matmul, _, _, _ = _load_triton_kernel() return triton_fp8_matmul(A, B, As, Bs, block_size, output_dtype) @@ -444,167 +367,6 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) -def _build_deepgemm_contiguous_layout( - expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool -) -> tuple: - """Build the TMA-aligned layout DeepGEMM's grouped GEMM expects. - - Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes - expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or - per-row expert ids with -1 for padding on Hopper. - - Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) - are routed past the last aligned expert block and marked `-1` in the Hopper layout (and - excluded from the Blackwell cumsum), so DeepGEMM skips them. - """ - device = expert_ids_sorted.device - num_tokens = expert_ids_sorted.size(0) - # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. - tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() - aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment - # Upper bound avoids GPU->CPU sync; padding rows are skipped by DeepGEMM. - total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) - - # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the - # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, - # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the - # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. - padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) - sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - - if use_psum_layout: # Blackwell (SM100+) - # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= - # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler - # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` - # between experts only matches the padded tensor when the stored cumsum is over aligned counts. - grouped_layout = aligned_tokens_per_expert.cumsum(0).int() - else: - # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). - grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) - - return sorted_to_padded, grouped_layout, total_padded_rows - - -def _pad_to_deepgemm_contiguous_layout( - hidden_states: torch.Tensor, - scales: torch.Tensor, - sorted_to_padded: torch.Tensor, - total_padded_rows: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """Pad sorted hidden states and scales into the TMA-aligned contiguous layout.""" - hidden_padded = torch.zeros( - total_padded_rows, hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype - ) - hidden_padded[sorted_to_padded] = hidden_states - scales_padded = torch.zeros(total_padded_rows, scales.shape[1], device=hidden_states.device, dtype=torch.float32) - scales_padded[sorted_to_padded] = scales - return hidden_padded, scales_padded - - -def _unpad_from_deepgemm_contiguous_layout( - hidden_states_padded: torch.Tensor, sorted_to_padded: torch.Tensor -) -> torch.Tensor: - """Remove padding rows from the TMA-aligned contiguous layout.""" - return hidden_states_padded[sorted_to_padded] - - -def fp8_deepgemm_experts_forward( - self: torch.nn.Module, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - if self.activation_scheme == "static": - raise NotImplementedError( - "DeepGEMM experts dispatch does not support activation_scheme='static'. " - "Use the default eager dispatch or switch to activation_scheme='dynamic'." - ) - if self.block_size is None: - raise ValueError( - "DeepGEMM requires block-wise quantization (block_size=[128, 128]), " - "but got per-tensor quantization (block_size=None)." - ) - if self.block_size[0] != 128 or self.block_size[1] != 128: - raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") - - _, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 = _load_deepgemm_kernel() - - device = hidden_states.device - num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) - - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - sample_weights = top_k_weights.reshape(-1) # (S,) - expert_ids = top_k_index.reshape(-1) # (S,) - - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, - # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the - # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. - # Sentinel rows are zeroed post-weighted-mul (see below), since the kernel leaves them uninitialized. - - # Sort by expert for grouped processing - expert_ids_g, perm = torch.sort(expert_ids) - selected_hidden_states_g = hidden_states[perm // num_top_k] - sample_weights_g = sample_weights[perm] - - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout - ) - - # --- Up projection per expert (DeepGEMM grouped contiguous) --- - w_up = self.gate_up_proj if self.has_gate else self.up_proj - ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv - act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) - act_fp8, act_scales = _pad_to_deepgemm_contiguous_layout(act_fp8, act_scales, sorted_to_padded, total_padded_rows) - proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) - deepgemm_grouped_fp8_matmul( - (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) - - # Apply gating or activation - if self.has_gate: - proj_out = self._apply_gate(proj_out) - else: - proj_out = self.act_fn(proj_out) - - # --- Down projection per expert (DeepGEMM grouped contiguous) --- - proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) - proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) - deepgemm_grouped_fp8_matmul( - (proj_fp8, proj_scales), - (self.down_proj, self.down_proj_scale_inv.float()), - proj_out, - grouped_layout, - use_psum_layout=use_psum_layout, - ) - - # Remove padding rows - proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) - - # Apply routing weights - weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) - - # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, - # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here - # so the downstream reduction stays finite even when the routing weight was already zero. - weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) - - # Restore original order - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - weighted_out = weighted_out[inv_perm] - - # Accumulate results using deterministic reshape+sum instead of index_add_ - # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) - final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - - return final_hidden_states.to(hidden_states.dtype) - - class FP8Experts(nn.Module): def __init__( self, diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 4ef11fe029b7..76fb2b7f70ef 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -24,6 +24,7 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) +from .deepgemm import deepgemm_experts_forward from .sonicmoe import sonicmoe_experts_forward @@ -465,6 +466,7 @@ class ExpertsInterface(GeneralInterface): "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, "sonicmoe": sonicmoe_experts_forward, + "deepgemm": deepgemm_experts_forward, } def get_interface(self, experts_implementation: str, default: Callable) -> Callable: From eada47e1592ff5d24f9289b56d06c0f95c19de7a Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 16:41:43 +0200 Subject: [PATCH 21/87] add deepgemm testing --- tests/test_modeling_common.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bc8f65891445..a7d44177e192 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -53,6 +53,7 @@ ) from transformers.integrations.moe import ( batched_mm_experts_forward, + deepgemm_experts_forward, grouped_mm_experts_forward, sonicmoe_experts_forward, ) @@ -118,6 +119,7 @@ is_torch_bf16_available_on_device, is_torch_fp16_available_on_device, ) +from transformers.utils.import_utils import get_cuda_runtime_version from transformers.utils.output_capturing import CompileableContextVar from .generation.test_utils import GenerationTesterMixin @@ -600,6 +602,17 @@ def _test_eager_matches_batched_and_grouped_inference(self, name, dtype): mocks["sonicmoe"] = Mock(wraps=sonicmoe_experts_forward) implementations.append("sonicmoe") + if ( + dtype == torch.bfloat16 + and is_kernels_available() + and torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and get_cuda_runtime_version() >= (12, 3) + ): + # DeepGEMM BF16 grouped forward requires Hopper+, CUDA runtime 12.3+, and bf16 hidden states + mocks["deepgemm"] = Mock(wraps=deepgemm_experts_forward) + implementations.append("deepgemm") + outputs = {} # This is needed because we call the functions through the interface's global mapping with patch.dict("transformers.integrations.moe.ALL_EXPERTS_FUNCTIONS._global_mapping", mocks): From 89d2f0bb3eeb38c1e6431013f6c67b5cbf30a388 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Sun, 26 Apr 2026 12:43:08 +0200 Subject: [PATCH 22/87] moe sentinel support --- src/transformers/integrations/hub_kernels.py | 6 ++++-- src/transformers/integrations/sonicmoe.py | 6 ++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 70a343424aa8..a362b9e114f2 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -289,7 +289,7 @@ def register_kernel_mapping_transformers(*args, **kwargs): "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, - "sonic-moe": {"repo_id": "kernels-community/sonic-moe", "version": 1}, + "sonic-moe": {"repo_id": "IlyasMoutawwakil/sonic-moe", "revision": "main"}, } _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {} @@ -376,7 +376,9 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) - kernel = get_kernel(repo_id, revision=revision, version=version) + # Entries in `_HUB_KERNEL_MAPPING` are vetted in-tree, so we trust non-`kernels-community` + # repos (e.g. user/team forks) without requiring the per-call `allow_all_kernels` flag. + kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True) mapping[kernel_name] = kernel except FileNotFoundError: mapping[kernel_name] = None diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index d6eee485fea7..d32b698d5d74 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -112,6 +112,12 @@ def sonicmoe_experts_forward( router_scores = top_k_weights.reshape(-1).to(hidden_states.dtype) expert_ids = top_k_index.reshape(-1).int() + # EP sentinel handling: leave `expert_ids` unclamped — the kernel's metadata stage drops + # `expert_ids >= num_experts` from the per-expert histogram and masks them out of the + # scatter indices, so sentinels never enter the grouped GEMM. Their routing weights are + # already zero (RouterParallel masks them at dispatch), so the per-token reduction + # contributes nothing for sentinel slots. + # Map activation function act_name = getattr(self.config, "hidden_act", "silu").lower() activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) From 60db1ca0706885deec9efb185d097d88d5dc0277 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Sun, 26 Apr 2026 13:00:34 +0000 Subject: [PATCH 23/87] fix --- src/transformers/integrations/moe.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 4ef11fe029b7..9cf262de0358 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -30,6 +30,12 @@ if is_torch_available(): import torch + # Patch the version-check helpers so dynamo doesn't trace into them — they transitively call + # `importlib.util.find_spec`, which dynamo refuses to trace. `assume_constant_result` makes + # dynamo evaluate them once at trace time and inline the bool, no body tracing. + is_torch_greater_or_equal = torch._dynamo.assume_constant_result(is_torch_greater_or_equal) + is_torch_less_or_equal = torch._dynamo.assume_constant_result(is_torch_less_or_equal) + logger = logging.get_logger(__name__) From 68b7b0fe2dc4e1877ad7af6e20b0e700de37e69c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 27 Apr 2026 15:56:52 +0200 Subject: [PATCH 24/87] compilable sonicmoe --- src/transformers/integrations/sonicmoe.py | 78 ++++++++++++++++------- 1 file changed, 56 insertions(+), 22 deletions(-) diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index d32b698d5d74..912b98655519 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -25,7 +25,6 @@ import torch from ..utils import logging -from ..utils.import_utils import is_kernels_available from .hub_kernels import lazy_load_kernel @@ -47,8 +46,6 @@ def _load_sonic_kernel(): Returns: Tuple of (ActivationType, moe_general_routing_inputs function) from the sonic-moe kernel. """ - if not is_kernels_available(): - raise ImportError("sonic-moe kernel requires the `kernels` package. Install it with `pip install -U kernels`.") if not torch.cuda.is_available(): raise ImportError( @@ -90,6 +87,50 @@ def _load_sonic_kernel(): return ActivationType, moe_general_routing_inputs +@torch._dynamo.allow_in_graph +def _sonicmoe_wrapper( + hidden_states: torch.Tensor, + router_scores: torch.Tensor, + expert_ids: torch.Tensor, + token_idx: torch.Tensor, + w1: torch.Tensor, + b1: torch.Tensor | None, + w2: torch.Tensor, + b2: torch.Tensor | None, + act_name: str, + num_experts: int, + concat_layout: bool, + is_inference_mode_enabled: bool, +) -> torch.Tensor: + """Module-level shim around `moe_general_routing_inputs` so `allow_in_graph` can wrap it. + + sonicmoe asserts `not torch.compiler.is_compiling()` internally because it dispatches + CuteDSL kernels, which Dynamo can't trace. `allow_in_graph` keeps the call in the FX + graph as a single opaque node (no tracing into the body, no graph break) while still + running the real Python at runtime — autograd through `_UpProjection` / `_DownProjection` + flows normally. The decorator must be applied at module load time, not inside the compiled + function — hence this shim plus the `allow_in_graph` decorator above. + """ + ActivationType, moe_general_routing_inputs = _load_sonic_kernel() + activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) + output, _ = moe_general_routing_inputs( + hidden_states, + router_scores, + token_idx, + expert_ids, + w1, + b1, + w2, + b2, + E=num_experts, + activation_type=activation_type, + is_inference_mode_enabled=is_inference_mode_enabled, + concat_layout=concat_layout, + stream_id=None, + ) + return output + + def sonicmoe_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, @@ -101,8 +142,6 @@ def sonicmoe_experts_forward( if hidden_states.device.type != "cuda": raise ValueError("sonicmoe requires CUDA device") - ActivationType, moe_general_routing_inputs = _load_sonic_kernel() - device = hidden_states.device num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) @@ -120,8 +159,6 @@ def sonicmoe_experts_forward( # Map activation function act_name = getattr(self.config, "hidden_act", "silu").lower() - activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) - # Permute weights as expected by sonic-moe (E=num_experts, H=hidden_size, I=intermediate_size). # Non-transposed: gate_up_proj is (E, 2*I, H), down_proj is (E, H, I) -> permute(1, 2, 0). # Transposed: gate_up_proj is (E, H, 2*I), down_proj is (E, I, H) -> permute(2, 1, 0). @@ -131,20 +168,17 @@ def sonicmoe_experts_forward( b1 = self.gate_up_proj_bias if self.has_bias else None b2 = self.down_proj_bias if self.has_bias else None - output, _ = moe_general_routing_inputs( - hidden_states, - router_scores, - token_idx, - expert_ids, - w1, - b1, - w2, - b2, - E=self.num_experts, - activation_type=activation_type, - stream_id=torch.cuda.current_stream(device).cuda_stream, - is_inference_mode_enabled=not torch.is_grad_enabled(), + return _sonicmoe_wrapper( + hidden_states=hidden_states, + router_scores=router_scores, + expert_ids=expert_ids, + token_idx=token_idx, + w1=w1, + b1=b1, + w2=w2, + b2=b2, + act_name=act_name, + num_experts=self.num_experts, concat_layout=self.is_concatenated, + is_inference_mode_enabled=not torch.is_grad_enabled(), ) - - return output From faaa7aaa6b37530d5d32796d71ec5100c039ae9f Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Wed, 29 Apr 2026 13:26:38 +0000 Subject: [PATCH 25/87] mega moe kernel support attempt --- src/transformers/integrations/deepgemm.py | 433 +++++++++++++----- .../integrations/finegrained_fp8.py | 320 ++++++++----- src/transformers/integrations/moe.py | 6 +- src/transformers/integrations/sonicmoe.py | 1 + .../integrations/tensor_parallel.py | 7 +- src/transformers/utils/quantization_config.py | 12 +- 6 files changed, 542 insertions(+), 237 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index c3af8f61c5fb..88ee8379b5d0 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -15,16 +15,23 @@ """DeepGEMM integration: fused grouped GEMM kernels from `kernels-community/deep-gemm`. Provides: -- `fp8_deepgemm_matmul`: FP8 dense matmul used as a fast path inside the finegrained-fp8 Linear. -- `fp8_deepgemm_experts_forward`: FP8 M-grouped experts forward, registered as "deepgemm" in the FP8 ExpertsInterface. -- `deepgemm_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. - -Requirements: CUDA, Hopper (SM90+), CUDA runtime >= 12.3, `kernels`. +- `deepgemm_bf16_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. +- `deepgemm_fp8_fp4_linear`: end-to-end FP8/FP4 linear (BF16 in, BF16 out) — quantizes activations + inside, dispatches cast settings on weight dtype, and runs the FP8/FP4 matmul. Used as the + DeepGEMM fast path inside `fp8_linear`. +- `deepgemm_fp8_fp4_experts_forward`: FP8 (or FP4 on SM100+) M-grouped experts forward, registered as "deepgemm" in the FP8 ExpertsInterface. +- `deepgemm_fp8_fp4_megamoe_experts_forward`: FP8 acts × FP4 weights Mega MoE forward (SM100+, + fuses EP dispatch + L1 + SwiGLU + L2 + EP combine via a `SymmBuffer`). + +Requirements: CUDA, Hopper (SM90+), CUDA runtime >= 12.3, `kernels-community/deep-gemm` (>= 2.5 +so the Mega MoE symbols are available — the loader raises if any required symbol is missing). +Mega MoE additionally requires SM100+ at call time. """ from __future__ import annotations import functools +from types import SimpleNamespace import torch @@ -40,20 +47,18 @@ # bi-directionally transfer 1D-5D tensors between GPU global and shared memory. _DEEPGEMM_M_ALIGNMENT = 128 +_FP8_DTYPE = torch.float8_e4m3fn +_FP8_MIN = torch.finfo(_FP8_DTYPE).min +_FP8_MAX = torch.finfo(_FP8_DTYPE).max + @functools.cache -def _load_deepgemm_kernel(): +def _load_deepgemm_kernel() -> SimpleNamespace: """ - Load DeepGEMM once and return its required symbols. - - Raises: - ImportError if CUDA/hardware requirements are not met, or the kernel or - required symbols are not found. + Load DeepGEMM once and return its entry points as a `SimpleNamespace`. - Returns: - Tuple of (deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, - deepgemm_grouped_bf16_matmul_nt, deepgemm_grouped_bf16_matmul_nn, - deepgemm_per_token_cast_to_fp8) from the DeepGEMM kernel. + Raises `ImportError` if CUDA/hardware requirements are not met or any required entry + point is missing. """ if not is_kernels_available(): raise ImportError("DeepGEMM kernel requires the `kernels` package. Install it with `pip install -U kernels`.") @@ -86,20 +91,28 @@ def _load_deepgemm_kernel(): "has a build matching the current torch/CUDA." ) - deepgemm_fp8_matmul = getattr(kernel, "fp8_gemm_nt", None) - deepgemm_grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) - deepgemm_grouped_bf16_matmul_nt = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) - deepgemm_grouped_bf16_matmul_nn = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) - deepgemm_per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + fp8_fp4_matmul = getattr(kernel, "fp8_fp4_gemm_nt", None) + grouped_fp8_fp4_matmul = getattr(kernel, "m_grouped_fp8_fp4_gemm_nt_contiguous", None) + grouped_bf16_matmul_nt = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) + grouped_bf16_matmul_nn = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) + per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + symm_buffer_cls = getattr(kernel, "SymmBuffer", None) + fp8_fp4_mega_moe = getattr(kernel, "fp8_fp4_mega_moe", None) + get_symm_buffer_for_mega_moe = getattr(kernel, "get_symm_buffer_for_mega_moe", None) + transform_weights_for_mega_moe = getattr(kernel, "transform_weights_for_mega_moe", None) missing = [ name for name, attr in [ - ("fp8_gemm_nt", deepgemm_fp8_matmul), - ("m_grouped_fp8_gemm_nt_contiguous", deepgemm_grouped_fp8_matmul), - ("m_grouped_bf16_gemm_nt_contiguous", deepgemm_grouped_bf16_matmul_nt), - ("m_grouped_bf16_gemm_nn_contiguous", deepgemm_grouped_bf16_matmul_nn), - ("utils.per_token_cast_to_fp8", deepgemm_per_token_cast_to_fp8), + ("fp8_fp4_gemm_nt", fp8_fp4_matmul), + ("m_grouped_fp8_fp4_gemm_nt_contiguous", grouped_fp8_fp4_matmul), + ("m_grouped_bf16_gemm_nt_contiguous", grouped_bf16_matmul_nt), + ("m_grouped_bf16_gemm_nn_contiguous", grouped_bf16_matmul_nn), + ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), + ("SymmBuffer", symm_buffer_cls), + ("fp8_fp4_mega_moe", fp8_fp4_mega_moe), + ("get_symm_buffer_for_mega_moe", get_symm_buffer_for_mega_moe), + ("transform_weights_for_mega_moe", transform_weights_for_mega_moe), ] if attr is None ] @@ -109,38 +122,87 @@ def _load_deepgemm_kernel(): "Please update the `kernels` package (`pip install -U kernels`)." ) - return ( - deepgemm_fp8_matmul, - deepgemm_grouped_fp8_matmul, - deepgemm_grouped_bf16_matmul_nt, - deepgemm_grouped_bf16_matmul_nn, - deepgemm_per_token_cast_to_fp8, + return SimpleNamespace( + fp8_fp4_matmul=fp8_fp4_matmul, + grouped_fp8_fp4_matmul=grouped_fp8_fp4_matmul, + grouped_bf16_matmul_nt=grouped_bf16_matmul_nt, + grouped_bf16_matmul_nn=grouped_bf16_matmul_nn, + per_token_cast_to_fp8=per_token_cast_to_fp8, + transform_weights_for_mega_moe=transform_weights_for_mega_moe, + get_symm_buffer_for_mega_moe=get_symm_buffer_for_mega_moe, + fp8_fp4_mega_moe=fp8_fp4_mega_moe, + symm_buffer_cls=symm_buffer_cls, ) -def fp8_deepgemm_matmul( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - output_dtype: torch.dtype = torch.float32, +def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: + """Normalize a scale-factor tensor for the DeepGEMM kernel boundary. + + Two SF flavors are produced by our path: + - `float32` (DeepSeek V3-style): pass through; the kernel transforms float→int internally + on SM100 to feed the 1D1D path. + - `torch.float8_e8m0fnu` (DeepSeek V4-style, 1 byte per scale): reinterpret 4 contiguous + bytes as one `int32`. No copy; last-dim shrinks 4×. + """ + if sf.dtype == torch.float8_e8m0fnu: + return sf.contiguous().view(torch.int32) + return sf + + +def deepgemm_fp8_fp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale_inv: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, + activation_scale: torch.Tensor | None = None, ) -> torch.Tensor: + """End-to-end DeepGEMM linear: per-token activation quant + FP8/FP4 matmul. + + Activation cast settings are inferred from the tensor dtypes: + - FP4 weights (`weight.dtype == torch.int8`): always gran_k=32 with packed-UE8M0 SF. Requires + SM100+ (Blackwell). + - FP8 weights + UE8M0 weight SFs (`weight_scale_inv.dtype == torch.float8_e8m0fnu`, + DeepSeek V4-style): gran_k=128 with packed-UE8M0 SF (skips the kernel-side float→int SF + transform on SM100). + - FP8 weights + float weight SFs (DeepSeek V3-style): gran_k=128 with float SF (works on + Hopper and Blackwell). + + Static (per-tensor) activation quantization is not supported — DeepGEMM's kernel needs per-row + SFs and rejects scalar SFs at its host-side check. Callers should route static activations + through the Triton fallback. """ - FP8 dense matmul via DeepGEMM's `fp8_gemm_nt`. Block-wise 128x128 scales expected. + if activation_scale is not None: + raise NotImplementedError( + "Static (per-tensor) activation quantization is not supported on the DeepGEMM path. " + "Use the Triton fallback for static activations." + ) - Args: - A: (M, K) float8_e4m3fn — quantized activations - B: (N, K) float8_e4m3fn — quantized weights - As: (M, K//128) float32 — per-block activation scales - Bs: (N//128, K//128) float32 — per-block weight scales - output_dtype: desired output dtype. - """ - deepgemm_fp8_matmul, _, _, _, _ = _load_deepgemm_kernel() - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) + is_fp4 = weight.dtype == torch.int8 + if is_fp4 and torch.cuda.get_device_capability(input.device)[0] < 10: + raise RuntimeError("FP4 weights (int8-packed e2m1) require SM100+ (Blackwell).") + + deepgemm = _load_deepgemm_kernel() + + if is_fp4: + cast_kwargs = {"use_ue8m0": True, "gran_k": 32, "use_packed_ue8m0": True} + elif weight_scale_inv.dtype == torch.float8_e8m0fnu: + cast_kwargs = {"use_ue8m0": True, "gran_k": 128, "use_packed_ue8m0": True} + else: + cast_kwargs = {"use_ue8m0": False, "gran_k": 128} + input_2d = input.view(-1, input.shape[-1]) + qinput_2d, scale_2d = deepgemm.per_token_cast_to_fp8(input_2d, **cast_kwargs) + + output = torch.empty(qinput_2d.shape[0], weight.shape[0], device=input.device, dtype=output_dtype) + deepgemm.fp8_fp4_matmul( + (qinput_2d, _coerce_sf_for_kernel(scale_2d)), + (weight, _coerce_sf_for_kernel(weight_scale_inv)), + output, + ) + output = output.view(input.shape[:-1] + (weight.shape[0],)) + if bias is not None: + output.add_(bias) + return output def _build_deepgemm_contiguous_layout( @@ -202,26 +264,20 @@ def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_pad return x_padded[sorted_to_padded] -def fp8_deepgemm_experts_forward( +def deepgemm_bf16_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, + process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: - if self.activation_scheme == "static": - raise NotImplementedError( - "DeepGEMM experts dispatch does not support activation_scheme='static'. " - "Use the default eager dispatch or switch to activation_scheme='dynamic'." - ) - if self.block_size is None: - raise ValueError( - "DeepGEMM requires block-wise quantization (block_size=[128, 128]), " - "but got per-tensor quantization (block_size=None)." - ) - if self.block_size[0] != 128 or self.block_size[1] != 128: - raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") + if hidden_states.dtype != torch.bfloat16: + raise ValueError(f"DeepGEMM experts path requires bfloat16 hidden states, got {hidden_states.dtype}") - _, deepgemm_grouped_fp8_matmul, _, _, deepgemm_per_token_cast_to_fp8 = _load_deepgemm_kernel() + # Non-transposed HF experts have weight layout (E, N, K) -> NT kernel. + # Transposed HF experts have weight layout (E, K, N) -> NN kernel. + deepgemm = _load_deepgemm_kernel() + grouped_bf16_matmul = deepgemm.grouped_bf16_matmul_nn if self.is_transposed else deepgemm.grouped_bf16_matmul_nt device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -245,16 +301,24 @@ def fp8_deepgemm_experts_forward( expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) - # --- Up projection per expert (DeepGEMM grouped contiguous) --- + if self.has_bias: + # Clamp now that the layout has been built — needed for the per-row bias gather below to stay + # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. + expert_ids_g.clamp_(0, self.num_experts - 1) + + # --- Up projection per expert (DeepGEMM grouped contiguous, bf16) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj - ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv - act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) - act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) - act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) - proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) - deepgemm_grouped_fp8_matmul( - (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) + # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). + up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] + act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) + proj_out = torch.empty(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) + grouped_bf16_matmul(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) + + # The kernel has no bias input -> add per-expert bias in-place on the unpadded slice; + # padding rows get discarded at unpad time. + if self.has_bias: + up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias + proj_out.index_add_(0, sorted_to_padded, up_bias[expert_ids_g]) # Apply gating or activation if self.has_gate: @@ -262,25 +326,21 @@ def fp8_deepgemm_experts_forward( else: proj_out = self.act_fn(proj_out) - # --- Down projection per expert (DeepGEMM grouped contiguous) --- - proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) - proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) - deepgemm_grouped_fp8_matmul( - (proj_fp8, proj_scales), - (self.down_proj, self.down_proj_scale_inv.float()), - proj_out, - grouped_layout, - use_psum_layout=use_psum_layout, - ) + # --- Down projection per expert (DeepGEMM grouped contiguous, bf16) --- + out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) + grouped_bf16_matmul(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) + + if self.has_bias: + out.index_add_(0, sorted_to_padded, self.down_proj_bias[expert_ids_g]) # Remove padding rows - proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) + out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) # Apply routing weights - weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + weighted_out = out * sample_weights_g.to(out.dtype).unsqueeze(-1) # (S, hidden_dim) - # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, - # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # EP sentinel handling: `out` rows past the valid expert blocks are left uninitialized by the kernel, + # so `out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here # so the downstream reduction stays finite even when the routing weight was already zero. weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) @@ -296,27 +356,56 @@ def fp8_deepgemm_experts_forward( return final_hidden_states.to(hidden_states.dtype) -def deepgemm_experts_forward( +def deepgemm_fp8_fp4_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, + process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: - if hidden_states.dtype != torch.bfloat16: - raise ValueError(f"DeepGEMM experts path requires bfloat16 hidden states, got {hidden_states.dtype}") + if self.activation_scheme == "static": + raise NotImplementedError( + "DeepGEMM experts dispatch does not support activation_scheme='static'. " + "Use the default eager dispatch or switch to activation_scheme='dynamic'." + ) - # Non-transposed HF experts have weight layout (E, N, K) -> NT kernel. - # Transposed HF experts have weight layout (E, K, N) -> NN kernel. - _, _, deepgemm_grouped_bf16_matmul_nt, deepgemm_grouped_bf16_matmul_nn, _ = _load_deepgemm_kernel() - deepgemm_grouped_bf16_matmul = ( - deepgemm_grouped_bf16_matmul_nn if self.is_transposed else deepgemm_grouped_bf16_matmul_nt - ) + deepgemm = _load_deepgemm_kernel() device = hidden_states.device num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) hidden_dim = hidden_states.size(-1) + # FP4 weights are int8-packed (2 e2m1 values per byte; `kPackedFP4 == torch::kInt8` in DeepGEMM). + # `m_grouped_fp8_fp4_gemm_nt_contiguous` accepts both FP8 and FP4 weight dtypes. Activation cast + # tracks (weight dtype, weight SF dtype), mirroring `deepgemm_fp8_fp4_linear`: + # - FP4 weights: gran_k=32 packed-UE8M0 SF (SM100+). + # - FP8 weights + UE8M0 SFs: gran_k=128 packed-UE8M0 SF (skips the kernel-side float→int + # transform on SM100). + # - FP8 weights + float SFs: gran_k=128 float SF (Hopper or Blackwell). + w_up = self.gate_up_proj if self.has_gate else self.up_proj + ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv + is_fp4_weights = w_up.dtype == torch.int8 + + if is_fp4_weights: + if torch.cuda.get_device_capability(device)[0] < 10: + raise RuntimeError( + "FP4 expert weights (int8-packed e2m1) require SM100+ (Blackwell); use FP8 weights on Hopper." + ) + cast_kwargs = {"use_ue8m0": True, "gran_k": 32, "use_packed_ue8m0": True} + else: + if self.block_size is None: + raise ValueError( + "DeepGEMM requires block-wise quantization (block_size=[128, 128]), " + "but got per-tensor quantization (block_size=None)." + ) + if self.block_size[0] != 128 or self.block_size[1] != 128: + raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") + if ws_up.dtype == torch.float8_e8m0fnu: + cast_kwargs = {"use_ue8m0": True, "gran_k": 128, "use_packed_ue8m0": True} + else: + cast_kwargs = {"use_ue8m0": False, "gran_k": 128} + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) @@ -334,24 +423,18 @@ def deepgemm_experts_forward( expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) - if self.has_bias: - # Clamp now that the layout has been built — needed for the per-row bias gather below to stay - # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. - expert_ids_g.clamp_(0, self.num_experts - 1) - - # --- Up projection per expert (DeepGEMM grouped contiguous, bf16) --- - w_up = self.gate_up_proj if self.has_gate else self.up_proj - # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). - up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] - act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) - proj_out = torch.empty(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) - deepgemm_grouped_bf16_matmul(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) - - # The kernel has no bias input -> add per-expert bias in-place on the unpadded slice; - # padding rows get discarded at unpad time. - if self.has_bias: - up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias - proj_out.index_add_(0, sorted_to_padded, up_bias[expert_ids_g]) + # --- Up projection per expert (DeepGEMM grouped contiguous) --- + act_fp8, act_scales = deepgemm.per_token_cast_to_fp8(selected_hidden_states_g, **cast_kwargs) + act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) + act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) + proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) + deepgemm.grouped_fp8_fp4_matmul( + (act_fp8, act_scales), + (w_up, _coerce_sf_for_kernel(ws_up)), + proj_out, + grouped_layout, + use_psum_layout=use_psum_layout, + ) # Apply gating or activation if self.has_gate: @@ -359,21 +442,25 @@ def deepgemm_experts_forward( else: proj_out = self.act_fn(proj_out) - # --- Down projection per expert (DeepGEMM grouped contiguous, bf16) --- - out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) - deepgemm_grouped_bf16_matmul(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) - - if self.has_bias: - out.index_add_(0, sorted_to_padded, self.down_proj_bias[expert_ids_g]) + # --- Down projection per expert (DeepGEMM grouped contiguous) --- + proj_fp8, proj_scales = deepgemm.per_token_cast_to_fp8(proj_out, **cast_kwargs) + proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + deepgemm.grouped_fp8_fp4_matmul( + (proj_fp8, proj_scales), + (self.down_proj, _coerce_sf_for_kernel(self.down_proj_scale_inv)), + proj_out, + grouped_layout, + use_psum_layout=use_psum_layout, + ) # Remove padding rows - out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) + proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) # Apply routing weights - weighted_out = out * sample_weights_g.to(out.dtype).unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) - # EP sentinel handling: `out` rows past the valid expert blocks are left uninitialized by the kernel, - # so `out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here # so the downstream reduction stays finite even when the routing weight was already zero. weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) @@ -387,3 +474,99 @@ def deepgemm_experts_forward( final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) return final_hidden_states.to(hidden_states.dtype) + + +def deepgemm_fp8_fp4_megamoe_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + process_group: torch.distributed.ProcessGroup | None = None, +) -> torch.Tensor: + """FP8 acts × FP4 weights Mega MoE forward via DeepGEMM. + + Fuses EP dispatch + L1 (FP8×FP4) + SwiGLU + L2 (FP8×FP4) + EP combine into a single + kernel, overlapping NVLink communication with tensor-core compute. The kernel handles + the full `(num_tokens, hidden) -> (num_tokens, hidden)` MoE forward including the + weighted top-k reduction; the caller must NOT all-reduce the output (the EP combine is + already inside the kernel). + + `process_group` (the EP group) is passed in by `MoeTensorParalellExperts._prepare_input_fn` + when the module is wrapped for TP — it is required for the symmetric-buffer rendezvous on + first forward. + + Caller-managed attributes on `self` (this dispatch does no quantization or weight + transformation — assume they are pre-set on the module): + - `gate_up_proj`: int8-packed FP4 L1 weight, + shape `(num_experts_per_rank, intermediate_hidden * 2, hidden // 2)`, + interleaved gate/up via `transform_weights_for_mega_moe`. + - `gate_up_proj_scale_inv`: int-packed UE8M0 SF for L1, UTCCP-transposed via + `transform_weights_for_mega_moe`. + - `down_proj`, `down_proj_scale_inv`: same conventions for L2. + + The `SymmBuffer` is lazily allocated on first call (and re-allocated if a later call + has more tokens than the cached buffer). The SwiGLU clamp is read from + `self.config.swiglu_limit` if present, otherwise the kernel runs unclamped. + + Args: + hidden_states: bf16 `(num_tokens, hidden)`. + top_k_index: int `(num_tokens, num_topk)` of GLOBAL expert ids; -1 marks skipped + slots (the kernel ignores them). Note: this differs from the `RouterParallel` + output used by the other dispatches, which remaps indices to local + sentinel. + top_k_weights: float `(num_tokens, num_topk)` routing weights. + + Returns: + `(num_tokens, hidden)` in `hidden_states.dtype` (already weighted-summed across + topk and reduced across EP ranks). + """ + # Mega MoE is Blackwell-only — the impl is `sm100_fp8_fp4_mega_moe.cuh` and there is + # no SM90 path. Use the regular "deepgemm" dispatch on Hopper. + if torch.cuda.get_device_capability(hidden_states.device)[0] < 10: + raise RuntimeError("DeepGEMM Mega MoE requires SM100+ (Blackwell). The 'deepgemm' dispatch supports SM90+.") + + deepgemm = _load_deepgemm_kernel() + + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + num_experts = self.gate_up_proj.size(0) + intermediate_hidden = self.gate_up_proj.size(1) // 2 + activation_clamp = getattr(getattr(self, "config", None), "swiglu_limit", None) + + # Lazily allocate the symmetric buffer on first call (re-allocate if the cached buffer is + # too small for this call). `process_group` is threaded in by `MoeTensorParalellExperts`. + if getattr(self, "symm_buffer", None) is None or self.symm_buffer.num_max_tokens_per_rank < num_tokens: + if process_group is None: + raise ValueError( + "DeepGEMM Mega MoE requires a `process_group` for the EP group; the TP wrapping " + "(`MoeTensorParalellExperts`) supplies it automatically. If you are calling this " + "dispatch directly, pass `process_group=...` explicitly." + ) + self.symm_buffer = deepgemm.get_symm_buffer_for_mega_moe( + process_group, + hidden=hidden_dim, + num_topk=num_top_k, + num_experts=num_experts, + num_max_tokens_per_rank=num_tokens, + intermediate_hidden=intermediate_hidden, + ) + + # Quantize activations to FP8 with packed UE8M0 per-32 SF — the layout the kernel expects. + x_fp8, x_sf = deepgemm.per_token_cast_to_fp8(hidden_states, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + + # Stage inputs into the symmetric buffer; the kernel reads from there during dispatch. + self.symm_buffer.x[:num_tokens].copy_(x_fp8) + self.symm_buffer.x_sf[:num_tokens].copy_(x_sf) + self.symm_buffer.topk_idx[:num_tokens].copy_(top_k_index) + self.symm_buffer.topk_weights[:num_tokens].copy_(top_k_weights) + + y = torch.empty((num_tokens, hidden_dim), dtype=torch.bfloat16, device=hidden_states.device) + deepgemm.fp8_fp4_mega_moe( + y, + (self.gate_up_proj, self.gate_up_proj_scale_inv), + (self.down_proj, self.down_proj_scale_inv), + self.symm_buffer, + activation_clamp=activation_clamp, + ) + + return y.to(hidden_states.dtype) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index dce3159a3bd7..01083f760d70 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -14,6 +14,7 @@ from __future__ import annotations import functools +from types import SimpleNamespace import torch import torch.nn as nn @@ -24,7 +25,11 @@ from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging from ..utils.import_utils import is_kernels_available -from .deepgemm import fp8_deepgemm_experts_forward, fp8_deepgemm_matmul +from .deepgemm import ( + deepgemm_fp8_fp4_experts_forward, + deepgemm_fp8_fp4_linear, + deepgemm_fp8_fp4_megamoe_experts_forward, +) from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation @@ -35,6 +40,7 @@ _FP8_DTYPE = torch.float8_e4m3fn _FP8_MIN = torch.finfo(_FP8_DTYPE).min _FP8_MAX = torch.finfo(_FP8_DTYPE).max +_UE8M0_SF_DTYPE = torch.float8_e8m0fnu def _first_attr(obj, *names): @@ -45,17 +51,11 @@ def _first_attr(obj, *names): @functools.cache -def _load_triton_kernel(): +def _load_finegrained_fp8_kernel() -> SimpleNamespace: """ - Load the finegrained-fp8 Triton kernel once and return its required symbols. + Load the finegrained-fp8 Triton kernel once and return its entry points as a `SimpleNamespace`. - Raises: - ImportError if the `kernels` package is missing, or the kernel or required - symbols cannot be found. - - Returns: - Tuple of (w8a8_fp8_matmul, fp8_act_quant, w8a8_fp8_matmul_batched, - w8a8_fp8_matmul_grouped) from the finegrained-fp8 kernel. + Raises `ImportError` if the `kernels` package is missing or any required entry point is absent. """ if not is_kernels_available(): raise ImportError( @@ -69,18 +69,18 @@ def _load_triton_kernel(): "has a build matching the current torch/CUDA." ) - triton_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul", None) - triton_fp8_act_quant = getattr(kernel, "fp8_act_quant", None) - triton_batched_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_batched", None) - triton_grouped_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_grouped", None) + matmul = getattr(kernel, "w8a8_fp8_matmul", None) + act_quant = getattr(kernel, "fp8_act_quant", None) + batched_matmul = getattr(kernel, "w8a8_fp8_matmul_batched", None) + grouped_matmul = getattr(kernel, "w8a8_fp8_matmul_grouped", None) missing = [ name for name, attr in [ - ("w8a8_fp8_matmul", triton_fp8_matmul), - ("fp8_act_quant", triton_fp8_act_quant), - ("w8a8_fp8_matmul_batched", triton_batched_fp8_matmul), - ("w8a8_fp8_matmul_grouped", triton_grouped_fp8_matmul), + ("w8a8_fp8_matmul", matmul), + ("fp8_act_quant", act_quant), + ("w8a8_fp8_matmul_batched", batched_matmul), + ("w8a8_fp8_matmul_grouped", grouped_matmul), ] if attr is None ] @@ -90,7 +90,12 @@ def _load_triton_kernel(): "Please update the `kernels` package (`pip install -U kernels`)." ) - return triton_fp8_matmul, triton_fp8_act_quant, triton_batched_fp8_matmul, triton_grouped_fp8_matmul + return SimpleNamespace( + matmul=matmul, + act_quant=act_quant, + batched_matmul=batched_matmul, + grouped_matmul=grouped_matmul, + ) def _cdiv(a: int, b: int) -> int: @@ -98,36 +103,108 @@ def _cdiv(a: int, b: int) -> int: return (a + b - 1) // b -def w8a8_fp8_matmul( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype = torch.float32, +def _alloc_expert_proj( + num_experts: int, + proj_out: int, + proj_in: int, + weight_dtype: torch.dtype, + sf_dtype: torch.dtype, + weight_k_div: int = 1, + sf_gran_n: int | None = None, + sf_gran_k: int | None = None, +) -> tuple[nn.Parameter, nn.Parameter]: + """Allocate `(weight, weight_scale_inv)` parameters for one expert projection. + + `weight_k_div` halves the K dim for FP4-packed storage (2 e2m1 values per byte). + `sf_gran_n` / `sf_gran_k` set per-block (None → per-row/per-tensor) SF granularity. + """ + weight_t = torch.empty(num_experts, proj_out, proj_in // weight_k_div, dtype=weight_dtype) + weight = nn.Parameter(weight_t, requires_grad=weight_t.is_floating_point()) + sf_out = _cdiv(proj_out, sf_gran_n) if sf_gran_n is not None else 1 + sf_in = _cdiv(proj_in, sf_gran_k) if sf_gran_k is not None else 1 + sf_t = torch.empty(num_experts, sf_out, sf_in, dtype=sf_dtype) + sf = nn.Parameter(sf_t, requires_grad=sf_t.is_floating_point()) + return weight, sf + + +def finegrained_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale_inv: torch.Tensor, + block_size: list[int] | None = None, + bias: torch.Tensor | None = None, + activation_scale: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: - """FP8 matmul: C = dequant(A, As) @ dequant(B, Bs)^T. + """End-to-end Triton FP8 linear: per-token (or static per-tensor) act-quant + matmul + bias. + + Triton has no FP4 path — caller must guard FP4 weights before reaching here. + """ + finegrained_fp8 = _load_finegrained_fp8_kernel() + if activation_scale is not None: + scale = activation_scale.to(torch.float32) + qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) + else: + gran_k = block_size[1] if block_size is not None else input.shape[-1] + qinput, scale = finegrained_fp8.act_quant(input, gran_k) - Supports both per-tensor and block-wise quantization: - - block_size=None or block_size=[N, K]: per-tensor mode (As is scalar/per-row, Bs is scalar) - - block_size=[block_n, block_k]: block-wise mode (As and Bs are per-block scale grids) + output = finegrained_fp8.matmul(qinput, weight, scale, weight_scale_inv, block_size, output_dtype) + + if bias is not None: + output.add_(bias) + + return output + + +def fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale_inv: torch.Tensor, + block_size: list[int] | None = None, + bias: torch.Tensor | None = None, + activation_scale: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """End-to-end FP8/FP4 linear used by `FP8Linear` and the eager `FP8Experts` loop. Dispatch order: - 1. DeepGEMM (Hopper+, block_size 128x128) if available - 2. Triton finegrained-fp8 kernel (universal fallback) + 1. DeepGEMM full pipeline (`deepgemm_fp8_fp4_linear`) — handles both FP8 (`float8_e4m3fn`) + and FP4 (`int8`-packed e2m1) weights, paired with the matching activation cast inside. + 3-6x faster than Triton on FP8; required for FP4 and for UE8M0 (`float8_e8m0fnu`) SFs. + 2. Triton finegrained-fp8 fallback (FP8 weights + float SFs) — applies on `ImportError` from + the DeepGEMM path or for static activations (DeepGEMM is dynamic-only). Raises if FP4 + weights or UE8M0 SFs reach this branch since Triton can't handle them. Args: - A: (M, K) float8_e4m3fn — quantized activations - B: (N, K) float8_e4m3fn — quantized weights - As: block-wise: (M, K//block_k) float32; per-tensor: (M,) per-row scales - Bs: block-wise: (N//block_n, K//block_k) float32; per-tensor: scalar or (1,) single weight scale - block_size: [block_n, block_k] for block-wise quantization, or None/[N, K] for per-tensor - output_dtype: desired output dtype + input: (..., K) bf16/fp16 activations. + weight: (N, K) `float8_e4m3fn` or (N, K // 2) `int8` (FP4-packed). + weight_scale_inv: per-block weight scales — `float32` (V3-style) or `float8_e8m0fnu` + (V4-style; reinterpreted as int32 at the DeepGEMM kernel boundary). + block_size: [block_n, block_k] for FP8 block-wise quant, or None/[N, K] for per-tensor. + Ignored for FP4 weights (the kernel infers SF granularity from the dtype). + bias: optional bias added to the matmul output. + activation_scale: pass a per-tensor scalar to use static activation quant; leave `None` + for dynamic (per-token) quant. + output_dtype: desired output dtype. """ - if block_size is not None and block_size[0] == block_size[1] == 128: + # Triton handles only FP8 weights + float SFs. FP4 weights and/or UE8M0 SFs (DeepSeek V4) + # must take the DeepGEMM path. Static activation (per-tensor scalar) is Triton-only — DeepGEMM's + # kernel expects per-row SFs and rejects scalar SFs at its host-side check. + deepgemm_required = weight.dtype == torch.int8 or weight_scale_inv.dtype == torch.float8_e8m0fnu + deepgemm_compatible = activation_scale is None and ( + deepgemm_required or (block_size is not None and block_size[0] == block_size[1] == 128) + ) + + if deepgemm_compatible: try: - # 3-6x faster than Triton - return fp8_deepgemm_matmul(A, B, As, Bs, output_dtype=output_dtype) + return deepgemm_fp8_fp4_linear( + input, + weight, + weight_scale_inv, + output_dtype=output_dtype, + activation_scale=activation_scale, + bias=bias, + ) except ImportError: logger.warning_once( "DeepGEMM kernel is not available or compatible, falling back to Triton finegrained-fp8 kernel. " @@ -135,8 +212,21 @@ def w8a8_fp8_matmul( "and that the `kernels` package is installed and up to date (`pip install -U kernels`)." ) - triton_fp8_matmul, _, _, _ = _load_triton_kernel() - return triton_fp8_matmul(A, B, As, Bs, block_size, output_dtype) + if deepgemm_required: + if activation_scale is not None: + raise RuntimeError( + "Static (per-tensor) activation quantization is not supported with FP4 weights or " + "UE8M0 weight scales — DeepGEMM expects per-row SFs and the Triton fallback can't " + "handle these formats. Use dynamic activation quantization instead." + ) + raise RuntimeError( + "FP4 weights and/or UE8M0 weight scales require the DeepGEMM path; the Triton fallback " + "handles FP8 weights with float32 SFs only. Make sure your system is compatible with the " + "DeepGEMM path: SM90+ GPU (SM100+ for FP4), CUDA runtime 12.3+, PyTorch ≥2.6, and the " + "`kernels` package installed." + ) + + return finegrained_fp8_linear(input, weight, weight_scale_inv, block_size, bias, activation_scale, output_dtype) class FP8Linear(nn.Linear): @@ -146,6 +236,7 @@ def __init__( out_features: int, block_size: tuple[int, int] | None = None, activation_scheme: str = "dynamic", + scale_fmt: str = "float", has_bias: bool = False, dtype=_FP8_DTYPE, ): @@ -160,11 +251,10 @@ def __init__( # If block size is None, it means that we are doing per-tensor quantization self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) else: + sf_dtype = _UE8M0_SF_DTYPE if scale_fmt == "ue8m0" else torch.float32 scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0] scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1] - self.weight_scale_inv = nn.Parameter( - torch.empty(scale_out_features, scale_in_features, dtype=torch.float32) - ) + self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=sf_dtype)) if self.activation_scheme == "static": self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) @@ -188,37 +278,23 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: weight = self.weight.contiguous() scale_inv = self.weight_scale_inv.contiguous() - if self.activation_scheme == "dynamic": - _, triton_fp8_act_quant, _, _ = _load_triton_kernel() - qinput, scale = triton_fp8_act_quant( - input, self.block_size[1] if self.block_size is not None else input.shape[-1] - ) - elif self.activation_scheme == "static": - scale = self.activation_scale.to(torch.float32) - qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) - else: - raise NotImplementedError(f"Unsupported activation scheme: {self.activation_scheme}") - - output = w8a8_fp8_matmul( - qinput, + return fp8_linear( + input, weight, - scale, scale_inv, - self.block_size, + block_size=self.block_size, + activation_scale=self.activation_scale, output_dtype=input.dtype, + bias=self.bias, ) - if self.bias is not None: - output.add_(self.bias) - - return output.to(dtype=input.dtype) - def fp8_batched_mm_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, + process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: if self.activation_scheme == "static": raise NotImplementedError( @@ -226,7 +302,14 @@ def fp8_batched_mm_experts_forward( "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) - _, _, triton_batched_fp8_matmul, _ = _load_triton_kernel() + w_up = self.gate_up_proj if self.has_gate else self.up_proj + if w_up.dtype == torch.int8: + raise NotImplementedError( + "'batched_mm' experts dispatch is Triton-only and does not support FP4 (int8-packed) " + "expert weights. Use experts_implementation='deepgemm' instead." + ) + + finegrained_fp8 = _load_finegrained_fp8_kernel() num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) @@ -244,7 +327,7 @@ def fp8_batched_mm_experts_forward( expert_ids.clamp_(0, self.num_experts - 1) # --- Up projection per expert (FP8 batched) --- - proj_out = triton_batched_fp8_matmul( + proj_out = finegrained_fp8.batched_matmul( selected_hidden_states, self.gate_up_proj if self.has_gate else self.up_proj, self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv, @@ -261,7 +344,7 @@ def fp8_batched_mm_experts_forward( proj_out = self.act_fn(proj_out) # (S, intermediate_dim) # --- Down projection per expert (FP8 batched) --- - proj_out = triton_batched_fp8_matmul( + proj_out = finegrained_fp8.batched_matmul( proj_out, self.down_proj, self.down_proj_scale_inv, @@ -284,6 +367,7 @@ def fp8_grouped_mm_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, + process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: if self.activation_scheme == "static": raise NotImplementedError( @@ -291,7 +375,14 @@ def fp8_grouped_mm_experts_forward( "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) - _, _, _, triton_grouped_fp8_matmul = _load_triton_kernel() + w_up = self.gate_up_proj if self.has_gate else self.up_proj + if w_up.dtype == torch.int8: + raise NotImplementedError( + "'grouped_mm' experts dispatch is Triton-only and does not support FP4 (int8-packed) " + "expert weights. Use experts_implementation='deepgemm' instead." + ) + + finegrained_fp8 = _load_finegrained_fp8_kernel() device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -320,7 +411,7 @@ def fp8_grouped_mm_experts_forward( offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) # --- Up projection per expert (FP8 grouped) --- - proj_out = triton_grouped_fp8_matmul( + proj_out = finegrained_fp8.grouped_matmul( selected_hidden_states_g, self.gate_up_proj if self.has_gate else self.up_proj, self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv, @@ -338,7 +429,7 @@ def fp8_grouped_mm_experts_forward( proj_out = self.act_fn(proj_out) # (S, intermediate_dim) # --- Down projection per expert (FP8 grouped) --- - proj_out = triton_grouped_fp8_matmul( + proj_out = finegrained_fp8.grouped_matmul( proj_out, self.down_proj, self.down_proj_scale_inv, @@ -373,6 +464,7 @@ def __init__( config, block_size: tuple[int, int] | None = None, activation_scheme: str = "dynamic", + scale_fmt: str = "float", has_bias: bool = False, has_gate: bool = True, dtype=_FP8_DTYPE, @@ -393,31 +485,40 @@ def __init__( self.intermediate_dim = _first_attr(config, "moe_intermediate_size", "intermediate_size") self.act_fn = ACT2FN[_first_attr(config, "hidden_activation", "hidden_act")] + # Expert weight precision is FP8 by default; DeepSeek V4-style models declare + # `config.expert_dtype = "fp4"` for FP4-packed expert weights. FP4 storage: + # - weight is `int8`, K dim halved (2 e2m1 values per byte). + # - SF is `float8_e8m0fnu` per-row at gran_k=32 (no block-wise SF; `block_size` ignored). + is_fp4 = getattr(config, "expert_dtype", "fp8") == "fp4" + if is_fp4: + alloc_kwargs = { + "weight_dtype": torch.int8, + "sf_dtype": _UE8M0_SF_DTYPE, + "weight_k_div": 2, + "sf_gran_n": 1, + "sf_gran_k": 32, + } + else: + alloc_kwargs = { + "weight_dtype": dtype, + "sf_dtype": _UE8M0_SF_DTYPE if scale_fmt == "ue8m0" else torch.float32, + "sf_gran_n": block_size[0] if block_size is not None else None, + "sf_gran_k": block_size[1] if block_size is not None else None, + } + if self.has_gate: - gu_proj_out, gu_proj_in = 2 * self.intermediate_dim, self.hidden_dim - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, gu_proj_out, gu_proj_in, dtype=dtype)) - gu_scale_out = _cdiv(gu_proj_out, self.block_size[0]) if self.block_size is not None else 1 - gu_scale_in = _cdiv(gu_proj_in, self.block_size[1]) if self.block_size is not None else 1 - self.gate_up_proj_scale_inv = nn.Parameter( - torch.empty(self.num_experts, gu_scale_out, gu_scale_in, dtype=torch.float32) + self.gate_up_proj, self.gate_up_proj_scale_inv = _alloc_expert_proj( + self.num_experts, 2 * self.intermediate_dim, self.hidden_dim, **alloc_kwargs ) self.register_parameter("gate_up_proj_bias", None) else: - u_proj_out, u_proj_in = self.intermediate_dim, self.hidden_dim - self.up_proj = nn.Parameter(torch.empty(self.num_experts, u_proj_out, u_proj_in, dtype=dtype)) - u_scale_out = _cdiv(u_proj_out, self.block_size[0]) if self.block_size is not None else 1 - u_scale_in = _cdiv(u_proj_in, self.block_size[1]) if self.block_size is not None else 1 - self.up_proj_scale_inv = nn.Parameter( - torch.empty(self.num_experts, u_scale_out, u_scale_in, dtype=torch.float32) + self.up_proj, self.up_proj_scale_inv = _alloc_expert_proj( + self.num_experts, self.intermediate_dim, self.hidden_dim, **alloc_kwargs ) self.register_parameter("up_proj_bias", None) - d_proj_out, d_proj_in = self.hidden_dim, self.intermediate_dim - self.down_proj = nn.Parameter(torch.empty(self.num_experts, d_proj_out, d_proj_in, dtype=dtype)) - d_scale_out = _cdiv(d_proj_out, self.block_size[0]) if self.block_size is not None else 1 - d_scale_in = _cdiv(d_proj_in, self.block_size[1]) if self.block_size is not None else 1 - self.down_proj_scale_inv = nn.Parameter( - torch.empty(self.num_experts, d_scale_out, d_scale_in, dtype=torch.float32) + self.down_proj, self.down_proj_scale_inv = _alloc_expert_proj( + self.num_experts, self.hidden_dim, self.intermediate_dim, **alloc_kwargs ) self.register_parameter("down_proj_bias", None) @@ -481,24 +582,14 @@ def linear( if weight.element_size() > 1: return F.linear(input, weight, None) - if self.activation_scheme == "static" and activation_scale is not None: - scale = activation_scale.to(torch.float32) - qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) - else: - _, triton_fp8_act_quant, _, _ = _load_triton_kernel() - qinput, scale = triton_fp8_act_quant( - input, self.block_size[1] if self.block_size is not None else input.shape[-1] - ) - - output = w8a8_fp8_matmul( - qinput, + return fp8_linear( + input, weight, - scale, weight_scale_inv, self.block_size, + activation_scale=activation_scale, output_dtype=input.dtype, ) - return output.to(dtype=input.dtype) class FP8ExpertsInterface(ExpertsInterface): @@ -507,7 +598,8 @@ class FP8ExpertsInterface(ExpertsInterface): _global_mapping = { "batched_mm": fp8_batched_mm_experts_forward, "grouped_mm": fp8_grouped_mm_experts_forward, - "deepgemm": fp8_deepgemm_experts_forward, + "deepgemm": deepgemm_fp8_fp4_experts_forward, + "deepgemm_megamoe": deepgemm_fp8_fp4_megamoe_experts_forward, } @@ -525,7 +617,7 @@ def replace_with_fp8_linear( Input model or `torch.nn.Module` as the function is run recursively. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`): Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons. - quantization_config (`FbgemmFp8Config`): + quantization_config (`FineGrainedFP8Config`): The quantization config object that contains the quantization parameters. pre_quantized (`book`, defaults to `False`): Whether the model is pre-quantized or not @@ -557,6 +649,7 @@ def replace_with_fp8_linear( config=config, block_size=quantization_config.weight_block_size, activation_scheme=quantization_config.activation_scheme, + scale_fmt=quantization_config.scale_fmt, has_bias=has_bias, has_gate=has_gate, **module_kwargs, @@ -567,6 +660,7 @@ def replace_with_fp8_linear( out_features=module.out_features, block_size=quantization_config.weight_block_size, activation_scheme=quantization_config.activation_scheme, + scale_fmt=quantization_config.scale_fmt, has_bias=module.bias is not None, **module_kwargs, ) @@ -644,6 +738,15 @@ def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor] quantized = quantized.reshape(original_shape) inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles) + + # If the target is DeepSeek V4-style storage (`scale_fmt="ue8m0"`), round inv_scales to + # UE8M0-representable values (powers of 2) and cast to `float8_e8m0fnu` byte storage so + # the on-disk dtype matches the parameter allocation in `FP8Linear`/`FP8Experts`. + scale_fmt = getattr(self.hf_quantizer.quantization_config, "scale_fmt", "float") + if scale_fmt == "ue8m0": + inv_scales = torch.pow(2.0, torch.ceil(torch.log2(inv_scales.clamp(min=torch.finfo(torch.float32).tiny)))) + inv_scales = inv_scales.to(_UE8M0_SF_DTYPE) + if target_keys.endswith("weight"): scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv" else: @@ -686,9 +789,10 @@ def convert( raise ValueError( f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})." ) - quantized = quantized.to(scales.dtype) - reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n) - expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n) + # Cast both to float32 before the multiplication. Going through `scales.dtype` would + # corrupt the result for V4-style `float8_e8m0fnu` SFs (incompatible with FP8 e4m3 weights). + reshaped = quantized.to(torch.float32).reshape(-1, rows // block_m, block_m, cols // block_n, block_n) + expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n) expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2) dequantized = reshaped * expanded_scales diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 76fb2b7f70ef..30797b262eac 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -24,7 +24,7 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) -from .deepgemm import deepgemm_experts_forward +from .deepgemm import deepgemm_bf16_experts_forward from .sonicmoe import sonicmoe_experts_forward @@ -114,6 +114,7 @@ def batched_mm_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, + process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) @@ -370,6 +371,7 @@ def grouped_mm_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, + process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -463,10 +465,10 @@ class ExpertsInterface(GeneralInterface): """Interface for registering custom experts forward functions.""" _global_mapping = { + "deepgemm": deepgemm_bf16_experts_forward, "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, "sonicmoe": sonicmoe_experts_forward, - "deepgemm": deepgemm_experts_forward, } def get_interface(self, experts_implementation: str, default: Callable) -> Callable: diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index d6eee485fea7..0adf51d30d11 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -95,6 +95,7 @@ def sonicmoe_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, + process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: if not self.has_gate: raise ValueError("sonicmoe requires gated experts (has_gate=True)") diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 82d6d284f052..6c02711fdadf 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1191,7 +1191,12 @@ def _prepare_input_fn(self, mod, inputs, device_mesh): # and partial_expert_output is different on each GPU before all-reduce top_k_weights = all_reduce_backward(top_k_weights, device_mesh) - return (hidden_states, top_k_index, top_k_weights) + # Pass the EP process group through to the experts forward. Used today by DeepGEMM Mega + # MoE for the symmetric-buffer rendezvous; future dispatches can use it for genuine EP + # all-to-all dispatch + combine (replacing the current compute-everywhere + all_reduce + # approach). Dispatches that don't need it accept it via a `process_group=None` default + # arg and ignore it. + return (hidden_states, top_k_index, top_k_weights, device_mesh.get_group()) def _prepare_output_fn(self, mod, outputs, device_mesh): # all_reduce_forward to sum partial expert outputs across GPUs diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index bf085d87498c..95b1a1454034 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1682,6 +1682,12 @@ class FineGrainedFP8Config(QuantizationConfigMixin): The scheme used for activation, the defaults and only support scheme for now is "dynamic". weight_block_size (`typing.tuple[int, int]`, *optional*, defaults to `(128, 128)`): The size of the weight blocks for quantization, default is (128, 128). + scale_fmt (`str`, *optional*, defaults to `"float"`): + Storage dtype of the per-block weight scales: + - `"float"`: fp32 scales (DeepSeek V3-style; Hopper SM90 1D2D path). + - `"ue8m0"`: 1-byte `torch.float8_e8m0fnu` scales (DeepSeek V4-style). At GEMM time + the underlying memory is reinterpreted as `torch.int32` (4 UE8M0 bytes per int), + feeding the SM100 1D1D dispatch directly with no float→int transform per call. dequantize (`bool`, *optional*, defaults to `False`): Whether to dequantize the model during loading. modules_to_not_convert (`list`, *optional*): @@ -1692,8 +1698,9 @@ def __init__( self, activation_scheme: str = "dynamic", weight_block_size: tuple[int, int] = (128, 128), - dequantize: bool = False, modules_to_not_convert: list | None = None, + dequantize: bool = False, + scale_fmt: str = "float", **kwargs, ): self.quant_method = QuantizationMethod.FP8 @@ -1701,6 +1708,7 @@ def __init__( self.activation_scheme = activation_scheme self.weight_block_size = weight_block_size self.dequantize = dequantize + self.scale_fmt = scale_fmt self.post_init() def post_init(self): @@ -1714,6 +1722,8 @@ def post_init(self): raise ValueError("weight_block_size must be a tuple of two integers") if self.weight_block_size is not None and (self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0): raise ValueError("weight_block_size must be a tuple of two positive integers") + if self.scale_fmt not in ("float", "ue8m0"): + raise ValueError(f"scale_fmt must be 'float' or 'ue8m0'; got {self.scale_fmt!r}") def get_loading_attributes(self): return {"dequantize": self.dequantize} From b502cd62d02519aa01d92ea35629237d542afc3d Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Wed, 29 Apr 2026 14:06:32 +0000 Subject: [PATCH 26/87] use package for now --- src/transformers/integrations/deepgemm.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 88ee8379b5d0..d3e62b258740 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -84,12 +84,21 @@ def _load_deepgemm_kernel() -> SimpleNamespace: "Please upgrade your CUDA toolkit or use a different `experts_implementation`." ) - kernel = lazy_load_kernel("deep-gemm") - if kernel is None: - raise ImportError( - "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " - "has a build matching the current torch/CUDA." - ) + try: + import deep_gemm as kernel + except ImportError: + if not is_kernels_available(): + raise ImportError( + "DeepGEMM requires either the `deep_gemm` package (`pip install -U deep-gemm`) or " + "the `kernels` package (`pip install -U kernels`) for the `kernels-community/deep-gemm` " + "hub build." + ) from None + kernel = lazy_load_kernel("deep-gemm") + if kernel is None: + raise ImportError( + "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " + "has a build matching the current torch/CUDA." + ) from None fp8_fp4_matmul = getattr(kernel, "fp8_fp4_gemm_nt", None) grouped_fp8_fp4_matmul = getattr(kernel, "m_grouped_fp8_fp4_gemm_nt_contiguous", None) From 1c17452830bb8e5578b372b0c1418b88548069d3 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Wed, 29 Apr 2026 14:40:35 +0000 Subject: [PATCH 27/87] skip ep router and experts pre/post processing --- .../integrations/tensor_parallel.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 6c02711fdadf..258bc7dbf7f2 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1136,6 +1136,14 @@ def _prepare_output_fn(self, mod, outputs, device_mesh): + masked in grouped_mm/batched_mm. After the expert forward, an all_reduce sums partial outputs across EP ranks to produce the full result. """ + # Mega MoE: keep the router's raw output. The router still runs (it produces topk_idx + # / topk_weights per token), but we skip the EP-time post-processing — Mega MoE's kernel + # does the EP token dispatch itself and needs GLOBAL expert ids with unmasked routing + # weights. Mirrored on the experts side by `MoeTensorParalellExperts._prepare_output_fn` + # which skips the post-forward all_reduce. + if getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe": + return outputs + ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() num_experts = getattr(mod, "num_experts", None) if num_experts is None: @@ -1178,6 +1186,12 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def _prepare_input_fn(self, mod, inputs, device_mesh): + # Mega MoE handles EP dispatch + combine inside the kernel — no PyTorch-level cross-rank + # bookkeeping is needed at this boundary. Pass inputs through unchanged, just tack on the + # EP `process_group` for the kernel's symm-buffer rendezvous on first forward. + if getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe": + return (*inputs, device_mesh.get_group()) + # inputs = (hidden_states, top_k_index, top_k_weights) hidden_states = inputs[0] top_k_index = inputs[1] @@ -1199,6 +1213,10 @@ def _prepare_input_fn(self, mod, inputs, device_mesh): return (hidden_states, top_k_index, top_k_weights, device_mesh.get_group()) def _prepare_output_fn(self, mod, outputs, device_mesh): + # Mega MoE handles the EP combine inside the kernel — output is already fully reduced. + if getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe": + return outputs + # all_reduce_forward to sum partial expert outputs across GPUs return all_reduce_forward(outputs, device_mesh) From 5a8ceaeae0c31ba280fa906c830fd0b9d25a6cc4 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Wed, 29 Apr 2026 14:53:42 +0000 Subject: [PATCH 28/87] simpler --- src/transformers/integrations/deepgemm.py | 8 +++++--- .../integrations/finegrained_fp8.py | 2 -- src/transformers/integrations/moe.py | 2 -- src/transformers/integrations/sonicmoe.py | 1 - .../integrations/tensor_parallel.py | 18 ++++++------------ 5 files changed, 11 insertions(+), 20 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index d3e62b258740..187f8988f616 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -278,7 +278,6 @@ def deepgemm_bf16_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, - process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: if hidden_states.dtype != torch.bfloat16: raise ValueError(f"DeepGEMM experts path requires bfloat16 hidden states, got {hidden_states.dtype}") @@ -370,7 +369,6 @@ def deepgemm_fp8_fp4_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, - process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: if self.activation_scheme == "static": raise NotImplementedError( @@ -551,11 +549,15 @@ def deepgemm_fp8_fp4_megamoe_experts_forward( "(`MoeTensorParalellExperts`) supplies it automatically. If you are calling this " "dispatch directly, pass `process_group=...` explicitly." ) + # `gate_up_proj.size(0)` is the per-rank expert count after `GroupedGemmParallel` + # sharding; the buffer needs the GLOBAL count (kernel asserts `num_experts % num_ranks + # == 0` and computes the per-rank slice itself). + num_experts_global = num_experts * process_group.size() self.symm_buffer = deepgemm.get_symm_buffer_for_mega_moe( process_group, hidden=hidden_dim, num_topk=num_top_k, - num_experts=num_experts, + num_experts=num_experts_global, num_max_tokens_per_rank=num_tokens, intermediate_hidden=intermediate_hidden, ) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 01083f760d70..c4daa9d4a947 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -294,7 +294,6 @@ def fp8_batched_mm_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, - process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: if self.activation_scheme == "static": raise NotImplementedError( @@ -367,7 +366,6 @@ def fp8_grouped_mm_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, - process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: if self.activation_scheme == "static": raise NotImplementedError( diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 84983357d591..62d81832763e 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -120,7 +120,6 @@ def batched_mm_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, - process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) @@ -377,7 +376,6 @@ def grouped_mm_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, - process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: device = hidden_states.device num_top_k = top_k_index.size(-1) diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index 85cb6160863c..912b98655519 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -136,7 +136,6 @@ def sonicmoe_experts_forward( hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, - process_group: torch.distributed.ProcessGroup | None = None, # noqa: ARG001 (unused; for dispatch ABI) ) -> torch.Tensor: if not self.has_gate: raise ValueError("sonicmoe requires gated experts (has_gate=True)") diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 258bc7dbf7f2..06b970a85ee4 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1186,12 +1186,6 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def _prepare_input_fn(self, mod, inputs, device_mesh): - # Mega MoE handles EP dispatch + combine inside the kernel — no PyTorch-level cross-rank - # bookkeeping is needed at this boundary. Pass inputs through unchanged, just tack on the - # EP `process_group` for the kernel's symm-buffer rendezvous on first forward. - if getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe": - return (*inputs, device_mesh.get_group()) - # inputs = (hidden_states, top_k_index, top_k_weights) hidden_states = inputs[0] top_k_index = inputs[1] @@ -1205,12 +1199,12 @@ def _prepare_input_fn(self, mod, inputs, device_mesh): # and partial_expert_output is different on each GPU before all-reduce top_k_weights = all_reduce_backward(top_k_weights, device_mesh) - # Pass the EP process group through to the experts forward. Used today by DeepGEMM Mega - # MoE for the symmetric-buffer rendezvous; future dispatches can use it for genuine EP - # all-to-all dispatch + combine (replacing the current compute-everywhere + all_reduce - # approach). Dispatches that don't need it accept it via a `process_group=None` default - # arg and ignore it. - return (hidden_states, top_k_index, top_k_weights, device_mesh.get_group()) + # Mega MoE handles EP dispatch + combine inside the kernel — append the EP `process_group` + # so the forward can rendezvous the symm-buffer on first call. + if getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe": + return hidden_states, top_k_index, top_k_weights, device_mesh.get_group() + + return hidden_states, top_k_index, top_k_weights def _prepare_output_fn(self, mod, outputs, device_mesh): # Mega MoE handles the EP combine inside the kernel — output is already fully reduced. From a05aa39b74c870b3f7040ff9fd720dbdff72c551 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Wed, 29 Apr 2026 15:02:37 +0000 Subject: [PATCH 29/87] fix Co-authored-by: Copilot --- tests/test_modeling_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a7d44177e192..32aaf0e3d96d 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -53,7 +53,7 @@ ) from transformers.integrations.moe import ( batched_mm_experts_forward, - deepgemm_experts_forward, + deepgemm_bf16_experts_forward, grouped_mm_experts_forward, sonicmoe_experts_forward, ) @@ -610,7 +610,7 @@ def _test_eager_matches_batched_and_grouped_inference(self, name, dtype): and get_cuda_runtime_version() >= (12, 3) ): # DeepGEMM BF16 grouped forward requires Hopper+, CUDA runtime 12.3+, and bf16 hidden states - mocks["deepgemm"] = Mock(wraps=deepgemm_experts_forward) + mocks["deepgemm"] = Mock(wraps=deepgemm_bf16_experts_forward) implementations.append("deepgemm") outputs = {} From 053c9df6c79a170c58b35447a8545daf954acbae Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Wed, 29 Apr 2026 15:10:41 +0000 Subject: [PATCH 30/87] fix --- src/transformers/utils/quantization_config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 95b1a1454034..4068d0f6e7a1 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1682,16 +1682,16 @@ class FineGrainedFP8Config(QuantizationConfigMixin): The scheme used for activation, the defaults and only support scheme for now is "dynamic". weight_block_size (`typing.tuple[int, int]`, *optional*, defaults to `(128, 128)`): The size of the weight blocks for quantization, default is (128, 128). + modules_to_not_convert (`list`, *optional*): + A list of module names that should not be converted during quantization. + dequantize (`bool`, *optional*, defaults to `False`): + Whether to dequantize the model during loading. scale_fmt (`str`, *optional*, defaults to `"float"`): Storage dtype of the per-block weight scales: - `"float"`: fp32 scales (DeepSeek V3-style; Hopper SM90 1D2D path). - `"ue8m0"`: 1-byte `torch.float8_e8m0fnu` scales (DeepSeek V4-style). At GEMM time the underlying memory is reinterpreted as `torch.int32` (4 UE8M0 bytes per int), feeding the SM100 1D1D dispatch directly with no float→int transform per call. - dequantize (`bool`, *optional*, defaults to `False`): - Whether to dequantize the model during loading. - modules_to_not_convert (`list`, *optional*): - A list of module names that should not be converted during quantization. """ def __init__( From 80a6fe5a33cb025f91f8ce86a1c82bf28f5b3abd Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 1 May 2026 12:24:12 +0200 Subject: [PATCH 31/87] fix --- .../integrations/finegrained_fp8.py | 38 ++++++++++--------- src/transformers/integrations/moe.py | 35 ++++++++++------- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index f423f2f6b830..86d35428d9ee 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -379,11 +379,6 @@ def fp8_grouped_mm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, - # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and the grouped matmul skips - # rows beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Sentinel rows are zeroed - # post-weighted-mul (see below), since the kernel leaves them uninitialized. - # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -396,6 +391,14 @@ def fp8_grouped_mm_experts_forward( tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and the grouped matmul skips + # rows beyond `offsets[-1]` — sentinels cost no real GEMM compute. The kernel writes only + # valid rows, so sentinel-tail `proj_out` rows are uninit; without the post-mask below, + # `proj_out[sentinel] * 0 = NaN * 0 = NaN` would poison the per-token reduction. FP8 + # quantized weights are inference-only, so no bwd pre-mask is needed. + sentinel_mask = (expert_ids_g >= self.num_experts).unsqueeze(-1) + # --- Up projection per expert (FP8 grouped) --- proj_out = triton_grouped_fp8_matmul( selected_hidden_states_g, @@ -427,10 +430,8 @@ def fp8_grouped_mm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) - # EP sentinel handling: `proj_out` rows past `offsets[-1]` are left uninitialized by the kernel, - # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here - # so the downstream reduction stays finite even when the routing weight was already zero. - weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Post-mask (fwd path). + weighted_out.masked_fill_(sentinel_mask, 0.0) # Restore original order inv_perm = torch.empty_like(perm) @@ -540,11 +541,6 @@ def fp8_deepgemm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, - # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the - # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. - # Sentinel rows are zeroed post-weighted-mul (see below), since the kernel leaves them uninitialized. - # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -555,6 +551,14 @@ def fp8_deepgemm_experts_forward( expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond + # the cumsum on Blackwell), and DeepGEMM skips them — sentinels cost no real GEMM compute. + # The kernel writes only valid rows, so sentinel-tail `proj_out` rows are uninit; without the + # post-mask below, `proj_out[sentinel] * 0 = NaN * 0 = NaN` would poison the per-token + # reduction. DeepGEMM is inference-only, so no bwd pre-mask is needed. + sentinel_mask = (expert_ids_g >= self.num_experts).unsqueeze(-1) + # --- Up projection per expert (DeepGEMM grouped contiguous) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv @@ -588,10 +592,8 @@ def fp8_deepgemm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) - # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, - # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here - # so the downstream reduction stays finite even when the routing weight was already zero. - weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Post-mask (fwd path). + weighted_out.masked_fill_(sentinel_mask, 0.0) # Restore original order inv_perm = torch.empty_like(perm) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 9cf262de0358..6d4bb7e71edd 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -385,11 +385,6 @@ def grouped_mm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, - # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and grouped_mm skips rows - # beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Sentinel rows are zeroed - # post-weighted-mul (see below), since the kernel leaves them uninitialized. - # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -402,10 +397,23 @@ def grouped_mm_experts_forward( tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) - if self.has_bias: - # Clamp now that the layout has been built — needed for the per-row bias gather below to stay - # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. - expert_ids_g.clamp_(0, self.num_experts - 1) + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and grouped_mm skips rows + # beyond `offsets[-1]` — sentinels cost no real GEMM compute. The kernel leaves sentinel-tail + # rows of its output uninit (both fwd output and bwd `d_input`), but ONE pre-mask + ONE + # post-mask covers the whole forward — no per-grouped_mm masking is needed, because + # intermediate sentinel-row NaN is only ever consumed by the next grouped_mm, which itself + # only reads rows `< offsets[-1]`: + # - fwd post-mask on `weighted_out`: kills `proj_out[sentinel] * 0 = NaN * 0 = NaN` + # before the per-token reduction sums it. + # - bwd pre-mask on `selected_hidden_states_g`: its `masked_fill_` backward zeros sentinel + # rows of `d_selected_hidden_states_g` after the up grouped_mm bwd writes them as + # uninit, and before the gather's scatter-add pushes them into `d_hidden_states`. + # In-place clamp on `expert_ids_g` keeps the per-row bias gather in-bounds (bias added at + # sentinel positions falls in rows the kernel skips, so harmless). Safe to mutate now — + # nothing downstream needs the sentinel info from `expert_ids_g` itself. + sentinel_mask = (expert_ids_g >= self.num_experts).unsqueeze(-1) + expert_ids_g.clamp_(max=self.num_experts - 1) # Select expert weights and biases # NOTE: We keep all experts here and rely on offsets to target the active ones. @@ -420,6 +428,9 @@ def grouped_mm_experts_forward( selected_weights = self.up_proj selected_biases = self.up_proj_bias[expert_ids_g] if self.has_bias else None + # Pre-mask (bwd path). + selected_hidden_states_g.masked_fill_(sentinel_mask, 0.0) + # --- Up projection per expert (grouped) --- proj_out = _grouped_linear( selected_hidden_states_g, selected_weights, offsets, bias=selected_biases, is_transposed=self.is_transposed @@ -445,10 +456,8 @@ def grouped_mm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) - # EP sentinel handling: `proj_out` rows past `offsets[-1]` are left uninitialized by grouped_mm, - # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here - # so the downstream reduction stays finite even when the routing weight was already zero. - weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Post-mask (fwd path). + weighted_out.masked_fill_(sentinel_mask, 0.0) # Restore original order inv_perm = torch.empty_like(perm) From ad8226ce7ceb34405b631eb7fcaaf249871a6e22 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 1 May 2026 14:32:04 +0200 Subject: [PATCH 32/87] dtensor support --- src/transformers/integrations/sonicmoe.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index 912b98655519..72fdbcf38a6d 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -23,6 +23,7 @@ import functools import torch +from torch.distributed.tensor import DTensor from ..utils import logging from .hub_kernels import lazy_load_kernel @@ -157,16 +158,27 @@ def sonicmoe_experts_forward( # already zero (RouterParallel masks them at dispatch), so the per-token reduction # contributes nothing for sentinel slots. + # FSDP2 / EP wraps weights as DTensors but the kernel takes raw CUTLASS / CuteDSL pointers, + # so unwrap to local shards before reshaping. `to_local()` is autograd-aware — backward + # will rewrap the gradient as a DTensor matching each parameter's placements. + w1 = self.gate_up_proj + w2 = self.down_proj + b1 = self.gate_up_proj_bias if self.has_bias else None + b2 = self.down_proj_bias if self.has_bias else None + if isinstance(w1, DTensor): + w1 = w1.to_local() + w2 = w2.to_local() + b1 = b1.to_local() if b1 is not None else None + b2 = b2.to_local() if b2 is not None else None + # Map activation function act_name = getattr(self.config, "hidden_act", "silu").lower() # Permute weights as expected by sonic-moe (E=num_experts, H=hidden_size, I=intermediate_size). # Non-transposed: gate_up_proj is (E, 2*I, H), down_proj is (E, H, I) -> permute(1, 2, 0). # Transposed: gate_up_proj is (E, H, 2*I), down_proj is (E, I, H) -> permute(2, 1, 0). perm = (2, 1, 0) if self.is_transposed else (1, 2, 0) - w1 = self.gate_up_proj.permute(*perm) # (2*I, H, E) - w2 = self.down_proj.permute(*perm) # (I, H, E) - b1 = self.gate_up_proj_bias if self.has_bias else None - b2 = self.down_proj_bias if self.has_bias else None + w1 = w1.permute(*perm) # (2*I, H, E) + w2 = w2.permute(*perm) # (I, H, E) return _sonicmoe_wrapper( hidden_states=hidden_states, From a663f4d79c0e4e04d79ac110eedf31e0dd3522d9 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 1 May 2026 14:50:17 +0200 Subject: [PATCH 33/87] more dtensor --- .../integrations/finegrained_fp8.py | 34 +++++++++++++++---- src/transformers/integrations/sonicmoe.py | 3 +- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 86d35428d9ee..650d4122cceb 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -399,11 +399,23 @@ def fp8_grouped_mm_experts_forward( # quantized weights are inference-only, so no bwd pre-mask is needed. sentinel_mask = (expert_ids_g >= self.num_experts).unsqueeze(-1) + # FSDP2 / EP wraps weights as DTensors but the kernel takes raw pointers — unwrap to + # local shards. Inference-only path, so `to_local()` autograd-awareness is moot. + w_up = self.gate_up_proj if self.has_gate else self.up_proj + ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv + w_down = self.down_proj + ws_down = self.down_proj_scale_inv + if isinstance(w_up, torch.distributed.tensor.DTensor): + w_up = w_up.to_local() + ws_up = ws_up.to_local() + w_down = w_down.to_local() + ws_down = ws_down.to_local() + # --- Up projection per expert (FP8 grouped) --- proj_out = triton_grouped_fp8_matmul( selected_hidden_states_g, - self.gate_up_proj if self.has_gate else self.up_proj, - self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv, + w_up, + ws_up, tokens_per_expert=tokens_per_expert, block_size=self.block_size, offsets=offsets, @@ -420,8 +432,8 @@ def fp8_grouped_mm_experts_forward( # --- Down projection per expert (FP8 grouped) --- proj_out = triton_grouped_fp8_matmul( proj_out, - self.down_proj, - self.down_proj_scale_inv, + w_down, + ws_down, tokens_per_expert=tokens_per_expert, block_size=self.block_size, offsets=offsets, @@ -559,9 +571,19 @@ def fp8_deepgemm_experts_forward( # reduction. DeepGEMM is inference-only, so no bwd pre-mask is needed. sentinel_mask = (expert_ids_g >= self.num_experts).unsqueeze(-1) - # --- Up projection per expert (DeepGEMM grouped contiguous) --- + # FSDP2 / EP wraps weights as DTensors but the kernel takes raw pointers — unwrap to + # local shards. Inference-only path, so `to_local()` autograd-awareness is moot. w_up = self.gate_up_proj if self.has_gate else self.up_proj ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv + w_down = self.down_proj + ws_down = self.down_proj_scale_inv + if isinstance(w_up, torch.distributed.tensor.DTensor): + w_up = w_up.to_local() + ws_up = ws_up.to_local() + w_down = w_down.to_local() + ws_down = ws_down.to_local() + + # --- Up projection per expert (DeepGEMM grouped contiguous) --- act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) act_fp8, act_scales = _pad_to_deepgemm_contiguous_layout(act_fp8, act_scales, sorted_to_padded, total_padded_rows) proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) @@ -580,7 +602,7 @@ def fp8_deepgemm_experts_forward( proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) deepgemm_grouped_fp8_matmul( (proj_fp8, proj_scales), - (self.down_proj, self.down_proj_scale_inv.float()), + (w_down, ws_down.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout, diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index 72fdbcf38a6d..ff59b8327f18 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -23,7 +23,6 @@ import functools import torch -from torch.distributed.tensor import DTensor from ..utils import logging from .hub_kernels import lazy_load_kernel @@ -165,7 +164,7 @@ def sonicmoe_experts_forward( w2 = self.down_proj b1 = self.gate_up_proj_bias if self.has_bias else None b2 = self.down_proj_bias if self.has_bias else None - if isinstance(w1, DTensor): + if isinstance(w1, torch.distributed.tensor.DTensor): w1 = w1.to_local() w2 = w2.to_local() b1 = b1.to_local() if b1 is not None else None From 74c3f2e3bd89e0c720b49b0690fa3c9005a4a719 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 1 May 2026 14:58:53 +0200 Subject: [PATCH 34/87] simpler --- src/transformers/integrations/finegrained_fp8.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 650d4122cceb..d0285f83bdef 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -257,13 +257,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.weight.element_size() > 1: return F.linear(input, self.weight, self.bias) - if isinstance(self.weight, torch.distributed.tensor.DTensor): - weight = self.weight._local_tensor.contiguous() - scale_inv = self.weight_scale_inv._local_tensor.contiguous() - else: - # why wouldn't it be contiguous? - weight = self.weight.contiguous() - scale_inv = self.weight_scale_inv.contiguous() + weight = self.weight + scale_inv = self.weight_scale_inv + if isinstance(weight, torch.distributed.tensor.DTensor): + weight = weight.to_local() + scale_inv = scale_inv.to_local() if self.activation_scheme == "dynamic": _, triton_fp8_act_quant, _, _ = _load_triton_kernel() From d3cae33e30a381d420d0377edeb4cc99fdab847c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 1 May 2026 14:59:43 +0200 Subject: [PATCH 35/87] remove comment --- src/transformers/integrations/hub_kernels.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index a362b9e114f2..da6f06ca9291 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -376,8 +376,6 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) - # Entries in `_HUB_KERNEL_MAPPING` are vetted in-tree, so we trust non-`kernels-community` - # repos (e.g. user/team forks) without requiring the per-call `allow_all_kernels` flag. kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True) mapping[kernel_name] = kernel except FileNotFoundError: From 0528e0e27d223615dc47641c487c649ebf77fede Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 4 May 2026 14:14:47 +0200 Subject: [PATCH 36/87] revert --- .../integrations/finegrained_fp8.py | 1 - src/transformers/integrations/hub_kernels.py | 2 +- src/transformers/integrations/sonicmoe.py | 44 ------------------- 3 files changed, 1 insertion(+), 46 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 8dc92536065b..337a4c5f9b89 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -287,7 +287,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) - def fp8_batched_mm_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index b8ed8556edd1..d8d30f13416f 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -376,7 +376,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) - kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True) + kernel = get_kernel(repo_id, revision=revision, version=version) mapping[kernel_name] = kernel except FileNotFoundError: mapping[kernel_name] = None diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index 1fa6a4f6e787..fe9d67d6a2e6 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -143,50 +143,6 @@ def _sonicmoe_wrapper( return output -@torch._dynamo.allow_in_graph -def _sonicmoe_wrapper( - hidden_states: torch.Tensor, - router_scores: torch.Tensor, - expert_ids: torch.Tensor, - token_idx: torch.Tensor, - w1: torch.Tensor, - b1: torch.Tensor | None, - w2: torch.Tensor, - b2: torch.Tensor | None, - act_name: str, - num_experts: int, - concat_layout: bool, - is_inference_mode_enabled: bool, -) -> torch.Tensor: - """Module-level shim around `moe_general_routing_inputs` so `allow_in_graph` can wrap it. - - sonicmoe asserts `not torch.compiler.is_compiling()` internally because it dispatches - CuteDSL kernels, which Dynamo can't trace. `allow_in_graph` keeps the call in the FX - graph as a single opaque node (no tracing into the body, no graph break) while still - running the real Python at runtime — autograd through `_UpProjection` / `_DownProjection` - flows normally. The decorator must be applied at module load time, not inside the compiled - function — hence this shim plus the `allow_in_graph` decorator above. - """ - ActivationType, moe_general_routing_inputs = _load_sonic_kernel() - activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) - output, _ = moe_general_routing_inputs( - hidden_states, - router_scores, - token_idx, - expert_ids, - w1, - b1, - w2, - b2, - E=num_experts, - activation_type=activation_type, - is_inference_mode_enabled=is_inference_mode_enabled, - concat_layout=concat_layout, - stream_id=None, - ) - return output - - def sonicmoe_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, From 2bfd0298b5a61a24b535ac9115211d924aa78314 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 4 May 2026 14:16:56 +0200 Subject: [PATCH 37/87] bc order --- src/transformers/utils/quantization_config.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index e0ec4ee547cf..f74db09665f0 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1678,28 +1678,28 @@ class FineGrainedFP8Config(QuantizationConfigMixin): FineGrainedFP8Config is a configuration class for fine-grained FP8 quantization used mainly for deepseek models. Args: - activation_scheme (`str`, *optional*, defaults to `"dynamic"`): - The scheme used for activation, the defaults and only support scheme for now is "dynamic". - weight_block_size (`typing.tuple[int, int]`, *optional*, defaults to `(128, 128)`): - The size of the weight blocks for quantization, default is (128, 128). - modules_to_not_convert (`list`, *optional*): - A list of module names that should not be converted during quantization. - dequantize (`bool`, *optional*, defaults to `False`): - Whether to dequantize the model during loading. - scale_fmt (`str`, *optional*, defaults to `"float"`): - Storage dtype of the per-block weight scales: - - `"float"`: fp32 scales (DeepSeek V3-style; Hopper SM90 1D2D path). - - `"ue8m0"`: 1-byte `torch.float8_e8m0fnu` scales (DeepSeek V4-style). At GEMM time - the underlying memory is reinterpreted as `torch.int32` (4 UE8M0 bytes per int), - feeding the SM100 1D1D dispatch directly with no float→int transform per call. + activation_scheme (`str`, *optional*, defaults to `"dynamic"`): + The scheme used for activation, the defaults and only support scheme for now is "dynamic". + weight_block_size (`typing.tuple[int, int]`, *optional*, defaults to `(128, 128)`): + The size of the weight blocks for quantization, default is (128, 128). + dequantize (`bool`, *optional*, defaults to `False`): + Whether to dequantize the model during loading. + modules_to_not_convert (`list`, *optional*): + A list of module names that should not be converted during quantization. + scale_fmt (`str`, *optional*, defaults to `"float"`): + Storage dtype of the per-block weight scales: + - `"float"`: fp32 scales (DeepSeek V3-style; Hopper SM90 1D2D path). + - `"ue8m0"`: 1-byte `torch.float8_e8m0fnu` scales (DeepSeek V4-style). At GEMM time + the underlying memory is reinterpreted as `torch.int32` (4 UE8M0 bytes per int), + feeding the SM100 1D1D dispatch directly with no float→int transform per call. """ def __init__( self, activation_scheme: str = "dynamic", weight_block_size: tuple[int, int] = (128, 128), - modules_to_not_convert: list | None = None, dequantize: bool = False, + modules_to_not_convert: list | None = None, scale_fmt: str = "float", **kwargs, ): From 368db007316843abcb8a1072b11ea27d55bc5da5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 4 May 2026 14:18:14 +0200 Subject: [PATCH 38/87] revert extra indent --- src/transformers/utils/quantization_config.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index f74db09665f0..11371a81505c 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1678,20 +1678,20 @@ class FineGrainedFP8Config(QuantizationConfigMixin): FineGrainedFP8Config is a configuration class for fine-grained FP8 quantization used mainly for deepseek models. Args: - activation_scheme (`str`, *optional*, defaults to `"dynamic"`): - The scheme used for activation, the defaults and only support scheme for now is "dynamic". - weight_block_size (`typing.tuple[int, int]`, *optional*, defaults to `(128, 128)`): - The size of the weight blocks for quantization, default is (128, 128). - dequantize (`bool`, *optional*, defaults to `False`): - Whether to dequantize the model during loading. - modules_to_not_convert (`list`, *optional*): - A list of module names that should not be converted during quantization. - scale_fmt (`str`, *optional*, defaults to `"float"`): - Storage dtype of the per-block weight scales: - - `"float"`: fp32 scales (DeepSeek V3-style; Hopper SM90 1D2D path). - - `"ue8m0"`: 1-byte `torch.float8_e8m0fnu` scales (DeepSeek V4-style). At GEMM time - the underlying memory is reinterpreted as `torch.int32` (4 UE8M0 bytes per int), - feeding the SM100 1D1D dispatch directly with no float→int transform per call. + activation_scheme (`str`, *optional*, defaults to `"dynamic"`): + The scheme used for activation, the defaults and only support scheme for now is "dynamic". + weight_block_size (`typing.tuple[int, int]`, *optional*, defaults to `(128, 128)`): + The size of the weight blocks for quantization, default is (128, 128). + dequantize (`bool`, *optional*, defaults to `False`): + Whether to dequantize the model during loading. + modules_to_not_convert (`list`, *optional*): + A list of module names that should not be converted during quantization. + scale_fmt (`str`, *optional*, defaults to `"float"`): + Storage dtype of the per-block weight scales: + - `"float"`: fp32 scales (DeepSeek V3-style; Hopper SM90 1D2D path). + - `"ue8m0"`: 1-byte `torch.float8_e8m0fnu` scales (DeepSeek V4-style). At GEMM time + the underlying memory is reinterpreted as `torch.int32` (4 UE8M0 bytes per int), + feeding the SM100 1D1D dispatch directly with no float→int transform per call. """ def __init__( From 28fbb842e5c433a87287aee11e69a7293045fe02 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 4 May 2026 14:24:00 +0200 Subject: [PATCH 39/87] revert unnecessary change --- .../integrations/finegrained_fp8.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 337a4c5f9b89..4792b7d872ef 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -14,7 +14,8 @@ from __future__ import annotations import functools -from types import SimpleNamespace +from collections.abc import Callable +from dataclasses import dataclass import torch import torch.nn as nn @@ -50,12 +51,23 @@ def _first_attr(obj, *names): raise AttributeError(f"{type(obj).__name__} has none of: {names}") +@dataclass(frozen=True) +class FineGrainedFP8: + """Entry points exposed by the `kernels-community/finegrained-fp8` Triton kernel.""" + + matmul: Callable + act_quant: Callable + batched_matmul: Callable + grouped_matmul: Callable + + @functools.cache -def _load_finegrained_fp8_kernel() -> SimpleNamespace: +def _load_finegrained_fp8_kernel() -> FineGrainedFP8: """ - Load the finegrained-fp8 Triton kernel once and return its entry points as a `SimpleNamespace`. + Load the finegrained-fp8 Triton kernel once and return its entry points. - Raises `ImportError` if the `kernels` package is missing or any required entry point is absent. + Raises `ImportError` if the `kernels` package is missing, or the kernel or required + symbols cannot be found. """ if not is_kernels_available(): raise ImportError( @@ -90,7 +102,7 @@ def _load_finegrained_fp8_kernel() -> SimpleNamespace: "Please update the `kernels` package (`pip install -U kernels`)." ) - return SimpleNamespace( + return FineGrainedFP8( matmul=matmul, act_quant=act_quant, batched_matmul=batched_matmul, From 6ef27abe5f7d2bd1776c3c49cbccf2499278e46b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 4 May 2026 14:46:06 +0200 Subject: [PATCH 40/87] update --- src/transformers/integrations/deepgemm.py | 39 +++++++++++++------ .../integrations/tensor_parallel.py | 16 ++++++-- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 187f8988f616..8963046f3baf 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -31,7 +31,8 @@ from __future__ import annotations import functools -from types import SimpleNamespace +from collections.abc import Callable +from dataclasses import dataclass import torch @@ -52,10 +53,29 @@ _FP8_MAX = torch.finfo(_FP8_DTYPE).max +@dataclass(frozen=True) +class DeepGEMM: + """Entry points exposed by the `kernels-community/deep-gemm` kernel. + + Mega MoE entry points are always importable on a current build — they raise at call + time on SM90 (Hopper), guarded by a runtime device-capability check in + `deepgemm_fp8_fp4_megamoe_experts_forward`. + """ + + fp8_fp4_matmul: Callable + grouped_fp8_fp4_matmul: Callable + grouped_bf16_matmul_nt: Callable + grouped_bf16_matmul_nn: Callable + per_token_cast_to_fp8: Callable + transform_weights_for_mega_moe: Callable + get_symm_buffer_for_mega_moe: Callable + fp8_fp4_mega_moe: Callable + + @functools.cache -def _load_deepgemm_kernel() -> SimpleNamespace: +def _load_deepgemm_kernel() -> DeepGEMM: """ - Load DeepGEMM once and return its entry points as a `SimpleNamespace`. + Load DeepGEMM once and return its entry points. Raises `ImportError` if CUDA/hardware requirements are not met or any required entry point is missing. @@ -105,10 +125,9 @@ def _load_deepgemm_kernel() -> SimpleNamespace: grouped_bf16_matmul_nt = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) grouped_bf16_matmul_nn = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") - symm_buffer_cls = getattr(kernel, "SymmBuffer", None) - fp8_fp4_mega_moe = getattr(kernel, "fp8_fp4_mega_moe", None) - get_symm_buffer_for_mega_moe = getattr(kernel, "get_symm_buffer_for_mega_moe", None) transform_weights_for_mega_moe = getattr(kernel, "transform_weights_for_mega_moe", None) + get_symm_buffer_for_mega_moe = getattr(kernel, "get_symm_buffer_for_mega_moe", None) + fp8_fp4_mega_moe = getattr(kernel, "fp8_fp4_mega_moe", None) missing = [ name @@ -118,10 +137,9 @@ def _load_deepgemm_kernel() -> SimpleNamespace: ("m_grouped_bf16_gemm_nt_contiguous", grouped_bf16_matmul_nt), ("m_grouped_bf16_gemm_nn_contiguous", grouped_bf16_matmul_nn), ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), - ("SymmBuffer", symm_buffer_cls), - ("fp8_fp4_mega_moe", fp8_fp4_mega_moe), - ("get_symm_buffer_for_mega_moe", get_symm_buffer_for_mega_moe), ("transform_weights_for_mega_moe", transform_weights_for_mega_moe), + ("get_symm_buffer_for_mega_moe", get_symm_buffer_for_mega_moe), + ("fp8_fp4_mega_moe", fp8_fp4_mega_moe), ] if attr is None ] @@ -131,7 +149,7 @@ def _load_deepgemm_kernel() -> SimpleNamespace: "Please update the `kernels` package (`pip install -U kernels`)." ) - return SimpleNamespace( + return DeepGEMM( fp8_fp4_matmul=fp8_fp4_matmul, grouped_fp8_fp4_matmul=grouped_fp8_fp4_matmul, grouped_bf16_matmul_nt=grouped_bf16_matmul_nt, @@ -140,7 +158,6 @@ def _load_deepgemm_kernel() -> SimpleNamespace: transform_weights_for_mega_moe=transform_weights_for_mega_moe, get_symm_buffer_for_mega_moe=get_symm_buffer_for_mega_moe, fp8_fp4_mega_moe=fp8_fp4_mega_moe, - symm_buffer_cls=symm_buffer_cls, ) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 082b40ce827a..7831f09d2559 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1079,6 +1079,16 @@ def update_module_attributes(self, module: nn.Module): module.num_experts = self.get_expected_sharded_shape((self.empty_param.shape[0],))[0] +def _is_ep_native_experts_impl(mod: nn.Module) -> bool: + """Whether `mod`'s experts implementation handles EP dispatch + combine itself. + + These kernels (e.g. DeepGEMM Mega MoE) want GLOBAL expert ids with unmasked routing + weights and produce the fully-reduced output, so `RouterParallel` skips the per-rank + index remap and `MoeTensorParalellExperts` skips the post-forward all-reduce. + """ + return getattr(getattr(mod, "config", None), "_experts_implementation", None) in {"deepgemm_megamoe"} + + class RouterParallel(TensorParallelLayer): """ Allows to reshape the router scores to support running expert parallel. @@ -1141,7 +1151,7 @@ def _prepare_output_fn(self, mod, outputs, device_mesh): # does the EP token dispatch itself and needs GLOBAL expert ids with unmasked routing # weights. Mirrored on the experts side by `MoeTensorParalellExperts._prepare_output_fn` # which skips the post-forward all_reduce. - if getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe": + if _is_ep_native_experts_impl(mod): return outputs ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() @@ -1201,14 +1211,14 @@ def _prepare_input_fn(self, mod, inputs, device_mesh): # Mega MoE handles EP dispatch + combine inside the kernel — append the EP `process_group` # so the forward can rendezvous the symm-buffer on first call. - if getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe": + if _is_ep_native_experts_impl(mod): return hidden_states, top_k_index, top_k_weights, device_mesh.get_group() return hidden_states, top_k_index, top_k_weights def _prepare_output_fn(self, mod, outputs, device_mesh): # Mega MoE handles the EP combine inside the kernel — output is already fully reduced. - if getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe": + if _is_ep_native_experts_impl(mod): return outputs # all_reduce_forward to sum partial expert outputs across GPUs From e436b7746f4d22a616d2c1290da6ed5230fddaad Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 4 May 2026 14:50:51 +0200 Subject: [PATCH 41/87] less defensive --- src/transformers/integrations/finegrained_fp8.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 4792b7d872ef..df2f979dc898 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -750,8 +750,7 @@ def _quantize_one(self, key: str, value: torch.Tensor) -> dict[str, torch.Tensor # DeepSeek V4-style storage (`scale_fmt="ue8m0"`): round inv_scales to UE8M0-representable # values (powers of 2) and cast to `float8_e8m0fnu` byte storage so the on-disk dtype # matches the parameter allocation in `FP8Linear`/`FP8Experts`. - scale_fmt = getattr(self.hf_quantizer.quantization_config, "scale_fmt", "float") - if scale_fmt == "ue8m0": + if self.hf_quantizer.quantization_config.scale_fmt == "ue8m0": inv_scales = torch.pow(2.0, torch.ceil(torch.log2(inv_scales.clamp(min=torch.finfo(torch.float32).tiny)))) inv_scales = inv_scales.to(_UE8M0_SF_DTYPE) scale_key = key.rsplit(".", 1)[0] + ".weight_scale_inv" if key.endswith("weight") else key + "_scale_inv" From 8e8f0ee2e9ff0406e7850cd8c8a8d9f0e2b7448f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 5 May 2026 13:42:48 +0200 Subject: [PATCH 42/87] allow all kernels --- src/transformers/integrations/hub_kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index d8d30f13416f..7d64fa2a954f 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -288,7 +288,7 @@ def register_kernel_mapping_transformers(*args, **kwargs): "mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, - "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, + "deep-gemm": {"repo_id": "adarshxs/deep-gemm", "revision": "v2"}, "sonic-moe": {"repo_id": "kernels-community/sonic-moe", "revision": "ep-support"}, } @@ -346,7 +346,7 @@ def load_and_register_attn_kernel( # Load the kernel from hub try: - kernel = get_kernel(repo_id, revision=rev, allow_all_kernels=allow_all_kernels) + kernel = get_kernel(repo_id, revision=rev, allow_all_kernels=True) except Exception as e: raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.") # correctly wrap the kernel From c494e35bc5adb143c37560a1c03d1f6801402239 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 5 May 2026 13:44:39 +0200 Subject: [PATCH 43/87] alow all kernels --- src/transformers/integrations/hub_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 7d64fa2a954f..a1c0243dc54c 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -376,7 +376,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) - kernel = get_kernel(repo_id, revision=revision, version=version) + kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True) mapping[kernel_name] = kernel except FileNotFoundError: mapping[kernel_name] = None From 03b04421333654c7bd48d9db87c24e46eb92e587 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 5 May 2026 13:51:57 +0200 Subject: [PATCH 44/87] hub only --- src/transformers/integrations/deepgemm.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 8963046f3baf..dffc202fc471 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -104,21 +104,12 @@ def _load_deepgemm_kernel() -> DeepGEMM: "Please upgrade your CUDA toolkit or use a different `experts_implementation`." ) - try: - import deep_gemm as kernel - except ImportError: - if not is_kernels_available(): - raise ImportError( - "DeepGEMM requires either the `deep_gemm` package (`pip install -U deep-gemm`) or " - "the `kernels` package (`pip install -U kernels`) for the `kernels-community/deep-gemm` " - "hub build." - ) from None - kernel = lazy_load_kernel("deep-gemm") - if kernel is None: - raise ImportError( - "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " - "has a build matching the current torch/CUDA." - ) from None + kernel = lazy_load_kernel("deep-gemm") + if kernel is None: + raise ImportError( + "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " + "has a build matching the current torch/CUDA." + ) from None fp8_fp4_matmul = getattr(kernel, "fp8_fp4_gemm_nt", None) grouped_fp8_fp4_matmul = getattr(kernel, "m_grouped_fp8_fp4_gemm_nt_contiguous", None) From 2e51b3cb42e8dc3f097016e4e1196c0264868325 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 5 May 2026 23:32:18 +0200 Subject: [PATCH 45/87] fix --- src/transformers/integrations/deepgemm.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index dffc202fc471..8204652adfcb 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -156,13 +156,19 @@ def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: """Normalize a scale-factor tensor for the DeepGEMM kernel boundary. Two SF flavors are produced by our path: - - `float32` (DeepSeek V3-style): pass through; the kernel transforms float→int internally - on SM100 to feed the 1D1D path. - - `torch.float8_e8m0fnu` (DeepSeek V4-style, 1 byte per scale): reinterpret 4 contiguous - bytes as one `int32`. No copy; last-dim shrinks 4×. + - `float32` (DeepSeek V3-style): the kernel's `check_sf_layout` requires the + SF tensor to be MN-major (`sf.stride(-2) == 1`). Default contiguous tensors + are K-major, so flip via transpose+contiguous+transpose. No-op when the + layout is already MN-major. + - `torch.float8_e8m0fnu` (DeepSeek V4-style, 1 byte per scale): the kernel + expects MN-major TMA-aligned packed `int32` (4 contiguous K-bytes per + lane). Use DeepGEMM's helper which guarantees that exact layout. """ if sf.dtype == torch.float8_e8m0fnu: - return sf.contiguous().view(torch.int32) + deepgemm = _load_deepgemm_kernel() + return deepgemm.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + if sf.stride(-2) != 1: + sf = sf.transpose(-1, -2).contiguous().transpose(-1, -2) return sf From 82c5fb594ac5484513b82f619a272bb67aaf3c9c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 5 May 2026 23:36:14 +0200 Subject: [PATCH 46/87] fix --- src/transformers/integrations/deepgemm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 8204652adfcb..c1010ab6b32c 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -67,6 +67,7 @@ class DeepGEMM: grouped_bf16_matmul_nt: Callable grouped_bf16_matmul_nn: Callable per_token_cast_to_fp8: Callable + get_mn_major_tma_aligned_packed_ue8m0_tensor: Callable transform_weights_for_mega_moe: Callable get_symm_buffer_for_mega_moe: Callable fp8_fp4_mega_moe: Callable @@ -116,6 +117,7 @@ def _load_deepgemm_kernel() -> DeepGEMM: grouped_bf16_matmul_nt = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) grouped_bf16_matmul_nn = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + get_mn_major_tma_aligned_packed_ue8m0_tensor = getattr(kernel, "get_mn_major_tma_aligned_packed_ue8m0_tensor", None) transform_weights_for_mega_moe = getattr(kernel, "transform_weights_for_mega_moe", None) get_symm_buffer_for_mega_moe = getattr(kernel, "get_symm_buffer_for_mega_moe", None) fp8_fp4_mega_moe = getattr(kernel, "fp8_fp4_mega_moe", None) @@ -128,6 +130,7 @@ def _load_deepgemm_kernel() -> DeepGEMM: ("m_grouped_bf16_gemm_nt_contiguous", grouped_bf16_matmul_nt), ("m_grouped_bf16_gemm_nn_contiguous", grouped_bf16_matmul_nn), ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), + ("get_mn_major_tma_aligned_packed_ue8m0_tensor", get_mn_major_tma_aligned_packed_ue8m0_tensor), ("transform_weights_for_mega_moe", transform_weights_for_mega_moe), ("get_symm_buffer_for_mega_moe", get_symm_buffer_for_mega_moe), ("fp8_fp4_mega_moe", fp8_fp4_mega_moe), @@ -146,6 +149,7 @@ def _load_deepgemm_kernel() -> DeepGEMM: grouped_bf16_matmul_nt=grouped_bf16_matmul_nt, grouped_bf16_matmul_nn=grouped_bf16_matmul_nn, per_token_cast_to_fp8=per_token_cast_to_fp8, + get_mn_major_tma_aligned_packed_ue8m0_tensor=get_mn_major_tma_aligned_packed_ue8m0_tensor, transform_weights_for_mega_moe=transform_weights_for_mega_moe, get_symm_buffer_for_mega_moe=get_symm_buffer_for_mega_moe, fp8_fp4_mega_moe=fp8_fp4_mega_moe, From 99fdf71daf32a3b7f9a918427790e57a25e24c03 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 5 May 2026 23:38:18 +0200 Subject: [PATCH 47/87] test --- src/transformers/integrations/deepgemm.py | 24 +++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index c1010ab6b32c..7eae265b9705 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -159,18 +159,22 @@ def _load_deepgemm_kernel() -> DeepGEMM: def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: """Normalize a scale-factor tensor for the DeepGEMM kernel boundary. - Two SF flavors are produced by our path: - - `float32` (DeepSeek V3-style): the kernel's `check_sf_layout` requires the - SF tensor to be MN-major (`sf.stride(-2) == 1`). Default contiguous tensors - are K-major, so flip via transpose+contiguous+transpose. No-op when the - layout is already MN-major. - - `torch.float8_e8m0fnu` (DeepSeek V4-style, 1 byte per scale): the kernel - expects MN-major TMA-aligned packed `int32` (4 contiguous K-bytes per - lane). Use DeepGEMM's helper which guarantees that exact layout. + The kernel's `check_sf_layout` requires `sf.stride(-2) == 1` (MN-major). + Default contiguous PyTorch tensors are K-major; flip when needed. + + Two SF flavors are handled: + - `float32` (DeepSeek V3-style): only flip layout if it's K-major. + - `torch.float8_e8m0fnu` (DeepSeek V4-style, 1 byte per scale): pack + 4 contiguous K-bytes into one `int32` (last dim shrinks 4×) and + ensure MN-major layout. Mirrors what + `get_mn_major_tma_aligned_packed_ue8m0_tensor` produces from a + float32 SF, just starting from already-packed bytes. """ if sf.dtype == torch.float8_e8m0fnu: - deepgemm = _load_deepgemm_kernel() - return deepgemm.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + # `view(int32)` requires the source to be contiguous (4 K-bytes adjacent). + # Pack first while still K-major, then flip to MN-major. + packed = sf.contiguous().view(torch.int32) + return packed.transpose(-1, -2).contiguous().transpose(-1, -2) if sf.stride(-2) != 1: sf = sf.transpose(-1, -2).contiguous().transpose(-1, -2) return sf From 21398a8b26a61836dc4cac9579d828d47635e85b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 5 May 2026 23:41:11 +0200 Subject: [PATCH 48/87] test --- src/transformers/integrations/deepgemm.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 7eae265b9705..869df986bbed 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -162,19 +162,20 @@ def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: The kernel's `check_sf_layout` requires `sf.stride(-2) == 1` (MN-major). Default contiguous PyTorch tensors are K-major; flip when needed. - Two SF flavors are handled: - - `float32` (DeepSeek V3-style): only flip layout if it's K-major. - - `torch.float8_e8m0fnu` (DeepSeek V4-style, 1 byte per scale): pack - 4 contiguous K-bytes into one `int32` (last dim shrinks 4×) and - ensure MN-major layout. Mirrors what + Three SF flavors arrive at this boundary: + - `float32` (DeepSeek V3-style): flip layout only. + - `int32` (already-packed UE8M0 from `per_token_cast_to_fp8( + use_packed_ue8m0=True)` or saved checkpoints): flip layout only. + - `float8_e8m0fnu` (raw UE8M0 bytes, 1 byte per scale; e.g. on-disk + weights): pack 4 contiguous K-bytes into `int32` (last dim shrinks + 4×) AND flip layout. Mirrors what `get_mn_major_tma_aligned_packed_ue8m0_tensor` produces from a - float32 SF, just starting from already-packed bytes. + float32 SF, just starting from already-quantized bytes. """ if sf.dtype == torch.float8_e8m0fnu: - # `view(int32)` requires the source to be contiguous (4 K-bytes adjacent). + # `view(int32)` requires the source contiguous (4 K-bytes adjacent). # Pack first while still K-major, then flip to MN-major. - packed = sf.contiguous().view(torch.int32) - return packed.transpose(-1, -2).contiguous().transpose(-1, -2) + sf = sf.contiguous().view(torch.int32) if sf.stride(-2) != 1: sf = sf.transpose(-1, -2).contiguous().transpose(-1, -2) return sf @@ -458,7 +459,7 @@ def deepgemm_fp8_fp4_experts_forward( act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) deepgemm.grouped_fp8_fp4_matmul( - (act_fp8, act_scales), + (act_fp8, _coerce_sf_for_kernel(act_scales)), (w_up, _coerce_sf_for_kernel(ws_up)), proj_out, grouped_layout, @@ -475,7 +476,7 @@ def deepgemm_fp8_fp4_experts_forward( proj_fp8, proj_scales = deepgemm.per_token_cast_to_fp8(proj_out, **cast_kwargs) proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) deepgemm.grouped_fp8_fp4_matmul( - (proj_fp8, proj_scales), + (proj_fp8, _coerce_sf_for_kernel(proj_scales)), (self.down_proj, _coerce_sf_for_kernel(self.down_proj_scale_inv)), proj_out, grouped_layout, From f21db7cd72bbb4313e6f1a547c3e918680c423c6 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 07:13:38 +0200 Subject: [PATCH 49/87] sync --- src/transformers/integrations/deepgemm.py | 18 +- test_deepgemm.py | 281 ++++++++++++++++++++++ 2 files changed, 294 insertions(+), 5 deletions(-) create mode 100644 test_deepgemm.py diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 869df986bbed..0e50e483711e 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -117,7 +117,9 @@ def _load_deepgemm_kernel() -> DeepGEMM: grouped_bf16_matmul_nt = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) grouped_bf16_matmul_nn = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") - get_mn_major_tma_aligned_packed_ue8m0_tensor = getattr(kernel, "get_mn_major_tma_aligned_packed_ue8m0_tensor", None) + get_mn_major_tma_aligned_packed_ue8m0_tensor = getattr( + kernel, "get_mn_major_tma_aligned_packed_ue8m0_tensor", None + ) transform_weights_for_mega_moe = getattr(kernel, "transform_weights_for_mega_moe", None) get_symm_buffer_for_mega_moe = getattr(kernel, "get_symm_buffer_for_mega_moe", None) fp8_fp4_mega_moe = getattr(kernel, "fp8_fp4_mega_moe", None) @@ -424,13 +426,19 @@ def deepgemm_fp8_fp4_experts_forward( ) cast_kwargs = {"use_ue8m0": True, "gran_k": 32, "use_packed_ue8m0": True} else: + # FP8 weights: DeepGEMM supports two SF granularities for the N axis + # of B (block-128 or per-row), and only gran_k=128 for the K axis. + # The block_size attribute is informational; the kernel infers the + # actual recipe from the SF dtype + shape (`get_default_recipe`). if self.block_size is None: raise ValueError( - "DeepGEMM requires block-wise quantization (block_size=[128, 128]), " - "but got per-tensor quantization (block_size=None)." + "DeepGEMM requires block-wise quantized FP8 weights, but the experts have " + "no `block_size` set (per-tensor quantization is not supported)." + ) + if self.block_size not in ((128, 128), (1, 128)): + raise ValueError( + f"DeepGEMM requires `block_size` ∈ {{(128, 128), (1, 128)}} for FP8 weights, got {self.block_size}." ) - if self.block_size[0] != 128 or self.block_size[1] != 128: - raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") if ws_up.dtype == torch.float8_e8m0fnu: cast_kwargs = {"use_ue8m0": True, "gran_k": 128, "use_packed_ue8m0": True} else: diff --git a/test_deepgemm.py b/test_deepgemm.py new file mode 100644 index 000000000000..41ff2ccfba5e --- /dev/null +++ b/test_deepgemm.py @@ -0,0 +1,281 @@ +"""Smoke-test the three DeepGEMM experts dispatches with synthetic experts. + +Each test builds a synthetic experts module with the right weight dtypes / SF formats and +runs the kernel forward, checking the output is finite and shaped correctly. + +Coverage: + 1. DSv3-style: FP8 weights (`float8_e4m3fn`) + float32 SF — Hopper SM90+ + 2. DSv4-style: FP4 weights (`int8`-packed e2m1) + UE8M0 SF — Blackwell SM100+ + 3. Mega MoE: same as DSv4 but with EP dispatch + combine inside the kernel — SM100+ + + distributed (uses `transform_weights_for_mega_moe` for the layout) + +Usage: + # Single GPU (DSv3 + DSv4): + python test_deepgemm_integration.py + + # Mega MoE (≥2 ranks): + torchrun --nproc_per_node=2 test_deepgemm_integration.py +""" + +from __future__ import annotations + +import os +from types import SimpleNamespace + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from transformers.integrations.deepgemm import ( + _load_deepgemm_kernel, + deepgemm_fp8_fp4_experts_forward, + deepgemm_fp8_fp4_megamoe_experts_forward, +) + + +_FP8_DTYPE = torch.float8_e4m3fn +_FP8_MAX = torch.finfo(_FP8_DTYPE).max +_UE8M0_SF_DTYPE = torch.float8_e8m0fnu + + +def _round_to_ue8m0(x: torch.Tensor) -> torch.Tensor: + """Round a positive float tensor to the nearest power of 2 representable as UE8M0.""" + return torch.pow(2.0, torch.ceil(torch.log2(x.clamp(min=torch.finfo(torch.float32).tiny)))).to(_UE8M0_SF_DTYPE) + + +def _make_fp8_experts(num_experts: int, hidden_size: int, intermediate_size: int, ue8m0_sf: bool, device: torch.device) -> SimpleNamespace: + """Synthetic FP8 experts. + + DeepGEMM picks the SF recipe per-arch based on the SF dtype (see + `get_default_recipe` in `csrc/utils/layout.hpp`): + + * SM90 + float SF → recipe (1, 128, 128): block-quantized SF for B, + shape `(E, N/128, K/128)`. + * SM100 + float SF → recipe (1, 128, 128): same block-quantized + shape; kernel broadcasts → packs UE8M0 + internally (DSv3 path, "legacy" on Blackwell). + * SM100 + UE8M0 SF → recipe (1, 1, 128): per-row SF for B, shape + `(E, N, K/128)`. This is the DSv4-FP8 path. + + `ue8m0_sf=False` exercises the float-SF (DSv3) path; `ue8m0_sf=True` + exercises the per-row UE8M0 (DSv4-FP8) path. + """ + block_k = 128 + # Per-row when UE8M0 (gran_mn=1), block-128 when float SF (gran_mn=128). + block_n = 1 if ue8m0_sf else 128 + + def _alloc(e: int, n: int, k: int) -> tuple[torch.Tensor, torch.Tensor]: + # Random bf16 → clamp to FP8 range → cast. Values are nonsense but byte-valid. + w_fp32 = (torch.randn(e, n, k, device=device) * 0.1).clamp(-_FP8_MAX, _FP8_MAX) + w_fp8 = w_fp32.to(_FP8_DTYPE) + sf_n = -(-n // block_n) # ceil-div + sf_k = -(-k // block_k) + sf = (torch.rand(e, sf_n, sf_k, device=device) * 0.05 + 0.001).to(torch.float32) + if ue8m0_sf: + sf = _round_to_ue8m0(sf) + return w_fp8, sf + + gate_up, gate_up_sf = _alloc(num_experts, 2 * intermediate_size, hidden_size) + down, down_sf = _alloc(num_experts, hidden_size, intermediate_size) + return SimpleNamespace( + num_experts=num_experts, + has_gate=True, + has_bias=False, + is_transposed=False, + # block_size matches the actual SF granularity: + # (128, 128) for the DSv3 (float-SF) block-quantized path, + # (1, 128) for the DSv4-FP8 (UE8M0-SF) per-row path. + block_size=(block_n, block_k), + activation_scheme="dynamic", + config=SimpleNamespace(hidden_act="silu"), + gate_up_proj=gate_up, + gate_up_proj_scale_inv=gate_up_sf, + down_proj=down, + down_proj_scale_inv=down_sf, + _apply_gate=lambda x: F.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + act_fn=F.silu, + ) + + +def _make_fp4_experts(num_experts: int, hidden_size: int, intermediate_size: int, device: torch.device) -> SimpleNamespace: + """Synthetic FP4 experts (`int8`-packed e2m1, K dim halved; UE8M0 SF, gran_k=32).""" + + def _alloc(e: int, n: int, k: int) -> tuple[torch.Tensor, torch.Tensor]: + # Any int8 byte pattern is a valid FP4-packed (2 e2m1 nibbles per byte). + w = torch.randint(low=-128, high=128, size=(e, n, k // 2), dtype=torch.int8, device=device) + # Random positive scales → round to UE8M0 (any e8m0 byte is a power-of-2 or special). + sf = (torch.rand(e, n, k // 32, device=device) * 0.05 + 0.001).to(torch.float32) + sf = _round_to_ue8m0(sf) + return w, sf + + gate_up, gate_up_sf = _alloc(num_experts, 2 * intermediate_size, hidden_size) + down, down_sf = _alloc(num_experts, hidden_size, intermediate_size) + return SimpleNamespace( + num_experts=num_experts, + has_gate=True, + has_bias=False, + is_transposed=False, + block_size=None, # FP4 ignores block_size — kernel infers SF granularity from dtype. + activation_scheme="dynamic", + config=SimpleNamespace(hidden_act="silu"), + gate_up_proj=gate_up, + gate_up_proj_scale_inv=gate_up_sf, + down_proj=down, + down_proj_scale_inv=down_sf, + _apply_gate=lambda x: F.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + act_fn=F.silu, + ) + + +def _random_routing(num_tokens: int, top_k: int, num_experts: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + idx = torch.randint(0, num_experts, (num_tokens, top_k), dtype=torch.int32, device=device) + w = torch.rand(num_tokens, top_k, dtype=torch.float32, device=device) + return idx, w / w.sum(dim=-1, keepdim=True).clamp_min(1e-6) + + +def _check_output(out: torch.Tensor, expected_shape: tuple[int, ...], label: str) -> None: + assert out.shape == expected_shape, f"[{label}] shape mismatch: {tuple(out.shape)} vs {expected_shape}" + assert torch.isfinite(out).all(), f"[{label}] output has non-finite values" + print(f"[{label}] PASS out: {tuple(out.shape)} dtype={out.dtype}") + + +# ── Tests ──────────────────────────────────────────────────────────────────────── + + +def test_dsv3_fp8(device: torch.device) -> None: + label = "DSv3 (FP8 + float SF)" + if torch.cuda.get_device_capability(device)[0] < 9: + print(f"[{label}] SKIP: needs SM90+ (Hopper)") + return + T, H, I, E, K = 256, 1024, 512, 16, 4 + experts = _make_fp8_experts(E, H, I, ue8m0_sf=False, device=device) + hidden = torch.randn(T, H, dtype=torch.bfloat16, device=device) * 0.1 + idx, w = _random_routing(T, K, E, device) + out = deepgemm_fp8_fp4_experts_forward(experts, hidden, idx, w.to(torch.bfloat16)) + _check_output(out, (T, H), label) + + +def test_dsv4_fp8(device: torch.device) -> None: + label = "DSv4-FP8 (FP8 + UE8M0 SF)" + if torch.cuda.get_device_capability(device)[0] < 10: + print(f"[{label}] SKIP: needs SM100+ (Blackwell) for UE8M0 SF dispatch") + return + T, H, I, E, K = 256, 1024, 512, 16, 4 + experts = _make_fp8_experts(E, H, I, ue8m0_sf=True, device=device) + hidden = torch.randn(T, H, dtype=torch.bfloat16, device=device) * 0.1 + idx, w = _random_routing(T, K, E, device) + out = deepgemm_fp8_fp4_experts_forward(experts, hidden, idx, w.to(torch.bfloat16)) + _check_output(out, (T, H), label) + + +def test_dsv4_fp4(device: torch.device) -> None: + label = "DSv4 (FP4 + UE8M0 SF)" + if torch.cuda.get_device_capability(device)[0] < 10: + print(f"[{label}] SKIP: needs SM100+ (Blackwell)") + return + T, H, I, E, K = 256, 1024, 512, 16, 4 + experts = _make_fp4_experts(E, H, I, device) + hidden = torch.randn(T, H, dtype=torch.bfloat16, device=device) * 0.1 + idx, w = _random_routing(T, K, E, device) + out = deepgemm_fp8_fp4_experts_forward(experts, hidden, idx, w.to(torch.bfloat16)) + _check_output(out, (T, H), label) + + +def test_megamoe(device: torch.device, world_size: int, rank: int) -> None: + label = "Mega MoE (FP8 act × FP4 weight, fused EP)" + if torch.cuda.get_device_capability(device)[0] < 10: + if rank == 0: + print(f"[{label}] SKIP: needs SM100+ (Blackwell)") + return + if world_size < 2: + if rank == 0: + print(f"[{label}] SKIP: needs >=2 ranks (run with `torchrun --nproc_per_node=2`)") + return + + deepgemm = _load_deepgemm_kernel() + T_local, H, I, K = 64, 1024, 512, 4 + E_global = 16 + E_local = E_global // world_size + + # Build raw FP4 experts on this rank's slice, then transform to the kernel's layout. + raw = _make_fp4_experts(E_local, H, I, device) + gate_up_t, gate_up_sf_t = deepgemm.transform_weights_for_mega_moe( + raw.gate_up_proj, raw.gate_up_proj_scale_inv, is_l1=True + ) + down_t, down_sf_t = deepgemm.transform_weights_for_mega_moe( + raw.down_proj, raw.down_proj_scale_inv, is_l1=False + ) + + experts = SimpleNamespace( + gate_up_proj=gate_up_t, + gate_up_proj_scale_inv=gate_up_sf_t, + down_proj=down_t, + down_proj_scale_inv=down_sf_t, + symm_buffer=None, # lazily allocated on first call + config=SimpleNamespace(), # no swiglu_limit → kernel runs unclamped + ) + + hidden = torch.randn(T_local, H, dtype=torch.bfloat16, device=device) * 0.1 + # Mega MoE expects GLOBAL expert ids (no per-rank remap); -1 marks skipped slots. + idx = torch.randint(0, E_global, (T_local, K), dtype=torch.int32, device=device) + w = torch.rand(T_local, K, dtype=torch.float32, device=device) + w = w / w.sum(dim=-1, keepdim=True).clamp_min(1e-6) + + out = deepgemm_fp8_fp4_megamoe_experts_forward( + experts, hidden, idx, w.to(torch.bfloat16), process_group=dist.group.WORLD + ) + if rank == 0: + _check_output(out, (T_local, H), label) + + +# ── Entrypoint ─────────────────────────────────────────────────────────────────── + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA required.") + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + + if world_size > 1 and not dist.is_initialized(): + dist.init_process_group("nccl") + + if rank == 0: + print(f"device cap: SM{''.join(str(x) for x in torch.cuda.get_device_capability(device))}, " + f"world_size={world_size}\n") + + # Single-GPU paths run on rank 0 only (ranks > 0 only participate in Mega MoE). + failures: list[tuple[str, BaseException]] = [] + if rank == 0: + for fn in (test_dsv3_fp8, test_dsv4_fp8, test_dsv4_fp4): + try: + fn(device) + except BaseException as exc: + failures.append((fn.__name__, exc)) + print(f"[{fn.__name__}] FAIL — {type(exc).__name__}: {exc}") + + if world_size > 1: + dist.barrier() + try: + test_megamoe(device, world_size, rank) + except BaseException as exc: + if rank == 0: + failures.append(("test_megamoe", exc)) + print(f"[test_megamoe] FAIL — {type(exc).__name__}: {exc}") + dist.destroy_process_group() + + if rank == 0: + if failures: + print(f"\n=== {len(failures)} test(s) failed ===") + for name, exc in failures: + print(f" - {name}: {type(exc).__name__}: {exc}") + raise SystemExit(1) + print("\n=== all tests passed ===") + + +if __name__ == "__main__": + main() From 1fe3768096c0be86ef9e573ca6993a28330028cc Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 07:15:55 +0200 Subject: [PATCH 50/87] check nvcc --- check_nvcc_b200.py | 168 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 check_nvcc_b200.py diff --git a/check_nvcc_b200.py b/check_nvcc_b200.py new file mode 100644 index 000000000000..254100db26d0 --- /dev/null +++ b/check_nvcc_b200.py @@ -0,0 +1,168 @@ +"""Smoke-test the user's nvcc + CUDA setup against the running GPU. + +Compiles and runs a tiny CUDA kernel targeting the device's actual compute +capability (e.g. `sm_100a` on B200, `sm_90a` on H100). Mirrors what DeepGEMM's +JIT does at the first kernel call: + + 1. Locate `nvcc` via `$CUDA_HOME/bin/nvcc`, then PATH, then `/usr/local/cuda`. + 2. nvcc-compile a kernel that uses an SM-specific intrinsic / API. + 3. Launch it, copy result back, sanity-check. + +If this succeeds, DeepGEMM JIT will work at runtime. If it fails, the message +points at the specific layer (toolchain, driver, runtime) so you can fix it +before pulling DeepGEMM into a model run. + +Usage: + python check_nvcc_b200.py + CUDA_HOME=/path/to/cuda python check_nvcc_b200.py +""" + +from __future__ import annotations + +import ctypes +import ctypes.util +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + +import torch + + +def _find_cuda_home() -> str: + """Same search order as the deep-gemm wheel's `_find_cuda_home`.""" + for var in ("CUDA_HOME", "CUDA_PATH"): + cand = os.environ.get(var) + if cand and (Path(cand) / "bin" / "nvcc").is_file(): + return cand + + nvcc = shutil.which("nvcc") + if nvcc: + return str(Path(nvcc).parent.parent) + + try: + import nvidia.cuda_nvcc as _nvcc # type: ignore + cand = Path(_nvcc.__file__).parent + if (cand / "bin" / "nvcc").is_file(): + return str(cand) + except ImportError: + pass + + for cand in ("/usr/local/cuda", "/opt/cuda", "/opt/nvidia/cuda", "/usr/lib/cuda"): + if (Path(cand) / "bin" / "nvcc").is_file(): + return cand + import glob + for cand in sorted(glob.glob("/usr/local/cuda-*"), reverse=True): + if (Path(cand) / "bin" / "nvcc").is_file(): + return cand + raise SystemExit("nvcc not found. Set CUDA_HOME or install CUDA toolkit.") + + +_KERNEL_SRC = r""" +#include +#include + +// One element per thread; writes its global index. Probes that arch-specific +// codegen + scheduling work end-to-end on the device. +__global__ void identity_kernel(int* out, int n) { + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + out[i] = i; + } +} + +extern "C" __host__ int run_check(int n) { + int* d_out = nullptr; + cudaError_t err = cudaMalloc(&d_out, n * sizeof(int)); + if (err != cudaSuccess) { fprintf(stderr, "cudaMalloc: %s\n", cudaGetErrorString(err)); return 1; } + + int threads = 128, blocks = (n + threads - 1) / threads; + identity_kernel<<>>(d_out, n); + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { fprintf(stderr, "kernel launch: %s\n", cudaGetErrorString(err)); cudaFree(d_out); return 2; } + + int* h_out = (int*)malloc(n * sizeof(int)); + err = cudaMemcpy(h_out, d_out, n * sizeof(int), cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { fprintf(stderr, "cudaMemcpy: %s\n", cudaGetErrorString(err)); free(h_out); cudaFree(d_out); return 3; } + + int ok = 1; + for (int i = 0; i < n; ++i) if (h_out[i] != i) { ok = 0; break; } + free(h_out); + cudaFree(d_out); + return ok ? 0 : 4; +} +""" + + +def main() -> int: + if not torch.cuda.is_available(): + print("FAIL: CUDA not available to torch.") + return 1 + + cap = torch.cuda.get_device_capability() + sm = f"{cap[0]}{cap[1]}a" + name = torch.cuda.get_device_name() + print(f"GPU: {name} (compute capability sm_{sm})") + + cuda_home = _find_cuda_home() + nvcc = str(Path(cuda_home) / "bin" / "nvcc") + print(f"CUDA_HOME: {cuda_home}") + + ver = subprocess.run([nvcc, "--version"], capture_output=True, text=True) + print(ver.stdout.strip().splitlines()[-1] if ver.stdout else ver.stderr) + + with tempfile.TemporaryDirectory() as td: + src = Path(td) / "probe.cu" + so = Path(td) / "probe.so" + src.write_text(_KERNEL_SRC) + + cmd = [ + nvcc, "-shared", "-Xcompiler=-fPIC", + "-O2", "-std=c++17", + f"-gencode=arch=compute_{cap[0]}{cap[1]}{'a' if cap[0] >= 9 else ''},code=sm_{sm}", + "-o", str(so), str(src), + ] + print("\n[1/3] nvcc compile…") + r = subprocess.run(cmd, capture_output=True, text=True) + if r.returncode != 0: + print(f"FAIL: nvcc compile (exit {r.returncode})") + print("--- stderr ---") + print(r.stderr) + return 1 + print(" OK") + + print("[2/3] dlopen…") + try: + lib = ctypes.CDLL(str(so)) + except OSError as e: + print(f"FAIL: dlopen: {e}") + print("Hint: missing libcudart.so on LD_LIBRARY_PATH. Try:") + print(f" export LD_LIBRARY_PATH={cuda_home}/lib64:$LD_LIBRARY_PATH") + return 1 + lib.run_check.restype = ctypes.c_int + lib.run_check.argtypes = [ctypes.c_int] + print(" OK") + + print("[3/3] launch kernel…") + rc = lib.run_check(1024) + labels = { + 0: "OK", + 1: "cudaMalloc failed", + 2: "kernel launch / sync failed", + 3: "cudaMemcpy failed", + 4: "kernel produced wrong values", + } + print(f" run_check → {rc} ({labels.get(rc, 'unknown')})") + if rc != 0: + print("\nFAIL: nvcc compiles but the kernel did not run correctly.") + print("Common causes: GPU driver too old for the toolkit, mismatched libcudart.") + return 1 + + print(f"\nPASS: nvcc {Path(nvcc).name} can compile + run sm_{sm} kernels on this {name}.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From bbfda35ecbec3485c885c11a15ddd03c226f565c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 07:18:47 +0200 Subject: [PATCH 51/87] probe --- probe_deepgemm_sf.py | 72 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 probe_deepgemm_sf.py diff --git a/probe_deepgemm_sf.py b/probe_deepgemm_sf.py new file mode 100644 index 000000000000..77d5a97ecaf9 --- /dev/null +++ b/probe_deepgemm_sf.py @@ -0,0 +1,72 @@ +"""Print the actual SF shapes / strides / dtypes the DeepGEMM integration feeds +into `m_grouped_fp8_fp4_gemm_nt_contiguous`, for each test case. + +When the kernel rejects an SF with `check_sf_layout` assertions, you usually +can't tell from the message *which* SF (activation or weight) failed and what +its actual layout was. This wraps `_coerce_sf_for_kernel` to log every call, +then runs the smoke tests so you can see the exact tensor metadata that hit +the kernel boundary right before the assertion fired. + +Usage: + CUDA_HOME=$HOME/cuda-12.9 python probe_deepgemm_sf.py +""" + +from __future__ import annotations + +import sys + +import torch + +import test_deepgemm as t +from transformers.integrations import deepgemm as di + + +_real_coerce = di._coerce_sf_for_kernel +_call_idx = [0] + + +def _verbose_coerce(sf: torch.Tensor) -> torch.Tensor: + out = _real_coerce(sf) + _call_idx[0] += 1 + print( + f" [#{_call_idx[0]}] in: shape={tuple(sf.shape)} " + f"stride={tuple(sf.stride())} dtype={sf.dtype}" + ) + print( + f" out: shape={tuple(out.shape)} " + f"stride={tuple(out.stride())} dtype={out.dtype}" + ) + return out + + +di._coerce_sf_for_kernel = _verbose_coerce + + +def _run(name: str, fn) -> bool: + print(f"\n=== {name} ===") + _call_idx[0] = 0 + try: + fn(d) + print(f" → PASS") + return True + except BaseException as exc: + print(f" → FAIL: {type(exc).__name__}: {str(exc)[:300]}") + return False + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + sys.exit("CUDA required.") + torch.cuda.set_device(0) + d = torch.device("cuda", 0) + print( + f"GPU: {torch.cuda.get_device_name(d)} " + f"SM{''.join(str(x) for x in torch.cuda.get_device_capability(d))}" + ) + + results = [ + _run("test_dsv3_fp8", t.test_dsv3_fp8), + _run("test_dsv4_fp8", t.test_dsv4_fp8), + _run("test_dsv4_fp4", t.test_dsv4_fp4), + ] + sys.exit(0 if all(results) else 1) From 68a373cd1a1f3277d3727acf8dbe2025b72a7f8b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 07:22:35 +0200 Subject: [PATCH 52/87] fix --- src/transformers/integrations/deepgemm.py | 58 ++++++++++++++++++----- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 0e50e483711e..a435a65c2fa5 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -161,26 +161,49 @@ def _load_deepgemm_kernel() -> DeepGEMM: def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: """Normalize a scale-factor tensor for the DeepGEMM kernel boundary. - The kernel's `check_sf_layout` requires `sf.stride(-2) == 1` (MN-major). - Default contiguous PyTorch tensors are K-major; flip when needed. + `check_sf_layout` (csrc/utils/layout.hpp) imposes two constraints: + + 1. `sf.stride(-2) == 1` — MN-major. + 2. `sf.stride(-1) == get_tma_aligned_size(mn, esize)` — TMA-aligned. + + PyTorch's default contiguous layout is K-major, so we explicitly build a + new tensor with the required strides via `empty_strided`. This also fixes + the size-1 last-dim case where `transpose+contiguous+transpose` is a no-op + (PyTorch reports stride(-1)=1 for size-1 dims, which fails (2)). Three SF flavors arrive at this boundary: - - `float32` (DeepSeek V3-style): flip layout only. + - `float32` (DeepSeek V3-style): rewrite layout only. - `int32` (already-packed UE8M0 from `per_token_cast_to_fp8( - use_packed_ue8m0=True)` or saved checkpoints): flip layout only. + use_packed_ue8m0=True)` or saved checkpoints): rewrite layout only. - `float8_e8m0fnu` (raw UE8M0 bytes, 1 byte per scale; e.g. on-disk weights): pack 4 contiguous K-bytes into `int32` (last dim shrinks - 4×) AND flip layout. Mirrors what - `get_mn_major_tma_aligned_packed_ue8m0_tensor` produces from a - float32 SF, just starting from already-quantized bytes. + 4×) and rewrite layout. """ if sf.dtype == torch.float8_e8m0fnu: - # `view(int32)` requires the source contiguous (4 K-bytes adjacent). - # Pack first while still K-major, then flip to MN-major. + # view(int32) requires the source contiguous (4 K-bytes adjacent). sf = sf.contiguous().view(torch.int32) - if sf.stride(-2) != 1: - sf = sf.transpose(-1, -2).contiguous().transpose(-1, -2) - return sf + + if sf.dim() not in (2, 3): + raise ValueError(f"DeepGEMM SF must be 2D or 3D, got {sf.dim()}D") + + mn = sf.size(-2) + kf = sf.size(-1) + elem_size = sf.element_size() + # `get_tma_aligned_size`: align(mn, 16 / element_size). + align_to = 16 // elem_size + aligned_mn = -(-mn // align_to) * align_to # ceil-multiple + + if sf.dim() == 2: + target_strides = (1, aligned_mn) + else: # 3D + target_strides = (kf * aligned_mn, 1, aligned_mn) + + if tuple(sf.stride()) == target_strides: + return sf + + out = torch.empty_strided(sf.shape, target_strides, dtype=sf.dtype, device=sf.device) + out.copy_(sf) + return out def deepgemm_fp8_fp4_linear( @@ -228,10 +251,13 @@ def deepgemm_fp8_fp4_linear( qinput_2d, scale_2d = deepgemm.per_token_cast_to_fp8(input_2d, **cast_kwargs) output = torch.empty(qinput_2d.shape[0], weight.shape[0], device=input.device, dtype=output_dtype) + # Pass the recipe explicitly — see comment in `deepgemm_fp8_fp4_experts_forward`. + sf_recipe = (1, 1, cast_kwargs["gran_k"]) if cast_kwargs.get("use_packed_ue8m0") else None deepgemm.fp8_fp4_matmul( (qinput_2d, _coerce_sf_for_kernel(scale_2d)), (weight, _coerce_sf_for_kernel(weight_scale_inv)), output, + recipe=sf_recipe, ) output = output.view(input.shape[:-1] + (weight.shape[0],)) if bias is not None: @@ -461,6 +487,12 @@ def deepgemm_fp8_fp4_experts_forward( expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) + # The kernel infers a default recipe from the SF dtype/shape on SM100 — + # `(1, 1, 128)` for any int SF, regardless of the SF's actual gran_k. For + # FP4 weights (gran_k=32) this picks the wrong shape contract, so pass + # the recipe explicitly. `(1, 1, gran_k)` matches `cast_kwargs["gran_k"]`. + sf_recipe = (1, 1, cast_kwargs["gran_k"]) if cast_kwargs.get("use_packed_ue8m0") else None + # --- Up projection per expert (DeepGEMM grouped contiguous) --- act_fp8, act_scales = deepgemm.per_token_cast_to_fp8(selected_hidden_states_g, **cast_kwargs) act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) @@ -471,6 +503,7 @@ def deepgemm_fp8_fp4_experts_forward( (w_up, _coerce_sf_for_kernel(ws_up)), proj_out, grouped_layout, + recipe=sf_recipe, use_psum_layout=use_psum_layout, ) @@ -488,6 +521,7 @@ def deepgemm_fp8_fp4_experts_forward( (self.down_proj, _coerce_sf_for_kernel(self.down_proj_scale_inv)), proj_out, grouped_layout, + recipe=sf_recipe, use_psum_layout=use_psum_layout, ) From 825f6d49917718d3a68ee939ec8fa7338b1747ce Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 07:29:28 +0200 Subject: [PATCH 53/87] test psum --- src/transformers/integrations/deepgemm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index a435a65c2fa5..cf44cf89dd08 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -482,7 +482,10 @@ def deepgemm_fp8_fp4_experts_forward( selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + # A/B test: gate psum_layout on int-SF only, to see whether the float-SF + # NaN on B200 correlates with the psum_layout dispatch. + is_int_sf = bool(cast_kwargs.get("use_packed_ue8m0")) or is_fp4_weights + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 and is_int_sf sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) From 028a39f16e4709e24b732cc5611bc1cccaee3470 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 07:32:03 +0200 Subject: [PATCH 54/87] test --- src/transformers/integrations/deepgemm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index cf44cf89dd08..e362bbb48084 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -183,6 +183,11 @@ def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: # view(int32) requires the source contiguous (4 K-bytes adjacent). sf = sf.contiguous().view(torch.int32) + # A/B: skip the MN-major rewrite for float SF — the kernel re-lays it out + # itself via `transform_sf_into_required_layout` (broadcast + pack). + if sf.dtype == torch.float32: + return sf.contiguous() + if sf.dim() not in (2, 3): raise ValueError(f"DeepGEMM SF must be 2D or 3D, got {sf.dim()}D") @@ -482,10 +487,7 @@ def deepgemm_fp8_fp4_experts_forward( selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - # A/B test: gate psum_layout on int-SF only, to see whether the float-SF - # NaN on B200 correlates with the psum_layout dispatch. - is_int_sf = bool(cast_kwargs.get("use_packed_ue8m0")) or is_fp4_weights - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 and is_int_sf + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) From ee173a5fae856b43f001b52f6ae2cc5280508dd0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 07:35:15 +0200 Subject: [PATCH 55/87] test --- src/transformers/integrations/deepgemm.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index e362bbb48084..a435a65c2fa5 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -183,11 +183,6 @@ def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: # view(int32) requires the source contiguous (4 K-bytes adjacent). sf = sf.contiguous().view(torch.int32) - # A/B: skip the MN-major rewrite for float SF — the kernel re-lays it out - # itself via `transform_sf_into_required_layout` (broadcast + pack). - if sf.dtype == torch.float32: - return sf.contiguous() - if sf.dim() not in (2, 3): raise ValueError(f"DeepGEMM SF must be 2D or 3D, got {sf.dim()}D") From 900e984dbb92a8c608f6bcda9e2523a7667e6478 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 07:37:21 +0200 Subject: [PATCH 56/87] probe --- probe_deepgemm_sf.py | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/probe_deepgemm_sf.py b/probe_deepgemm_sf.py index 77d5a97ecaf9..814445e06249 100644 --- a/probe_deepgemm_sf.py +++ b/probe_deepgemm_sf.py @@ -28,13 +28,18 @@ def _verbose_coerce(sf: torch.Tensor) -> torch.Tensor: out = _real_coerce(sf) _call_idx[0] += 1 + nonfinite_in = (~torch.isfinite(sf.float())).sum().item() if sf.is_floating_point() else 0 + nonfinite_out = (~torch.isfinite(out.float())).sum().item() if out.is_floating_point() else 0 print( f" [#{_call_idx[0]}] in: shape={tuple(sf.shape)} " - f"stride={tuple(sf.stride())} dtype={sf.dtype}" + f"stride={tuple(sf.stride())} dtype={sf.dtype} " + f"min={sf.float().abs().min().item():.3e} max={sf.float().abs().max().item():.3e} " + f"nonfinite={nonfinite_in}" ) print( f" out: shape={tuple(out.shape)} " - f"stride={tuple(out.stride())} dtype={out.dtype}" + f"stride={tuple(out.stride())} dtype={out.dtype} " + f"nonfinite={nonfinite_out}" ) return out @@ -42,6 +47,36 @@ def _verbose_coerce(sf: torch.Tensor) -> torch.Tensor: di._coerce_sf_for_kernel = _verbose_coerce +# Wrap the matmul itself: print output stats after the call so we can see +# where NaN actually appears in the pipeline. +_real_matmul = None + + +def _verbose_matmul(*args, **kwargs): + global _real_matmul + out_tensor = args[2] # (a_pair, b_pair, d, ...) + label = f"matmul (d.shape={tuple(out_tensor.shape)})" + _real_matmul(*args, **kwargs) + nf = (~torch.isfinite(out_tensor)).sum().item() + print( + f" → {label}: nonfinite_count={nf} " + f"min_abs={out_tensor.abs().min().item():.3e} " + f"max_abs={out_tensor.abs().max().item():.3e}" + ) + + +def _patch_matmul(): + global _real_matmul + deepgemm = di._load_deepgemm_kernel() + _real_matmul = deepgemm.grouped_fp8_fp4_matmul + + # Replace the cached kernel's matmul with our wrapper. + object.__setattr__(deepgemm, "grouped_fp8_fp4_matmul", _verbose_matmul) + + +_patch_matmul() + + def _run(name: str, fn) -> bool: print(f"\n=== {name} ===") _call_idx[0] = 0 From 09a711f362e5746bab4ecb948ade074d4a2e004e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 07:39:24 +0200 Subject: [PATCH 57/87] fix --- src/transformers/integrations/deepgemm.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index a435a65c2fa5..4d4271215902 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -311,10 +311,16 @@ def _build_deepgemm_contiguous_layout( def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: """Pad a sorted tensor into the TMA-aligned contiguous layout. - Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) - or via the psum offsets (Blackwell), so their values never enter the computation. + Padding rows are zero-initialized: on SM100 the psum_layout dispatch computes + every row in the per-expert aligned range (it has no per-row skip mask, only + cumulative offsets), so any garbage in the padding feeds straight into the + GEMM. For float-SF activations that's catastrophic — uninitialized float32 + bit patterns can be huge (or NaN), blow up the FP8 dequant, and overflow. + With zero-initialised padding: FP8 acts → 0, float SF → 0 (dequant = 0), + UE8M0 SF → byte 0 (≈2^-127, dequant ≈ 0). Dot product on padding rows + becomes 0, harmless. """ - padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + padded = torch.zeros(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) padded[sorted_to_padded] = x return padded From a0d49400f8015541c63dd22dac4baca69bfad4d5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 08:46:55 +0200 Subject: [PATCH 58/87] test --- test_deepgemm.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/test_deepgemm.py b/test_deepgemm.py index 41ff2ccfba5e..f6f2f63d0aae 100644 --- a/test_deepgemm.py +++ b/test_deepgemm.py @@ -65,14 +65,26 @@ def _make_fp8_experts(num_experts: int, hidden_size: int, intermediate_size: int block_n = 1 if ue8m0_sf else 128 def _alloc(e: int, n: int, k: int) -> tuple[torch.Tensor, torch.Tensor]: - # Random bf16 → clamp to FP8 range → cast. Values are nonsense but byte-valid. - w_fp32 = (torch.randn(e, n, k, device=device) * 0.1).clamp(-_FP8_MAX, _FP8_MAX) - w_fp8 = w_fp32.to(_FP8_DTYPE) + # Per-block-amax FP8 quantization (matches what real DeepSeek-V3 + # checkpoints look like): generate random bf16 weights, compute the + # max-abs of each (block_n × block_k) tile as the SF, divide by 448 + # so the quantized values use the full FP8 range, then cast to FP8. + # Without this, dequantized weights are tiny and the GEMM + # accumulation on Blackwell's float→UE8M0 conversion path can + # produce NaN. + w_fp32 = torch.randn(e, n, k, device=device) * 0.1 sf_n = -(-n // block_n) # ceil-div sf_k = -(-k // block_k) - sf = (torch.rand(e, sf_n, sf_k, device=device) * 0.05 + 0.001).to(torch.float32) + # Block amax → scale. + w_blocks = w_fp32.view(e, sf_n, block_n, sf_k, block_k) + amax = w_blocks.abs().amax(dim=(2, 4)).clamp(min=1e-4) # (e, sf_n, sf_k) + sf = (amax / _FP8_MAX).to(torch.float32) if ue8m0_sf: sf = _round_to_ue8m0(sf) + # Quantize using the dequantized SF (so the cast actually matches). + sf_dequant = sf.float() + w_scaled = w_fp32 / sf_dequant.view(e, sf_n, 1, sf_k, 1).expand(-1, -1, block_n, -1, block_k).reshape(e, n, k) + w_fp8 = w_scaled.clamp(-_FP8_MAX, _FP8_MAX).to(_FP8_DTYPE) return w_fp8, sf gate_up, gate_up_sf = _alloc(num_experts, 2 * intermediate_size, hidden_size) From 653b7b3bcd164c7da32a1b5f48d9810f2a177c09 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 08:52:56 +0200 Subject: [PATCH 59/87] nan issue --- probe_dsv3_conversion.py | 78 ++++++++++++++++++++++ repro_nan_dsv3_b200.py | 136 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 probe_dsv3_conversion.py create mode 100644 repro_nan_dsv3_b200.py diff --git a/probe_dsv3_conversion.py b/probe_dsv3_conversion.py new file mode 100644 index 000000000000..723e063e7fe0 --- /dev/null +++ b/probe_dsv3_conversion.py @@ -0,0 +1,78 @@ +"""Compare the kernel's float→packed-UE8M0 conversion against a Python +equivalent, on the exact SF tensors that DSv3 feeds to the GEMM. Goal: find +out whether `transpose_and_pack_fp32_into_ue8m0` (the JIT helper the kernel +runs internally for DSv3 on SM100) produces values matching what we'd compute +ourselves. + +If the kernel's output matches Python on every byte, the NaN comes from +somewhere else in the GEMM. If it doesn't match, the conversion itself is the +bug. +""" + +from __future__ import annotations + +import torch + +from transformers.integrations.deepgemm import _load_deepgemm_kernel + + +def py_pack(sf_fp32: torch.Tensor) -> torch.Tensor: + """Same as `pack_fp32_into_ue8m0`: extract the biased exponent (bits + [30:23]) of each float as a uint8, then pack 4 K-consecutive bytes into + one int32 (LSB = lowest K). Returns an MN-major int32 tensor. + """ + # Extract biased exponent → uint8 + byte = (sf_fp32.view(torch.int32) >> 23).to(torch.uint8) + # Reshape so K dim is divisible by 4 + *batch, mn, k = byte.shape + assert k % 4 == 0 + # Pack each group of 4 K-bytes into 1 int32 in little-endian order + grouped = byte.view(*batch, mn, k // 4, 4).to(torch.int32) + packed = grouped[..., 0] | (grouped[..., 1] << 8) | (grouped[..., 2] << 16) | (grouped[..., 3] << 24) + return packed # shape (..., mn, k//4) K-major; caller should rewrite to MN-major + + +def main(): + torch.cuda.set_device(0) + d = torch.device("cuda", 0) + dg = _load_deepgemm_kernel() + + # Activation SF case: shape (M, K_blocks), per-row. + print("=== activation SF: (3056, 8) float32 → kernel pack vs python pack ===") + M, Kb = 3056, 8 + sf_a = (torch.rand(M, Kb, device=d) * 0.05 + 0.001).to(torch.float32) + + kernel_packed = dg.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf_a) + py_packed = py_pack(sf_a) + print(f" kernel out shape={tuple(kernel_packed.shape)} stride={tuple(kernel_packed.stride())} dtype={kernel_packed.dtype}") + print(f" python out shape={tuple(py_packed.shape)}") + # Compare bytes + kernel_bytes = kernel_packed.contiguous().view(torch.uint8).flatten() + python_bytes = py_packed.contiguous().view(torch.uint8).flatten() + n = min(kernel_bytes.numel(), python_bytes.numel()) + diff = (kernel_bytes[:n] != python_bytes[:n]).sum().item() + print(f" byte-equal count: {(n - diff)}/{n} (diff={diff})") + + print("\n=== weight SF: (16, 8, 8) float32 → kernel broadcast+pack ===") + E, sn, sk = 16, 8, 8 + N, K = 1024, 1024 + sf_w = (torch.rand(E, sn, sk, device=d) * 0.05 + 0.001).to(torch.float32) + + # Python broadcast: each block-row repeated 128 times along dim -2. + sf_w_broadcast = sf_w.repeat_interleave(N // sn, dim=-2) # (E, N, sk) + py_packed_w = py_pack(sf_w_broadcast) # (E, N, sk//4) + print(f" python broadcast+pack shape={tuple(py_packed_w.shape)}") + + # Kernel path: pass float SF to transform_sf_into_required_layout via the + # recipe machinery. We don't have direct access; the closest helper is + # the public one which expects per-row float input. Skip and just confirm + # python pack is correct on the broadcasted form. + + # Sanity: print a slice of packed bytes for visual inspection. + print(f" python packed[0, 0, :] = {py_packed_w[0, 0, :].tolist()}") + print(f" python packed[0, 127, :] = {py_packed_w[0, 127, :].tolist()} (same block as row 0)") + print(f" python packed[0, 128, :] = {py_packed_w[0, 128, :].tolist()} (next block)") + + +if __name__ == "__main__": + main() diff --git a/repro_nan_dsv3_b200.py b/repro_nan_dsv3_b200.py new file mode 100644 index 000000000000..5d803dfd37c9 --- /dev/null +++ b/repro_nan_dsv3_b200.py @@ -0,0 +1,136 @@ +"""Minimal reproducer for a NaN that appears on B200 (SM100) but not on H100 +(SM90) when calling DeepGEMM's `m_grouped_fp8_fp4_gemm_nt_contiguous` with +float32 scale factors (DSv3-style block-quantized SFs). + +Setup: + * FP8 weights (E, N, K) cast from a real bf16 tensor with proper + per-(128, 128)-block amax scaling — i.e., dequant_w ≈ original. + * Per-token FP8 activations (M, K) with proper float32 SFs. + * Block-quantized float32 weight SF of shape (E, N/128, K/128). + +Path: + * On SM90: kernel uses the `(FP32, 128, 128)` recipe directly, dispatches + `sm90_m_grouped_fp8_gemm_contiguous_1d2d`. Output is finite. + * On SM100: kernel converts float SF → packed UE8M0 int32 internally via + `index_select(broadcast)` + `transpose_and_pack_fp32_into_ue8m0`, then + dispatches `sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d`. Output has + millions of NaNs. + +To run: + H100: python repro_nan_dsv3_b200.py + B200: CUDA_HOME=$HOME/cuda-12.9 python repro_nan_dsv3_b200.py +""" + +from __future__ import annotations + +import sys + +import torch + +from transformers.integrations.deepgemm import _load_deepgemm_kernel + + +_FP8 = torch.float8_e4m3fn +_FP8_MAX = torch.finfo(_FP8).max # 448.0 + + +def make_grouped_fp8(e: int, n: int, k: int, block_n: int, block_k: int, device): + """Per-block-amax FP8 quantization. Returns (w_fp8, sf_float32).""" + w_fp32 = torch.randn(e, n, k, device=device) * 0.1 + sf_n, sf_k = n // block_n, k // block_k + blocks = w_fp32.view(e, sf_n, block_n, sf_k, block_k) + amax = blocks.abs().amax(dim=(2, 4)).clamp(min=1e-4) # (e, sf_n, sf_k) + sf = (amax / _FP8_MAX).to(torch.float32) + sf_expanded = ( + sf.view(e, sf_n, 1, sf_k, 1).expand(-1, -1, block_n, -1, block_k).reshape(e, n, k) + ) + w_fp8 = (w_fp32 / sf_expanded).clamp(-_FP8_MAX, _FP8_MAX).to(_FP8) + return w_fp8, sf + + +def make_per_token_fp8(x_bf16: torch.Tensor, gran_k: int = 128): + """Per-row amax FP8 quantization. Returns (x_fp8, sf_float32).""" + m, n = x_bf16.shape + assert n % gran_k == 0 + x_view = x_bf16.float().view(m, n // gran_k, gran_k) + amax = x_view.abs().amax(dim=2).clamp(min=1e-4) # (m, n/gran_k) + sf = (amax / _FP8_MAX).to(torch.float32) + x_fp8 = (x_view / sf.unsqueeze(2)).clamp(-_FP8_MAX, _FP8_MAX).view(m, n).to(_FP8) + return x_fp8, sf + + +def to_mn_major(sf: torch.Tensor) -> torch.Tensor: + """Rewrite SF to MN-major + TMA-aligned strides (kernel requires this).""" + elem = sf.element_size() + align = 16 // elem + mn = sf.size(-2) + aligned_mn = -(-mn // align) * align + if sf.dim() == 2: + target = (1, aligned_mn) + elif sf.dim() == 3: + target = (sf.size(-1) * aligned_mn, 1, aligned_mn) + else: + raise ValueError(sf.dim()) + if tuple(sf.stride()) == target: + return sf + out = torch.empty_strided(sf.shape, target, dtype=sf.dtype, device=sf.device) + out.copy_(sf) + return out + + +def main() -> int: + if not torch.cuda.is_available(): + sys.exit("CUDA required.") + device = torch.device("cuda", 0) + torch.cuda.set_device(device) + cap = torch.cuda.get_device_capability(device) + print(f"GPU: {torch.cuda.get_device_name(device)} SM{cap[0]}{cap[1]}") + + dg = _load_deepgemm_kernel() + torch.manual_seed(0) + + # Shapes — small but trigger the real codegen path. + M, N, K, E = 512, 1024, 1024, 4 + block_n, block_k = 128, 128 + + # FP8 weights with realistic per-(128,128)-block amax scales. + w_fp8, w_sf_block = make_grouped_fp8(E, N, K, block_n, block_k, device) + print(f" w_fp8: {tuple(w_fp8.shape)} {w_fp8.dtype}, " + f"w_sf_block: {tuple(w_sf_block.shape)} (block-quantized 128×128)") + + # Activations (one expert per token in this minimal case: round-robin). + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 0.1 + a_fp8, a_sf = make_per_token_fp8(x) + print(f" a_fp8: {tuple(a_fp8.shape)} {a_fp8.dtype}, " + f"a_sf: {tuple(a_sf.shape)} per-token") + + # Grouped layout: equal split, per-row expert id (Hopper layout). + grouped_layout = torch.repeat_interleave( + torch.arange(E, dtype=torch.int32, device=device), M // E + ) + + # Kernel call. Pass float32 SF to force the SM100 broadcast+pack path. + d = torch.empty(M, N, dtype=torch.bfloat16, device=device) + dg.grouped_fp8_fp4_matmul( + (a_fp8, to_mn_major(a_sf)), + (w_fp8, to_mn_major(w_sf_block)), + d, + grouped_layout, + # No `recipe` → kernel picks default `(1, 128, 128)` for (float, float). + # No `use_psum_layout` → default False (per-row id grouped_layout). + ) + + nf = (~torch.isfinite(d)).sum().item() + finite_pct = 100.0 * (1 - nf / d.numel()) + print(f" output: shape={tuple(d.shape)} " + f"nonfinite={nf}/{d.numel()} ({100.0 - finite_pct:.2f}%) " + f"finite_pct={finite_pct:.2f}%") + if nf > 0: + print(" → REPRO: output has NaN/Inf.") + return 1 + print(" → OK: output is finite.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From ed8af6bb2d25b93e7b99607c9e3b8ee6027e1928 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 08:55:55 +0200 Subject: [PATCH 60/87] repro --- repro_nan_dsv3_b200.py | 160 +++++++++++++++-------------------------- 1 file changed, 56 insertions(+), 104 deletions(-) diff --git a/repro_nan_dsv3_b200.py b/repro_nan_dsv3_b200.py index 5d803dfd37c9..7ed42d18fe88 100644 --- a/repro_nan_dsv3_b200.py +++ b/repro_nan_dsv3_b200.py @@ -1,20 +1,13 @@ -"""Minimal reproducer for a NaN that appears on B200 (SM100) but not on H100 -(SM90) when calling DeepGEMM's `m_grouped_fp8_fp4_gemm_nt_contiguous` with -float32 scale factors (DSv3-style block-quantized SFs). - -Setup: - * FP8 weights (E, N, K) cast from a real bf16 tensor with proper - per-(128, 128)-block amax scaling — i.e., dequant_w ≈ original. - * Per-token FP8 activations (M, K) with proper float32 SFs. - * Block-quantized float32 weight SF of shape (E, N/128, K/128). - -Path: - * On SM90: kernel uses the `(FP32, 128, 128)` recipe directly, dispatches - `sm90_m_grouped_fp8_gemm_contiguous_1d2d`. Output is finite. - * On SM100: kernel converts float SF → packed UE8M0 int32 internally via - `index_select(broadcast)` + `transpose_and_pack_fp32_into_ue8m0`, then - dispatches `sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d`. Output has - millions of NaNs. +"""Minimal isolated test of DeepGEMM's `get_mn_major_tma_aligned_packed_ue8m0_tensor`. + +The full GEMM produces NaN on B200 with float32 SFs but is finite on H100. +Earlier I speculated the bug was in `transpose_and_pack_fp32_into_ue8m0` (the +JIT helper that converts float SF → packed UE8M0 int32 on SM100). This script +tests *only* that conversion against a byte-exact Python reference, with no +GEMM in the loop. Outcome: + + - bytes match on both archs → conversion is fine, NaN is from the GEMM. + - bytes diverge on B200 → confirmed conversion bug. To run: H100: python repro_nan_dsv3_b200.py @@ -30,52 +23,31 @@ from transformers.integrations.deepgemm import _load_deepgemm_kernel -_FP8 = torch.float8_e4m3fn -_FP8_MAX = torch.finfo(_FP8).max # 448.0 - - -def make_grouped_fp8(e: int, n: int, k: int, block_n: int, block_k: int, device): - """Per-block-amax FP8 quantization. Returns (w_fp8, sf_float32).""" - w_fp32 = torch.randn(e, n, k, device=device) * 0.1 - sf_n, sf_k = n // block_n, k // block_k - blocks = w_fp32.view(e, sf_n, block_n, sf_k, block_k) - amax = blocks.abs().amax(dim=(2, 4)).clamp(min=1e-4) # (e, sf_n, sf_k) - sf = (amax / _FP8_MAX).to(torch.float32) - sf_expanded = ( - sf.view(e, sf_n, 1, sf_k, 1).expand(-1, -1, block_n, -1, block_k).reshape(e, n, k) - ) - w_fp8 = (w_fp32 / sf_expanded).clamp(-_FP8_MAX, _FP8_MAX).to(_FP8) - return w_fp8, sf - - -def make_per_token_fp8(x_bf16: torch.Tensor, gran_k: int = 128): - """Per-row amax FP8 quantization. Returns (x_fp8, sf_float32).""" - m, n = x_bf16.shape - assert n % gran_k == 0 - x_view = x_bf16.float().view(m, n // gran_k, gran_k) - amax = x_view.abs().amax(dim=2).clamp(min=1e-4) # (m, n/gran_k) - sf = (amax / _FP8_MAX).to(torch.float32) - x_fp8 = (x_view / sf.unsqueeze(2)).clamp(-_FP8_MAX, _FP8_MAX).view(m, n).to(_FP8) - return x_fp8, sf - - -def to_mn_major(sf: torch.Tensor) -> torch.Tensor: - """Rewrite SF to MN-major + TMA-aligned strides (kernel requires this).""" - elem = sf.element_size() - align = 16 // elem - mn = sf.size(-2) - aligned_mn = -(-mn // align) * align - if sf.dim() == 2: - target = (1, aligned_mn) - elif sf.dim() == 3: - target = (sf.size(-1) * aligned_mn, 1, aligned_mn) - else: - raise ValueError(sf.dim()) - if tuple(sf.stride()) == target: - return sf - out = torch.empty_strided(sf.shape, target, dtype=sf.dtype, device=sf.device) - out.copy_(sf) - return out +def python_pack(sf_fp32: torch.Tensor) -> torch.Tensor: + """Reference: extract biased exponent (bits [30:23]) of each float32 as a + uint8, then pack 4 K-consecutive bytes into one int32 (LE, byte 0 = lowest + K). Output shape: same as input but last dim shrunk 4×; layout: K-major. + """ + byte = (sf_fp32.contiguous().view(torch.int32) >> 23).to(torch.uint8) + *batch, mn, k = byte.shape + assert k % 4 == 0 + g = byte.view(*batch, mn, k // 4, 4).to(torch.int32) + return g[..., 0] | (g[..., 1] << 8) | (g[..., 2] << 16) | (g[..., 3] << 24) + + +def compare(name: str, kernel_out: torch.Tensor, py_out: torch.Tensor) -> bool: + # Compare byte-by-byte, ignoring TMA-alignment padding the kernel may add. + *_, mn, kf = py_out.shape + # Slice kernel output to the same shape (it can be wider in mn due to TMA align). + if kernel_out.shape != py_out.shape: + kernel_out = kernel_out[..., :mn, :kf].contiguous() + py_bytes = py_out.contiguous().view(torch.uint8).flatten() + k_bytes = kernel_out.contiguous().view(torch.uint8).flatten() + n = min(py_bytes.numel(), k_bytes.numel()) + diff = (py_bytes[:n] != k_bytes[:n]).sum().item() + print(f" [{name}] kernel={tuple(kernel_out.shape)} stride={tuple(kernel_out.stride())} " + f"py={tuple(py_out.shape)} diff_bytes={diff}/{n}") + return diff == 0 def main() -> int: @@ -89,47 +61,27 @@ def main() -> int: dg = _load_deepgemm_kernel() torch.manual_seed(0) - # Shapes — small but trigger the real codegen path. - M, N, K, E = 512, 1024, 1024, 4 - block_n, block_k = 128, 128 - - # FP8 weights with realistic per-(128,128)-block amax scales. - w_fp8, w_sf_block = make_grouped_fp8(E, N, K, block_n, block_k, device) - print(f" w_fp8: {tuple(w_fp8.shape)} {w_fp8.dtype}, " - f"w_sf_block: {tuple(w_sf_block.shape)} (block-quantized 128×128)") - - # Activations (one expert per token in this minimal case: round-robin). - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 0.1 - a_fp8, a_sf = make_per_token_fp8(x) - print(f" a_fp8: {tuple(a_fp8.shape)} {a_fp8.dtype}, " - f"a_sf: {tuple(a_sf.shape)} per-token") - - # Grouped layout: equal split, per-row expert id (Hopper layout). - grouped_layout = torch.repeat_interleave( - torch.arange(E, dtype=torch.int32, device=device), M // E - ) - - # Kernel call. Pass float32 SF to force the SM100 broadcast+pack path. - d = torch.empty(M, N, dtype=torch.bfloat16, device=device) - dg.grouped_fp8_fp4_matmul( - (a_fp8, to_mn_major(a_sf)), - (w_fp8, to_mn_major(w_sf_block)), - d, - grouped_layout, - # No `recipe` → kernel picks default `(1, 128, 128)` for (float, float). - # No `use_psum_layout` → default False (per-row id grouped_layout). - ) - - nf = (~torch.isfinite(d)).sum().item() - finite_pct = 100.0 * (1 - nf / d.numel()) - print(f" output: shape={tuple(d.shape)} " - f"nonfinite={nf}/{d.numel()} ({100.0 - finite_pct:.2f}%) " - f"finite_pct={finite_pct:.2f}%") - if nf > 0: - print(" → REPRO: output has NaN/Inf.") - return 1 - print(" → OK: output is finite.") - return 0 + cases = [ + ("act SF (per-token)", (512, 8)), # 2D, contig + ("weight SF (grouped)", (4, 1024, 8)), # 3D, per-row N + ("weight SF (block)", (4, 8, 8)), # 3D, block-quant — needs broadcast in real path + ] + + all_ok = True + for name, shape in cases: + sf = (torch.rand(*shape, device=device) * 0.05 + 0.001).to(torch.float32) + kernel_out = dg.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + py_out = python_pack(sf) + if not compare(name, kernel_out, py_out): + all_ok = False + + if all_ok: + print("\nResult: kernel pack matches Python reference byte-for-byte.") + print(" → bug is NOT in `transpose_and_pack_fp32_into_ue8m0`.") + return 0 + print("\nResult: kernel pack diverges from Python reference.") + print(" → confirmed bug in `transpose_and_pack_fp32_into_ue8m0`.") + return 1 if __name__ == "__main__": From fb8d338800213fb620b5451aa590ca111ff2c6ef Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 08:57:58 +0200 Subject: [PATCH 61/87] repro --- repro_nan_dsv3_b200.py | 133 +++++++++++++++++++++++++---------------- 1 file changed, 81 insertions(+), 52 deletions(-) diff --git a/repro_nan_dsv3_b200.py b/repro_nan_dsv3_b200.py index 7ed42d18fe88..024be391eac9 100644 --- a/repro_nan_dsv3_b200.py +++ b/repro_nan_dsv3_b200.py @@ -1,17 +1,24 @@ -"""Minimal isolated test of DeepGEMM's `get_mn_major_tma_aligned_packed_ue8m0_tensor`. - -The full GEMM produces NaN on B200 with float32 SFs but is finite on H100. -Earlier I speculated the bug was in `transpose_and_pack_fp32_into_ue8m0` (the -JIT helper that converts float SF → packed UE8M0 int32 on SM100). This script -tests *only* that conversion against a byte-exact Python reference, with no -GEMM in the loop. Outcome: - - - bytes match on both archs → conversion is fine, NaN is from the GEMM. - - bytes diverge on B200 → confirmed conversion bug. - -To run: - H100: python repro_nan_dsv3_b200.py - B200: CUDA_HOME=$HOME/cuda-12.9 python repro_nan_dsv3_b200.py +"""Pinpoint the DSv3-on-B200 NaN. + +DeepGEMM's `pack_fp32_into_ue8m0` doesn't convert fp32 to UE8M0 — it expects +the fp32 input to *already* be UE8M0-rounded (each value an exact power of 2, +mantissa bits all zero) and just repacks the exponent bytes. The kernel's +inner shifts (`>> 23`, `>> 15`, `>> 7`, `<< 1`) only cleanly extract the +biased exponent for the first lane; the rest leak mantissa bits into adjacent +byte slots when the mantissa isn't zero. + +This script verifies that on raw arbitrary fp32 SFs (kernel output diverges +from a "biased-exponent only" reference) but matches byte-for-byte once the +input is rounded to powers of 2 via `ceil_to_ue8m0`. + +Implication: on SM100 the kernel's `(FP32, x, gran_k)` → packed-int path +silently corrupts SFs unless the caller pre-rounds them. SM90 sidesteps this +because its FP8 path consumes raw fp32 SFs directly without going through +`pack_fp32_into_ue8m0`. + +Run on H100 and B200; both should print: + raw → DIVERGES (kernel needs UE8M0-rounded inputs). + ue8m0 → MATCHES. """ from __future__ import annotations @@ -23,10 +30,19 @@ from transformers.integrations.deepgemm import _load_deepgemm_kernel -def python_pack(sf_fp32: torch.Tensor) -> torch.Tensor: - """Reference: extract biased exponent (bits [30:23]) of each float32 as a - uint8, then pack 4 K-consecutive bytes into one int32 (LE, byte 0 = lowest - K). Output shape: same as input but last dim shrunk 4×; layout: K-major. +def ceil_to_ue8m0(x: torch.Tensor) -> torch.Tensor: + """Round each positive float up to the nearest power of 2 representable as + UE8M0 (mantissa zeroed out). Mirrors upstream's `deep_gemm.utils.math`. + """ + return ( + (x.view(torch.int32) + ((1 << 23) - 1)).bitwise_and_(~((1 << 23) - 1)).view(torch.float) + ) + + +def python_pack_exponent_only(sf_fp32: torch.Tensor) -> torch.Tensor: + """Reference assuming UE8M0 input: extract biased exponent (bits [30:23]) + of each float as a uint8, then pack 4 K-consecutive bytes into one int32 + LE. This *only* matches the kernel when the input has zero mantissa. """ byte = (sf_fp32.contiguous().view(torch.int32) >> 23).to(torch.uint8) *batch, mn, k = byte.shape @@ -35,19 +51,16 @@ def python_pack(sf_fp32: torch.Tensor) -> torch.Tensor: return g[..., 0] | (g[..., 1] << 8) | (g[..., 2] << 16) | (g[..., 3] << 24) -def compare(name: str, kernel_out: torch.Tensor, py_out: torch.Tensor) -> bool: - # Compare byte-by-byte, ignoring TMA-alignment padding the kernel may add. - *_, mn, kf = py_out.shape - # Slice kernel output to the same shape (it can be wider in mn due to TMA align). - if kernel_out.shape != py_out.shape: - kernel_out = kernel_out[..., :mn, :kf].contiguous() - py_bytes = py_out.contiguous().view(torch.uint8).flatten() - k_bytes = kernel_out.contiguous().view(torch.uint8).flatten() - n = min(py_bytes.numel(), k_bytes.numel()) - diff = (py_bytes[:n] != k_bytes[:n]).sum().item() - print(f" [{name}] kernel={tuple(kernel_out.shape)} stride={tuple(kernel_out.stride())} " - f"py={tuple(py_out.shape)} diff_bytes={diff}/{n}") - return diff == 0 +def compare(label: str, kernel_out: torch.Tensor, ref_out: torch.Tensor) -> int: + if kernel_out.shape != ref_out.shape: + kernel_out = kernel_out[..., : ref_out.size(-2), : ref_out.size(-1)].contiguous() + py_b = ref_out.contiguous().view(torch.uint8).flatten() + k_b = kernel_out.contiguous().view(torch.uint8).flatten() + n = min(py_b.numel(), k_b.numel()) + diff = (py_b[:n] != k_b[:n]).sum().item() + status = "MATCH" if diff == 0 else "DIVERGE" + print(f" [{label}] diff_bytes={diff}/{n} ({status})") + return diff def main() -> int: @@ -61,27 +74,43 @@ def main() -> int: dg = _load_deepgemm_kernel() torch.manual_seed(0) - cases = [ - ("act SF (per-token)", (512, 8)), # 2D, contig - ("weight SF (grouped)", (4, 1024, 8)), # 3D, per-row N - ("weight SF (block)", (4, 8, 8)), # 3D, block-quant — needs broadcast in real path - ] - - all_ok = True - for name, shape in cases: - sf = (torch.rand(*shape, device=device) * 0.05 + 0.001).to(torch.float32) - kernel_out = dg.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) - py_out = python_pack(sf) - if not compare(name, kernel_out, py_out): - all_ok = False - - if all_ok: - print("\nResult: kernel pack matches Python reference byte-for-byte.") - print(" → bug is NOT in `transpose_and_pack_fp32_into_ue8m0`.") - return 0 - print("\nResult: kernel pack diverges from Python reference.") - print(" → confirmed bug in `transpose_and_pack_fp32_into_ue8m0`.") - return 1 + # One representative SF tensor — what the integration would feed for + # block-quantized weight SFs in DSv3 inference. + sf_raw = (torch.rand(4, 8, 8, device=device) * 0.05 + 0.001).to(torch.float32) + sf_ue8m0 = ceil_to_ue8m0(sf_raw) + + print("\nfp32 SF → kernel pack vs python `extract biased exponent and pack 4` reference:\n") + + diff_raw = compare( + "raw fp32 SF (mantissa != 0)", + dg.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf_raw), + python_pack_exponent_only(sf_raw), + ) + diff_ue8m0 = compare( + "ceil_to_ue8m0 fp32 SF (mantissa == 0)", + dg.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf_ue8m0), + python_pack_exponent_only(sf_ue8m0), + ) + + print() + print("Conclusion:") + print( + f" raw: {'DIVERGES' if diff_raw else 'MATCHES'} " + "(kernel reads mantissa bits when not zero)" + ) + print( + f" ue8m0: {'DIVERGES' if diff_ue8m0 else 'MATCHES'} " + "(kernel cleanly repacks exponent bytes)" + ) + + if diff_ue8m0 != 0: + print("\nUnexpected: pack diverges even with UE8M0-rounded input — that is a real kernel bug.") + return 1 + if diff_raw == 0: + print("\nUnexpected: kernel matches without UE8M0 rounding — investigate.") + return 1 + print("\nFix: in our integration, round float SFs via ceil_to_ue8m0 before passing.") + return 0 if __name__ == "__main__": From 562eb51202232d5b021bda9373cf297eca0d1c8c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:02:35 +0200 Subject: [PATCH 62/87] fix --- src/transformers/integrations/deepgemm.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 4d4271215902..2d09c4d4033e 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -183,6 +183,22 @@ def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: # view(int32) requires the source contiguous (4 K-bytes adjacent). sf = sf.contiguous().view(torch.int32) + # On SM100 the kernel's `pack_fp32_into_ue8m0` only repacks bytes — it + # reads bits [30:23] (biased exponent) of each fp32 and writes them as + # UE8M0 *but* its inner shifts (>> 15, >> 7, << 1) leak mantissa bits + # into adjacent byte slots. The kernel therefore requires inputs whose + # mantissa is already zero (i.e. each value an exact power of 2). + # `amax / 448` floats produced by per-token / per-block quantizers do + # not satisfy that, so round up to the nearest UE8M0 power-of-2 here. + # On SM90 the dispatch consumes raw fp32 SFs without going through this + # pack, so the rounding would only lose precision — skip it there. + if sf.dtype == torch.float32 and torch.cuda.get_device_capability(sf.device)[0] >= 10: + sf = ( + (sf.view(torch.int32) + ((1 << 23) - 1)) + .bitwise_and_(~((1 << 23) - 1)) + .view(torch.float) + ) + if sf.dim() not in (2, 3): raise ValueError(f"DeepGEMM SF must be 2D or 3D, got {sf.dim()}D") From 481e5e687f70350df2d7fc3241c35a05273ff57a Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:16:55 +0200 Subject: [PATCH 63/87] simplifications --- src/transformers/integrations/deepgemm.py | 628 +++++++++------------- test_deepgemm.py | 34 +- 2 files changed, 283 insertions(+), 379 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 2d09c4d4033e..5919daba4cc5 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -15,17 +15,13 @@ """DeepGEMM integration: fused grouped GEMM kernels from `kernels-community/deep-gemm`. Provides: -- `deepgemm_bf16_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. -- `deepgemm_fp8_fp4_linear`: end-to-end FP8/FP4 linear (BF16 in, BF16 out) — quantizes activations - inside, dispatches cast settings on weight dtype, and runs the FP8/FP4 matmul. Used as the - DeepGEMM fast path inside `fp8_linear`. -- `deepgemm_fp8_fp4_experts_forward`: FP8 (or FP4 on SM100+) M-grouped experts forward, registered as "deepgemm" in the FP8 ExpertsInterface. -- `deepgemm_fp8_fp4_megamoe_experts_forward`: FP8 acts × FP4 weights Mega MoE forward (SM100+, - fuses EP dispatch + L1 + SwiGLU + L2 + EP combine via a `SymmBuffer`). - -Requirements: CUDA, Hopper (SM90+), CUDA runtime >= 12.3, `kernels-community/deep-gemm` (>= 2.5 -so the Mega MoE symbols are available — the loader raises if any required symbol is missing). -Mega MoE additionally requires SM100+ at call time. +- `deepgemm_bf16_experts_forward`: BF16 M-grouped experts forward. +- `deepgemm_fp8_fp4_linear`: end-to-end FP8/FP4 linear (BF16 in, BF16 out). +- `deepgemm_fp8_fp4_experts_forward`: FP8 (or FP4 on SM100+) M-grouped experts forward. +- `deepgemm_fp8_fp4_megamoe_experts_forward`: FP8×FP4 Mega MoE forward (SM100+). + +Requirements: CUDA, Hopper (SM90+), CUDA runtime ≥ 12.3, kernels-community/deep-gemm +≥ 2.5 (Mega MoE symbols required). Mega MoE additionally needs SM100+ at call time. """ from __future__ import annotations @@ -44,23 +40,17 @@ logger = logging.get_logger(__name__) # DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. -# TMA is an H100 hardware addition that allows applications to asynchronously and -# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. _DEEPGEMM_M_ALIGNMENT = 128 - _FP8_DTYPE = torch.float8_e4m3fn -_FP8_MIN = torch.finfo(_FP8_DTYPE).min _FP8_MAX = torch.finfo(_FP8_DTYPE).max +# ── Kernel loading ───────────────────────────────────────────────────────────── + + @dataclass(frozen=True) class DeepGEMM: - """Entry points exposed by the `kernels-community/deep-gemm` kernel. - - Mega MoE entry points are always importable on a current build — they raise at call - time on SM90 (Hopper), guarded by a runtime device-capability check in - `deepgemm_fp8_fp4_megamoe_experts_forward`. - """ + """Curated entry points exposed by `kernels-community/deep-gemm`.""" fp8_fp4_matmul: Callable grouped_fp8_fp4_matmul: Callable @@ -75,42 +65,25 @@ class DeepGEMM: @functools.cache def _load_deepgemm_kernel() -> DeepGEMM: - """ - Load DeepGEMM once and return its entry points. - - Raises `ImportError` if CUDA/hardware requirements are not met or any required entry - point is missing. - """ + """Load DeepGEMM once; raise `ImportError` if env or any required symbol is missing.""" if not is_kernels_available(): raise ImportError("DeepGEMM kernel requires the `kernels` package. Install it with `pip install -U kernels`.") - if not torch.cuda.is_available(): - raise ImportError( - "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." - ) + raise ImportError("DeepGEMM kernel requires CUDA, but CUDA is not available.") - # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions major = torch.cuda.get_device_capability()[0] if major < 9: - raise ImportError( - f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " - f"has compute capability {major}.x. Use a different `experts_implementation`." - ) + raise ImportError(f"DeepGEMM requires Hopper (SM90+); current device is SM{major}0.") - # DeepGEMM requires CUDA runtime >= 12.3 cuda_major, cuda_minor = get_cuda_runtime_version() - if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): - raise ImportError( - f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " - "Please upgrade your CUDA toolkit or use a different `experts_implementation`." - ) + if (cuda_major, cuda_minor) < (12, 3): + raise ImportError(f"DeepGEMM requires CUDA runtime ≥ 12.3, found {cuda_major}.{cuda_minor}.") kernel = lazy_load_kernel("deep-gemm") if kernel is None: raise ImportError( - "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " - "has a build matching the current torch/CUDA." - ) from None + "Failed to load `kernels-community/deep-gemm` — check that a build matches the current torch/CUDA." + ) fp8_fp4_matmul = getattr(kernel, "fp8_fp4_gemm_nt", None) grouped_fp8_fp4_matmul = getattr(kernel, "m_grouped_fp8_fp4_gemm_nt_contiguous", None) @@ -141,10 +114,8 @@ def _load_deepgemm_kernel() -> DeepGEMM: ] if missing: raise ImportError( - f"DeepGEMM kernel is missing required symbols: {', '.join(missing)}. " - "Please update the `kernels` package (`pip install -U kernels`)." + f"DeepGEMM kernel is missing required symbols: {', '.join(missing)}. Update with `pip install -U kernels`." ) - return DeepGEMM( fp8_fp4_matmul=fp8_fp4_matmul, grouped_fp8_fp4_matmul=grouped_fp8_fp4_matmul, @@ -158,166 +129,109 @@ def _load_deepgemm_kernel() -> DeepGEMM: ) -def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: - """Normalize a scale-factor tensor for the DeepGEMM kernel boundary. +# ── Scale-factor helpers ─────────────────────────────────────────────────────── + - `check_sf_layout` (csrc/utils/layout.hpp) imposes two constraints: +def _ceil_to_ue8m0(sf: torch.Tensor) -> torch.Tensor: + """Round each fp32 SF up to the nearest power of 2 (zero mantissa). - 1. `sf.stride(-2) == 1` — MN-major. - 2. `sf.stride(-1) == get_tma_aligned_size(mn, esize)` — TMA-aligned. + Mirrors `deep_gemm.utils.math.ceil_to_ue8m0`. On SM100 the kernel's + `pack_fp32_into_ue8m0` cleanly extracts the biased exponent only when the + mantissa is already zero — its inner shifts (`>> 15`, `>> 7`, `<< 1`) + otherwise leak mantissa bits into adjacent UE8M0 byte slots and silently + corrupt the SF. SM90 consumes raw fp32 SFs without going through this path. + """ + int_view = sf.view(torch.int32) + return (int_view + ((1 << 23) - 1)).bitwise_and_(~((1 << 23) - 1)).view(torch.float) - PyTorch's default contiguous layout is K-major, so we explicitly build a - new tensor with the required strides via `empty_strided`. This also fixes - the size-1 last-dim case where `transpose+contiguous+transpose` is a no-op - (PyTorch reports stride(-1)=1 for size-1 dims, which fails (2)). - Three SF flavors arrive at this boundary: - - `float32` (DeepSeek V3-style): rewrite layout only. - - `int32` (already-packed UE8M0 from `per_token_cast_to_fp8( - use_packed_ue8m0=True)` or saved checkpoints): rewrite layout only. - - `float8_e8m0fnu` (raw UE8M0 bytes, 1 byte per scale; e.g. on-disk - weights): pack 4 contiguous K-bytes into `int32` (last dim shrinks - 4×) and rewrite layout. +def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: + """Lay out `sf` as DeepGEMM's `check_sf_layout` expects: MN-major + (`stride(-2) == 1`) and TMA-aligned (`stride(-1) == align(mn, 16/esize)`). + + Inputs come in three flavors: + - `float8_e8m0fnu`: raw UE8M0 bytes — pack 4 K-bytes → int32 (last dim /4). + - `float32`: per-token / per-block SFs from `per_token_cast_to_fp8` or + on-disk weights — round to UE8M0 on SM100 (see `_ceil_to_ue8m0`). + - `int32`: already-packed UE8M0 — pass through. """ if sf.dtype == torch.float8_e8m0fnu: - # view(int32) requires the source contiguous (4 K-bytes adjacent). sf = sf.contiguous().view(torch.int32) - - # On SM100 the kernel's `pack_fp32_into_ue8m0` only repacks bytes — it - # reads bits [30:23] (biased exponent) of each fp32 and writes them as - # UE8M0 *but* its inner shifts (>> 15, >> 7, << 1) leak mantissa bits - # into adjacent byte slots. The kernel therefore requires inputs whose - # mantissa is already zero (i.e. each value an exact power of 2). - # `amax / 448` floats produced by per-token / per-block quantizers do - # not satisfy that, so round up to the nearest UE8M0 power-of-2 here. - # On SM90 the dispatch consumes raw fp32 SFs without going through this - # pack, so the rounding would only lose precision — skip it there. - if sf.dtype == torch.float32 and torch.cuda.get_device_capability(sf.device)[0] >= 10: - sf = ( - (sf.view(torch.int32) + ((1 << 23) - 1)) - .bitwise_and_(~((1 << 23) - 1)) - .view(torch.float) - ) + elif sf.dtype == torch.float32 and torch.cuda.get_device_capability(sf.device)[0] >= 10: + sf = _ceil_to_ue8m0(sf) if sf.dim() not in (2, 3): raise ValueError(f"DeepGEMM SF must be 2D or 3D, got {sf.dim()}D") - mn = sf.size(-2) - kf = sf.size(-1) - elem_size = sf.element_size() - # `get_tma_aligned_size`: align(mn, 16 / element_size). - align_to = 16 // elem_size - aligned_mn = -(-mn // align_to) * align_to # ceil-multiple - - if sf.dim() == 2: - target_strides = (1, aligned_mn) - else: # 3D - target_strides = (kf * aligned_mn, 1, aligned_mn) + mn, kf = sf.size(-2), sf.size(-1) + align_to = 16 // sf.element_size() # `get_tma_aligned_size`: align(mn, 16 / element_size) + aligned_mn = -(-mn // align_to) * align_to + target_strides = (1, aligned_mn) if sf.dim() == 2 else (kf * aligned_mn, 1, aligned_mn) if tuple(sf.stride()) == target_strides: return sf - out = torch.empty_strided(sf.shape, target_strides, dtype=sf.dtype, device=sf.device) out.copy_(sf) return out -def deepgemm_fp8_fp4_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale_inv: torch.Tensor, - bias: torch.Tensor | None = None, - output_dtype: torch.dtype = torch.bfloat16, - activation_scale: torch.Tensor | None = None, -) -> torch.Tensor: - """End-to-end DeepGEMM linear: per-token activation quant + FP8/FP4 matmul. +def _select_fp8_cast_kwargs( + weight: torch.Tensor, weight_scale_inv: torch.Tensor, block_size: tuple | None, device: torch.device +) -> dict: + """Pick the `per_token_cast_to_fp8` kwargs from weight dtype + SF dtype. - Activation cast settings are inferred from the tensor dtypes: - - FP4 weights (`weight.dtype == torch.int8`): always gran_k=32 with packed-UE8M0 SF. Requires - SM100+ (Blackwell). - - FP8 weights + UE8M0 weight SFs (`weight_scale_inv.dtype == torch.float8_e8m0fnu`, - DeepSeek V4-style): gran_k=128 with packed-UE8M0 SF (skips the kernel-side float→int SF - transform on SM100). - - FP8 weights + float weight SFs (DeepSeek V3-style): gran_k=128 with float SF (works on - Hopper and Blackwell). - - Static (per-tensor) activation quantization is not supported — DeepGEMM's kernel needs per-row - SFs and rejects scalar SFs at its host-side check. Callers should route static activations - through the Triton fallback. + Three cases mirror the kernel's recipes: + - FP4 weights (`int8`): gran_k=32 packed-UE8M0 SF. SM100+ only. + - FP8 weights + UE8M0 SF: gran_k=128 packed-UE8M0 SF (DSv4). + - FP8 weights + float SF: gran_k=128 float SF (DSv3). """ - if activation_scale is not None: - raise NotImplementedError( - "Static (per-tensor) activation quantization is not supported on the DeepGEMM path. " - "Use the Triton fallback for static activations." + if weight.dtype == torch.int8: # FP4 + if torch.cuda.get_device_capability(device)[0] < 10: + raise RuntimeError("FP4 weights (int8-packed e2m1) require SM100+ (Blackwell).") + return {"use_ue8m0": True, "gran_k": 32, "use_packed_ue8m0": True} + # FP8 weights: validate block_size (informational; kernel infers recipe from SF dtype/shape). + if block_size is None: + raise ValueError( + "DeepGEMM requires block-wise quantized FP8 weights, but the experts have no `block_size` set." ) + if block_size not in ((128, 128), (1, 128)): + raise ValueError(f"DeepGEMM requires `block_size` ∈ {{(128, 128), (1, 128)}}, got {block_size}.") + if weight_scale_inv.dtype == torch.float8_e8m0fnu: + return {"use_ue8m0": True, "gran_k": 128, "use_packed_ue8m0": True} + return {"use_ue8m0": False, "gran_k": 128} - is_fp4 = weight.dtype == torch.int8 - if is_fp4 and torch.cuda.get_device_capability(input.device)[0] < 10: - raise RuntimeError("FP4 weights (int8-packed e2m1) require SM100+ (Blackwell).") - deepgemm = _load_deepgemm_kernel() - - if is_fp4: - cast_kwargs = {"use_ue8m0": True, "gran_k": 32, "use_packed_ue8m0": True} - elif weight_scale_inv.dtype == torch.float8_e8m0fnu: - cast_kwargs = {"use_ue8m0": True, "gran_k": 128, "use_packed_ue8m0": True} - else: - cast_kwargs = {"use_ue8m0": False, "gran_k": 128} - input_2d = input.view(-1, input.shape[-1]) - qinput_2d, scale_2d = deepgemm.per_token_cast_to_fp8(input_2d, **cast_kwargs) - - output = torch.empty(qinput_2d.shape[0], weight.shape[0], device=input.device, dtype=output_dtype) - # Pass the recipe explicitly — see comment in `deepgemm_fp8_fp4_experts_forward`. - sf_recipe = (1, 1, cast_kwargs["gran_k"]) if cast_kwargs.get("use_packed_ue8m0") else None - deepgemm.fp8_fp4_matmul( - (qinput_2d, _coerce_sf_for_kernel(scale_2d)), - (weight, _coerce_sf_for_kernel(weight_scale_inv)), - output, - recipe=sf_recipe, - ) - output = output.view(input.shape[:-1] + (weight.shape[0],)) - if bias is not None: - output.add_(bias) - return output +# ── Layout helpers (M-grouped contiguous, TMA-aligned) ───────────────────────── def _build_deepgemm_contiguous_layout( expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool -) -> tuple: - """Build the TMA-aligned layout DeepGEMM's grouped GEMM expects. - - Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes - expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or - per-row expert ids with -1 for padding on Hopper. - - Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) - are routed past the last aligned expert block and marked `-1` in the Hopper layout (and - excluded from the Blackwell cumsum), so DeepGEMM skips them. +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Build the TMA-aligned grouped layout DeepGEMM expects. + + Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`: + - `grouped_layout` is per-row expert id (Hopper, with `-1` for padding / + sentinels) or a cumsum of aligned per-expert counts (Blackwell). + - EP sentinels (values == `num_experts`) are routed past the last expert + block so DeepGEMM skips them. """ device = expert_ids_sorted.device num_tokens = expert_ids_sorted.size(0) - # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. + # `histc` drops values > max, so EP sentinels (== num_experts) don't count. tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment - # Upper bound avoids GPU->CPU sync; padding rows are skipped by DeepGEMM. + # Upper bound — avoids GPU→CPU sync; padding rows are skipped. total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) - # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the - # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, - # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the - # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. + # Exclusive cumsum of per-expert padding (index `num_experts` = total padding, + # which routes EP sentinels past all aligned blocks on Blackwell). padding_per_expert = aligned_tokens_per_expert - tokens_per_expert cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - if use_psum_layout: # Blackwell (SM100+) - # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= - # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler - # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` - # between experts only matches the padded tensor when the stored cumsum is over aligned counts. + if use_psum_layout: # SM100+: kernel reads cumsum of aligned counts as expert boundaries. grouped_layout = aligned_tokens_per_expert.cumsum(0).int() - else: - # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). + else: # SM90: per-row expert id, -1 = skip (padding & sentinels). grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) @@ -327,14 +241,11 @@ def _build_deepgemm_contiguous_layout( def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: """Pad a sorted tensor into the TMA-aligned contiguous layout. - Padding rows are zero-initialized: on SM100 the psum_layout dispatch computes - every row in the per-expert aligned range (it has no per-row skip mask, only - cumulative offsets), so any garbage in the padding feeds straight into the - GEMM. For float-SF activations that's catastrophic — uninitialized float32 - bit patterns can be huge (or NaN), blow up the FP8 dequant, and overflow. - With zero-initialised padding: FP8 acts → 0, float SF → 0 (dequant = 0), - UE8M0 SF → byte 0 (≈2^-127, dequant ≈ 0). Dot product on padding rows - becomes 0, harmless. + Padding rows are zero-initialized: on SM100 the psum_layout dispatch + computes every row in the per-expert aligned range (no per-row skip mask, + only cumulative offsets), so any garbage in the padding feeds straight into + the GEMM. With zeros: FP8 acts → 0, float SF → 0, UE8M0 SF → byte 0 — dot + product on padding rows is 0, harmless. """ padded = torch.zeros(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) padded[sorted_to_padded] = x @@ -342,10 +253,103 @@ def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_pad def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: - """Remove padding rows from the TMA-aligned contiguous layout.""" return x_padded[sorted_to_padded] +# ── Routing helpers (sort → matmul → restore) ───────────────────────────────── + + +def _dispatch_routed_input( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + num_experts: int, + use_psum_layout: bool, +) -> tuple: + """Sort tokens by expert id and build the M-grouped layout. + + EP sentinels (`top_k_index == num_experts`) are kept unclamped here so the + sort pushes them to the tail and `_build_deepgemm_contiguous_layout` routes + them past valid expert blocks. + + Returns `(sorted_hidden, sorted_weights, expert_ids_g, perm, + sorted_to_padded, grouped_layout, total_padded_rows)`. + """ + num_top_k = top_k_index.size(-1) + expert_ids_g, perm = torch.sort(top_k_index.reshape(-1)) + sorted_hidden = hidden_states[perm // num_top_k] + sorted_weights = top_k_weights.reshape(-1)[perm] + sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( + expert_ids_g, num_experts, _DEEPGEMM_M_ALIGNMENT, use_psum_layout + ) + return sorted_hidden, sorted_weights, expert_ids_g, perm, sorted_to_padded, grouped_layout, total_padded_rows + + +def _combine_routed_output( + out_padded: torch.Tensor, + sorted_weights: torch.Tensor, + expert_ids_g: torch.Tensor, + perm: torch.Tensor, + sorted_to_padded: torch.Tensor, + num_experts: int, + num_tokens: int, + num_top_k: int, + hidden_dim: int, + out_dtype: torch.dtype, +) -> torch.Tensor: + """Unpad → weighted multiply → mask sentinels → restore order → top-k reduce.""" + out = _unpad_from_deepgemm_contiguous_layout(out_padded, sorted_to_padded) + weighted = out * sorted_weights.to(out.dtype).unsqueeze(-1) + # Sentinel rows past the valid expert blocks may carry NaN from allocator + # reuse (`0 * NaN = NaN`); zero them so the top-k reduction stays finite. + weighted.masked_fill_((expert_ids_g >= num_experts).unsqueeze(-1), 0.0) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=out.device) + # Deterministic reshape+sum (index_add_ with duplicates is non-deterministic on CUDA). + return weighted[inv_perm].view(num_tokens, num_top_k, hidden_dim).sum(dim=1).to(out_dtype) + + +# ── Public dispatches ────────────────────────────────────────────────────────── + + +def deepgemm_fp8_fp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale_inv: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, + activation_scale: torch.Tensor | None = None, +) -> torch.Tensor: + """End-to-end DeepGEMM linear: per-token activation quant + FP8/FP4 matmul. + + Static (per-tensor) activation quantization is rejected — DeepGEMM needs + per-row SFs. Callers should route static activations through the Triton fallback. + """ + if activation_scale is not None: + raise NotImplementedError("Static activation quantization is not supported on the DeepGEMM path.") + + deepgemm = _load_deepgemm_kernel() + cast_kwargs = _select_fp8_cast_kwargs(weight, weight_scale_inv, block_size=(128, 128), device=input.device) + + input_2d = input.view(-1, input.shape[-1]) + qinput_2d, scale_2d = deepgemm.per_token_cast_to_fp8(input_2d, **cast_kwargs) + output = torch.empty(qinput_2d.shape[0], weight.shape[0], device=input.device, dtype=output_dtype) + + # Pass `(1, 1, gran_k)` for int-SF paths so the kernel uses the right K granularity + # (the default `(1, 1, 128)` mismatches FP4's gran_k=32). Float-SF leaves it None. + sf_recipe = (1, 1, cast_kwargs["gran_k"]) if cast_kwargs.get("use_packed_ue8m0") else None + deepgemm.fp8_fp4_matmul( + (qinput_2d, _coerce_sf_for_kernel(scale_2d)), + (weight, _coerce_sf_for_kernel(weight_scale_inv)), + output, + recipe=sf_recipe, + ) + output = output.view(input.shape[:-1] + (weight.shape[0],)) + if bias is not None: + output.add_(bias) + return output + + def deepgemm_bf16_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, @@ -355,86 +359,53 @@ def deepgemm_bf16_experts_forward( if hidden_states.dtype != torch.bfloat16: raise ValueError(f"DeepGEMM experts path requires bfloat16 hidden states, got {hidden_states.dtype}") - # Non-transposed HF experts have weight layout (E, N, K) -> NT kernel. - # Transposed HF experts have weight layout (E, K, N) -> NN kernel. deepgemm = _load_deepgemm_kernel() + # Non-transposed weights (E, N, K) → NT kernel; transposed (E, K, N) → NN kernel. grouped_bf16_matmul = deepgemm.grouped_bf16_matmul_nn if self.is_transposed else deepgemm.grouped_bf16_matmul_nt device = hidden_states.device num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) - - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - sample_weights = top_k_weights.reshape(-1) # (S,) - expert_ids = top_k_index.reshape(-1) # (S,) - - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, - # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the - # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. - # Sentinel rows are zeroed post-weighted-mul (see below), since the kernel leaves them uninitialized. - expert_ids_g, perm = torch.sort(expert_ids) - selected_hidden_states_g = hidden_states[perm // num_top_k] - sample_weights_g = sample_weights[perm] + num_tokens, hidden_dim = hidden_states.size(0), hidden_states.size(-1) use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout + sorted_hidden, sorted_weights, expert_ids_g, perm, sorted_to_padded, grouped_layout, total_padded_rows = ( + _dispatch_routed_input(hidden_states, top_k_index, top_k_weights, self.num_experts, use_psum_layout) ) if self.has_bias: - # Clamp now that the layout has been built — needed for the per-row bias gather below to stay - # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. + # Clamp now that the layout is built; bias added to sentinel rows lands in skipped positions. expert_ids_g.clamp_(0, self.num_experts - 1) - # --- Up projection per expert (DeepGEMM grouped contiguous, bf16) --- + # Up projection. w_up = self.gate_up_proj if self.has_gate else self.up_proj - # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] - act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) + act = _pad_for_deepgemm(sorted_hidden, sorted_to_padded, total_padded_rows) proj_out = torch.empty(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) grouped_bf16_matmul(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) - - # The kernel has no bias input -> add per-expert bias in-place on the unpadded slice; - # padding rows get discarded at unpad time. if self.has_bias: up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias proj_out.index_add_(0, sorted_to_padded, up_bias[expert_ids_g]) - # Apply gating or activation - if self.has_gate: - proj_out = self._apply_gate(proj_out) - else: - proj_out = self.act_fn(proj_out) + proj_out = self._apply_gate(proj_out) if self.has_gate else self.act_fn(proj_out) - # --- Down projection per expert (DeepGEMM grouped contiguous, bf16) --- + # Down projection. out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) grouped_bf16_matmul(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) - if self.has_bias: out.index_add_(0, sorted_to_padded, self.down_proj_bias[expert_ids_g]) - # Remove padding rows - out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) - - # Apply routing weights - weighted_out = out * sample_weights_g.to(out.dtype).unsqueeze(-1) # (S, hidden_dim) - - # EP sentinel handling: `out` rows past the valid expert blocks are left uninitialized by the kernel, - # so `out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here - # so the downstream reduction stays finite even when the routing weight was already zero. - weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) - - # Restore original order - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - weighted_out = weighted_out[inv_perm] - - # Accumulate results using deterministic reshape+sum instead of index_add_ - # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) - final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - - return final_hidden_states.to(hidden_states.dtype) + return _combine_routed_output( + out, + sorted_weights, + expert_ids_g, + perm, + sorted_to_padded, + self.num_experts, + num_tokens, + num_top_k, + hidden_dim, + hidden_states.dtype, + ) def deepgemm_fp8_fp4_experts_forward( @@ -445,78 +416,26 @@ def deepgemm_fp8_fp4_experts_forward( ) -> torch.Tensor: if self.activation_scheme == "static": raise NotImplementedError( - "DeepGEMM experts dispatch does not support activation_scheme='static'. " - "Use the default eager dispatch or switch to activation_scheme='dynamic'." + "DeepGEMM experts dispatch does not support activation_scheme='static'. Use 'dynamic'." ) deepgemm = _load_deepgemm_kernel() - device = hidden_states.device num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) - - # FP4 weights are int8-packed (2 e2m1 values per byte; `kPackedFP4 == torch::kInt8` in DeepGEMM). - # `m_grouped_fp8_fp4_gemm_nt_contiguous` accepts both FP8 and FP4 weight dtypes. Activation cast - # tracks (weight dtype, weight SF dtype), mirroring `deepgemm_fp8_fp4_linear`: - # - FP4 weights: gran_k=32 packed-UE8M0 SF (SM100+). - # - FP8 weights + UE8M0 SFs: gran_k=128 packed-UE8M0 SF (skips the kernel-side float→int - # transform on SM100). - # - FP8 weights + float SFs: gran_k=128 float SF (Hopper or Blackwell). + num_tokens, hidden_dim = hidden_states.size(0), hidden_states.size(-1) + w_up = self.gate_up_proj if self.has_gate else self.up_proj ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv - is_fp4_weights = w_up.dtype == torch.int8 - - if is_fp4_weights: - if torch.cuda.get_device_capability(device)[0] < 10: - raise RuntimeError( - "FP4 expert weights (int8-packed e2m1) require SM100+ (Blackwell); use FP8 weights on Hopper." - ) - cast_kwargs = {"use_ue8m0": True, "gran_k": 32, "use_packed_ue8m0": True} - else: - # FP8 weights: DeepGEMM supports two SF granularities for the N axis - # of B (block-128 or per-row), and only gran_k=128 for the K axis. - # The block_size attribute is informational; the kernel infers the - # actual recipe from the SF dtype + shape (`get_default_recipe`). - if self.block_size is None: - raise ValueError( - "DeepGEMM requires block-wise quantized FP8 weights, but the experts have " - "no `block_size` set (per-tensor quantization is not supported)." - ) - if self.block_size not in ((128, 128), (1, 128)): - raise ValueError( - f"DeepGEMM requires `block_size` ∈ {{(128, 128), (1, 128)}} for FP8 weights, got {self.block_size}." - ) - if ws_up.dtype == torch.float8_e8m0fnu: - cast_kwargs = {"use_ue8m0": True, "gran_k": 128, "use_packed_ue8m0": True} - else: - cast_kwargs = {"use_ue8m0": False, "gran_k": 128} - - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - sample_weights = top_k_weights.reshape(-1) # (S,) - expert_ids = top_k_index.reshape(-1) # (S,) - - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, - # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the - # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. - # Sentinel rows are zeroed post-weighted-mul (see below), since the kernel leaves them uninitialized. - expert_ids_g, perm = torch.sort(expert_ids) - selected_hidden_states_g = hidden_states[perm // num_top_k] - sample_weights_g = sample_weights[perm] + cast_kwargs = _select_fp8_cast_kwargs(w_up, ws_up, getattr(self, "block_size", None), device) use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout + sorted_hidden, sorted_weights, expert_ids_g, perm, sorted_to_padded, grouped_layout, total_padded_rows = ( + _dispatch_routed_input(hidden_states, top_k_index, top_k_weights, self.num_experts, use_psum_layout) ) - - # The kernel infers a default recipe from the SF dtype/shape on SM100 — - # `(1, 1, 128)` for any int SF, regardless of the SF's actual gran_k. For - # FP4 weights (gran_k=32) this picks the wrong shape contract, so pass - # the recipe explicitly. `(1, 1, gran_k)` matches `cast_kwargs["gran_k"]`. sf_recipe = (1, 1, cast_kwargs["gran_k"]) if cast_kwargs.get("use_packed_ue8m0") else None - # --- Up projection per expert (DeepGEMM grouped contiguous) --- - act_fp8, act_scales = deepgemm.per_token_cast_to_fp8(selected_hidden_states_g, **cast_kwargs) + # Up projection. + act_fp8, act_scales = deepgemm.per_token_cast_to_fp8(sorted_hidden, **cast_kwargs) act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) @@ -528,46 +447,32 @@ def deepgemm_fp8_fp4_experts_forward( recipe=sf_recipe, use_psum_layout=use_psum_layout, ) + proj_out = self._apply_gate(proj_out) if self.has_gate else self.act_fn(proj_out) - # Apply gating or activation - if self.has_gate: - proj_out = self._apply_gate(proj_out) - else: - proj_out = self.act_fn(proj_out) - - # --- Down projection per expert (DeepGEMM grouped contiguous) --- + # Down projection. proj_fp8, proj_scales = deepgemm.per_token_cast_to_fp8(proj_out, **cast_kwargs) - proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) deepgemm.grouped_fp8_fp4_matmul( (proj_fp8, _coerce_sf_for_kernel(proj_scales)), (self.down_proj, _coerce_sf_for_kernel(self.down_proj_scale_inv)), - proj_out, + out, grouped_layout, recipe=sf_recipe, use_psum_layout=use_psum_layout, ) - # Remove padding rows - proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) - - # Apply routing weights - weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) - - # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, - # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here - # so the downstream reduction stays finite even when the routing weight was already zero. - weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) - - # Restore original order - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - weighted_out = weighted_out[inv_perm] - - # Accumulate results using deterministic reshape+sum instead of index_add_ - # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) - final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - - return final_hidden_states.to(hidden_states.dtype) + return _combine_routed_output( + out, + sorted_weights, + expert_ids_g, + perm, + sorted_to_padded, + self.num_experts, + num_tokens, + num_top_k, + hidden_dim, + hidden_states.dtype, + ) def deepgemm_fp8_fp4_megamoe_experts_forward( @@ -577,82 +482,50 @@ def deepgemm_fp8_fp4_megamoe_experts_forward( top_k_weights: torch.Tensor, process_group: torch.distributed.ProcessGroup | None = None, ) -> torch.Tensor: - """FP8 acts × FP4 weights Mega MoE forward via DeepGEMM. - - Fuses EP dispatch + L1 (FP8×FP4) + SwiGLU + L2 (FP8×FP4) + EP combine into a single - kernel, overlapping NVLink communication with tensor-core compute. The kernel handles - the full `(num_tokens, hidden) -> (num_tokens, hidden)` MoE forward including the - weighted top-k reduction; the caller must NOT all-reduce the output (the EP combine is - already inside the kernel). - - `process_group` (the EP group) is passed in by `MoeTensorParalellExperts._prepare_input_fn` - when the module is wrapped for TP — it is required for the symmetric-buffer rendezvous on - first forward. - - Caller-managed attributes on `self` (this dispatch does no quantization or weight - transformation — assume they are pre-set on the module): - - `gate_up_proj`: int8-packed FP4 L1 weight, - shape `(num_experts_per_rank, intermediate_hidden * 2, hidden // 2)`, - interleaved gate/up via `transform_weights_for_mega_moe`. - - `gate_up_proj_scale_inv`: int-packed UE8M0 SF for L1, UTCCP-transposed via - `transform_weights_for_mega_moe`. - - `down_proj`, `down_proj_scale_inv`: same conventions for L2. - - The `SymmBuffer` is lazily allocated on first call (and re-allocated if a later call - has more tokens than the cached buffer). The SwiGLU clamp is read from - `self.config.swiglu_limit` if present, otherwise the kernel runs unclamped. - - Args: - hidden_states: bf16 `(num_tokens, hidden)`. - top_k_index: int `(num_tokens, num_topk)` of GLOBAL expert ids; -1 marks skipped - slots (the kernel ignores them). Note: this differs from the `RouterParallel` - output used by the other dispatches, which remaps indices to local + sentinel. - top_k_weights: float `(num_tokens, num_topk)` routing weights. - - Returns: - `(num_tokens, hidden)` in `hidden_states.dtype` (already weighted-summed across - topk and reduced across EP ranks). + """FP8 acts × FP4 weights Mega MoE forward (SM100+). + + Fuses EP dispatch + L1 + SwiGLU + L2 + EP combine into one kernel, + overlapping NVLink with tensor-core compute. The kernel handles the full + `(num_tokens, hidden) → (num_tokens, hidden)` MoE forward including the + weighted top-k reduction; the caller must NOT all-reduce the output. + + `process_group` is supplied automatically by `MoeTensorParalellExperts._prepare_input_fn` + when the module is wrapped for TP — it's required for the symm-buffer rendezvous + on first forward. `top_k_index` is GLOBAL expert ids (`-1` marks skipped slots). + + Caller-managed `self` attributes: + - `gate_up_proj`, `gate_up_proj_scale_inv`: L1 weight + UE8M0 SF, both + transformed via `transform_weights_for_mega_moe(is_l1=True)`. + - `down_proj`, `down_proj_scale_inv`: same for L2 (`is_l1=False`). + - `config.swiglu_limit` (optional): SwiGLU clamp; absent → unclamped. """ - # Mega MoE is Blackwell-only — the impl is `sm100_fp8_fp4_mega_moe.cuh` and there is - # no SM90 path. Use the regular "deepgemm" dispatch on Hopper. if torch.cuda.get_device_capability(hidden_states.device)[0] < 10: - raise RuntimeError("DeepGEMM Mega MoE requires SM100+ (Blackwell). The 'deepgemm' dispatch supports SM90+.") + raise RuntimeError("DeepGEMM Mega MoE requires SM100+ (Blackwell). Use the 'deepgemm' dispatch on Hopper.") deepgemm = _load_deepgemm_kernel() - + num_tokens, hidden_dim = hidden_states.size(0), hidden_states.size(-1) num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) num_experts = self.gate_up_proj.size(0) intermediate_hidden = self.gate_up_proj.size(1) // 2 - activation_clamp = getattr(getattr(self, "config", None), "swiglu_limit", None) - # Lazily allocate the symmetric buffer on first call (re-allocate if the cached buffer is - # too small for this call). `process_group` is threaded in by `MoeTensorParalellExperts`. + # Lazily allocate the symmetric buffer (re-allocate if cached one is too small). if getattr(self, "symm_buffer", None) is None or self.symm_buffer.num_max_tokens_per_rank < num_tokens: if process_group is None: raise ValueError( - "DeepGEMM Mega MoE requires a `process_group` for the EP group; the TP wrapping " - "(`MoeTensorParalellExperts`) supplies it automatically. If you are calling this " - "dispatch directly, pass `process_group=...` explicitly." + "DeepGEMM Mega MoE requires a `process_group` for the EP group. The TP wrapping " + "(MoeTensorParalellExperts) supplies it automatically; pass it explicitly otherwise." ) - # `gate_up_proj.size(0)` is the per-rank expert count after `GroupedGemmParallel` - # sharding; the buffer needs the GLOBAL count (kernel asserts `num_experts % num_ranks - # == 0` and computes the per-rank slice itself). - num_experts_global = num_experts * process_group.size() + # `gate_up_proj.size(0)` is per-rank after sharding; the buffer needs the GLOBAL count. self.symm_buffer = deepgemm.get_symm_buffer_for_mega_moe( process_group, hidden=hidden_dim, num_topk=num_top_k, - num_experts=num_experts_global, + num_experts=num_experts * process_group.size(), num_max_tokens_per_rank=num_tokens, intermediate_hidden=intermediate_hidden, ) - # Quantize activations to FP8 with packed UE8M0 per-32 SF — the layout the kernel expects. x_fp8, x_sf = deepgemm.per_token_cast_to_fp8(hidden_states, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) - - # Stage inputs into the symmetric buffer; the kernel reads from there during dispatch. self.symm_buffer.x[:num_tokens].copy_(x_fp8) self.symm_buffer.x_sf[:num_tokens].copy_(x_sf) self.symm_buffer.topk_idx[:num_tokens].copy_(top_k_index) @@ -664,7 +537,6 @@ def deepgemm_fp8_fp4_megamoe_experts_forward( (self.gate_up_proj, self.gate_up_proj_scale_inv), (self.down_proj, self.down_proj_scale_inv), self.symm_buffer, - activation_clamp=activation_clamp, + activation_clamp=getattr(getattr(self, "config", None), "swiglu_limit", None), ) - return y.to(hidden_states.dtype) diff --git a/test_deepgemm.py b/test_deepgemm.py index f6f2f63d0aae..9ecca0ed78c6 100644 --- a/test_deepgemm.py +++ b/test_deepgemm.py @@ -28,6 +28,7 @@ from transformers.integrations.deepgemm import ( _load_deepgemm_kernel, + deepgemm_bf16_experts_forward, deepgemm_fp8_fp4_experts_forward, deepgemm_fp8_fp4_megamoe_experts_forward, ) @@ -109,6 +110,24 @@ def _alloc(e: int, n: int, k: int) -> tuple[torch.Tensor, torch.Tensor]: ) +def _make_bf16_experts(num_experts: int, hidden_size: int, intermediate_size: int, device: torch.device) -> SimpleNamespace: + """Synthetic BF16 experts (no quantization, no SF) — exercises the + `deepgemm_bf16_experts_forward` path that calls the bf16 grouped GEMM.""" + gate_up = torch.randn(num_experts, 2 * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device) * 0.1 + down = torch.randn(num_experts, hidden_size, intermediate_size, dtype=torch.bfloat16, device=device) * 0.1 + return SimpleNamespace( + num_experts=num_experts, + has_gate=True, + has_bias=False, + is_transposed=False, + config=SimpleNamespace(hidden_act="silu"), + gate_up_proj=gate_up, + down_proj=down, + _apply_gate=lambda x: F.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + act_fn=F.silu, + ) + + def _make_fp4_experts(num_experts: int, hidden_size: int, intermediate_size: int, device: torch.device) -> SimpleNamespace: """Synthetic FP4 experts (`int8`-packed e2m1, K dim halved; UE8M0 SF, gran_k=32).""" @@ -154,6 +173,19 @@ def _check_output(out: torch.Tensor, expected_shape: tuple[int, ...], label: str # ── Tests ──────────────────────────────────────────────────────────────────────── +def test_bf16(device: torch.device) -> None: + label = "BF16 experts (no quant)" + if torch.cuda.get_device_capability(device)[0] < 9: + print(f"[{label}] SKIP: needs SM90+ (Hopper)") + return + T, H, I, E, K = 256, 1024, 512, 16, 4 + experts = _make_bf16_experts(E, H, I, device) + hidden = torch.randn(T, H, dtype=torch.bfloat16, device=device) * 0.1 + idx, w = _random_routing(T, K, E, device) + out = deepgemm_bf16_experts_forward(experts, hidden, idx, w.to(torch.bfloat16)) + _check_output(out, (T, H), label) + + def test_dsv3_fp8(device: torch.device) -> None: label = "DSv3 (FP8 + float SF)" if torch.cuda.get_device_capability(device)[0] < 9: @@ -263,7 +295,7 @@ def main() -> None: # Single-GPU paths run on rank 0 only (ranks > 0 only participate in Mega MoE). failures: list[tuple[str, BaseException]] = [] if rank == 0: - for fn in (test_dsv3_fp8, test_dsv4_fp8, test_dsv4_fp4): + for fn in (test_bf16, test_dsv3_fp8, test_dsv4_fp8, test_dsv4_fp4): try: fn(device) except BaseException as exc: From d3dbd325e929f254378a72e32b320b6564d7f3b2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:20:17 +0200 Subject: [PATCH 64/87] fix --- src/transformers/integrations/deepgemm.py | 7 ++++--- test_deepgemm.py | 8 +++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 5919daba4cc5..09bb1c872d6b 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -494,9 +494,10 @@ def deepgemm_fp8_fp4_megamoe_experts_forward( on first forward. `top_k_index` is GLOBAL expert ids (`-1` marks skipped slots). Caller-managed `self` attributes: - - `gate_up_proj`, `gate_up_proj_scale_inv`: L1 weight + UE8M0 SF, both - transformed via `transform_weights_for_mega_moe(is_l1=True)`. - - `down_proj`, `down_proj_scale_inv`: same for L2 (`is_l1=False`). + - `gate_up_proj`, `gate_up_proj_scale_inv`: L1 weight + UE8M0 SF. + - `down_proj`, `down_proj_scale_inv`: L2 weight + UE8M0 SF. + Both pairs must be transformed together via + `transform_weights_for_mega_moe((gate_up, gate_up_sf), (down, down_sf))`. - `config.swiglu_limit` (optional): SwiGLU clamp; absent → unclamped. """ if torch.cuda.get_device_capability(hidden_states.device)[0] < 10: diff --git a/test_deepgemm.py b/test_deepgemm.py index 9ecca0ed78c6..4bea773d6d67 100644 --- a/test_deepgemm.py +++ b/test_deepgemm.py @@ -243,11 +243,9 @@ def test_megamoe(device: torch.device, world_size: int, rank: int) -> None: # Build raw FP4 experts on this rank's slice, then transform to the kernel's layout. raw = _make_fp4_experts(E_local, H, I, device) - gate_up_t, gate_up_sf_t = deepgemm.transform_weights_for_mega_moe( - raw.gate_up_proj, raw.gate_up_proj_scale_inv, is_l1=True - ) - down_t, down_sf_t = deepgemm.transform_weights_for_mega_moe( - raw.down_proj, raw.down_proj_scale_inv, is_l1=False + (gate_up_t, gate_up_sf_t), (down_t, down_sf_t) = deepgemm.transform_weights_for_mega_moe( + (raw.gate_up_proj, raw.gate_up_proj_scale_inv), + (raw.down_proj, raw.down_proj_scale_inv), ) experts = SimpleNamespace( From 9f168ff191575069404ed3eebf01720606d6339e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:22:29 +0200 Subject: [PATCH 65/87] fix --- test_deepgemm.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test_deepgemm.py b/test_deepgemm.py index 4bea773d6d67..fe3a14f85b1e 100644 --- a/test_deepgemm.py +++ b/test_deepgemm.py @@ -243,9 +243,16 @@ def test_megamoe(device: torch.device, world_size: int, rank: int) -> None: # Build raw FP4 experts on this rank's slice, then transform to the kernel's layout. raw = _make_fp4_experts(E_local, H, I, device) + # Mega MoE requires SFs already packed as int32 UE8M0 (it transposes them for UTCCP). + gate_up_sf_packed = deepgemm.transform_sf_into_required_layout( + raw.gate_up_proj_scale_inv, 2 * I, H, recipe=(1, 1, 32), num_groups=E_local + ) + down_sf_packed = deepgemm.transform_sf_into_required_layout( + raw.down_proj_scale_inv, H, I, recipe=(1, 1, 32), num_groups=E_local + ) (gate_up_t, gate_up_sf_t), (down_t, down_sf_t) = deepgemm.transform_weights_for_mega_moe( - (raw.gate_up_proj, raw.gate_up_proj_scale_inv), - (raw.down_proj, raw.down_proj_scale_inv), + (raw.gate_up_proj, gate_up_sf_packed), + (raw.down_proj, down_sf_packed), ) experts = SimpleNamespace( From 7274c22ecbd41807bb552f1a0953a8894ab382c8 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:23:33 +0200 Subject: [PATCH 66/87] fix --- src/transformers/integrations/deepgemm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 09bb1c872d6b..5fccc94b65d8 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -58,6 +58,7 @@ class DeepGEMM: grouped_bf16_matmul_nn: Callable per_token_cast_to_fp8: Callable get_mn_major_tma_aligned_packed_ue8m0_tensor: Callable + transform_sf_into_required_layout: Callable transform_weights_for_mega_moe: Callable get_symm_buffer_for_mega_moe: Callable fp8_fp4_mega_moe: Callable @@ -93,6 +94,7 @@ def _load_deepgemm_kernel() -> DeepGEMM: get_mn_major_tma_aligned_packed_ue8m0_tensor = getattr( kernel, "get_mn_major_tma_aligned_packed_ue8m0_tensor", None ) + transform_sf_into_required_layout = getattr(kernel, "transform_sf_into_required_layout", None) transform_weights_for_mega_moe = getattr(kernel, "transform_weights_for_mega_moe", None) get_symm_buffer_for_mega_moe = getattr(kernel, "get_symm_buffer_for_mega_moe", None) fp8_fp4_mega_moe = getattr(kernel, "fp8_fp4_mega_moe", None) @@ -106,6 +108,7 @@ def _load_deepgemm_kernel() -> DeepGEMM: ("m_grouped_bf16_gemm_nn_contiguous", grouped_bf16_matmul_nn), ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), ("get_mn_major_tma_aligned_packed_ue8m0_tensor", get_mn_major_tma_aligned_packed_ue8m0_tensor), + ("transform_sf_into_required_layout", transform_sf_into_required_layout), ("transform_weights_for_mega_moe", transform_weights_for_mega_moe), ("get_symm_buffer_for_mega_moe", get_symm_buffer_for_mega_moe), ("fp8_fp4_mega_moe", fp8_fp4_mega_moe), @@ -123,6 +126,7 @@ def _load_deepgemm_kernel() -> DeepGEMM: grouped_bf16_matmul_nn=grouped_bf16_matmul_nn, per_token_cast_to_fp8=per_token_cast_to_fp8, get_mn_major_tma_aligned_packed_ue8m0_tensor=get_mn_major_tma_aligned_packed_ue8m0_tensor, + transform_sf_into_required_layout=transform_sf_into_required_layout, transform_weights_for_mega_moe=transform_weights_for_mega_moe, get_symm_buffer_for_mega_moe=get_symm_buffer_for_mega_moe, fp8_fp4_mega_moe=fp8_fp4_mega_moe, From a800b8c35716b279bba8d0889487b9cc10ebfee8 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:24:35 +0200 Subject: [PATCH 67/87] fix --- test_deepgemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_deepgemm.py b/test_deepgemm.py index fe3a14f85b1e..e37949b1613d 100644 --- a/test_deepgemm.py +++ b/test_deepgemm.py @@ -245,10 +245,10 @@ def test_megamoe(device: torch.device, world_size: int, rank: int) -> None: raw = _make_fp4_experts(E_local, H, I, device) # Mega MoE requires SFs already packed as int32 UE8M0 (it transposes them for UTCCP). gate_up_sf_packed = deepgemm.transform_sf_into_required_layout( - raw.gate_up_proj_scale_inv, 2 * I, H, recipe=(1, 1, 32), num_groups=E_local + raw.gate_up_proj_scale_inv, 2 * I, H, recipe_ab=(1, 32), num_groups=E_local ) down_sf_packed = deepgemm.transform_sf_into_required_layout( - raw.down_proj_scale_inv, H, I, recipe=(1, 1, 32), num_groups=E_local + raw.down_proj_scale_inv, H, I, recipe_ab=(1, 32), num_groups=E_local ) (gate_up_t, gate_up_sf_t), (down_t, down_sf_t) = deepgemm.transform_weights_for_mega_moe( (raw.gate_up_proj, gate_up_sf_packed), From 804c98853df67b8fdb7e27f65542aa7f46d68978 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:32:32 +0200 Subject: [PATCH 68/87] fix --- test_deepgemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_deepgemm.py b/test_deepgemm.py index e37949b1613d..04d7255ba900 100644 --- a/test_deepgemm.py +++ b/test_deepgemm.py @@ -245,10 +245,10 @@ def test_megamoe(device: torch.device, world_size: int, rank: int) -> None: raw = _make_fp4_experts(E_local, H, I, device) # Mega MoE requires SFs already packed as int32 UE8M0 (it transposes them for UTCCP). gate_up_sf_packed = deepgemm.transform_sf_into_required_layout( - raw.gate_up_proj_scale_inv, 2 * I, H, recipe_ab=(1, 32), num_groups=E_local + raw.gate_up_proj_scale_inv, 2 * I, H, recipe=(1, 32), num_groups=E_local ) down_sf_packed = deepgemm.transform_sf_into_required_layout( - raw.down_proj_scale_inv, H, I, recipe_ab=(1, 32), num_groups=E_local + raw.down_proj_scale_inv, H, I, recipe=(1, 32), num_groups=E_local ) (gate_up_t, gate_up_sf_t), (down_t, down_sf_t) = deepgemm.transform_weights_for_mega_moe( (raw.gate_up_proj, gate_up_sf_packed), From 7f365628b5f45503dc4737963caa26eae9dbd054 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:33:36 +0200 Subject: [PATCH 69/87] fix --- test_deepgemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_deepgemm.py b/test_deepgemm.py index 04d7255ba900..0382c9f696e0 100644 --- a/test_deepgemm.py +++ b/test_deepgemm.py @@ -245,10 +245,10 @@ def test_megamoe(device: torch.device, world_size: int, rank: int) -> None: raw = _make_fp4_experts(E_local, H, I, device) # Mega MoE requires SFs already packed as int32 UE8M0 (it transposes them for UTCCP). gate_up_sf_packed = deepgemm.transform_sf_into_required_layout( - raw.gate_up_proj_scale_inv, 2 * I, H, recipe=(1, 32), num_groups=E_local + raw.gate_up_proj_scale_inv.float(), 2 * I, H, recipe=(1, 32), num_groups=E_local ) down_sf_packed = deepgemm.transform_sf_into_required_layout( - raw.down_proj_scale_inv, H, I, recipe=(1, 32), num_groups=E_local + raw.down_proj_scale_inv.float(), H, I, recipe=(1, 32), num_groups=E_local ) (gate_up_t, gate_up_sf_t), (down_t, down_sf_t) = deepgemm.transform_weights_for_mega_moe( (raw.gate_up_proj, gate_up_sf_packed), From 6df9f47ef76f3e2fea007a75809a6d033df2a60a Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:45:08 +0200 Subject: [PATCH 70/87] fix --- src/transformers/integrations/deepgemm.py | 68 +++++++++++++---------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 5fccc94b65d8..987abca7d8ee 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -270,32 +270,48 @@ def _dispatch_routed_input( num_experts: int, use_psum_layout: bool, ) -> tuple: - """Sort tokens by expert id and build the M-grouped layout. + """Sort tokens by expert id and build the M-grouped padded layout. - EP sentinels (`top_k_index == num_experts`) are kept unclamped here so the - sort pushes them to the tail and `_build_deepgemm_contiguous_layout` routes - them past valid expert blocks. - - Returns `(sorted_hidden, sorted_weights, expert_ids_g, perm, - sorted_to_padded, grouped_layout, total_padded_rows)`. + Returns `(sorted_hidden_states_g, sample_weights_g, expert_ids_g, + sentinel_mask, perm, sorted_to_padded, grouped_layout, + total_padded_rows)`. """ + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) num_top_k = top_k_index.size(-1) - expert_ids_g, perm = torch.sort(top_k_index.reshape(-1)) - sorted_hidden = hidden_states[perm // num_top_k] - sorted_weights = top_k_weights.reshape(-1)[perm] + expert_ids = top_k_index.reshape(-1) # (S,) + sample_weights = top_k_weights.reshape(-1) # (S,) + + # Sort by expert for grouped processing + expert_ids_g, perm = torch.sort(expert_ids) + sorted_hidden_states_g = hidden_states[perm // num_top_k] + sample_weights_g = sample_weights[perm] + + # Build the M-grouped padded layout (DeepGEMM contract: each expert's rows + # start on a `_DEEPGEMM_M_ALIGNMENT` boundary, sentinels routed past valid + # expert blocks). sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( expert_ids_g, num_experts, _DEEPGEMM_M_ALIGNMENT, use_psum_layout ) - return sorted_hidden, sorted_weights, expert_ids_g, perm, sorted_to_padded, grouped_layout, total_padded_rows + + # EP sentinel mask is captured before the in-place clamp; used by the post-mask in + # `_combine_routed_output` to zero sentinel rows before the per-token reduction. The clamp + # keeps any per-row gather (e.g. bias) in-bounds — bias added at sentinel positions falls + # in rows the kernel skips, so harmless. Safe to mutate now: the layout was built from the + # unclamped tensor and nothing downstream needs the sentinel info from `expert_ids_g` itself. + sentinel_mask = (expert_ids_g >= num_experts).unsqueeze(-1) + expert_ids_g.clamp_(max=num_experts - 1) + return ( + sorted_hidden_states_g, sample_weights_g, expert_ids_g, sentinel_mask, perm, + sorted_to_padded, grouped_layout, total_padded_rows, + ) def _combine_routed_output( out_padded: torch.Tensor, sorted_weights: torch.Tensor, - expert_ids_g: torch.Tensor, + sentinel_mask: torch.Tensor, perm: torch.Tensor, sorted_to_padded: torch.Tensor, - num_experts: int, num_tokens: int, num_top_k: int, hidden_dim: int, @@ -306,7 +322,7 @@ def _combine_routed_output( weighted = out * sorted_weights.to(out.dtype).unsqueeze(-1) # Sentinel rows past the valid expert blocks may carry NaN from allocator # reuse (`0 * NaN = NaN`); zero them so the top-k reduction stays finite. - weighted.masked_fill_((expert_ids_g >= num_experts).unsqueeze(-1), 0.0) + weighted.masked_fill_(sentinel_mask, 0.0) inv_perm = torch.empty_like(perm) inv_perm[perm] = torch.arange(perm.size(0), device=out.device) # Deterministic reshape+sum (index_add_ with duplicates is non-deterministic on CUDA). @@ -372,13 +388,10 @@ def deepgemm_bf16_experts_forward( num_tokens, hidden_dim = hidden_states.size(0), hidden_states.size(-1) use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - sorted_hidden, sorted_weights, expert_ids_g, perm, sorted_to_padded, grouped_layout, total_padded_rows = ( - _dispatch_routed_input(hidden_states, top_k_index, top_k_weights, self.num_experts, use_psum_layout) - ) - - if self.has_bias: - # Clamp now that the layout is built; bias added to sentinel rows lands in skipped positions. - expert_ids_g.clamp_(0, self.num_experts - 1) + ( + sorted_hidden, sorted_weights, expert_ids_g, sentinel_mask, perm, + sorted_to_padded, grouped_layout, total_padded_rows, + ) = _dispatch_routed_input(hidden_states, top_k_index, top_k_weights, self.num_experts, use_psum_layout) # Up projection. w_up = self.gate_up_proj if self.has_gate else self.up_proj @@ -401,10 +414,9 @@ def deepgemm_bf16_experts_forward( return _combine_routed_output( out, sorted_weights, - expert_ids_g, + sentinel_mask, perm, sorted_to_padded, - self.num_experts, num_tokens, num_top_k, hidden_dim, @@ -433,9 +445,10 @@ def deepgemm_fp8_fp4_experts_forward( cast_kwargs = _select_fp8_cast_kwargs(w_up, ws_up, getattr(self, "block_size", None), device) use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - sorted_hidden, sorted_weights, expert_ids_g, perm, sorted_to_padded, grouped_layout, total_padded_rows = ( - _dispatch_routed_input(hidden_states, top_k_index, top_k_weights, self.num_experts, use_psum_layout) - ) + ( + sorted_hidden, sorted_weights, _expert_ids_g, sentinel_mask, perm, + sorted_to_padded, grouped_layout, total_padded_rows, + ) = _dispatch_routed_input(hidden_states, top_k_index, top_k_weights, self.num_experts, use_psum_layout) sf_recipe = (1, 1, cast_kwargs["gran_k"]) if cast_kwargs.get("use_packed_ue8m0") else None # Up projection. @@ -468,10 +481,9 @@ def deepgemm_fp8_fp4_experts_forward( return _combine_routed_output( out, sorted_weights, - expert_ids_g, + sentinel_mask, perm, sorted_to_padded, - self.num_experts, num_tokens, num_top_k, hidden_dim, From c3432c94a430f1ff0af80fe4b90db8ca61eccb1c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:47:33 +0200 Subject: [PATCH 71/87] empty --- src/transformers/integrations/deepgemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 987abca7d8ee..cd2ea5033e1a 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -251,7 +251,7 @@ def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_pad the GEMM. With zeros: FP8 acts → 0, float SF → 0, UE8M0 SF → byte 0 — dot product on padding rows is 0, harmless. """ - padded = torch.zeros(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) padded[sorted_to_padded] = x return padded From 77f09fe5510bfbef9efb5a082642c9cf2d271be2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 09:47:40 +0200 Subject: [PATCH 72/87] simplify --- src/transformers/integrations/deepgemm.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index cd2ea5033e1a..78ef71be95d2 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -243,14 +243,7 @@ def _build_deepgemm_contiguous_layout( def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: - """Pad a sorted tensor into the TMA-aligned contiguous layout. - - Padding rows are zero-initialized: on SM100 the psum_layout dispatch - computes every row in the per-expert aligned range (no per-row skip mask, - only cumulative offsets), so any garbage in the padding feeds straight into - the GEMM. With zeros: FP8 acts → 0, float SF → 0, UE8M0 SF → byte 0 — dot - product on padding rows is 0, harmless. - """ + """Pad a sorted tensor into the TMA-aligned contiguous layout.""" padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) padded[sorted_to_padded] = x return padded From 7dfbeddeddb948496b97fc1a205956859d81b6ab Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 10:06:04 +0200 Subject: [PATCH 73/87] test deepseek --- test_deepseek.py | 124 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 test_deepseek.py diff --git a/test_deepseek.py b/test_deepseek.py new file mode 100644 index 000000000000..f6a4c33afc38 --- /dev/null +++ b/test_deepseek.py @@ -0,0 +1,124 @@ +"""End-to-end DeepGEMM EP test on real DeepSeek checkpoints. + +Drives both DeepGEMM dispatches end-to-end through `from_pretrained`: + + 1. `deepseek-ai/DeepSeek-V3.2` → `experts_implementation="deepgemm"` + (block-quantized FP8 weights, float SF, DSv3 recipe). + 2. `deepseek-ai/DeepSeek-V4-Flash` → `experts_implementation="deepgemm_megamoe"` + (per-row UE8M0 SF, FP4 weights, fused EP + L1 + SwiGLU + L2 path; SM100+ only). + +Run on B200 with the local HF cache and torchrun: + + HF_HOME=/raid/arthur \\ + CUDA_HOME=$HOME/cuda-12.9 \\ + torchrun --nproc_per_node=8 test_deepgemm_real.py + +Generates a short continuation per checkpoint and asserts the output is finite. +Expensive (loads hundreds of GB per checkpoint). +""" + +from __future__ import annotations + +import gc +import os +import sys + +import torch +import torch.distributed as dist + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.distributed import DistributedConfig + + +_CHECKPOINTS = [ + ("deepseek-ai/DeepSeek-V3.2", "deepgemm", 9), + ("deepseek-ai/DeepSeek-V4-Flash", "deepgemm_megamoe", 10), +] +_PROMPT = "DeepGEMM tests: list three properties of UE8M0 scale factors." + + +def _rank0_print(msg: str) -> None: + if int(os.environ.get("RANK", "0")) == 0: + print(msg, flush=True) + + +def _run_one(model_id: str, dispatch: str, rank: int) -> None: + _rank0_print(f"\n=== {model_id} (dispatch={dispatch}) ===") + + model = AutoModelForCausalLM.from_pretrained( + model_id, + tp_plan="auto", + distributed_config=DistributedConfig(enable_expert_parallel=True), + experts_implementation=dispatch, + torch_dtype=torch.bfloat16, + ) + model.eval() + tok = AutoTokenizer.from_pretrained(model_id) + inputs = tok(_PROMPT, return_tensors="pt").to(model.device) + + dist.barrier() + with torch.no_grad(): + out_ids = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tok.eos_token_id, + ) + + if rank == 0: + finite = torch.isfinite(out_ids.float()).all().item() + new_tokens = out_ids[0, inputs.input_ids.size(1):] + completion = tok.decode(new_tokens, skip_special_tokens=True) + print(f"[{model_id}] generated {new_tokens.numel()} tokens (finite={finite}):", flush=True) + print(f" prompt: {_PROMPT}", flush=True) + print(f" completion: {completion}", flush=True) + if not finite or new_tokens.numel() == 0: + raise RuntimeError(f"{model_id}: generation failed (finite={finite}, n={new_tokens.numel()})") + + dist.barrier() + del model, tok, inputs, out_ids + gc.collect() + torch.cuda.empty_cache() + + +def main() -> int: + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + if world_size < 2: + sys.exit("EP test needs >=2 ranks (run with `torchrun --nproc_per_node=N`).") + + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group("nccl", device_id=torch.device("cuda", local_rank)) + + cap_major = torch.cuda.get_device_capability()[0] + _rank0_print(f"device cap: SM{cap_major}0, world_size={world_size}") + + failures: list[tuple[str, BaseException]] = [] + for model_id, dispatch, min_sm in _CHECKPOINTS: + if cap_major < min_sm: + _rank0_print(f"[{model_id}] SKIP: needs SM{min_sm}0+, got SM{cap_major}0") + continue + try: + _run_one(model_id, dispatch, rank) + except BaseException as exc: + if rank == 0: + failures.append((model_id, exc)) + print(f"[{model_id}] FAIL — {type(exc).__name__}: {exc}", flush=True) + + dist.barrier() + dist.destroy_process_group() + + if rank == 0: + if failures: + print(f"\n=== {len(failures)} model(s) failed ===", flush=True) + for name, exc in failures: + print(f" - {name}: {type(exc).__name__}: {exc}", flush=True) + return 1 + print("\n=== all models passed ===", flush=True) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From 689cc29617d9e1eca56a4b753880caede4fac9f0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 10:08:28 +0200 Subject: [PATCH 74/87] dsv4 only --- test_deepseek.py | 60 ++++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/test_deepseek.py b/test_deepseek.py index f6a4c33afc38..ff3503b67803 100644 --- a/test_deepseek.py +++ b/test_deepseek.py @@ -1,20 +1,25 @@ -"""End-to-end DeepGEMM EP test on real DeepSeek checkpoints. +"""End-to-end DeepGEMM EP test on a real DeepSeek checkpoint. -Drives both DeepGEMM dispatches end-to-end through `from_pretrained`: +Drives both DeepGEMM dispatches against `deepseek-ai/DeepSeek-V4-Flash` +(per-row UE8M0 SF, FP4 weights): - 1. `deepseek-ai/DeepSeek-V3.2` → `experts_implementation="deepgemm"` - (block-quantized FP8 weights, float SF, DSv3 recipe). - 2. `deepseek-ai/DeepSeek-V4-Flash` → `experts_implementation="deepgemm_megamoe"` - (per-row UE8M0 SF, FP4 weights, fused EP + L1 + SwiGLU + L2 path; SM100+ only). + 1. `experts_implementation="deepgemm"` — M-grouped FP8/FP4 path. + 2. `experts_implementation="deepgemm_megamoe"` — fused EP + L1 + SwiGLU + L2 + (SM100+ only). -Run on B200 with the local HF cache and torchrun: +Run on B200 with the local HF cache and torchrun. `HF_HUB_OFFLINE=1` keeps +loading off the network so a cache populated by another user (read-only for us) +still works. - HF_HOME=/raid/arthur \\ + HF_HUB_OFFLINE=1 HF_HOME=/raid/arthur \\ CUDA_HOME=$HOME/cuda-12.9 \\ - torchrun --nproc_per_node=8 test_deepgemm_real.py + torchrun --nproc_per_node=8 test_deepseek.py -Generates a short continuation per checkpoint and asserts the output is finite. -Expensive (loads hundreds of GB per checkpoint). +DeepSeek-V3.2 is intentionally not included: this transformers checkout only +registers `deepseek_v3` / `deepseek_v4`, not `deepseek_v32`. Add it back here +once the architecture lands. + +Generates a short continuation per dispatch and asserts the output is finite. """ from __future__ import annotations @@ -30,9 +35,10 @@ from transformers.distributed import DistributedConfig -_CHECKPOINTS = [ - ("deepseek-ai/DeepSeek-V3.2", "deepgemm", 9), - ("deepseek-ai/DeepSeek-V4-Flash", "deepgemm_megamoe", 10), +_CHECKPOINT = "deepseek-ai/DeepSeek-V4-Flash" +_RUNS = [ + ("deepgemm", 9), # M-grouped FP8/FP4 path + ("deepgemm_megamoe", 10), # fused Mega MoE (SM100+) ] _PROMPT = "DeepGEMM tests: list three properties of UE8M0 scale factors." @@ -42,18 +48,18 @@ def _rank0_print(msg: str) -> None: print(msg, flush=True) -def _run_one(model_id: str, dispatch: str, rank: int) -> None: - _rank0_print(f"\n=== {model_id} (dispatch={dispatch}) ===") +def _run_one(dispatch: str, rank: int) -> None: + _rank0_print(f"\n=== {_CHECKPOINT} (dispatch={dispatch}) ===") model = AutoModelForCausalLM.from_pretrained( - model_id, + _CHECKPOINT, tp_plan="auto", distributed_config=DistributedConfig(enable_expert_parallel=True), experts_implementation=dispatch, torch_dtype=torch.bfloat16, ) model.eval() - tok = AutoTokenizer.from_pretrained(model_id) + tok = AutoTokenizer.from_pretrained(_CHECKPOINT) inputs = tok(_PROMPT, return_tensors="pt").to(model.device) dist.barrier() @@ -69,11 +75,11 @@ def _run_one(model_id: str, dispatch: str, rank: int) -> None: finite = torch.isfinite(out_ids.float()).all().item() new_tokens = out_ids[0, inputs.input_ids.size(1):] completion = tok.decode(new_tokens, skip_special_tokens=True) - print(f"[{model_id}] generated {new_tokens.numel()} tokens (finite={finite}):", flush=True) + print(f"[{dispatch}] generated {new_tokens.numel()} tokens (finite={finite}):", flush=True) print(f" prompt: {_PROMPT}", flush=True) print(f" completion: {completion}", flush=True) if not finite or new_tokens.numel() == 0: - raise RuntimeError(f"{model_id}: generation failed (finite={finite}, n={new_tokens.numel()})") + raise RuntimeError(f"{dispatch}: generation failed (finite={finite}, n={new_tokens.numel()})") dist.barrier() del model, tok, inputs, out_ids @@ -96,27 +102,27 @@ def main() -> int: _rank0_print(f"device cap: SM{cap_major}0, world_size={world_size}") failures: list[tuple[str, BaseException]] = [] - for model_id, dispatch, min_sm in _CHECKPOINTS: + for dispatch, min_sm in _RUNS: if cap_major < min_sm: - _rank0_print(f"[{model_id}] SKIP: needs SM{min_sm}0+, got SM{cap_major}0") + _rank0_print(f"[{dispatch}] SKIP: needs SM{min_sm}0+, got SM{cap_major}0") continue try: - _run_one(model_id, dispatch, rank) + _run_one(dispatch, rank) except BaseException as exc: if rank == 0: - failures.append((model_id, exc)) - print(f"[{model_id}] FAIL — {type(exc).__name__}: {exc}", flush=True) + failures.append((dispatch, exc)) + print(f"[{dispatch}] FAIL — {type(exc).__name__}: {exc}", flush=True) dist.barrier() dist.destroy_process_group() if rank == 0: if failures: - print(f"\n=== {len(failures)} model(s) failed ===", flush=True) + print(f"\n=== {len(failures)} dispatch(es) failed ===", flush=True) for name, exc in failures: print(f" - {name}: {type(exc).__name__}: {exc}", flush=True) return 1 - print("\n=== all models passed ===", flush=True) + print("\n=== all dispatches passed ===", flush=True) return 0 From 8de089ce1fb3633db3f32924073bf82b362cdb69 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 10:11:50 +0200 Subject: [PATCH 75/87] download dsv4 --- test_deepseek.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test_deepseek.py b/test_deepseek.py index ff3503b67803..c463f8292ede 100644 --- a/test_deepseek.py +++ b/test_deepseek.py @@ -7,11 +7,10 @@ 2. `experts_implementation="deepgemm_megamoe"` — fused EP + L1 + SwiGLU + L2 (SM100+ only). -Run on B200 with the local HF cache and torchrun. `HF_HUB_OFFLINE=1` keeps -loading off the network so a cache populated by another user (read-only for us) -still works. +Run on B200 with a writable HF cache on the raid mount and torchrun. First run +downloads the checkpoint (hundreds of GB). - HF_HUB_OFFLINE=1 HF_HOME=/raid/arthur \\ + HF_HOME=/raid/ilyas \\ CUDA_HOME=$HOME/cuda-12.9 \\ torchrun --nproc_per_node=8 test_deepseek.py @@ -56,7 +55,7 @@ def _run_one(dispatch: str, rank: int) -> None: tp_plan="auto", distributed_config=DistributedConfig(enable_expert_parallel=True), experts_implementation=dispatch, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, ) model.eval() tok = AutoTokenizer.from_pretrained(_CHECKPOINT) From cb4d6f9ec4f81f05990bd6fc61a4ae25ab297713 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 10:17:05 +0200 Subject: [PATCH 76/87] fix test --- test_deepseek.py | 130 ++++++++++++++++++++++++++++++----------------- 1 file changed, 83 insertions(+), 47 deletions(-) diff --git a/test_deepseek.py b/test_deepseek.py index c463f8292ede..1882d9b14175 100644 --- a/test_deepseek.py +++ b/test_deepseek.py @@ -1,11 +1,18 @@ """End-to-end DeepGEMM EP test on a real DeepSeek checkpoint. -Drives both DeepGEMM dispatches against `deepseek-ai/DeepSeek-V4-Flash` -(per-row UE8M0 SF, FP4 weights): - - 1. `experts_implementation="deepgemm"` — M-grouped FP8/FP4 path. - 2. `experts_implementation="deepgemm_megamoe"` — fused EP + L1 + SwiGLU + L2 - (SM100+ only). +Drives every relevant experts dispatch against `deepseek-ai/DeepSeek-V4-Flash` +using **two model loads** total — same weights are reused across dispatches +via `model.set_experts_implementation`: + + 1. Dequantized load (`FineGrainedFP8Config(dequantize=True)`, + `dtype=torch.bfloat16`) → bf16 weights. Cycles: + - `grouped_mm` (torch grouped GEMM) + - `sonicmoe` (sonicmoe kernel) + - `deepgemm` (`deepgemm_bf16_experts_forward`) + 2. Native quantized load (`dtype="auto"`) → FP8/FP4 weights kept on disk. + Cycles: + - `deepgemm` (`deepgemm_fp8_fp4_experts_forward`) + - `deepgemm_megamoe` (fused Mega MoE; skipped on < SM100) Run on B200 with a writable HF cache on the raid mount and torchrun. First run downloads the checkpoint (hundreds of GB). @@ -15,10 +22,7 @@ torchrun --nproc_per_node=8 test_deepseek.py DeepSeek-V3.2 is intentionally not included: this transformers checkout only -registers `deepseek_v3` / `deepseek_v4`, not `deepseek_v32`. Add it back here -once the architecture lands. - -Generates a short continuation per dispatch and asserts the output is finite. +registers `deepseek_v3` / `deepseek_v4`, not `deepseek_v32`. """ from __future__ import annotations @@ -30,37 +34,32 @@ import torch import torch.distributed as dist -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config from transformers.distributed import DistributedConfig _CHECKPOINT = "deepseek-ai/DeepSeek-V4-Flash" -_RUNS = [ - ("deepgemm", 9), # M-grouped FP8/FP4 path - ("deepgemm_megamoe", 10), # fused Mega MoE (SM100+) -] _PROMPT = "DeepGEMM tests: list three properties of UE8M0 scale factors." +# (label, dispatch, min_sm). All entries in a phase share one model load. +_DEQUANTIZED_DISPATCHES = [ + ("dequantized + grouped_mm", "grouped_mm", 9), + ("dequantized + sonicmoe", "sonicmoe", 9), + ("dequantized + deepgemm", "deepgemm", 9), +] +_QUANTIZED_DISPATCHES = [ + ("quantized + deepgemm", "deepgemm", 9), + ("quantized + deepgemm_megamoe", "deepgemm_megamoe", 10), +] + def _rank0_print(msg: str) -> None: if int(os.environ.get("RANK", "0")) == 0: print(msg, flush=True) -def _run_one(dispatch: str, rank: int) -> None: - _rank0_print(f"\n=== {_CHECKPOINT} (dispatch={dispatch}) ===") - - model = AutoModelForCausalLM.from_pretrained( - _CHECKPOINT, - tp_plan="auto", - distributed_config=DistributedConfig(enable_expert_parallel=True), - experts_implementation=dispatch, - dtype=torch.bfloat16, - ) - model.eval() - tok = AutoTokenizer.from_pretrained(_CHECKPOINT) +def _generate_and_check(model, tok, label: str, rank: int) -> None: inputs = tok(_PROMPT, return_tensors="pt").to(model.device) - dist.barrier() with torch.no_grad(): out_ids = model.generate( @@ -69,21 +68,50 @@ def _run_one(dispatch: str, rank: int) -> None: do_sample=False, pad_token_id=tok.eos_token_id, ) - if rank == 0: finite = torch.isfinite(out_ids.float()).all().item() new_tokens = out_ids[0, inputs.input_ids.size(1):] completion = tok.decode(new_tokens, skip_special_tokens=True) - print(f"[{dispatch}] generated {new_tokens.numel()} tokens (finite={finite}):", flush=True) + print(f"[{label}] generated {new_tokens.numel()} tokens (finite={finite}):", flush=True) print(f" prompt: {_PROMPT}", flush=True) print(f" completion: {completion}", flush=True) if not finite or new_tokens.numel() == 0: - raise RuntimeError(f"{dispatch}: generation failed (finite={finite}, n={new_tokens.numel()})") - + raise RuntimeError(f"{label}: generation failed (finite={finite}, n={new_tokens.numel()})") dist.barrier() - del model, tok, inputs, out_ids - gc.collect() - torch.cuda.empty_cache() + + +def _run_phase(load_kwargs: dict, dispatches, cap_major: int, rank: int, failures: list) -> None: + runnable = [(lab, d) for (lab, d, sm) in dispatches if cap_major >= sm] + skipped = [(lab, d, sm) for (lab, d, sm) in dispatches if cap_major < sm] + for lab, _, sm in skipped: + _rank0_print(f"[{lab}] SKIP: needs SM{sm}0+, got SM{cap_major}0") + if not runnable: + return + + _rank0_print(f"\n--- loading {_CHECKPOINT} (kwargs: {sorted(load_kwargs)}) ---") + model = AutoModelForCausalLM.from_pretrained( + _CHECKPOINT, + tp_plan="auto", + distributed_config=DistributedConfig(enable_expert_parallel=True), + **load_kwargs, + ) + model.eval() + tok = AutoTokenizer.from_pretrained(_CHECKPOINT) + + try: + for label, dispatch in runnable: + _rank0_print(f"\n=== {label} ===") + try: + model.set_experts_implementation(dispatch) + _generate_and_check(model, tok, label, rank) + except BaseException as exc: + if rank == 0: + failures.append((label, exc)) + print(f"[{label}] FAIL — {type(exc).__name__}: {exc}", flush=True) + finally: + del model, tok + gc.collect() + torch.cuda.empty_cache() def main() -> int: @@ -101,27 +129,35 @@ def main() -> int: _rank0_print(f"device cap: SM{cap_major}0, world_size={world_size}") failures: list[tuple[str, BaseException]] = [] - for dispatch, min_sm in _RUNS: - if cap_major < min_sm: - _rank0_print(f"[{dispatch}] SKIP: needs SM{min_sm}0+, got SM{cap_major}0") - continue - try: - _run_one(dispatch, rank) - except BaseException as exc: - if rank == 0: - failures.append((dispatch, exc)) - print(f"[{dispatch}] FAIL — {type(exc).__name__}: {exc}", flush=True) + + _run_phase( + load_kwargs={ + "quantization_config": FineGrainedFP8Config(dequantize=True), + "dtype": torch.bfloat16, + }, + dispatches=_DEQUANTIZED_DISPATCHES, + cap_major=cap_major, + rank=rank, + failures=failures, + ) + _run_phase( + load_kwargs={"dtype": "auto"}, + dispatches=_QUANTIZED_DISPATCHES, + cap_major=cap_major, + rank=rank, + failures=failures, + ) dist.barrier() dist.destroy_process_group() if rank == 0: if failures: - print(f"\n=== {len(failures)} dispatch(es) failed ===", flush=True) + print(f"\n=== {len(failures)} run(s) failed ===", flush=True) for name, exc in failures: print(f" - {name}: {type(exc).__name__}: {exc}", flush=True) return 1 - print("\n=== all dispatches passed ===", flush=True) + print("\n=== all runs passed ===", flush=True) return 0 From 8f29ed6ac9d481c664706d405f7c6bb9a87c107d Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 10:25:49 +0200 Subject: [PATCH 77/87] push --- .../quantizers/quantizer_finegrained_fp8.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index be10624d4842..7ce0aa6e3db7 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -186,21 +186,21 @@ def update_weight_conversions(self, weight_conversions): :meth:`get_weight_conversions` is still appended at the end as a fallback for plain ``nn.Linear`` weights with no model-specific converter. """ - if not (self.pre_quantized and self.quantization_config.dequantize): - return weight_conversions + self.get_weight_conversions() - from ..core_model_loading import WeightConverter, WeightRenaming from ..integrations.finegrained_fp8 import Fp8Dequantize # Some upstream FP8 checkpoints (e.g. DeepSeek-V4-Flash) ship per-block scales # under a ``.scale`` suffix instead of HF's canonical ``.weight_scale_inv``. - # Prepending the rename here (instead of in each model's conversion_mapping) - # keeps the model-side mapping clean — the rename only kicks in when FP8 dequant - # is actually active, so a non-FP8 save / load round-trip doesn't see a stray - # rule that ``test_reverse_loading_mapping`` can't match. + # Apply the rename whenever this quantizer is active — needed for both + # `dequantize=True` (where scales fold into the dequantized weight) and the + # native quantized path (where `weight_scale_inv` parameters live on the model). + # Confined to this quantizer, so non-FP8 loads never see the rule. scale_rename = WeightRenaming(source_patterns=r"^(.+)\.scale$", target_patterns=r"\1.weight_scale_inv") weight_conversions = [scale_rename] + list(weight_conversions) + if not (self.pre_quantized and self.quantization_config.dequantize): + return weight_conversions + self.get_weight_conversions() + updated: list = [] for conv in weight_conversions: # Only WeightConverter has ``.operations`` to extend with the dequant op; From 946e2007166a2cd9f6b1d9a8e9a68b06af80d545 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 10:30:39 +0200 Subject: [PATCH 78/87] test --- .../quantizers/quantizer_finegrained_fp8.py | 14 ++--- test_deepseek.py | 56 ++++++++++++------- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index 7ce0aa6e3db7..be10624d4842 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -186,21 +186,21 @@ def update_weight_conversions(self, weight_conversions): :meth:`get_weight_conversions` is still appended at the end as a fallback for plain ``nn.Linear`` weights with no model-specific converter. """ + if not (self.pre_quantized and self.quantization_config.dequantize): + return weight_conversions + self.get_weight_conversions() + from ..core_model_loading import WeightConverter, WeightRenaming from ..integrations.finegrained_fp8 import Fp8Dequantize # Some upstream FP8 checkpoints (e.g. DeepSeek-V4-Flash) ship per-block scales # under a ``.scale`` suffix instead of HF's canonical ``.weight_scale_inv``. - # Apply the rename whenever this quantizer is active — needed for both - # `dequantize=True` (where scales fold into the dequantized weight) and the - # native quantized path (where `weight_scale_inv` parameters live on the model). - # Confined to this quantizer, so non-FP8 loads never see the rule. + # Prepending the rename here (instead of in each model's conversion_mapping) + # keeps the model-side mapping clean — the rename only kicks in when FP8 dequant + # is actually active, so a non-FP8 save / load round-trip doesn't see a stray + # rule that ``test_reverse_loading_mapping`` can't match. scale_rename = WeightRenaming(source_patterns=r"^(.+)\.scale$", target_patterns=r"\1.weight_scale_inv") weight_conversions = [scale_rename] + list(weight_conversions) - if not (self.pre_quantized and self.quantization_config.dequantize): - return weight_conversions + self.get_weight_conversions() - updated: list = [] for conv in weight_conversions: # Only WeightConverter has ``.operations`` to extend with the dequant op; diff --git a/test_deepseek.py b/test_deepseek.py index 1882d9b14175..17d9fbce1264 100644 --- a/test_deepseek.py +++ b/test_deepseek.py @@ -80,23 +80,31 @@ def _generate_and_check(model, tok, label: str, rank: int) -> None: dist.barrier() -def _run_phase(load_kwargs: dict, dispatches, cap_major: int, rank: int, failures: list) -> None: +def _run_phase(load_kwargs: dict, dispatches, cap_major: int, rank: int, results: list) -> None: runnable = [(lab, d) for (lab, d, sm) in dispatches if cap_major >= sm] skipped = [(lab, d, sm) for (lab, d, sm) in dispatches if cap_major < sm] for lab, _, sm in skipped: _rank0_print(f"[{lab}] SKIP: needs SM{sm}0+, got SM{cap_major}0") + results.append((lab, "SKIP", f"needs SM{sm}0+")) if not runnable: return _rank0_print(f"\n--- loading {_CHECKPOINT} (kwargs: {sorted(load_kwargs)}) ---") - model = AutoModelForCausalLM.from_pretrained( - _CHECKPOINT, - tp_plan="auto", - distributed_config=DistributedConfig(enable_expert_parallel=True), - **load_kwargs, - ) - model.eval() - tok = AutoTokenizer.from_pretrained(_CHECKPOINT) + try: + model = AutoModelForCausalLM.from_pretrained( + _CHECKPOINT, + tp_plan="auto", + distributed_config=DistributedConfig(enable_expert_parallel=True), + **load_kwargs, + ) + model.eval() + tok = AutoTokenizer.from_pretrained(_CHECKPOINT) + except BaseException as exc: + if rank == 0: + print(f"[load] FAIL — {type(exc).__name__}: {exc}", flush=True) + for label, _ in runnable: + results.append((label, "FAIL", f"load: {type(exc).__name__}: {exc}")) + return try: for label, dispatch in runnable: @@ -104,10 +112,11 @@ def _run_phase(load_kwargs: dict, dispatches, cap_major: int, rank: int, failure try: model.set_experts_implementation(dispatch) _generate_and_check(model, tok, label, rank) + results.append((label, "PASS", "")) except BaseException as exc: if rank == 0: - failures.append((label, exc)) print(f"[{label}] FAIL — {type(exc).__name__}: {exc}", flush=True) + results.append((label, "FAIL", f"{type(exc).__name__}: {exc}")) finally: del model, tok gc.collect() @@ -128,7 +137,7 @@ def main() -> int: cap_major = torch.cuda.get_device_capability()[0] _rank0_print(f"device cap: SM{cap_major}0, world_size={world_size}") - failures: list[tuple[str, BaseException]] = [] + results: list[tuple[str, str, str]] = [] # (label, status, detail) _run_phase( load_kwargs={ @@ -138,26 +147,35 @@ def main() -> int: dispatches=_DEQUANTIZED_DISPATCHES, cap_major=cap_major, rank=rank, - failures=failures, + results=results, ) _run_phase( load_kwargs={"dtype": "auto"}, dispatches=_QUANTIZED_DISPATCHES, cap_major=cap_major, rank=rank, - failures=failures, + results=results, ) dist.barrier() dist.destroy_process_group() if rank == 0: - if failures: - print(f"\n=== {len(failures)} run(s) failed ===", flush=True) - for name, exc in failures: - print(f" - {name}: {type(exc).__name__}: {exc}", flush=True) - return 1 - print("\n=== all runs passed ===", flush=True) + passed = [r for r in results if r[1] == "PASS"] + failed = [r for r in results if r[1] == "FAIL"] + skipped = [r for r in results if r[1] == "SKIP"] + width = max((len(r[0]) for r in results), default=20) + print("\n=== summary ===", flush=True) + for label, status, detail in results: + line = f" {label.ljust(width)} {status}" + if detail: + line += f" ({detail})" + print(line, flush=True) + print( + f"\n totals: {len(passed)} passed, {len(failed)} failed, {len(skipped)} skipped", + flush=True, + ) + return 1 if failed else 0 return 0 From 6619647ddd1b8b87f8f45b33cffd81fee9d376c9 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 10:35:08 +0200 Subject: [PATCH 79/87] fix --- .../quantizers/quantizer_finegrained_fp8.py | 14 +++---- test_deepseek.py | 42 ++++++------------- 2 files changed, 19 insertions(+), 37 deletions(-) diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index be10624d4842..b480496abc88 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -186,21 +186,21 @@ def update_weight_conversions(self, weight_conversions): :meth:`get_weight_conversions` is still appended at the end as a fallback for plain ``nn.Linear`` weights with no model-specific converter. """ - if not (self.pre_quantized and self.quantization_config.dequantize): - return weight_conversions + self.get_weight_conversions() - from ..core_model_loading import WeightConverter, WeightRenaming from ..integrations.finegrained_fp8 import Fp8Dequantize # Some upstream FP8 checkpoints (e.g. DeepSeek-V4-Flash) ship per-block scales # under a ``.scale`` suffix instead of HF's canonical ``.weight_scale_inv``. - # Prepending the rename here (instead of in each model's conversion_mapping) - # keeps the model-side mapping clean — the rename only kicks in when FP8 dequant - # is actually active, so a non-FP8 save / load round-trip doesn't see a stray - # rule that ``test_reverse_loading_mapping`` can't match. + # Apply the rename whenever this quantizer is active — needed both when keeping + # the quantized layout (scale_inv parameters live on the model) and when + # dequantizing (scales fold into the dequantized weight). Confined to this + # quantizer, so non-FP8 loads never see the rule. scale_rename = WeightRenaming(source_patterns=r"^(.+)\.scale$", target_patterns=r"\1.weight_scale_inv") weight_conversions = [scale_rename] + list(weight_conversions) + if not (self.pre_quantized and self.quantization_config.dequantize): + return weight_conversions + self.get_weight_conversions() + updated: list = [] for conv in weight_conversions: # Only WeightConverter has ``.operations`` to extend with the dequant op; diff --git a/test_deepseek.py b/test_deepseek.py index 17d9fbce1264..b71c95a4bdb8 100644 --- a/test_deepseek.py +++ b/test_deepseek.py @@ -1,18 +1,15 @@ """End-to-end DeepGEMM EP test on a real DeepSeek checkpoint. -Drives every relevant experts dispatch against `deepseek-ai/DeepSeek-V4-Flash` -using **two model loads** total — same weights are reused across dispatches -via `model.set_experts_implementation`: - - 1. Dequantized load (`FineGrainedFP8Config(dequantize=True)`, - `dtype=torch.bfloat16`) → bf16 weights. Cycles: - - `grouped_mm` (torch grouped GEMM) - - `sonicmoe` (sonicmoe kernel) - - `deepgemm` (`deepgemm_bf16_experts_forward`) - 2. Native quantized load (`dtype="auto"`) → FP8/FP4 weights kept on disk. - Cycles: - - `deepgemm` (`deepgemm_fp8_fp4_experts_forward`) - - `deepgemm_megamoe` (fused Mega MoE; skipped on < SM100) +Drives the FP8/FP4 DeepGEMM dispatches against `deepseek-ai/DeepSeek-V4-Flash` +with one model load shared across dispatches via `model.set_experts_implementation`: + + - `deepgemm` → `deepgemm_fp8_fp4_experts_forward` + - `deepgemm_megamoe` → fused Mega MoE (skipped on < SM100) + +Dequantize-on-load is intentionally NOT exercised here: V4-Flash is ~671B +parameters; dequantizing to bf16 needs ~1.3 TB across the world, which doesn't +fit in 8× B200 (178 GB each). For the bf16 / dequantized expert dispatches +(`grouped_mm`, `sonicmoe`, `deepgemm` bf16) use a smaller MoE checkpoint. Run on B200 with a writable HF cache on the raid mount and torchrun. First run downloads the checkpoint (hundreds of GB). @@ -34,19 +31,14 @@ import torch import torch.distributed as dist -from transformers import AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config +from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.distributed import DistributedConfig _CHECKPOINT = "deepseek-ai/DeepSeek-V4-Flash" _PROMPT = "DeepGEMM tests: list three properties of UE8M0 scale factors." -# (label, dispatch, min_sm). All entries in a phase share one model load. -_DEQUANTIZED_DISPATCHES = [ - ("dequantized + grouped_mm", "grouped_mm", 9), - ("dequantized + sonicmoe", "sonicmoe", 9), - ("dequantized + deepgemm", "deepgemm", 9), -] +# (label, dispatch, min_sm). All entries share one model load. _QUANTIZED_DISPATCHES = [ ("quantized + deepgemm", "deepgemm", 9), ("quantized + deepgemm_megamoe", "deepgemm_megamoe", 10), @@ -139,16 +131,6 @@ def main() -> int: results: list[tuple[str, str, str]] = [] # (label, status, detail) - _run_phase( - load_kwargs={ - "quantization_config": FineGrainedFP8Config(dequantize=True), - "dtype": torch.bfloat16, - }, - dispatches=_DEQUANTIZED_DISPATCHES, - cap_major=cap_major, - rank=rank, - results=results, - ) _run_phase( load_kwargs={"dtype": "auto"}, dispatches=_QUANTIZED_DISPATCHES, From 913339ac82f758b900683010dc737a9069a7e2c0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 10:48:26 +0200 Subject: [PATCH 80/87] fix ep plan --- .../models/deepseek_v4/configuration_deepseek_v4.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py index 2cbc02c6d0f7..2a6d70d83b2c 100644 --- a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py @@ -124,7 +124,9 @@ class DeepseekV4Config(PreTrainedConfig): # no `base_model_tp_plan` for V4: we don't ship a pure-TP plan, only EP. "layers.*.mlp.gate": "ep_router", "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm", "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm", "layers.*.mlp.experts": "moe_tp_experts", } From 7cfc2b26279c822211488d9217d64243a4dba267 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 10:57:16 +0200 Subject: [PATCH 81/87] fix attempt --- .../quantizers/quantizer_finegrained_fp8.py | 97 +++++++++++-------- 1 file changed, 57 insertions(+), 40 deletions(-) diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index b480496abc88..dc32b17622a1 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -168,59 +168,76 @@ def get_weight_conversions(self): return [] def update_weight_conversions(self, weight_conversions): - """When loading with ``dequantize=True``, attach an :class:`Fp8Dequantize` op to - every existing :class:`WeightConverter` so that per-block scales are folded into - the weight *before* any later merge/concat ops collapse the per-expert structure. - - For each model-supplied converter that has a ``.weight`` source, we: - 1. anchor the existing weight patterns with ``$`` so they don't accidentally - also match the ``.weight_scale_inv`` keys (the regex is searched, so the - unanchored prefix would match both, sending scales to the wrong bucket); - 2. add anchored ``*.weight_scale_inv`` sources next to each weight pattern so - the loader collects scale tensors alongside the weight tensors into the - *same* converter bucket (both keys rewrite to the same target); - 3. prepend a fresh :class:`Fp8Dequantize` op so dequant runs first, before - any merge/concat collapses the per-expert structure. - - The generic ``weight$ + weight_scale_inv → weight`` converter from - :meth:`get_weight_conversions` is still appended at the end as a fallback for - plain ``nn.Linear`` weights with no model-specific converter. + """Adapt model-supplied :class:`WeightConverter` instances for FP8 loading. + + Two paths share the same underlying problem: the model's converters merge per-expert + ``.weight`` keys into batched ``gate_up_proj`` / ``down_proj`` tensors, but the + FP8 ``.weight_scale_inv`` siblings need the same merge to populate + ``gate_up_proj_scale_inv`` / ``down_proj_scale_inv`` correctly. The unanchored + ``.weight`` source patterns also accidentally match ``.weight_scale_inv`` keys + (regex-search), so we anchor them first and synthesize parallel scale converters. + + - ``dequantize=True``: scales fold into the weight via :class:`Fp8Dequantize` + *before* merge/concat — both `.weight` and `.weight_scale_inv` go to the same + target bucket. + - ``dequantize=False`` (native quantized): scales stay as separate parameters, + so we synthesize a parallel ``WeightConverter`` per model converter that + targets ``_scale_inv`` and reuses the same merge ops. + + We also prepend a ``.scale → .weight_scale_inv`` rename for upstream checkpoints + (e.g. DeepSeek-V4-Flash) that ship per-block scales under the ``.scale`` suffix. """ from ..core_model_loading import WeightConverter, WeightRenaming from ..integrations.finegrained_fp8 import Fp8Dequantize - # Some upstream FP8 checkpoints (e.g. DeepSeek-V4-Flash) ship per-block scales - # under a ``.scale`` suffix instead of HF's canonical ``.weight_scale_inv``. - # Apply the rename whenever this quantizer is active — needed both when keeping - # the quantized layout (scale_inv parameters live on the model) and when - # dequantizing (scales fold into the dequantized weight). Confined to this - # quantizer, so non-FP8 loads never see the rule. + # `.scale` → `.weight_scale_inv`. Confined to this quantizer, so non-FP8 loads never see it. scale_rename = WeightRenaming(source_patterns=r"^(.+)\.scale$", target_patterns=r"\1.weight_scale_inv") weight_conversions = [scale_rename] + list(weight_conversions) - - if not (self.pre_quantized and self.quantization_config.dequantize): - return weight_conversions + self.get_weight_conversions() + dequantize = self.pre_quantized and self.quantization_config.dequantize updated: list = [] for conv in weight_conversions: - # Only WeightConverter has ``.operations`` to extend with the dequant op; - # WeightRenaming (e.g. the ``scale_rename`` we prepended) just passes through. + # Only WeightConverter has `.operations` and `.source_patterns` to manipulate. if not isinstance(conv, WeightConverter): updated.append(conv) continue weight_sources = [p for p in conv.source_patterns if p.endswith(".weight")] - if weight_sources: - anchored_weight = [p + "$" for p in weight_sources] - scale_sources = [p[: -len(".weight")] + ".weight_scale_inv$" for p in weight_sources] - other = [p for p in conv.source_patterns if not p.endswith(".weight")] - new_sources = anchored_weight + scale_sources + other - new_ops = [Fp8Dequantize(self)] + list(conv.operations) - conv = WeightConverter( - source_patterns=new_sources, + if not weight_sources: + updated.append(conv) + continue + + anchored_weight = [p + "$" for p in weight_sources] + scale_sources = [p[: -len(".weight")] + ".weight_scale_inv$" for p in weight_sources] + other = [p for p in conv.source_patterns if not p.endswith(".weight")] + + if dequantize: + # Both .weight and .weight_scale_inv go to the same target; Fp8Dequantize + # folds scales into the weight before any merge/concat downstream. + updated.append(WeightConverter( + source_patterns=anchored_weight + scale_sources + other, target_patterns=conv._original_target_patterns, - operations=new_ops, - ) - updated.append(conv) - # Generic fallback for plain ``nn.Linear`` weights with no model-specific converter. + operations=[Fp8Dequantize(self)] + list(conv.operations), + )) + else: + # Native quantized: anchor the existing weight converter so .weight_scale_inv + # keys don't leak in, then synthesize a parallel converter routing the scale + # keys to `_scale_inv` with the same merge ops. + updated.append(WeightConverter( + source_patterns=anchored_weight + other, + target_patterns=conv._original_target_patterns, + operations=list(conv.operations), + )) + target = conv._original_target_patterns + if isinstance(target, str): + scale_target = target + "_scale_inv" + else: + scale_target = [t + "_scale_inv" for t in target] + updated.append(WeightConverter( + source_patterns=scale_sources, + target_patterns=scale_target, + operations=list(conv.operations), + )) + + # Generic fallback for plain `nn.Linear` weights with no model-specific converter. updated.extend(self.get_weight_conversions()) return updated From 4c12f6ea39c54bff758de2be5698b845a4ed2c0b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 11:00:20 +0200 Subject: [PATCH 82/87] debug --- .../quantizers/quantizer_finegrained_fp8.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index dc32b17622a1..eb81db03fd3c 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -194,6 +194,15 @@ def update_weight_conversions(self, weight_conversions): scale_rename = WeightRenaming(source_patterns=r"^(.+)\.scale$", target_patterns=r"\1.weight_scale_inv") weight_conversions = [scale_rename] + list(weight_conversions) dequantize = self.pre_quantized and self.quantization_config.dequantize + import os + if os.environ.get("RANK", "0") == "0": + print( + f"[fp8 quantizer] update_weight_conversions: dequantize={dequantize} " + f"input_converters={len(weight_conversions)} " + f"weight_converters_with_weight_src=" + f"{sum(1 for c in weight_conversions if hasattr(c, 'source_patterns') and any(str(p).endswith('.weight') for p in (c.source_patterns if isinstance(c.source_patterns, list) else [c.source_patterns])))}", + flush=True, + ) updated: list = [] for conv in weight_conversions: @@ -240,4 +249,13 @@ def update_weight_conversions(self, weight_conversions): # Generic fallback for plain `nn.Linear` weights with no model-specific converter. updated.extend(self.get_weight_conversions()) + if os.environ.get("RANK", "0") == "0": + scale_targets = [ + getattr(c, "_original_target_patterns", None) + for c in updated + if isinstance(c, WeightConverter) + and isinstance(getattr(c, "_original_target_patterns", None), str) + and "_scale_inv" in c._original_target_patterns + ] + print(f"[fp8 quantizer] output_converters={len(updated)} scale_inv_targets={scale_targets}", flush=True) return updated From 92e63378bbc80c09b0c72b824efb931888e28ce7 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 11:07:03 +0200 Subject: [PATCH 83/87] attempt --- .../quantizers/quantizer_finegrained_fp8.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index eb81db03fd3c..c571cfbb2f4a 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -250,12 +250,14 @@ def update_weight_conversions(self, weight_conversions): # Generic fallback for plain `nn.Linear` weights with no model-specific converter. updated.extend(self.get_weight_conversions()) if os.environ.get("RANK", "0") == "0": - scale_targets = [ - getattr(c, "_original_target_patterns", None) - for c in updated - if isinstance(c, WeightConverter) - and isinstance(getattr(c, "_original_target_patterns", None), str) - and "_scale_inv" in c._original_target_patterns - ] - print(f"[fp8 quantizer] output_converters={len(updated)} scale_inv_targets={scale_targets}", flush=True) + print(f"[fp8 quantizer] output_converters={len(updated)}", flush=True) + for i, c in enumerate(updated): + if isinstance(c, WeightConverter) and any( + "scale_inv" in str(p) or "scale_inv" in str(getattr(c, "_original_target_patterns", "") or "") + for p in (c.source_patterns if isinstance(c.source_patterns, list) else [c.source_patterns]) + ): + print( + f" [{i}] src={c.source_patterns} target={getattr(c, '_original_target_patterns', None)}", + flush=True, + ) return updated From a9997f5952dfdf240671a4ef15b73e9512cadf60 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 May 2026 11:08:52 +0200 Subject: [PATCH 84/87] debug --- src/transformers/integrations/deepgemm.py | 18 +++++++++++++++-- .../quantizers/quantizer_finegrained_fp8.py | 20 ------------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 78ef71be95d2..617b0c9975de 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -449,9 +449,23 @@ def deepgemm_fp8_fp4_experts_forward( act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) + a_sf = _coerce_sf_for_kernel(act_scales) + b_sf = _coerce_sf_for_kernel(ws_up) + import os as _os + if _os.environ.get("RANK", "0") == "0" and not getattr(self, "_dg_shape_logged", False): + print( + f"[deepgemm.fp8_fp4] recipe={sf_recipe} use_psum={use_psum_layout}\n" + f" a_fp8.shape={tuple(act_fp8.shape)} dtype={act_fp8.dtype}\n" + f" a_sf.shape={tuple(a_sf.shape)} stride={tuple(a_sf.stride())} dtype={a_sf.dtype}\n" + f" b.shape={tuple(w_up.shape)} dtype={w_up.dtype}\n" + f" b_sf.shape={tuple(b_sf.shape)} stride={tuple(b_sf.stride())} dtype={b_sf.dtype}\n" + f" raw ws_up.shape={tuple(ws_up.shape)} dtype={ws_up.dtype}", + flush=True, + ) + self._dg_shape_logged = True deepgemm.grouped_fp8_fp4_matmul( - (act_fp8, _coerce_sf_for_kernel(act_scales)), - (w_up, _coerce_sf_for_kernel(ws_up)), + (act_fp8, a_sf), + (w_up, b_sf), proj_out, grouped_layout, recipe=sf_recipe, diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index c571cfbb2f4a..dc32b17622a1 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -194,15 +194,6 @@ def update_weight_conversions(self, weight_conversions): scale_rename = WeightRenaming(source_patterns=r"^(.+)\.scale$", target_patterns=r"\1.weight_scale_inv") weight_conversions = [scale_rename] + list(weight_conversions) dequantize = self.pre_quantized and self.quantization_config.dequantize - import os - if os.environ.get("RANK", "0") == "0": - print( - f"[fp8 quantizer] update_weight_conversions: dequantize={dequantize} " - f"input_converters={len(weight_conversions)} " - f"weight_converters_with_weight_src=" - f"{sum(1 for c in weight_conversions if hasattr(c, 'source_patterns') and any(str(p).endswith('.weight') for p in (c.source_patterns if isinstance(c.source_patterns, list) else [c.source_patterns])))}", - flush=True, - ) updated: list = [] for conv in weight_conversions: @@ -249,15 +240,4 @@ def update_weight_conversions(self, weight_conversions): # Generic fallback for plain `nn.Linear` weights with no model-specific converter. updated.extend(self.get_weight_conversions()) - if os.environ.get("RANK", "0") == "0": - print(f"[fp8 quantizer] output_converters={len(updated)}", flush=True) - for i, c in enumerate(updated): - if isinstance(c, WeightConverter) and any( - "scale_inv" in str(p) or "scale_inv" in str(getattr(c, "_original_target_patterns", "") or "") - for p in (c.source_patterns if isinstance(c.source_patterns, list) else [c.source_patterns]) - ): - print( - f" [{i}] src={c.source_patterns} target={getattr(c, '_original_target_patterns', None)}", - flush=True, - ) return updated From 197a8f62cb4dd17666e7c06a8d09858d5c167a59 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Mon, 11 May 2026 02:40:36 -0700 Subject: [PATCH 85/87] fixes in modeling --- probe_deepgemm_sf.py | 107 ----------------- .../deepseek_v4/configuration_deepseek_v4.py | 30 +++-- .../deepseek_v4/modeling_deepseek_v4.py | 110 ++++++++++++++---- .../models/deepseek_v4/modular_deepseek_v4.py | 110 ++++++++++++++---- .../quantizers/quantizer_finegrained_fp8.py | 54 ++++++--- 5 files changed, 231 insertions(+), 180 deletions(-) delete mode 100644 probe_deepgemm_sf.py diff --git a/probe_deepgemm_sf.py b/probe_deepgemm_sf.py deleted file mode 100644 index 814445e06249..000000000000 --- a/probe_deepgemm_sf.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Print the actual SF shapes / strides / dtypes the DeepGEMM integration feeds -into `m_grouped_fp8_fp4_gemm_nt_contiguous`, for each test case. - -When the kernel rejects an SF with `check_sf_layout` assertions, you usually -can't tell from the message *which* SF (activation or weight) failed and what -its actual layout was. This wraps `_coerce_sf_for_kernel` to log every call, -then runs the smoke tests so you can see the exact tensor metadata that hit -the kernel boundary right before the assertion fired. - -Usage: - CUDA_HOME=$HOME/cuda-12.9 python probe_deepgemm_sf.py -""" - -from __future__ import annotations - -import sys - -import torch - -import test_deepgemm as t -from transformers.integrations import deepgemm as di - - -_real_coerce = di._coerce_sf_for_kernel -_call_idx = [0] - - -def _verbose_coerce(sf: torch.Tensor) -> torch.Tensor: - out = _real_coerce(sf) - _call_idx[0] += 1 - nonfinite_in = (~torch.isfinite(sf.float())).sum().item() if sf.is_floating_point() else 0 - nonfinite_out = (~torch.isfinite(out.float())).sum().item() if out.is_floating_point() else 0 - print( - f" [#{_call_idx[0]}] in: shape={tuple(sf.shape)} " - f"stride={tuple(sf.stride())} dtype={sf.dtype} " - f"min={sf.float().abs().min().item():.3e} max={sf.float().abs().max().item():.3e} " - f"nonfinite={nonfinite_in}" - ) - print( - f" out: shape={tuple(out.shape)} " - f"stride={tuple(out.stride())} dtype={out.dtype} " - f"nonfinite={nonfinite_out}" - ) - return out - - -di._coerce_sf_for_kernel = _verbose_coerce - - -# Wrap the matmul itself: print output stats after the call so we can see -# where NaN actually appears in the pipeline. -_real_matmul = None - - -def _verbose_matmul(*args, **kwargs): - global _real_matmul - out_tensor = args[2] # (a_pair, b_pair, d, ...) - label = f"matmul (d.shape={tuple(out_tensor.shape)})" - _real_matmul(*args, **kwargs) - nf = (~torch.isfinite(out_tensor)).sum().item() - print( - f" → {label}: nonfinite_count={nf} " - f"min_abs={out_tensor.abs().min().item():.3e} " - f"max_abs={out_tensor.abs().max().item():.3e}" - ) - - -def _patch_matmul(): - global _real_matmul - deepgemm = di._load_deepgemm_kernel() - _real_matmul = deepgemm.grouped_fp8_fp4_matmul - - # Replace the cached kernel's matmul with our wrapper. - object.__setattr__(deepgemm, "grouped_fp8_fp4_matmul", _verbose_matmul) - - -_patch_matmul() - - -def _run(name: str, fn) -> bool: - print(f"\n=== {name} ===") - _call_idx[0] = 0 - try: - fn(d) - print(f" → PASS") - return True - except BaseException as exc: - print(f" → FAIL: {type(exc).__name__}: {str(exc)[:300]}") - return False - - -if __name__ == "__main__": - if not torch.cuda.is_available(): - sys.exit("CUDA required.") - torch.cuda.set_device(0) - d = torch.device("cuda", 0) - print( - f"GPU: {torch.cuda.get_device_name(d)} " - f"SM{''.join(str(x) for x in torch.cuda.get_device_capability(d))}" - ) - - results = [ - _run("test_dsv3_fp8", t.test_dsv3_fp8), - _run("test_dsv4_fp8", t.test_dsv4_fp8), - _run("test_dsv4_fp4", t.test_dsv4_fp4), - ] - sys.exit(0 if all(results) else 1) diff --git a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py index 2a6d70d83b2c..69480a529bf8 100644 --- a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py @@ -112,22 +112,32 @@ class DeepseekV4Config(PreTrainedConfig): "norm": (["hidden_states"], ["hidden_states"]), } base_model_ep_plan = { - # EP-only by default, same shape as gpt-oss: route on the gate, run the - # routed experts as a grouped-GEMM kernel sharded along the expert axis, - # and wrap the experts module with `moe_tp_experts` so its output gets - # all-reduced across ranks. Attention stays replicated (V4 is shared-KV - # MQA + a CSA / HCA compressor branch — both broadcast a single KV head - # across all attention heads via `repeat_kv`, so colwise-sharding - # `q_b_proj` would leave KV replicated and `repeat_kv` would no longer - # match the rank-local query head count). The shared MLP also stays - # replicated — it's small and not worth TP-ing. There's deliberately - # no `base_model_tp_plan` for V4: we don't ship a pure-TP plan, only EP. + # V4 ships EP only (no `base_model_tp_plan` — the runtime picks one plan or + # the other, never both, and V4 is MoE so EP is the only sensible config). + # MoE parallelism: route on the gate, run the routed experts as a grouped-GEMM + # kernel sharded along the expert axis, and wrap the experts module with + # `moe_tp_experts` so its output gets all-reduced across ranks. Same shape as + # gpt-oss. Main attention stays replicated: V4 is shared-KV MQA + a CSA / HCA + # compressor branch — both broadcast a single KV head across all attention + # heads via `repeat_kv`, so colwise-sharding `q_b_proj` would leave KV + # replicated and `repeat_kv` would no longer match the rank-local query head + # count. The shared MLP also stays replicated — it's small and not worth + # sharding. The Lightning Indexer is the one carve-out: its keys are + # replicated (own compressor at index_head_dim fed by replicated + # hidden_states), so head-sharding is well-formed; `q_b_proj` and + # `weights_proj` go colwise, and `scores_sync` is a `nn.Identity` whose + # `"all_reduce"` output hook sums the per-rank partial `index_scores` across + # the mesh so every rank picks the same top-k. Mirrors the reference + # inference (`inference/model.py:393, 394, 422-423`). "layers.*.mlp.gate": "ep_router", "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm", "layers.*.mlp.experts.down_proj": "grouped_gemm", "layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm", "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.self_attn.compressor.indexer.q_b_proj": "colwise", + "layers.*.self_attn.compressor.indexer.weights_proj": "colwise", + "layers.*.self_attn.compressor.indexer.scores_sync": "all_reduce", } vocab_size: int = 129280 diff --git a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py index 746219c11138..955caada32a6 100644 --- a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py @@ -398,7 +398,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, None]: batch, _, _ = hidden_states.shape cache_layer: DeepseekV4HCACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -427,7 +427,9 @@ def forward( if cache_layer is not None: compressed = cache_layer.update_compressor_states("compressor", compressed) - return compressed.unsqueeze(1) + # `None` validity: HCA has no indexer / per-query gather — every query attends to + # every entry in the compressor section, so attention just right-pads its mask with 0s. + return compressed.unsqueeze(1), None class DeepseekV4Indexer(nn.Module): @@ -476,6 +478,10 @@ def __init__(self, config: DeepseekV4Config): self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.head_dim, bias=False) self.weights_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False) self.rotary_emb = DeepseekV4RotaryEmbedding(config) + # No-op marker. Under TP it picks up the `"all_reduce"` plan entry, which adds a + # forward output hook that sums the (partial, per-rank) `index_scores` across the + # TP mesh so every rank picks the same top-k. See `inference/model.py:422-423`. + self.scores_sync = nn.Identity() def forward( self, @@ -484,7 +490,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.LongTensor: + ) -> tuple[torch.LongTensor, torch.BoolTensor]: batch, seq_len, _ = hidden_states.shape cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -539,9 +545,24 @@ def forward( scores = torch.matmul(q.float(), compressed_kv.transpose(-1, -2).float().unsqueeze(1)) # [B, S, H, T] scores = F.relu(scores) * self.softmax_scale weights = self.weights_proj(hidden_states).float() * self.weights_scaling # [B, S, H] - index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] + index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] — partial under TP + index_scores = self.scores_sync(index_scores) + # Causal mask before topk (mirrors `inference/model.py:424-426`): query at absolute + # position `p` may only pick compressed entries `i` whose source-token window ends + # at or before `p` — i.e. ``i < (p + 1) // compress_rate``. Without this, prefill + # queries can score and pick entries that aggregate future source tokens. + if compressed_kv.shape[1] > 0: + cutoff = (position_ids + 1).div(self.compress_rate, rounding_mode="floor").unsqueeze(-1) # [B, S, 1] + i_idx = torch.arange(compressed_kv.shape[1], device=index_scores.device) + index_scores = index_scores.masked_fill(i_idx >= cutoff, float("-inf")) topk = min(self.index_topk, compressed_kv.shape[1]) - return index_scores.topk(topk, dim=-1).indices + result = index_scores.topk(topk, dim=-1) + # `valid[b, s, j] == False` iff the j-th pick for query `s` was a forced `-inf` + # placeholder (the query had fewer than `index_topk` causally-valid entries). The + # pick still indexes a real but non-causal entry, so attention has to skip it — + # mirrors `inference/model.py:428-430` (`topk_idxs == -1` slots). + valid = result.values > float("-inf") + return result.indices, valid class DeepseekV4CSACompressor(nn.Module): @@ -585,7 +606,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.BoolTensor]: batch, seq_len, _ = hidden_states.shape cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -641,11 +662,14 @@ def forward( compressed = cache_layer.update_compressor_states("compressor", compressed) compressed_kv = compressed.unsqueeze(1) - # Lightning Indexer: gather top-`index_topk` compressed entries per query. - topk = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) # [B, S, k] + # Lightning Indexer: gather top-`index_topk` compressed entries per query and pass + # the per-pick validity mask up so attention can ``-inf``-mask the placeholder picks + # (queries with fewer than `index_topk` causally-valid entries). + topk_indices, topk_valid = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) expanded = compressed_kv.unsqueeze(2).expand(-1, -1, seq_len, -1, -1) - idx = topk.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) - return torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + idx = topk_indices.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) + gathered = torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + return gathered, topk_valid.reshape(batch, -1) # valid: [B, S*k] over the flat-packed KV axis def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -765,17 +789,45 @@ def forward( if past_key_values is not None: # sliding where K==V kv = past_key_values.update(kv, kv, self.layer_idx)[0] + compressor_valid = None if self.compressor is not None: # Compressed KV (CSA or HCA) - compressed_kv = self.compressor(hidden_states, q_residual, position_ids, past_key_values, self.layer_idx) + compressed_kv, compressor_valid = self.compressor( + hidden_states, q_residual, position_ids, past_key_values, self.layer_idx + ) kv = torch.cat([kv, compressed_kv], dim=2) # The compressor path concatenates extra entries onto the KV axis after the # standard sliding-window cache update, so a tensor `attention_mask` (built - # for the pre-concat KV length) needs to be right-padded to cover them. + # for the pre-concat KV length) needs to be extended to cover them. # Flex-attention passes a `BlockMask` whose KV-length axis comes from its - # own `mask_mod`, not from a dense tensor — skip the pad in that case. + # own `mask_mod`, not from a dense tensor — skip the extend in that case. if isinstance(attention_mask, torch.Tensor) and kv.shape[2] > attention_mask.shape[-1]: - attention_mask = F.pad(attention_mask, (0, kv.shape[2] - attention_mask.shape[-1]), value=0.0) + n_compressor = kv.shape[2] - attention_mask.shape[-1] + S = q.shape[2] + if self.layer_type == "compressed_sparse_attention" and S > 1 and n_compressor == S * (n_compressor // S): + # CSA's compressor returns per-query top-`k` entries flat-packed as + # `[B, 1, S*k, D]`. The official inference's `sparse_attn` gathers per query + # so each query only attends to its own `k` slots; we mirror it with a + # block-diagonal mask over the compressor section. `compressor_valid` is + # False at slots the indexer flagged as `-inf` placeholders — strip those too. + # Decode (S=1) collapses to a no-op block-diagonal so we let the simple + # right-pad branch handle it. + k_per_query = n_compressor // S + q_idx = torch.arange(S, device=q.device).unsqueeze(1) # [S, 1] + kv_idx = torch.arange(n_compressor, device=q.device).div(k_per_query, rounding_mode="floor") # [S*k] + own_q = q_idx == kv_idx # [S, S*k] + if compressor_valid is not None: + allowed = own_q.unsqueeze(0) & compressor_valid.unsqueeze(1) # [B, S, S*k] + extra = torch.where(allowed, 0.0, float("-inf")).to(attention_mask.dtype).unsqueeze(1) + extra = extra.expand(attention_mask.shape[0], attention_mask.shape[1], -1, -1) + else: + extra = torch.where(own_q, 0.0, float("-inf")).to(attention_mask.dtype) + extra = extra.expand(*attention_mask.shape[:-2], S, n_compressor) + attention_mask = torch.cat([attention_mask, extra], dim=-1) + else: + # HCA (every query sees every entry) and CSA decode (single query's k slots + # span the whole compressor section): plain 0.0 right-pad is correct. + attention_mask = F.pad(attention_mask, (0, n_compressor), value=0.0) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward @@ -870,14 +922,18 @@ def forward(self, hidden_streams: torch.Tensor) -> tuple[torch.Tensor, torch.Ten pre_scale, post_scale, comb_scale = self.scale.unbind(0) hc = self.hc_mult pre = torch.sigmoid(mix[..., :hc] * pre_scale + self.base[:hc]) + self.hc_eps - post = torch.sigmoid(mix[..., hc : 2 * hc] * post_scale + self.base[hc : 2 * hc]) + self.hc_eps - comb = ( - torch.sigmoid( - mix[..., 2 * hc :].view(*mix.shape[:-1], hc, hc) * comb_scale + self.base[2 * hc :].view(hc, hc) - ) - + self.hc_eps - ) - for _ in range(self.hc_sinkhorn_iters): + # `post` is `2 * sigmoid` (range [0, 2], no eps) to match the reference kernel + # (`inference/kernel.py:394`). `pre` and `comb` keep the `sigmoid + eps` form + # they share with the kernel (lines 392, 408). + post = 2 * torch.sigmoid(mix[..., hc : 2 * hc] * post_scale + self.base[hc : 2 * hc]) + # Sinkhorn init mirrors `inference/kernel.py:401-413`: row-softmax + eps, + # then one column-normalisation, then `iters - 1` symmetric (row, col) rounds. + # Different positive starting matrix → different doubly-stochastic fixed point + # than a `sigmoid+eps` init, so this matters even though row/col count nets out. + comb_logits = mix[..., 2 * hc :].view(*mix.shape[:-1], hc, hc) * comb_scale + self.base[2 * hc :].view(hc, hc) + comb = torch.softmax(comb_logits, dim=-1) + self.hc_eps + comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) + for _ in range(self.hc_sinkhorn_iters - 1): comb = comb / (comb.sum(dim=-1, keepdim=True) + self.hc_eps) comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) # Collapse the `hc_mult` parallel streams down to a single sequence using @@ -1069,16 +1125,22 @@ def forward( # `post` / `comb` come out of the HC modules in fp32 (Sinkhorn projection runs # in float); the .to(dtype) puts everything back to the input dtype before mixing # so both sites stay consistent with `hidden_states`'s entry dtype. + # `comb` is consumed transposed: reference `inference/model.py:685` indexes it as + # `sum_j comb[j, k] * residual[j, d]` (sum over the FIRST hc axis), which is + # equivalent to `comb.T @ residual`. Sinkhorn produces a doubly-stochastic but + # non-symmetric matrix, so the direction matters. dtype = hidden_states.dtype post, comb, collapsed = self.attn_hc(hidden_states) attn_output, _ = self.self_attn(self.input_layernorm(collapsed), **kwargs) hidden_states = post.to(dtype).unsqueeze(-1) * attn_output.unsqueeze(-2) + torch.matmul( - comb.to(dtype), hidden_states + comb.to(dtype).transpose(-1, -2), hidden_states ) post, comb, collapsed = self.ffn_hc(hidden_states) mlp_output = self.mlp(self.post_attention_layernorm(collapsed), input_ids=input_ids) - return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul(comb.to(dtype), hidden_states) + return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul( + comb.to(dtype).transpose(-1, -2), hidden_states + ) @auto_docstring diff --git a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py index 759bfabf017b..06c91ed6cf22 100644 --- a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py @@ -334,7 +334,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, None]: batch, _, _ = hidden_states.shape cache_layer: DeepseekV4HCACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -363,7 +363,9 @@ def forward( if cache_layer is not None: compressed = cache_layer.update_compressor_states("compressor", compressed) - return compressed.unsqueeze(1) + # `None` validity: HCA has no indexer / per-query gather — every query attends to + # every entry in the compressor section, so attention just right-pads its mask with 0s. + return compressed.unsqueeze(1), None class DeepseekV4Indexer(nn.Module): @@ -412,6 +414,10 @@ def __init__(self, config: DeepseekV4Config): self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.head_dim, bias=False) self.weights_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False) self.rotary_emb = DeepseekV4RotaryEmbedding(config) + # No-op marker. Under TP it picks up the `"all_reduce"` plan entry, which adds a + # forward output hook that sums the (partial, per-rank) `index_scores` across the + # TP mesh so every rank picks the same top-k. See `inference/model.py:422-423`. + self.scores_sync = nn.Identity() def forward( self, @@ -420,7 +426,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.LongTensor: + ) -> tuple[torch.LongTensor, torch.BoolTensor]: batch, seq_len, _ = hidden_states.shape cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -475,9 +481,24 @@ def forward( scores = torch.matmul(q.float(), compressed_kv.transpose(-1, -2).float().unsqueeze(1)) # [B, S, H, T] scores = F.relu(scores) * self.softmax_scale weights = self.weights_proj(hidden_states).float() * self.weights_scaling # [B, S, H] - index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] + index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] — partial under TP + index_scores = self.scores_sync(index_scores) + # Causal mask before topk (mirrors `inference/model.py:424-426`): query at absolute + # position `p` may only pick compressed entries `i` whose source-token window ends + # at or before `p` — i.e. ``i < (p + 1) // compress_rate``. Without this, prefill + # queries can score and pick entries that aggregate future source tokens. + if compressed_kv.shape[1] > 0: + cutoff = (position_ids + 1).div(self.compress_rate, rounding_mode="floor").unsqueeze(-1) # [B, S, 1] + i_idx = torch.arange(compressed_kv.shape[1], device=index_scores.device) + index_scores = index_scores.masked_fill(i_idx >= cutoff, float("-inf")) topk = min(self.index_topk, compressed_kv.shape[1]) - return index_scores.topk(topk, dim=-1).indices + result = index_scores.topk(topk, dim=-1) + # `valid[b, s, j] == False` iff the j-th pick for query `s` was a forced `-inf` + # placeholder (the query had fewer than `index_topk` causally-valid entries). The + # pick still indexes a real but non-causal entry, so attention has to skip it — + # mirrors `inference/model.py:428-430` (`topk_idxs == -1` slots). + valid = result.values > float("-inf") + return result.indices, valid class DeepseekV4CSACompressor(nn.Module): @@ -521,7 +542,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.BoolTensor]: batch, seq_len, _ = hidden_states.shape cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -577,11 +598,14 @@ def forward( compressed = cache_layer.update_compressor_states("compressor", compressed) compressed_kv = compressed.unsqueeze(1) - # Lightning Indexer: gather top-`index_topk` compressed entries per query. - topk = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) # [B, S, k] + # Lightning Indexer: gather top-`index_topk` compressed entries per query and pass + # the per-pick validity mask up so attention can ``-inf``-mask the placeholder picks + # (queries with fewer than `index_topk` causally-valid entries). + topk_indices, topk_valid = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) expanded = compressed_kv.unsqueeze(2).expand(-1, -1, seq_len, -1, -1) - idx = topk.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) - return torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + idx = topk_indices.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) + gathered = torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + return gathered, topk_valid.reshape(batch, -1) # valid: [B, S*k] over the flat-packed KV axis COMPRESSOR_CLASSES = { @@ -658,17 +682,45 @@ def forward( if past_key_values is not None: # sliding where K==V kv = past_key_values.update(kv, kv, self.layer_idx)[0] + compressor_valid = None if self.compressor is not None: # Compressed KV (CSA or HCA) - compressed_kv = self.compressor(hidden_states, q_residual, position_ids, past_key_values, self.layer_idx) + compressed_kv, compressor_valid = self.compressor( + hidden_states, q_residual, position_ids, past_key_values, self.layer_idx + ) kv = torch.cat([kv, compressed_kv], dim=2) # The compressor path concatenates extra entries onto the KV axis after the # standard sliding-window cache update, so a tensor `attention_mask` (built - # for the pre-concat KV length) needs to be right-padded to cover them. + # for the pre-concat KV length) needs to be extended to cover them. # Flex-attention passes a `BlockMask` whose KV-length axis comes from its - # own `mask_mod`, not from a dense tensor — skip the pad in that case. + # own `mask_mod`, not from a dense tensor — skip the extend in that case. if isinstance(attention_mask, torch.Tensor) and kv.shape[2] > attention_mask.shape[-1]: - attention_mask = F.pad(attention_mask, (0, kv.shape[2] - attention_mask.shape[-1]), value=0.0) + n_compressor = kv.shape[2] - attention_mask.shape[-1] + S = q.shape[2] + if self.layer_type == "compressed_sparse_attention" and S > 1 and n_compressor == S * (n_compressor // S): + # CSA's compressor returns per-query top-`k` entries flat-packed as + # `[B, 1, S*k, D]`. The official inference's `sparse_attn` gathers per query + # so each query only attends to its own `k` slots; we mirror it with a + # block-diagonal mask over the compressor section. `compressor_valid` is + # False at slots the indexer flagged as `-inf` placeholders — strip those too. + # Decode (S=1) collapses to a no-op block-diagonal so we let the simple + # right-pad branch handle it. + k_per_query = n_compressor // S + q_idx = torch.arange(S, device=q.device).unsqueeze(1) # [S, 1] + kv_idx = torch.arange(n_compressor, device=q.device).div(k_per_query, rounding_mode="floor") # [S*k] + own_q = q_idx == kv_idx # [S, S*k] + if compressor_valid is not None: + allowed = own_q.unsqueeze(0) & compressor_valid.unsqueeze(1) # [B, S, S*k] + extra = torch.where(allowed, 0.0, float("-inf")).to(attention_mask.dtype).unsqueeze(1) + extra = extra.expand(attention_mask.shape[0], attention_mask.shape[1], -1, -1) + else: + extra = torch.where(own_q, 0.0, float("-inf")).to(attention_mask.dtype) + extra = extra.expand(*attention_mask.shape[:-2], S, n_compressor) + attention_mask = torch.cat([attention_mask, extra], dim=-1) + else: + # HCA (every query sees every entry) and CSA decode (single query's k slots + # span the whole compressor section): plain 0.0 right-pad is correct. + attention_mask = F.pad(attention_mask, (0, n_compressor), value=0.0) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward @@ -763,14 +815,18 @@ def forward(self, hidden_streams: torch.Tensor) -> tuple[torch.Tensor, torch.Ten pre_scale, post_scale, comb_scale = self.scale.unbind(0) hc = self.hc_mult pre = torch.sigmoid(mix[..., :hc] * pre_scale + self.base[:hc]) + self.hc_eps - post = torch.sigmoid(mix[..., hc : 2 * hc] * post_scale + self.base[hc : 2 * hc]) + self.hc_eps - comb = ( - torch.sigmoid( - mix[..., 2 * hc :].view(*mix.shape[:-1], hc, hc) * comb_scale + self.base[2 * hc :].view(hc, hc) - ) - + self.hc_eps - ) - for _ in range(self.hc_sinkhorn_iters): + # `post` is `2 * sigmoid` (range [0, 2], no eps) to match the reference kernel + # (`inference/kernel.py:394`). `pre` and `comb` keep the `sigmoid + eps` form + # they share with the kernel (lines 392, 408). + post = 2 * torch.sigmoid(mix[..., hc : 2 * hc] * post_scale + self.base[hc : 2 * hc]) + # Sinkhorn init mirrors `inference/kernel.py:401-413`: row-softmax + eps, + # then one column-normalisation, then `iters - 1` symmetric (row, col) rounds. + # Different positive starting matrix → different doubly-stochastic fixed point + # than a `sigmoid+eps` init, so this matters even though row/col count nets out. + comb_logits = mix[..., 2 * hc :].view(*mix.shape[:-1], hc, hc) * comb_scale + self.base[2 * hc :].view(hc, hc) + comb = torch.softmax(comb_logits, dim=-1) + self.hc_eps + comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) + for _ in range(self.hc_sinkhorn_iters - 1): comb = comb / (comb.sum(dim=-1, keepdim=True) + self.hc_eps) comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) # Collapse the `hc_mult` parallel streams down to a single sequence using @@ -936,16 +992,22 @@ def forward( # `post` / `comb` come out of the HC modules in fp32 (Sinkhorn projection runs # in float); the .to(dtype) puts everything back to the input dtype before mixing # so both sites stay consistent with `hidden_states`'s entry dtype. + # `comb` is consumed transposed: reference `inference/model.py:685` indexes it as + # `sum_j comb[j, k] * residual[j, d]` (sum over the FIRST hc axis), which is + # equivalent to `comb.T @ residual`. Sinkhorn produces a doubly-stochastic but + # non-symmetric matrix, so the direction matters. dtype = hidden_states.dtype post, comb, collapsed = self.attn_hc(hidden_states) attn_output, _ = self.self_attn(self.input_layernorm(collapsed), **kwargs) hidden_states = post.to(dtype).unsqueeze(-1) * attn_output.unsqueeze(-2) + torch.matmul( - comb.to(dtype), hidden_states + comb.to(dtype).transpose(-1, -2), hidden_states ) post, comb, collapsed = self.ffn_hc(hidden_states) mlp_output = self.mlp(self.post_attention_layernorm(collapsed), input_ids=input_ids) - return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul(comb.to(dtype), hidden_states) + return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul( + comb.to(dtype).transpose(-1, -2), hidden_states + ) class DeepseekV4PreTrainedModel(MixtralPreTrainedModel): diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index dc32b17622a1..10e5ca45cbf0 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -105,6 +105,24 @@ def _process_model_before_weight_loading( model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules ) + # DeepSeek-V4-Flash checkpoints mix FP8 and bf16 projections in the attention + # compressor/indexer branch: modules below are stored directly in bf16 (no + # companion scale tensor), so converting them to FP8Linear would create + # missing ``weight_scale_inv`` keys at load and random-init those params. + # Use `config.model_type` (not class-name substring) so this still applies + # with wrappers / auto classes where `model.__class__.__name__` may differ. + if self.pre_quantized and getattr(getattr(model, "config", None), "model_type", None) == "deepseek_v4": + self.modules_to_not_convert.extend( + [ + "self_attn.compressor.kv_proj", + "self_attn.compressor.gate_proj", + "self_attn.compressor.indexer.kv_proj", + "self_attn.compressor.indexer.gate_proj", + "self_attn.compressor.indexer.weights_proj", + ] + ) + self.modules_to_not_convert = list(set(self.modules_to_not_convert)) + model = replace_with_fp8_linear( model, modules_to_not_convert=self.modules_to_not_convert, @@ -213,30 +231,36 @@ def update_weight_conversions(self, weight_conversions): if dequantize: # Both .weight and .weight_scale_inv go to the same target; Fp8Dequantize # folds scales into the weight before any merge/concat downstream. - updated.append(WeightConverter( - source_patterns=anchored_weight + scale_sources + other, - target_patterns=conv._original_target_patterns, - operations=[Fp8Dequantize(self)] + list(conv.operations), - )) + updated.append( + WeightConverter( + source_patterns=anchored_weight + scale_sources + other, + target_patterns=conv._original_target_patterns, + operations=[Fp8Dequantize(self)] + list(conv.operations), + ) + ) else: # Native quantized: anchor the existing weight converter so .weight_scale_inv # keys don't leak in, then synthesize a parallel converter routing the scale # keys to `_scale_inv` with the same merge ops. - updated.append(WeightConverter( - source_patterns=anchored_weight + other, - target_patterns=conv._original_target_patterns, - operations=list(conv.operations), - )) + updated.append( + WeightConverter( + source_patterns=anchored_weight + other, + target_patterns=conv._original_target_patterns, + operations=list(conv.operations), + ) + ) target = conv._original_target_patterns if isinstance(target, str): scale_target = target + "_scale_inv" else: scale_target = [t + "_scale_inv" for t in target] - updated.append(WeightConverter( - source_patterns=scale_sources, - target_patterns=scale_target, - operations=list(conv.operations), - )) + updated.append( + WeightConverter( + source_patterns=scale_sources, + target_patterns=scale_target, + operations=list(conv.operations), + ) + ) # Generic fallback for plain `nn.Linear` weights with no model-specific converter. updated.extend(self.get_weight_conversions()) From 30d0780de4eab7e445d04b5f7ef42882aaeb6ab6 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Mon, 11 May 2026 03:41:53 -0700 Subject: [PATCH 86/87] more modeling changes --- .../integrations/tensor_parallel.py | 81 ++++--- .../deepseek_v4/configuration_deepseek_v4.py | 26 ++- .../deepseek_v4/modeling_deepseek_v4.py | 62 ++++-- .../models/deepseek_v4/modular_deepseek_v4.py | 61 +++++- test_deepseek.py | 197 +++++++++++++++--- 5 files changed, 339 insertions(+), 88 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 7831f09d2559..35b8f6e19f47 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -37,6 +37,20 @@ logger = logging.get_logger(__name__) +def to_local(t): + """Unwrap a `DTensor` to its local shard if needed; pass through otherwise. + + Custom kernels (CUTLASS, CuteDSL, Triton) take raw tensor pointers and don't + understand `DTensor`, so weights wrapped by FSDP2 / EP need this unwrap before + they can be fed to the kernel. ``to_local()`` is autograd-aware on the train + path: backward rewraps the gradient as a DTensor matching each parameter's + placements. + """ + if is_torch_available() and isinstance(t, torch.distributed.tensor.DTensor): + return t.to_local() + return t + + def initialize_tensor_parallelism( tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None ): @@ -766,6 +780,28 @@ def _backward_hook(mod, grad_input, grad_output, mesh=device_mesh): module.register_full_backward_hook(_backward_hook) +class AllReduceParallel(TensorParallelLayer): + """ + Marker layer: parameters (if any) are replicated; the forward output is all-reduced + across the TP mesh. Use as a no-op `nn.Identity` placed at a sync point after a + colwise-sharded compute that ends in a head-axis (or similar) reduction, so each + rank holds only a partial sum and needs to share it before the next dependent op + (e.g. the lightning indexer's score sum before its top-k). + """ + + def _prepare_input_fn(self, mod, inputs, device_mesh): + return inputs + + def _prepare_output_fn(self, mod, outputs, device_mesh): + return all_reduce_forward(outputs, device_mesh) + + def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None): + return param[...].to(device=device, dtype=dtype) + + def prepare_module_tp(self, module, device_mesh, **kwargs): + distribute_module(module, device_mesh, output_fn=self._prepare_output_fn) + + class MlaKvAProjParallel(TensorParallelLayer): """ For MLA attention used in DeepSeek-V2 style models (deepseek_v2, longcat_flash, glm_moe_dsa, glm4_moe_lite): @@ -1079,16 +1115,6 @@ def update_module_attributes(self, module: nn.Module): module.num_experts = self.get_expected_sharded_shape((self.empty_param.shape[0],))[0] -def _is_ep_native_experts_impl(mod: nn.Module) -> bool: - """Whether `mod`'s experts implementation handles EP dispatch + combine itself. - - These kernels (e.g. DeepGEMM Mega MoE) want GLOBAL expert ids with unmasked routing - weights and produce the fully-reduced output, so `RouterParallel` skips the per-rank - index remap and `MoeTensorParalellExperts` skips the post-forward all-reduce. - """ - return getattr(getattr(mod, "config", None), "_experts_implementation", None) in {"deepgemm_megamoe"} - - class RouterParallel(TensorParallelLayer): """ Allows to reshape the router scores to support running expert parallel. @@ -1098,7 +1124,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def _prepare_input_fn(self, mod, inputs, device_mesh): - return inputs[0] if inputs else inputs + return inputs def _prepare_output_fn(self, mod, outputs, device_mesh): """ @@ -1145,13 +1171,11 @@ def _prepare_output_fn(self, mod, outputs, device_mesh): The sentinel index (num_local_experts) is skipped by one_hot encoding or clamped + masked in grouped_mm/batched_mm. After the expert forward, an all_reduce sums partial outputs across EP ranks to produce the full result. + + Mega MoE skips this remap: its kernel does the EP dispatch itself and wants raw + global expert ids with unmasked routing weights. """ - # Mega MoE: keep the router's raw output. The router still runs (it produces topk_idx - # / topk_weights per token), but we skip the EP-time post-processing — Mega MoE's kernel - # does the EP token dispatch itself and needs GLOBAL expert ids with unmasked routing - # weights. Mirrored on the experts side by `MoeTensorParalellExperts._prepare_output_fn` - # which skips the post-forward all_reduce. - if _is_ep_native_experts_impl(mod): + if _is_megamoe(mod): return outputs ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() @@ -1201,6 +1225,12 @@ def _prepare_input_fn(self, mod, inputs, device_mesh): top_k_index = inputs[1] top_k_weights = inputs[2] + # Mega MoE is inference-only (the kernel has no backward) and handles EP + # dispatch + combine + per-rank token sharding internally. Skip the gradient + # sync hooks and append the EP `process_group` so the forward can rendezvous. + if _is_megamoe(mod): + return hidden_states, top_k_index, top_k_weights, device_mesh.get_group() + # all_reduce_backward on hidden_states for correct colwise (gate_up_proj) gradient hidden_states = all_reduce_backward(hidden_states, device_mesh) @@ -1209,18 +1239,12 @@ def _prepare_input_fn(self, mod, inputs, device_mesh): # and partial_expert_output is different on each GPU before all-reduce top_k_weights = all_reduce_backward(top_k_weights, device_mesh) - # Mega MoE handles EP dispatch + combine inside the kernel — append the EP `process_group` - # so the forward can rendezvous the symm-buffer on first call. - if _is_ep_native_experts_impl(mod): - return hidden_states, top_k_index, top_k_weights, device_mesh.get_group() - return hidden_states, top_k_index, top_k_weights def _prepare_output_fn(self, mod, outputs, device_mesh): - # Mega MoE handles the EP combine inside the kernel — output is already fully reduced. - if _is_ep_native_experts_impl(mod): + # Mega MoE returned the fully-combined gathered output; skip the all-reduce. + if _is_megamoe(mod): return outputs - # all_reduce_forward to sum partial expert outputs across GPUs return all_reduce_forward(outputs, device_mesh) @@ -1232,6 +1256,10 @@ def shard_tensor( return param[...].to(device=device, dtype=dtype) +def _is_megamoe(mod: nn.Module) -> bool: + return getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe" + + class MoeIdentityExpertParallel(TensorParallelLayer): """ TP class for zero/identity experts in MoE layers. @@ -1274,6 +1302,7 @@ class ParallelInterface(GeneralInterface): "moe_identity_expert": MoeIdentityExpertParallel(), "replicated_with_grad_allreduce": ReplicatedWithGradAllReduce(), "mla_kv_a_proj": MlaKvAProjParallel(), + "all_reduce": AllReduceParallel(), } if is_torch_available() and _torch_distributed_available else {} @@ -1294,6 +1323,7 @@ class ParallelInterface(GeneralInterface): "sequence_parallel": None, "replicated_with_grad_allreduce": None, "mla_kv_a_proj": None, + "all_reduce": None, } # Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced) @@ -1309,6 +1339,7 @@ class ParallelInterface(GeneralInterface): "sequence_parallel": None, "replicated_with_grad_allreduce": None, "mla_kv_a_proj": None, + "all_reduce": None, } @classmethod diff --git a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py index 69480a529bf8..751729173866 100644 --- a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py @@ -199,7 +199,7 @@ class DeepseekV4Config(PreTrainedConfig): # back to wrapping the whole dict as a single set of params when the subset check # fails, which then warns about `main` / `compress` as unrecognized keys. Override # to iterate the rope-type-keyed sub-dicts directly. - _rope_type_labels = ("main", "compress") + _rope_type_labels = ("sliding", "compress") def validate_rope(self): rope_parameters_dict = getattr(self, "rope_parameters", None) or {} @@ -297,19 +297,27 @@ def __post_init__(self, **kwargs): # `rope_parameters`: split the flat dict (left by `convert_rope_params_to_dict`, # which folded any legacy `rope_scaling` block in) into per-rope-type - # `{main, compress}` sub-dicts. Idempotent: re-loading an already-split config - # is a no-op via the `isinstance` short-circuit. The two sub-dicts differ only - # in `rope_theta` (main: 10000, compress: 160000). + # `{sliding, compress}` sub-dicts. Mirrors reference `inference/model.py:475-481`: + # sliding-window attention layers use base RoPE (`rope_theta=10000`, no YaRN — + # `original_seq_len=0` disables it); CSA/HCA layers (and their internal + # compressors/indexer) use `compress_rope_theta=160000` with YaRN frequency + # interpolation. Idempotent: re-loading an already-split config is a no-op via + # the `isinstance` short-circuit. rp = self.rope_parameters or {} - if isinstance(rp.get("main"), dict) and isinstance(rp.get("compress"), dict): + if isinstance(rp.get("sliding"), dict) and isinstance(rp.get("compress"), dict): # Already nested — drop any leftover top-level keys. - self.rope_parameters = {"main": rp["main"], "compress": rp["compress"]} + self.rope_parameters = {"sliding": rp["sliding"], "compress": rp["compress"]} else: - base = {k: v for k, v in rp.items() if k not in ("main", "compress")} - base.setdefault("rope_theta", self.rope_theta) + base = {k: v for k, v in rp.items() if k not in ("main", "sliding", "compress")} base.setdefault("rope_type", "default") base["partial_rotary_factor"] = self.partial_rotary_factor - self.rope_parameters = {"main": dict(base), "compress": {**base, "rope_theta": self.compress_rope_theta}} + sliding = { + "rope_theta": self.rope_theta, + "rope_type": "default", + "partial_rotary_factor": self.partial_rotary_factor, + } + compress = {**base, "rope_theta": self.compress_rope_theta} + self.rope_parameters = {"sliding": sliding, "compress": compress} __all__ = ["DeepseekV4Config"] diff --git a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py index 955caada32a6..9567d6ad733b 100644 --- a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py @@ -44,6 +44,13 @@ @use_kernel_forward_from_hub("RMSNorm") class DeepseekV4RMSNorm(nn.Module): + """Like V3's, but the weight·hidden multiply happens in fp32 before the cast + back to the input dtype — mirrors reference `inference/model.py:191-196` + where `weight` is declared `dtype=torch.float32` and the multiply runs in + fp32 before the final `.to(dtype)`. V3 downcasts hidden_states first and + multiplies in bf16, which loses precision across the ~5 norms × 43 layers. + """ + def __init__(self, hidden_size, eps: float = 1e-6) -> None: """ DeepseekV4RMSNorm is equivalent to T5LayerNorm @@ -53,11 +60,11 @@ def __init__(self, hidden_size, eps: float = 1e-6) -> None: self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) + dtype = hidden_states.dtype + hidden_states = hidden_states.float() variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return (self.weight.float() * hidden_states).to(dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" @@ -106,10 +113,13 @@ def __init__(self, config: DeepseekV4Config): rope_init_fn = self.compute_default_rope_parameters if self.rope_type[layer_type] != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] - inv_freq, attention_scaling = rope_init_fn(config, layer_type=layer_type) + # Reference (`inference/model.py:228`) builds cos/sin via + # `torch.polar(torch.ones_like(freqs), freqs)` — unit magnitude, no YaRN + # `attention_factor` scaling. We discard `rope_init_fn`'s scaling factor + # and emit raw `cos = freqs.cos()` / `sin = freqs.sin()` in `forward`. + inv_freq, _ = rope_init_fn(config, layer_type=layer_type) self.register_buffer(f"{layer_type}_inv_freq", inv_freq, persistent=False) self.register_buffer(f"{layer_type}_original_inv_freq", inv_freq.clone(), persistent=False) - setattr(self, f"{layer_type}_attention_scaling", attention_scaling) @staticmethod def compute_default_rope_parameters( @@ -157,14 +167,13 @@ def forward(self, x, position_ids, layer_type=None): # the `repeat_interleave(2)` next to the rotation math, where the link between # the doubled dim and `rotate_half` is local and obvious. inv_freq = getattr(self, f"{layer_type}_inv_freq") - attention_scaling = getattr(self, f"{layer_type}_attention_scaling") inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - cos = freqs.cos() * attention_scaling - sin = freqs.sin() * attention_scaling + cos = freqs.cos() + sin = freqs.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -764,11 +773,18 @@ def __init__(self, config: DeepseekV4Config, layer_idx: int): self.compressor = ( COMPRESSOR_CLASSES[self.layer_type](config) if self.layer_type != "sliding_attention" else None ) + # Reference `inference/model.py:475-481` picks `freqs_cis` per attention layer: + # sliding layers use base RoPE (`rope_theta=10000`, no YaRN); CSA/HCA layers + # use `compress_rope_theta=160000` with YaRN. The compressor inside a CSA/HCA + # layer shares its parent attention's `freqs_cis`, so main Q/K and compressor + # rotate consistently. We mirror that by selecting from a per-branch dict + # built once in `DeepseekV4Model.forward`. + self.rope_branch = "sliding" if self.layer_type == "sliding_attention" else "compress" def forward( self, hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], + position_embeddings: dict[str, tuple[torch.Tensor, torch.Tensor]], position_ids: torch.Tensor, attention_mask: torch.Tensor | None, past_key_values: Cache | None = None, @@ -776,7 +792,7 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - cos, sin = position_embeddings + cos, sin = position_embeddings[self.rope_branch] q_residual = self.q_a_norm(self.q_a_proj(hidden_states)) q = self.q_b_proj(q_residual).view(*hidden_shape).transpose(1, 2) @@ -963,6 +979,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DeepseekV4MLP(nn.Module): + """Used by the shared expert. Mirrors reference `inference/model.py:587-606`: + gate/up promoted to fp32, optionally clamped by `swiglu_limit`, SiLU+mul stay in + fp32, then the product is cast back to the input dtype before `down_proj`. + """ + def __init__(self, config): super().__init__() self.config = config @@ -973,9 +994,15 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + gate = self.gate_proj(x).float() + up = self.up_proj(x).float() + limit = self.config.swiglu_limit + if limit > 0: + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + return self.down_proj((self.act_fn(gate) * up).to(dtype)) @use_experts_implementation @@ -1285,7 +1312,14 @@ def forward( position_ids=position_ids, ) hidden_states = inputs_embeds.unsqueeze(2).expand(-1, -1, self.config.hc_mult, -1).contiguous() - position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="main") + # Per-layer rope: sliding-attention layers use base RoPE (`rope_theta=10000`, + # no YaRN), CSA/HCA layers use `compress_rope_theta=160000` with YaRN — + # mirrors reference `inference/model.py:475-481`. Each attention picks the + # right branch via `self.rope_branch`. + position_embeddings = { + "sliding": self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="sliding"), + "compress": self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="compress"), + } for layer in self.layers: hidden_states = layer( diff --git a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py index 06c91ed6cf22..9f6501273b5e 100644 --- a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py @@ -64,7 +64,19 @@ def apply_rotary_pos_emb( class DeepseekV4RMSNorm(DeepseekV3RMSNorm): - pass + """Like V3's, but the weight·hidden multiply happens in fp32 before the cast + back to the input dtype — mirrors reference `inference/model.py:191-196` + where `weight` is declared `dtype=torch.float32` and the multiply runs in + fp32 before the final `.to(dtype)`. V3 downcasts hidden_states first and + multiplies in bf16, which loses precision across the ~5 norms × 43 layers. + """ + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + dtype = hidden_states.dtype + hidden_states = hidden_states.float() + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight.float() * hidden_states).to(dtype) class DeepseekV4UnweightedRMSNorm(nn.Module): @@ -108,10 +120,13 @@ def __init__(self, config: DeepseekV4Config): rope_init_fn = self.compute_default_rope_parameters if self.rope_type[layer_type] != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] - inv_freq, attention_scaling = rope_init_fn(config, layer_type=layer_type) + # Reference (`inference/model.py:228`) builds cos/sin via + # `torch.polar(torch.ones_like(freqs), freqs)` — unit magnitude, no YaRN + # `attention_factor` scaling. We discard `rope_init_fn`'s scaling factor + # and emit raw `cos = freqs.cos()` / `sin = freqs.sin()` in `forward`. + inv_freq, _ = rope_init_fn(config, layer_type=layer_type) self.register_buffer(f"{layer_type}_inv_freq", inv_freq, persistent=False) self.register_buffer(f"{layer_type}_original_inv_freq", inv_freq.clone(), persistent=False) - setattr(self, f"{layer_type}_attention_scaling", attention_scaling) def forward(self, x, position_ids, layer_type=None): # Key difference vs Laguna's forward: no `torch.cat([freqs, freqs], dim=-1)` @@ -120,14 +135,13 @@ def forward(self, x, position_ids, layer_type=None): # the `repeat_interleave(2)` next to the rotation math, where the link between # the doubled dim and `rotate_half` is local and obvious. inv_freq = getattr(self, f"{layer_type}_inv_freq") - attention_scaling = getattr(self, f"{layer_type}_attention_scaling") inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - cos = freqs.cos() * attention_scaling - sin = freqs.sin() * attention_scaling + cos = freqs.cos() + sin = freqs.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -657,11 +671,18 @@ def __init__(self, config: DeepseekV4Config, layer_idx: int): self.compressor = ( COMPRESSOR_CLASSES[self.layer_type](config) if self.layer_type != "sliding_attention" else None ) + # Reference `inference/model.py:475-481` picks `freqs_cis` per attention layer: + # sliding layers use base RoPE (`rope_theta=10000`, no YaRN); CSA/HCA layers + # use `compress_rope_theta=160000` with YaRN. The compressor inside a CSA/HCA + # layer shares its parent attention's `freqs_cis`, so main Q/K and compressor + # rotate consistently. We mirror that by selecting from a per-branch dict + # built once in `DeepseekV4Model.forward`. + self.rope_branch = "sliding" if self.layer_type == "sliding_attention" else "compress" def forward( self, hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], + position_embeddings: dict[str, tuple[torch.Tensor, torch.Tensor]], position_ids: torch.Tensor, attention_mask: torch.Tensor | None, past_key_values: Cache | None = None, @@ -669,7 +690,7 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - cos, sin = position_embeddings + cos, sin = position_embeddings[self.rope_branch] q_residual = self.q_a_norm(self.q_a_proj(hidden_states)) q = self.q_b_proj(q_residual).view(*hidden_shape).transpose(1, 2) @@ -856,7 +877,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DeepseekV4MLP(LlamaMLP): - pass + """Used by the shared expert. Mirrors reference `inference/model.py:587-606`: + gate/up promoted to fp32, optionally clamped by `swiglu_limit`, SiLU+mul stay in + fp32, then the product is cast back to the input dtype before `down_proj`. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + gate = self.gate_proj(x).float() + up = self.up_proj(x).float() + limit = self.config.swiglu_limit + if limit > 0: + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + return self.down_proj((self.act_fn(gate) * up).to(dtype)) @use_experts_implementation @@ -1139,7 +1173,14 @@ def forward( position_ids=position_ids, ) hidden_states = inputs_embeds.unsqueeze(2).expand(-1, -1, self.config.hc_mult, -1).contiguous() - position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="main") + # Per-layer rope: sliding-attention layers use base RoPE (`rope_theta=10000`, + # no YaRN), CSA/HCA layers use `compress_rope_theta=160000` with YaRN — + # mirrors reference `inference/model.py:475-481`. Each attention picks the + # right branch via `self.rope_branch`. + position_embeddings = { + "sliding": self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="sliding"), + "compress": self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="compress"), + } for layer in self.layers: hidden_states = layer( diff --git a/test_deepseek.py b/test_deepseek.py index b71c95a4bdb8..93914db804ce 100644 --- a/test_deepseek.py +++ b/test_deepseek.py @@ -27,21 +27,40 @@ import gc import os import sys +import traceback import torch import torch.distributed as dist from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.distributed import DistributedConfig +from transformers.utils.quantization_config import FineGrainedFP8Config _CHECKPOINT = "deepseek-ai/DeepSeek-V4-Flash" -_PROMPT = "DeepGEMM tests: list three properties of UE8M0 scale factors." +_PROMPTS = [ + "The capital of France is", + "List the first ten prime numbers:", + "Translate to French: 'The quick brown fox jumps over the lazy dog.'", + "Write a Python function fibonacci(n) that returns the nth Fibonacci number.", + "What are the three properties of the UE8M0 scale factor format?", + 'Write a short story that begins with: "Once upon a time, in a forest far away, there lived a..."', +] + + +def _format_chat(prompt: str) -> str: + """Wrap a single-turn user prompt in V4-Flash's chat-mode template. + + The tokenizer doesn't ship a Jinja `chat_template`; the canonical format is in + `encoding/encoding_dsv4.py` on the model repo. Chat mode places `` + right after `<|Assistant|>` to skip the reasoning block and answer directly. + """ + return f"<|begin▁of▁sentence|><|User|>{prompt}<|Assistant|>" # (label, dispatch, min_sm). All entries share one model load. _QUANTIZED_DISPATCHES = [ ("quantized + deepgemm", "deepgemm", 9), - ("quantized + deepgemm_megamoe", "deepgemm_megamoe", 10), + # ("quantized + deepgemm_megamoe", "deepgemm_megamoe", 10), ] @@ -50,29 +69,141 @@ def _rank0_print(msg: str) -> None: print(msg, flush=True) -def _generate_and_check(model, tok, label: str, rank: int) -> None: - inputs = tok(_PROMPT, return_tensors="pt").to(model.device) - dist.barrier() - with torch.no_grad(): - out_ids = model.generate( - **inputs, - max_new_tokens=32, - do_sample=False, - pad_token_id=tok.eos_token_id, +def _render_report(results: list[tuple[str, str, str]], completions: dict[str, list[str]]) -> None: + """Render a side-by-side rich table comparing each dispatch's per-prompt completion.""" + from rich.console import Console + from rich.table import Table + + console = Console() + console.print("") + + # Per-dispatch status row first. + status_table = Table(title="Run summary", show_lines=False) + status_table.add_column("dispatch", style="bold") + status_table.add_column("status") + status_table.add_column("detail", overflow="fold") + for label, status, detail in results: + style = "green" if status == "PASS" else "yellow" if status == "SKIP" else "red" + status_table.add_row(label, f"[{style}]{status}[/{style}]", detail) + console.print(status_table) + + if not completions: + return + + # Per-prompt completions side by side. Each dispatch gets its own column. + dispatch_keys = list(completions.keys()) + title = "completions: " + " vs ".join(dispatch_keys) + completion_table = Table(title=title, show_lines=True) + completion_table.add_column("#", justify="right", no_wrap=True) + completion_table.add_column("prompt", overflow="fold", max_width=30) + for d in dispatch_keys: + completion_table.add_column(d, overflow="fold", max_width=42) + + for i, prompt in enumerate(_PROMPTS): + row = [str(i + 1), prompt] + for d in dispatch_keys: + comps = completions.get(d, []) + row.append(comps[i] if i < len(comps) else "") + completion_table.add_row(*row) + console.print(completion_table) + + +def _generate_and_check(model, tok, label: str, rank: int, completions: list[str]) -> None: + for i, prompt in enumerate(_PROMPTS): + inputs = tok(_format_chat(prompt), return_tensors="pt", add_special_tokens=False).to(model.device) + dist.barrier() + with torch.no_grad(): + out_ids = model.generate( + **inputs, + max_new_tokens=64, + do_sample=False, + pad_token_id=tok.eos_token_id, + ) + if rank == 0: + finite = torch.isfinite(out_ids.float()).all().item() + new_tokens = out_ids[0, inputs.input_ids.size(1):] + completion = tok.decode(new_tokens, skip_special_tokens=True) + completion_raw = tok.decode(new_tokens, skip_special_tokens=False) + completions.append(completion) + print(f"[{label}] prompt {i + 1}/{len(_PROMPTS)} — {new_tokens.numel()} tokens (finite={finite}):", flush=True) + print(f" completion: {completion!r}", flush=True) + print(f" raw decode: {completion_raw!r}", flush=True) + print(f" token ids: {new_tokens.tolist()}", flush=True) + if not finite or new_tokens.numel() == 0: + raise RuntimeError(f"{label}: generation failed (finite={finite}, n={new_tokens.numel()})") + dist.barrier() + + +def _run_dequant_phase(results: list, completions: dict[str, list[str]]) -> None: + """Dequantized (bf16) baseline via FineGrainedFP8Config(dequantize=True) + + `device_map="auto"` CPU offload. Runs only on rank 0 — the dequantized model + is ~1.3 TB and we use the host's RAM (`max_memory`) to keep all 8 ranks' + activations from contending with the model weights on GPU 0. + """ + label = "dequantized (bf16)" + dispatch = "dequantized" + print(f"\n--- loading {_CHECKPOINT} (dequantize=True, device_map=auto across all GPUs) ---", flush=True) + # Spread the bf16 model across all 8 GPUs (~1.36 TB total) instead of GPU 0 + CPU offload — + # the latter forces every forward pass to stream weights over PCIe and is ~10× slower. + # Other torchrun ranks have torn down by now, so GPUs 1..N are free. + n_gpus = torch.cuda.device_count() + max_memory = {i: "170GiB" for i in range(n_gpus)} + max_memory["cpu"] = "1500GiB" # fallback for any leftover + qcfg = FineGrainedFP8Config(dequantize=True) + try: + model = AutoModelForCausalLM.from_pretrained( + _CHECKPOINT, + device_map="auto", + dtype="auto", + quantization_config=qcfg, + max_memory=max_memory, ) - if rank == 0: - finite = torch.isfinite(out_ids.float()).all().item() - new_tokens = out_ids[0, inputs.input_ids.size(1):] - completion = tok.decode(new_tokens, skip_special_tokens=True) - print(f"[{label}] generated {new_tokens.numel()} tokens (finite={finite}):", flush=True) - print(f" prompt: {_PROMPT}", flush=True) - print(f" completion: {completion}", flush=True) - if not finite or new_tokens.numel() == 0: - raise RuntimeError(f"{label}: generation failed (finite={finite}, n={new_tokens.numel()})") - dist.barrier() + model.eval() + tok = AutoTokenizer.from_pretrained(_CHECKPOINT) + except BaseException as exc: + print(f"[load] FAIL — {type(exc).__name__}: {exc}", flush=True) + results.append((label, "FAIL", f"load: {type(exc).__name__}: {exc}")) + return + completions[dispatch] = [] + print(f"\n=== {label} ===", flush=True) + try: + for i, prompt in enumerate(_PROMPTS): + inputs = tok(_format_chat(prompt), return_tensors="pt", add_special_tokens=False).to(model.device) + with torch.no_grad(): + out_ids = model.generate( + **inputs, + max_new_tokens=64, + do_sample=False, + pad_token_id=tok.eos_token_id, + ) + new_tokens = out_ids[0, inputs.input_ids.size(1):] + completion = tok.decode(new_tokens, skip_special_tokens=True) + completion_raw = tok.decode(new_tokens, skip_special_tokens=False) + completions[dispatch].append(completion) + print(f"[{label}] prompt {i + 1}/{len(_PROMPTS)} — {new_tokens.numel()} tokens:", flush=True) + print(f" completion: {completion!r}", flush=True) + print(f" raw decode: {completion_raw!r}", flush=True) + print(f" token ids: {new_tokens.tolist()}", flush=True) + results.append((label, "PASS", "")) + except BaseException as exc: + print(f"[{label}] FAIL — {type(exc).__name__}: {exc}", flush=True) + traceback.print_exc() + results.append((label, "FAIL", f"{type(exc).__name__}: {exc}")) + finally: + del model, tok + gc.collect() + torch.cuda.empty_cache() -def _run_phase(load_kwargs: dict, dispatches, cap_major: int, rank: int, results: list) -> None: + +def _run_phase( + load_kwargs: dict, + dispatches, + cap_major: int, + rank: int, + results: list, + completions: dict[str, list[str]], +) -> None: runnable = [(lab, d) for (lab, d, sm) in dispatches if cap_major >= sm] skipped = [(lab, d, sm) for (lab, d, sm) in dispatches if cap_major < sm] for lab, _, sm in skipped: @@ -103,11 +234,13 @@ def _run_phase(load_kwargs: dict, dispatches, cap_major: int, rank: int, results _rank0_print(f"\n=== {label} ===") try: model.set_experts_implementation(dispatch) - _generate_and_check(model, tok, label, rank) + completions[dispatch] = [] + _generate_and_check(model, tok, label, rank, completions[dispatch]) results.append((label, "PASS", "")) except BaseException as exc: if rank == 0: print(f"[{label}] FAIL — {type(exc).__name__}: {exc}", flush=True) + traceback.print_exc() results.append((label, "FAIL", f"{type(exc).__name__}: {exc}")) finally: del model, tok @@ -130,6 +263,7 @@ def main() -> int: _rank0_print(f"device cap: SM{cap_major}0, world_size={world_size}") results: list[tuple[str, str, str]] = [] # (label, status, detail) + completions: dict[str, list[str]] = {} # dispatch → per-prompt completions _run_phase( load_kwargs={"dtype": "auto"}, @@ -137,22 +271,25 @@ def main() -> int: cap_major=cap_major, rank=rank, results=results, + completions=completions, ) dist.barrier() dist.destroy_process_group() + # Dequantized bf16 baseline: rank 0 only, spread across all GPUs once + # ranks 1..N have exited and released their per-rank quantized weights. + if rank == 0: + import time + time.sleep(5) + torch.cuda.empty_cache() + _run_dequant_phase(results, completions) + if rank == 0: passed = [r for r in results if r[1] == "PASS"] failed = [r for r in results if r[1] == "FAIL"] skipped = [r for r in results if r[1] == "SKIP"] - width = max((len(r[0]) for r in results), default=20) - print("\n=== summary ===", flush=True) - for label, status, detail in results: - line = f" {label.ljust(width)} {status}" - if detail: - line += f" ({detail})" - print(line, flush=True) + _render_report(results, completions) print( f"\n totals: {len(passed)} passed, {len(failed)} failed, {len(skipped)} skipped", flush=True, From 3b1c470a6dc5b8db43dff3044b742fec31eaff5d Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Mon, 11 May 2026 03:46:23 -0700 Subject: [PATCH 87/87] more modeling changes --- check_nvcc_b200.py | 168 --------- probe_dsv3_conversion.py | 78 ----- repro_nan_dsv3_b200.py | 117 ------- .../deepseek_v4/configuration_deepseek_v4.py | 2 +- .../deepseek_v4/modeling_deepseek_v4.py | 26 +- .../models/deepseek_v4/modular_deepseek_v4.py | 26 +- test_deepgemm.py | 330 ------------------ 7 files changed, 13 insertions(+), 734 deletions(-) delete mode 100644 check_nvcc_b200.py delete mode 100644 probe_dsv3_conversion.py delete mode 100644 repro_nan_dsv3_b200.py delete mode 100644 test_deepgemm.py diff --git a/check_nvcc_b200.py b/check_nvcc_b200.py deleted file mode 100644 index 254100db26d0..000000000000 --- a/check_nvcc_b200.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Smoke-test the user's nvcc + CUDA setup against the running GPU. - -Compiles and runs a tiny CUDA kernel targeting the device's actual compute -capability (e.g. `sm_100a` on B200, `sm_90a` on H100). Mirrors what DeepGEMM's -JIT does at the first kernel call: - - 1. Locate `nvcc` via `$CUDA_HOME/bin/nvcc`, then PATH, then `/usr/local/cuda`. - 2. nvcc-compile a kernel that uses an SM-specific intrinsic / API. - 3. Launch it, copy result back, sanity-check. - -If this succeeds, DeepGEMM JIT will work at runtime. If it fails, the message -points at the specific layer (toolchain, driver, runtime) so you can fix it -before pulling DeepGEMM into a model run. - -Usage: - python check_nvcc_b200.py - CUDA_HOME=/path/to/cuda python check_nvcc_b200.py -""" - -from __future__ import annotations - -import ctypes -import ctypes.util -import os -import shutil -import subprocess -import sys -import tempfile -from pathlib import Path - -import torch - - -def _find_cuda_home() -> str: - """Same search order as the deep-gemm wheel's `_find_cuda_home`.""" - for var in ("CUDA_HOME", "CUDA_PATH"): - cand = os.environ.get(var) - if cand and (Path(cand) / "bin" / "nvcc").is_file(): - return cand - - nvcc = shutil.which("nvcc") - if nvcc: - return str(Path(nvcc).parent.parent) - - try: - import nvidia.cuda_nvcc as _nvcc # type: ignore - cand = Path(_nvcc.__file__).parent - if (cand / "bin" / "nvcc").is_file(): - return str(cand) - except ImportError: - pass - - for cand in ("/usr/local/cuda", "/opt/cuda", "/opt/nvidia/cuda", "/usr/lib/cuda"): - if (Path(cand) / "bin" / "nvcc").is_file(): - return cand - import glob - for cand in sorted(glob.glob("/usr/local/cuda-*"), reverse=True): - if (Path(cand) / "bin" / "nvcc").is_file(): - return cand - raise SystemExit("nvcc not found. Set CUDA_HOME or install CUDA toolkit.") - - -_KERNEL_SRC = r""" -#include -#include - -// One element per thread; writes its global index. Probes that arch-specific -// codegen + scheduling work end-to-end on the device. -__global__ void identity_kernel(int* out, int n) { - const int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - out[i] = i; - } -} - -extern "C" __host__ int run_check(int n) { - int* d_out = nullptr; - cudaError_t err = cudaMalloc(&d_out, n * sizeof(int)); - if (err != cudaSuccess) { fprintf(stderr, "cudaMalloc: %s\n", cudaGetErrorString(err)); return 1; } - - int threads = 128, blocks = (n + threads - 1) / threads; - identity_kernel<<>>(d_out, n); - err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { fprintf(stderr, "kernel launch: %s\n", cudaGetErrorString(err)); cudaFree(d_out); return 2; } - - int* h_out = (int*)malloc(n * sizeof(int)); - err = cudaMemcpy(h_out, d_out, n * sizeof(int), cudaMemcpyDeviceToHost); - if (err != cudaSuccess) { fprintf(stderr, "cudaMemcpy: %s\n", cudaGetErrorString(err)); free(h_out); cudaFree(d_out); return 3; } - - int ok = 1; - for (int i = 0; i < n; ++i) if (h_out[i] != i) { ok = 0; break; } - free(h_out); - cudaFree(d_out); - return ok ? 0 : 4; -} -""" - - -def main() -> int: - if not torch.cuda.is_available(): - print("FAIL: CUDA not available to torch.") - return 1 - - cap = torch.cuda.get_device_capability() - sm = f"{cap[0]}{cap[1]}a" - name = torch.cuda.get_device_name() - print(f"GPU: {name} (compute capability sm_{sm})") - - cuda_home = _find_cuda_home() - nvcc = str(Path(cuda_home) / "bin" / "nvcc") - print(f"CUDA_HOME: {cuda_home}") - - ver = subprocess.run([nvcc, "--version"], capture_output=True, text=True) - print(ver.stdout.strip().splitlines()[-1] if ver.stdout else ver.stderr) - - with tempfile.TemporaryDirectory() as td: - src = Path(td) / "probe.cu" - so = Path(td) / "probe.so" - src.write_text(_KERNEL_SRC) - - cmd = [ - nvcc, "-shared", "-Xcompiler=-fPIC", - "-O2", "-std=c++17", - f"-gencode=arch=compute_{cap[0]}{cap[1]}{'a' if cap[0] >= 9 else ''},code=sm_{sm}", - "-o", str(so), str(src), - ] - print("\n[1/3] nvcc compile…") - r = subprocess.run(cmd, capture_output=True, text=True) - if r.returncode != 0: - print(f"FAIL: nvcc compile (exit {r.returncode})") - print("--- stderr ---") - print(r.stderr) - return 1 - print(" OK") - - print("[2/3] dlopen…") - try: - lib = ctypes.CDLL(str(so)) - except OSError as e: - print(f"FAIL: dlopen: {e}") - print("Hint: missing libcudart.so on LD_LIBRARY_PATH. Try:") - print(f" export LD_LIBRARY_PATH={cuda_home}/lib64:$LD_LIBRARY_PATH") - return 1 - lib.run_check.restype = ctypes.c_int - lib.run_check.argtypes = [ctypes.c_int] - print(" OK") - - print("[3/3] launch kernel…") - rc = lib.run_check(1024) - labels = { - 0: "OK", - 1: "cudaMalloc failed", - 2: "kernel launch / sync failed", - 3: "cudaMemcpy failed", - 4: "kernel produced wrong values", - } - print(f" run_check → {rc} ({labels.get(rc, 'unknown')})") - if rc != 0: - print("\nFAIL: nvcc compiles but the kernel did not run correctly.") - print("Common causes: GPU driver too old for the toolkit, mismatched libcudart.") - return 1 - - print(f"\nPASS: nvcc {Path(nvcc).name} can compile + run sm_{sm} kernels on this {name}.") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/probe_dsv3_conversion.py b/probe_dsv3_conversion.py deleted file mode 100644 index 723e063e7fe0..000000000000 --- a/probe_dsv3_conversion.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Compare the kernel's float→packed-UE8M0 conversion against a Python -equivalent, on the exact SF tensors that DSv3 feeds to the GEMM. Goal: find -out whether `transpose_and_pack_fp32_into_ue8m0` (the JIT helper the kernel -runs internally for DSv3 on SM100) produces values matching what we'd compute -ourselves. - -If the kernel's output matches Python on every byte, the NaN comes from -somewhere else in the GEMM. If it doesn't match, the conversion itself is the -bug. -""" - -from __future__ import annotations - -import torch - -from transformers.integrations.deepgemm import _load_deepgemm_kernel - - -def py_pack(sf_fp32: torch.Tensor) -> torch.Tensor: - """Same as `pack_fp32_into_ue8m0`: extract the biased exponent (bits - [30:23]) of each float as a uint8, then pack 4 K-consecutive bytes into - one int32 (LSB = lowest K). Returns an MN-major int32 tensor. - """ - # Extract biased exponent → uint8 - byte = (sf_fp32.view(torch.int32) >> 23).to(torch.uint8) - # Reshape so K dim is divisible by 4 - *batch, mn, k = byte.shape - assert k % 4 == 0 - # Pack each group of 4 K-bytes into 1 int32 in little-endian order - grouped = byte.view(*batch, mn, k // 4, 4).to(torch.int32) - packed = grouped[..., 0] | (grouped[..., 1] << 8) | (grouped[..., 2] << 16) | (grouped[..., 3] << 24) - return packed # shape (..., mn, k//4) K-major; caller should rewrite to MN-major - - -def main(): - torch.cuda.set_device(0) - d = torch.device("cuda", 0) - dg = _load_deepgemm_kernel() - - # Activation SF case: shape (M, K_blocks), per-row. - print("=== activation SF: (3056, 8) float32 → kernel pack vs python pack ===") - M, Kb = 3056, 8 - sf_a = (torch.rand(M, Kb, device=d) * 0.05 + 0.001).to(torch.float32) - - kernel_packed = dg.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf_a) - py_packed = py_pack(sf_a) - print(f" kernel out shape={tuple(kernel_packed.shape)} stride={tuple(kernel_packed.stride())} dtype={kernel_packed.dtype}") - print(f" python out shape={tuple(py_packed.shape)}") - # Compare bytes - kernel_bytes = kernel_packed.contiguous().view(torch.uint8).flatten() - python_bytes = py_packed.contiguous().view(torch.uint8).flatten() - n = min(kernel_bytes.numel(), python_bytes.numel()) - diff = (kernel_bytes[:n] != python_bytes[:n]).sum().item() - print(f" byte-equal count: {(n - diff)}/{n} (diff={diff})") - - print("\n=== weight SF: (16, 8, 8) float32 → kernel broadcast+pack ===") - E, sn, sk = 16, 8, 8 - N, K = 1024, 1024 - sf_w = (torch.rand(E, sn, sk, device=d) * 0.05 + 0.001).to(torch.float32) - - # Python broadcast: each block-row repeated 128 times along dim -2. - sf_w_broadcast = sf_w.repeat_interleave(N // sn, dim=-2) # (E, N, sk) - py_packed_w = py_pack(sf_w_broadcast) # (E, N, sk//4) - print(f" python broadcast+pack shape={tuple(py_packed_w.shape)}") - - # Kernel path: pass float SF to transform_sf_into_required_layout via the - # recipe machinery. We don't have direct access; the closest helper is - # the public one which expects per-row float input. Skip and just confirm - # python pack is correct on the broadcasted form. - - # Sanity: print a slice of packed bytes for visual inspection. - print(f" python packed[0, 0, :] = {py_packed_w[0, 0, :].tolist()}") - print(f" python packed[0, 127, :] = {py_packed_w[0, 127, :].tolist()} (same block as row 0)") - print(f" python packed[0, 128, :] = {py_packed_w[0, 128, :].tolist()} (next block)") - - -if __name__ == "__main__": - main() diff --git a/repro_nan_dsv3_b200.py b/repro_nan_dsv3_b200.py deleted file mode 100644 index 024be391eac9..000000000000 --- a/repro_nan_dsv3_b200.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Pinpoint the DSv3-on-B200 NaN. - -DeepGEMM's `pack_fp32_into_ue8m0` doesn't convert fp32 to UE8M0 — it expects -the fp32 input to *already* be UE8M0-rounded (each value an exact power of 2, -mantissa bits all zero) and just repacks the exponent bytes. The kernel's -inner shifts (`>> 23`, `>> 15`, `>> 7`, `<< 1`) only cleanly extract the -biased exponent for the first lane; the rest leak mantissa bits into adjacent -byte slots when the mantissa isn't zero. - -This script verifies that on raw arbitrary fp32 SFs (kernel output diverges -from a "biased-exponent only" reference) but matches byte-for-byte once the -input is rounded to powers of 2 via `ceil_to_ue8m0`. - -Implication: on SM100 the kernel's `(FP32, x, gran_k)` → packed-int path -silently corrupts SFs unless the caller pre-rounds them. SM90 sidesteps this -because its FP8 path consumes raw fp32 SFs directly without going through -`pack_fp32_into_ue8m0`. - -Run on H100 and B200; both should print: - raw → DIVERGES (kernel needs UE8M0-rounded inputs). - ue8m0 → MATCHES. -""" - -from __future__ import annotations - -import sys - -import torch - -from transformers.integrations.deepgemm import _load_deepgemm_kernel - - -def ceil_to_ue8m0(x: torch.Tensor) -> torch.Tensor: - """Round each positive float up to the nearest power of 2 representable as - UE8M0 (mantissa zeroed out). Mirrors upstream's `deep_gemm.utils.math`. - """ - return ( - (x.view(torch.int32) + ((1 << 23) - 1)).bitwise_and_(~((1 << 23) - 1)).view(torch.float) - ) - - -def python_pack_exponent_only(sf_fp32: torch.Tensor) -> torch.Tensor: - """Reference assuming UE8M0 input: extract biased exponent (bits [30:23]) - of each float as a uint8, then pack 4 K-consecutive bytes into one int32 - LE. This *only* matches the kernel when the input has zero mantissa. - """ - byte = (sf_fp32.contiguous().view(torch.int32) >> 23).to(torch.uint8) - *batch, mn, k = byte.shape - assert k % 4 == 0 - g = byte.view(*batch, mn, k // 4, 4).to(torch.int32) - return g[..., 0] | (g[..., 1] << 8) | (g[..., 2] << 16) | (g[..., 3] << 24) - - -def compare(label: str, kernel_out: torch.Tensor, ref_out: torch.Tensor) -> int: - if kernel_out.shape != ref_out.shape: - kernel_out = kernel_out[..., : ref_out.size(-2), : ref_out.size(-1)].contiguous() - py_b = ref_out.contiguous().view(torch.uint8).flatten() - k_b = kernel_out.contiguous().view(torch.uint8).flatten() - n = min(py_b.numel(), k_b.numel()) - diff = (py_b[:n] != k_b[:n]).sum().item() - status = "MATCH" if diff == 0 else "DIVERGE" - print(f" [{label}] diff_bytes={diff}/{n} ({status})") - return diff - - -def main() -> int: - if not torch.cuda.is_available(): - sys.exit("CUDA required.") - device = torch.device("cuda", 0) - torch.cuda.set_device(device) - cap = torch.cuda.get_device_capability(device) - print(f"GPU: {torch.cuda.get_device_name(device)} SM{cap[0]}{cap[1]}") - - dg = _load_deepgemm_kernel() - torch.manual_seed(0) - - # One representative SF tensor — what the integration would feed for - # block-quantized weight SFs in DSv3 inference. - sf_raw = (torch.rand(4, 8, 8, device=device) * 0.05 + 0.001).to(torch.float32) - sf_ue8m0 = ceil_to_ue8m0(sf_raw) - - print("\nfp32 SF → kernel pack vs python `extract biased exponent and pack 4` reference:\n") - - diff_raw = compare( - "raw fp32 SF (mantissa != 0)", - dg.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf_raw), - python_pack_exponent_only(sf_raw), - ) - diff_ue8m0 = compare( - "ceil_to_ue8m0 fp32 SF (mantissa == 0)", - dg.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf_ue8m0), - python_pack_exponent_only(sf_ue8m0), - ) - - print() - print("Conclusion:") - print( - f" raw: {'DIVERGES' if diff_raw else 'MATCHES'} " - "(kernel reads mantissa bits when not zero)" - ) - print( - f" ue8m0: {'DIVERGES' if diff_ue8m0 else 'MATCHES'} " - "(kernel cleanly repacks exponent bytes)" - ) - - if diff_ue8m0 != 0: - print("\nUnexpected: pack diverges even with UE8M0-rounded input — that is a real kernel bug.") - return 1 - if diff_raw == 0: - print("\nUnexpected: kernel matches without UE8M0 rounding — investigate.") - return 1 - print("\nFix: in our integration, round float SFs via ceil_to_ue8m0 before passing.") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py index 751729173866..2f9e3527d430 100644 --- a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py @@ -308,7 +308,7 @@ def __post_init__(self, **kwargs): # Already nested — drop any leftover top-level keys. self.rope_parameters = {"sliding": rp["sliding"], "compress": rp["compress"]} else: - base = {k: v for k, v in rp.items() if k not in ("main", "sliding", "compress")} + base = {k: v for k, v in rp.items() if k not in ("sliding", "compress")} base.setdefault("rope_type", "default") base["partial_rotary_factor"] = self.partial_rotary_factor sliding = { diff --git a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py index 9567d6ad733b..ec0abffff670 100644 --- a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py @@ -831,14 +831,9 @@ def forward( k_per_query = n_compressor // S q_idx = torch.arange(S, device=q.device).unsqueeze(1) # [S, 1] kv_idx = torch.arange(n_compressor, device=q.device).div(k_per_query, rounding_mode="floor") # [S*k] - own_q = q_idx == kv_idx # [S, S*k] - if compressor_valid is not None: - allowed = own_q.unsqueeze(0) & compressor_valid.unsqueeze(1) # [B, S, S*k] - extra = torch.where(allowed, 0.0, float("-inf")).to(attention_mask.dtype).unsqueeze(1) - extra = extra.expand(attention_mask.shape[0], attention_mask.shape[1], -1, -1) - else: - extra = torch.where(own_q, 0.0, float("-inf")).to(attention_mask.dtype) - extra = extra.expand(*attention_mask.shape[:-2], S, n_compressor) + allowed = (q_idx == kv_idx).unsqueeze(0) & compressor_valid.unsqueeze(1) # [B, S, S*k] + extra = torch.where(allowed, 0.0, float("-inf")).to(attention_mask.dtype).unsqueeze(1) + extra = extra.expand(attention_mask.shape[0], attention_mask.shape[1], -1, -1) attention_mask = torch.cat([attention_mask, extra], dim=-1) else: # HCA (every query sees every entry) and CSA decode (single query's k slots @@ -899,10 +894,10 @@ class DeepseekV4HyperConnection(nn.Module): [B, S, H] [B, S, H] [B, S, H, H] × scale[0] × scale[1] × scale[2] + base[:H] + base[H:2H] + base[2H:] - σ() + eps σ() + eps σ() + eps + σ() + eps 2·σ() softmax(-1) + eps │ │ │ - pre post Sinkhorn(iters) - (stream collapse weights) (block-output placement) row/col normalise + pre post Sinkhorn(iters) + (stream collapse weights) (block-output placement, range [0, 2]) row/col normalise │ comb (stream mixer) @@ -924,15 +919,6 @@ def __init__(self, config: DeepseekV4Config): self.scale = nn.Parameter(torch.empty(3)) def forward(self, hidden_streams: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - r""" - Compute `pre`, `post`, `comb` from the mHC mapping (paper §2.2 eq. 8). - `comb` is projected onto the doubly-stochastic manifold via Sinkhorn- - Knopp: starting from the sigmoid-positive matrix, alternate row and - column normalisation for `hc_sinkhorn_iters` steps. `pre` then collapses - the `hc_mult` parallel streams into a single sequence (input projection - into the sublayer); `post` and `comb` are returned for the caller to - apply on the sublayer output. - """ flat = self.input_norm(hidden_streams.flatten(start_dim=2).float()) mix = F.linear(flat, self.fn.float()) # [B, S, (2+H)*H] pre_scale, post_scale, comb_scale = self.scale.unbind(0) diff --git a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py index 9f6501273b5e..96ca7e1e3edb 100644 --- a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py @@ -729,14 +729,9 @@ def forward( k_per_query = n_compressor // S q_idx = torch.arange(S, device=q.device).unsqueeze(1) # [S, 1] kv_idx = torch.arange(n_compressor, device=q.device).div(k_per_query, rounding_mode="floor") # [S*k] - own_q = q_idx == kv_idx # [S, S*k] - if compressor_valid is not None: - allowed = own_q.unsqueeze(0) & compressor_valid.unsqueeze(1) # [B, S, S*k] - extra = torch.where(allowed, 0.0, float("-inf")).to(attention_mask.dtype).unsqueeze(1) - extra = extra.expand(attention_mask.shape[0], attention_mask.shape[1], -1, -1) - else: - extra = torch.where(own_q, 0.0, float("-inf")).to(attention_mask.dtype) - extra = extra.expand(*attention_mask.shape[:-2], S, n_compressor) + allowed = (q_idx == kv_idx).unsqueeze(0) & compressor_valid.unsqueeze(1) # [B, S, S*k] + extra = torch.where(allowed, 0.0, float("-inf")).to(attention_mask.dtype).unsqueeze(1) + extra = extra.expand(attention_mask.shape[0], attention_mask.shape[1], -1, -1) attention_mask = torch.cat([attention_mask, extra], dim=-1) else: # HCA (every query sees every entry) and CSA decode (single query's k slots @@ -797,10 +792,10 @@ class DeepseekV4HyperConnection(nn.Module): [B, S, H] [B, S, H] [B, S, H, H] × scale[0] × scale[1] × scale[2] + base[:H] + base[H:2H] + base[2H:] - σ() + eps σ() + eps σ() + eps + σ() + eps 2·σ() softmax(-1) + eps │ │ │ - pre post Sinkhorn(iters) - (stream collapse weights) (block-output placement) row/col normalise + pre post Sinkhorn(iters) + (stream collapse weights) (block-output placement, range [0, 2]) row/col normalise │ comb (stream mixer) @@ -822,15 +817,6 @@ def __init__(self, config: DeepseekV4Config): self.scale = nn.Parameter(torch.empty(3)) def forward(self, hidden_streams: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - r""" - Compute `pre`, `post`, `comb` from the mHC mapping (paper §2.2 eq. 8). - `comb` is projected onto the doubly-stochastic manifold via Sinkhorn- - Knopp: starting from the sigmoid-positive matrix, alternate row and - column normalisation for `hc_sinkhorn_iters` steps. `pre` then collapses - the `hc_mult` parallel streams into a single sequence (input projection - into the sublayer); `post` and `comb` are returned for the caller to - apply on the sublayer output. - """ flat = self.input_norm(hidden_streams.flatten(start_dim=2).float()) mix = F.linear(flat, self.fn.float()) # [B, S, (2+H)*H] pre_scale, post_scale, comb_scale = self.scale.unbind(0) diff --git a/test_deepgemm.py b/test_deepgemm.py deleted file mode 100644 index 0382c9f696e0..000000000000 --- a/test_deepgemm.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Smoke-test the three DeepGEMM experts dispatches with synthetic experts. - -Each test builds a synthetic experts module with the right weight dtypes / SF formats and -runs the kernel forward, checking the output is finite and shaped correctly. - -Coverage: - 1. DSv3-style: FP8 weights (`float8_e4m3fn`) + float32 SF — Hopper SM90+ - 2. DSv4-style: FP4 weights (`int8`-packed e2m1) + UE8M0 SF — Blackwell SM100+ - 3. Mega MoE: same as DSv4 but with EP dispatch + combine inside the kernel — SM100+ - + distributed (uses `transform_weights_for_mega_moe` for the layout) - -Usage: - # Single GPU (DSv3 + DSv4): - python test_deepgemm_integration.py - - # Mega MoE (≥2 ranks): - torchrun --nproc_per_node=2 test_deepgemm_integration.py -""" - -from __future__ import annotations - -import os -from types import SimpleNamespace - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -from transformers.integrations.deepgemm import ( - _load_deepgemm_kernel, - deepgemm_bf16_experts_forward, - deepgemm_fp8_fp4_experts_forward, - deepgemm_fp8_fp4_megamoe_experts_forward, -) - - -_FP8_DTYPE = torch.float8_e4m3fn -_FP8_MAX = torch.finfo(_FP8_DTYPE).max -_UE8M0_SF_DTYPE = torch.float8_e8m0fnu - - -def _round_to_ue8m0(x: torch.Tensor) -> torch.Tensor: - """Round a positive float tensor to the nearest power of 2 representable as UE8M0.""" - return torch.pow(2.0, torch.ceil(torch.log2(x.clamp(min=torch.finfo(torch.float32).tiny)))).to(_UE8M0_SF_DTYPE) - - -def _make_fp8_experts(num_experts: int, hidden_size: int, intermediate_size: int, ue8m0_sf: bool, device: torch.device) -> SimpleNamespace: - """Synthetic FP8 experts. - - DeepGEMM picks the SF recipe per-arch based on the SF dtype (see - `get_default_recipe` in `csrc/utils/layout.hpp`): - - * SM90 + float SF → recipe (1, 128, 128): block-quantized SF for B, - shape `(E, N/128, K/128)`. - * SM100 + float SF → recipe (1, 128, 128): same block-quantized - shape; kernel broadcasts → packs UE8M0 - internally (DSv3 path, "legacy" on Blackwell). - * SM100 + UE8M0 SF → recipe (1, 1, 128): per-row SF for B, shape - `(E, N, K/128)`. This is the DSv4-FP8 path. - - `ue8m0_sf=False` exercises the float-SF (DSv3) path; `ue8m0_sf=True` - exercises the per-row UE8M0 (DSv4-FP8) path. - """ - block_k = 128 - # Per-row when UE8M0 (gran_mn=1), block-128 when float SF (gran_mn=128). - block_n = 1 if ue8m0_sf else 128 - - def _alloc(e: int, n: int, k: int) -> tuple[torch.Tensor, torch.Tensor]: - # Per-block-amax FP8 quantization (matches what real DeepSeek-V3 - # checkpoints look like): generate random bf16 weights, compute the - # max-abs of each (block_n × block_k) tile as the SF, divide by 448 - # so the quantized values use the full FP8 range, then cast to FP8. - # Without this, dequantized weights are tiny and the GEMM - # accumulation on Blackwell's float→UE8M0 conversion path can - # produce NaN. - w_fp32 = torch.randn(e, n, k, device=device) * 0.1 - sf_n = -(-n // block_n) # ceil-div - sf_k = -(-k // block_k) - # Block amax → scale. - w_blocks = w_fp32.view(e, sf_n, block_n, sf_k, block_k) - amax = w_blocks.abs().amax(dim=(2, 4)).clamp(min=1e-4) # (e, sf_n, sf_k) - sf = (amax / _FP8_MAX).to(torch.float32) - if ue8m0_sf: - sf = _round_to_ue8m0(sf) - # Quantize using the dequantized SF (so the cast actually matches). - sf_dequant = sf.float() - w_scaled = w_fp32 / sf_dequant.view(e, sf_n, 1, sf_k, 1).expand(-1, -1, block_n, -1, block_k).reshape(e, n, k) - w_fp8 = w_scaled.clamp(-_FP8_MAX, _FP8_MAX).to(_FP8_DTYPE) - return w_fp8, sf - - gate_up, gate_up_sf = _alloc(num_experts, 2 * intermediate_size, hidden_size) - down, down_sf = _alloc(num_experts, hidden_size, intermediate_size) - return SimpleNamespace( - num_experts=num_experts, - has_gate=True, - has_bias=False, - is_transposed=False, - # block_size matches the actual SF granularity: - # (128, 128) for the DSv3 (float-SF) block-quantized path, - # (1, 128) for the DSv4-FP8 (UE8M0-SF) per-row path. - block_size=(block_n, block_k), - activation_scheme="dynamic", - config=SimpleNamespace(hidden_act="silu"), - gate_up_proj=gate_up, - gate_up_proj_scale_inv=gate_up_sf, - down_proj=down, - down_proj_scale_inv=down_sf, - _apply_gate=lambda x: F.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], - act_fn=F.silu, - ) - - -def _make_bf16_experts(num_experts: int, hidden_size: int, intermediate_size: int, device: torch.device) -> SimpleNamespace: - """Synthetic BF16 experts (no quantization, no SF) — exercises the - `deepgemm_bf16_experts_forward` path that calls the bf16 grouped GEMM.""" - gate_up = torch.randn(num_experts, 2 * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device) * 0.1 - down = torch.randn(num_experts, hidden_size, intermediate_size, dtype=torch.bfloat16, device=device) * 0.1 - return SimpleNamespace( - num_experts=num_experts, - has_gate=True, - has_bias=False, - is_transposed=False, - config=SimpleNamespace(hidden_act="silu"), - gate_up_proj=gate_up, - down_proj=down, - _apply_gate=lambda x: F.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], - act_fn=F.silu, - ) - - -def _make_fp4_experts(num_experts: int, hidden_size: int, intermediate_size: int, device: torch.device) -> SimpleNamespace: - """Synthetic FP4 experts (`int8`-packed e2m1, K dim halved; UE8M0 SF, gran_k=32).""" - - def _alloc(e: int, n: int, k: int) -> tuple[torch.Tensor, torch.Tensor]: - # Any int8 byte pattern is a valid FP4-packed (2 e2m1 nibbles per byte). - w = torch.randint(low=-128, high=128, size=(e, n, k // 2), dtype=torch.int8, device=device) - # Random positive scales → round to UE8M0 (any e8m0 byte is a power-of-2 or special). - sf = (torch.rand(e, n, k // 32, device=device) * 0.05 + 0.001).to(torch.float32) - sf = _round_to_ue8m0(sf) - return w, sf - - gate_up, gate_up_sf = _alloc(num_experts, 2 * intermediate_size, hidden_size) - down, down_sf = _alloc(num_experts, hidden_size, intermediate_size) - return SimpleNamespace( - num_experts=num_experts, - has_gate=True, - has_bias=False, - is_transposed=False, - block_size=None, # FP4 ignores block_size — kernel infers SF granularity from dtype. - activation_scheme="dynamic", - config=SimpleNamespace(hidden_act="silu"), - gate_up_proj=gate_up, - gate_up_proj_scale_inv=gate_up_sf, - down_proj=down, - down_proj_scale_inv=down_sf, - _apply_gate=lambda x: F.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], - act_fn=F.silu, - ) - - -def _random_routing(num_tokens: int, top_k: int, num_experts: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: - idx = torch.randint(0, num_experts, (num_tokens, top_k), dtype=torch.int32, device=device) - w = torch.rand(num_tokens, top_k, dtype=torch.float32, device=device) - return idx, w / w.sum(dim=-1, keepdim=True).clamp_min(1e-6) - - -def _check_output(out: torch.Tensor, expected_shape: tuple[int, ...], label: str) -> None: - assert out.shape == expected_shape, f"[{label}] shape mismatch: {tuple(out.shape)} vs {expected_shape}" - assert torch.isfinite(out).all(), f"[{label}] output has non-finite values" - print(f"[{label}] PASS out: {tuple(out.shape)} dtype={out.dtype}") - - -# ── Tests ──────────────────────────────────────────────────────────────────────── - - -def test_bf16(device: torch.device) -> None: - label = "BF16 experts (no quant)" - if torch.cuda.get_device_capability(device)[0] < 9: - print(f"[{label}] SKIP: needs SM90+ (Hopper)") - return - T, H, I, E, K = 256, 1024, 512, 16, 4 - experts = _make_bf16_experts(E, H, I, device) - hidden = torch.randn(T, H, dtype=torch.bfloat16, device=device) * 0.1 - idx, w = _random_routing(T, K, E, device) - out = deepgemm_bf16_experts_forward(experts, hidden, idx, w.to(torch.bfloat16)) - _check_output(out, (T, H), label) - - -def test_dsv3_fp8(device: torch.device) -> None: - label = "DSv3 (FP8 + float SF)" - if torch.cuda.get_device_capability(device)[0] < 9: - print(f"[{label}] SKIP: needs SM90+ (Hopper)") - return - T, H, I, E, K = 256, 1024, 512, 16, 4 - experts = _make_fp8_experts(E, H, I, ue8m0_sf=False, device=device) - hidden = torch.randn(T, H, dtype=torch.bfloat16, device=device) * 0.1 - idx, w = _random_routing(T, K, E, device) - out = deepgemm_fp8_fp4_experts_forward(experts, hidden, idx, w.to(torch.bfloat16)) - _check_output(out, (T, H), label) - - -def test_dsv4_fp8(device: torch.device) -> None: - label = "DSv4-FP8 (FP8 + UE8M0 SF)" - if torch.cuda.get_device_capability(device)[0] < 10: - print(f"[{label}] SKIP: needs SM100+ (Blackwell) for UE8M0 SF dispatch") - return - T, H, I, E, K = 256, 1024, 512, 16, 4 - experts = _make_fp8_experts(E, H, I, ue8m0_sf=True, device=device) - hidden = torch.randn(T, H, dtype=torch.bfloat16, device=device) * 0.1 - idx, w = _random_routing(T, K, E, device) - out = deepgemm_fp8_fp4_experts_forward(experts, hidden, idx, w.to(torch.bfloat16)) - _check_output(out, (T, H), label) - - -def test_dsv4_fp4(device: torch.device) -> None: - label = "DSv4 (FP4 + UE8M0 SF)" - if torch.cuda.get_device_capability(device)[0] < 10: - print(f"[{label}] SKIP: needs SM100+ (Blackwell)") - return - T, H, I, E, K = 256, 1024, 512, 16, 4 - experts = _make_fp4_experts(E, H, I, device) - hidden = torch.randn(T, H, dtype=torch.bfloat16, device=device) * 0.1 - idx, w = _random_routing(T, K, E, device) - out = deepgemm_fp8_fp4_experts_forward(experts, hidden, idx, w.to(torch.bfloat16)) - _check_output(out, (T, H), label) - - -def test_megamoe(device: torch.device, world_size: int, rank: int) -> None: - label = "Mega MoE (FP8 act × FP4 weight, fused EP)" - if torch.cuda.get_device_capability(device)[0] < 10: - if rank == 0: - print(f"[{label}] SKIP: needs SM100+ (Blackwell)") - return - if world_size < 2: - if rank == 0: - print(f"[{label}] SKIP: needs >=2 ranks (run with `torchrun --nproc_per_node=2`)") - return - - deepgemm = _load_deepgemm_kernel() - T_local, H, I, K = 64, 1024, 512, 4 - E_global = 16 - E_local = E_global // world_size - - # Build raw FP4 experts on this rank's slice, then transform to the kernel's layout. - raw = _make_fp4_experts(E_local, H, I, device) - # Mega MoE requires SFs already packed as int32 UE8M0 (it transposes them for UTCCP). - gate_up_sf_packed = deepgemm.transform_sf_into_required_layout( - raw.gate_up_proj_scale_inv.float(), 2 * I, H, recipe=(1, 32), num_groups=E_local - ) - down_sf_packed = deepgemm.transform_sf_into_required_layout( - raw.down_proj_scale_inv.float(), H, I, recipe=(1, 32), num_groups=E_local - ) - (gate_up_t, gate_up_sf_t), (down_t, down_sf_t) = deepgemm.transform_weights_for_mega_moe( - (raw.gate_up_proj, gate_up_sf_packed), - (raw.down_proj, down_sf_packed), - ) - - experts = SimpleNamespace( - gate_up_proj=gate_up_t, - gate_up_proj_scale_inv=gate_up_sf_t, - down_proj=down_t, - down_proj_scale_inv=down_sf_t, - symm_buffer=None, # lazily allocated on first call - config=SimpleNamespace(), # no swiglu_limit → kernel runs unclamped - ) - - hidden = torch.randn(T_local, H, dtype=torch.bfloat16, device=device) * 0.1 - # Mega MoE expects GLOBAL expert ids (no per-rank remap); -1 marks skipped slots. - idx = torch.randint(0, E_global, (T_local, K), dtype=torch.int32, device=device) - w = torch.rand(T_local, K, dtype=torch.float32, device=device) - w = w / w.sum(dim=-1, keepdim=True).clamp_min(1e-6) - - out = deepgemm_fp8_fp4_megamoe_experts_forward( - experts, hidden, idx, w.to(torch.bfloat16), process_group=dist.group.WORLD - ) - if rank == 0: - _check_output(out, (T_local, H), label) - - -# ── Entrypoint ─────────────────────────────────────────────────────────────────── - - -def main() -> None: - if not torch.cuda.is_available(): - raise SystemExit("CUDA required.") - - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) - - if world_size > 1 and not dist.is_initialized(): - dist.init_process_group("nccl") - - if rank == 0: - print(f"device cap: SM{''.join(str(x) for x in torch.cuda.get_device_capability(device))}, " - f"world_size={world_size}\n") - - # Single-GPU paths run on rank 0 only (ranks > 0 only participate in Mega MoE). - failures: list[tuple[str, BaseException]] = [] - if rank == 0: - for fn in (test_bf16, test_dsv3_fp8, test_dsv4_fp8, test_dsv4_fp4): - try: - fn(device) - except BaseException as exc: - failures.append((fn.__name__, exc)) - print(f"[{fn.__name__}] FAIL — {type(exc).__name__}: {exc}") - - if world_size > 1: - dist.barrier() - try: - test_megamoe(device, world_size, rank) - except BaseException as exc: - if rank == 0: - failures.append(("test_megamoe", exc)) - print(f"[test_megamoe] FAIL — {type(exc).__name__}: {exc}") - dist.destroy_process_group() - - if rank == 0: - if failures: - print(f"\n=== {len(failures)} test(s) failed ===") - for name, exc in failures: - print(f" - {name}: {type(exc).__name__}: {exc}") - raise SystemExit(1) - print("\n=== all tests passed ===") - - -if __name__ == "__main__": - main()