diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py new file mode 100644 index 000000000000..617b0c9975de --- /dev/null +++ b/src/transformers/integrations/deepgemm.py @@ -0,0 +1,566 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepGEMM integration: fused grouped GEMM kernels from `kernels-community/deep-gemm`. + +Provides: +- `deepgemm_bf16_experts_forward`: BF16 M-grouped experts forward. +- `deepgemm_fp8_fp4_linear`: end-to-end FP8/FP4 linear (BF16 in, BF16 out). +- `deepgemm_fp8_fp4_experts_forward`: FP8 (or FP4 on SM100+) M-grouped experts forward. +- `deepgemm_fp8_fp4_megamoe_experts_forward`: FP8×FP4 Mega MoE forward (SM100+). + +Requirements: CUDA, Hopper (SM90+), CUDA runtime ≥ 12.3, kernels-community/deep-gemm +≥ 2.5 (Mega MoE symbols required). Mega MoE additionally needs SM100+ at call time. +""" + +from __future__ import annotations + +import functools +from collections.abc import Callable +from dataclasses import dataclass + +import torch + +from ..utils import logging +from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import +from .hub_kernels import lazy_load_kernel + + +logger = logging.get_logger(__name__) + +# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. +_DEEPGEMM_M_ALIGNMENT = 128 +_FP8_DTYPE = torch.float8_e4m3fn +_FP8_MAX = torch.finfo(_FP8_DTYPE).max + + +# ── Kernel loading ───────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class DeepGEMM: + """Curated entry points exposed by `kernels-community/deep-gemm`.""" + + fp8_fp4_matmul: Callable + grouped_fp8_fp4_matmul: Callable + grouped_bf16_matmul_nt: Callable + grouped_bf16_matmul_nn: Callable + per_token_cast_to_fp8: Callable + get_mn_major_tma_aligned_packed_ue8m0_tensor: Callable + transform_sf_into_required_layout: Callable + transform_weights_for_mega_moe: Callable + get_symm_buffer_for_mega_moe: Callable + fp8_fp4_mega_moe: Callable + + +@functools.cache +def _load_deepgemm_kernel() -> DeepGEMM: + """Load DeepGEMM once; raise `ImportError` if env or any required symbol is missing.""" + if not is_kernels_available(): + raise ImportError("DeepGEMM kernel requires the `kernels` package. Install it with `pip install -U kernels`.") + if not torch.cuda.is_available(): + raise ImportError("DeepGEMM kernel requires CUDA, but CUDA is not available.") + + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError(f"DeepGEMM requires Hopper (SM90+); current device is SM{major}0.") + + cuda_major, cuda_minor = get_cuda_runtime_version() + if (cuda_major, cuda_minor) < (12, 3): + raise ImportError(f"DeepGEMM requires CUDA runtime ≥ 12.3, found {cuda_major}.{cuda_minor}.") + + kernel = lazy_load_kernel("deep-gemm") + if kernel is None: + raise ImportError( + "Failed to load `kernels-community/deep-gemm` — check that a build matches the current torch/CUDA." + ) + + fp8_fp4_matmul = getattr(kernel, "fp8_fp4_gemm_nt", None) + grouped_fp8_fp4_matmul = getattr(kernel, "m_grouped_fp8_fp4_gemm_nt_contiguous", None) + grouped_bf16_matmul_nt = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) + grouped_bf16_matmul_nn = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) + per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + get_mn_major_tma_aligned_packed_ue8m0_tensor = getattr( + kernel, "get_mn_major_tma_aligned_packed_ue8m0_tensor", None + ) + transform_sf_into_required_layout = getattr(kernel, "transform_sf_into_required_layout", None) + transform_weights_for_mega_moe = getattr(kernel, "transform_weights_for_mega_moe", None) + get_symm_buffer_for_mega_moe = getattr(kernel, "get_symm_buffer_for_mega_moe", None) + fp8_fp4_mega_moe = getattr(kernel, "fp8_fp4_mega_moe", None) + + missing = [ + name + for name, attr in [ + ("fp8_fp4_gemm_nt", fp8_fp4_matmul), + ("m_grouped_fp8_fp4_gemm_nt_contiguous", grouped_fp8_fp4_matmul), + ("m_grouped_bf16_gemm_nt_contiguous", grouped_bf16_matmul_nt), + ("m_grouped_bf16_gemm_nn_contiguous", grouped_bf16_matmul_nn), + ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), + ("get_mn_major_tma_aligned_packed_ue8m0_tensor", get_mn_major_tma_aligned_packed_ue8m0_tensor), + ("transform_sf_into_required_layout", transform_sf_into_required_layout), + ("transform_weights_for_mega_moe", transform_weights_for_mega_moe), + ("get_symm_buffer_for_mega_moe", get_symm_buffer_for_mega_moe), + ("fp8_fp4_mega_moe", fp8_fp4_mega_moe), + ] + if attr is None + ] + if missing: + raise ImportError( + f"DeepGEMM kernel is missing required symbols: {', '.join(missing)}. Update with `pip install -U kernels`." + ) + return DeepGEMM( + fp8_fp4_matmul=fp8_fp4_matmul, + grouped_fp8_fp4_matmul=grouped_fp8_fp4_matmul, + grouped_bf16_matmul_nt=grouped_bf16_matmul_nt, + grouped_bf16_matmul_nn=grouped_bf16_matmul_nn, + per_token_cast_to_fp8=per_token_cast_to_fp8, + get_mn_major_tma_aligned_packed_ue8m0_tensor=get_mn_major_tma_aligned_packed_ue8m0_tensor, + transform_sf_into_required_layout=transform_sf_into_required_layout, + transform_weights_for_mega_moe=transform_weights_for_mega_moe, + get_symm_buffer_for_mega_moe=get_symm_buffer_for_mega_moe, + fp8_fp4_mega_moe=fp8_fp4_mega_moe, + ) + + +# ── Scale-factor helpers ─────────────────────────────────────────────────────── + + +def _ceil_to_ue8m0(sf: torch.Tensor) -> torch.Tensor: + """Round each fp32 SF up to the nearest power of 2 (zero mantissa). + + Mirrors `deep_gemm.utils.math.ceil_to_ue8m0`. On SM100 the kernel's + `pack_fp32_into_ue8m0` cleanly extracts the biased exponent only when the + mantissa is already zero — its inner shifts (`>> 15`, `>> 7`, `<< 1`) + otherwise leak mantissa bits into adjacent UE8M0 byte slots and silently + corrupt the SF. SM90 consumes raw fp32 SFs without going through this path. + """ + int_view = sf.view(torch.int32) + return (int_view + ((1 << 23) - 1)).bitwise_and_(~((1 << 23) - 1)).view(torch.float) + + +def _coerce_sf_for_kernel(sf: torch.Tensor) -> torch.Tensor: + """Lay out `sf` as DeepGEMM's `check_sf_layout` expects: MN-major + (`stride(-2) == 1`) and TMA-aligned (`stride(-1) == align(mn, 16/esize)`). + + Inputs come in three flavors: + - `float8_e8m0fnu`: raw UE8M0 bytes — pack 4 K-bytes → int32 (last dim /4). + - `float32`: per-token / per-block SFs from `per_token_cast_to_fp8` or + on-disk weights — round to UE8M0 on SM100 (see `_ceil_to_ue8m0`). + - `int32`: already-packed UE8M0 — pass through. + """ + if sf.dtype == torch.float8_e8m0fnu: + sf = sf.contiguous().view(torch.int32) + elif sf.dtype == torch.float32 and torch.cuda.get_device_capability(sf.device)[0] >= 10: + sf = _ceil_to_ue8m0(sf) + + if sf.dim() not in (2, 3): + raise ValueError(f"DeepGEMM SF must be 2D or 3D, got {sf.dim()}D") + + mn, kf = sf.size(-2), sf.size(-1) + align_to = 16 // sf.element_size() # `get_tma_aligned_size`: align(mn, 16 / element_size) + aligned_mn = -(-mn // align_to) * align_to + target_strides = (1, aligned_mn) if sf.dim() == 2 else (kf * aligned_mn, 1, aligned_mn) + + if tuple(sf.stride()) == target_strides: + return sf + out = torch.empty_strided(sf.shape, target_strides, dtype=sf.dtype, device=sf.device) + out.copy_(sf) + return out + + +def _select_fp8_cast_kwargs( + weight: torch.Tensor, weight_scale_inv: torch.Tensor, block_size: tuple | None, device: torch.device +) -> dict: + """Pick the `per_token_cast_to_fp8` kwargs from weight dtype + SF dtype. + + Three cases mirror the kernel's recipes: + - FP4 weights (`int8`): gran_k=32 packed-UE8M0 SF. SM100+ only. + - FP8 weights + UE8M0 SF: gran_k=128 packed-UE8M0 SF (DSv4). + - FP8 weights + float SF: gran_k=128 float SF (DSv3). + """ + if weight.dtype == torch.int8: # FP4 + if torch.cuda.get_device_capability(device)[0] < 10: + raise RuntimeError("FP4 weights (int8-packed e2m1) require SM100+ (Blackwell).") + return {"use_ue8m0": True, "gran_k": 32, "use_packed_ue8m0": True} + # FP8 weights: validate block_size (informational; kernel infers recipe from SF dtype/shape). + if block_size is None: + raise ValueError( + "DeepGEMM requires block-wise quantized FP8 weights, but the experts have no `block_size` set." + ) + if block_size not in ((128, 128), (1, 128)): + raise ValueError(f"DeepGEMM requires `block_size` ∈ {{(128, 128), (1, 128)}}, got {block_size}.") + if weight_scale_inv.dtype == torch.float8_e8m0fnu: + return {"use_ue8m0": True, "gran_k": 128, "use_packed_ue8m0": True} + return {"use_ue8m0": False, "gran_k": 128} + + +# ── Layout helpers (M-grouped contiguous, TMA-aligned) ───────────────────────── + + +def _build_deepgemm_contiguous_layout( + expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Build the TMA-aligned grouped layout DeepGEMM expects. + + Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`: + - `grouped_layout` is per-row expert id (Hopper, with `-1` for padding / + sentinels) or a cumsum of aligned per-expert counts (Blackwell). + - EP sentinels (values == `num_experts`) are routed past the last expert + block so DeepGEMM skips them. + """ + device = expert_ids_sorted.device + num_tokens = expert_ids_sorted.size(0) + # `histc` drops values > max, so EP sentinels (== num_experts) don't count. + tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() + aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment + # Upper bound — avoids GPU→CPU sync; padding rows are skipped. + total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + + # Exclusive cumsum of per-expert padding (index `num_experts` = total padding, + # which routes EP sentinels past all aligned blocks on Blackwell). + padding_per_expert = aligned_tokens_per_expert - tokens_per_expert + cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) + sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] + + if use_psum_layout: # SM100+: kernel reads cumsum of aligned counts as expert boundaries. + grouped_layout = aligned_tokens_per_expert.cumsum(0).int() + else: # SM90: per-row expert id, -1 = skip (padding & sentinels). + grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) + grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) + + return sorted_to_padded, grouped_layout, total_padded_rows + + +def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: + """Pad a sorted tensor into the TMA-aligned contiguous layout.""" + padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + padded[sorted_to_padded] = x + return padded + + +def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: + return x_padded[sorted_to_padded] + + +# ── Routing helpers (sort → matmul → restore) ───────────────────────────────── + + +def _dispatch_routed_input( + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + num_experts: int, + use_psum_layout: bool, +) -> tuple: + """Sort tokens by expert id and build the M-grouped padded layout. + + Returns `(sorted_hidden_states_g, sample_weights_g, expert_ids_g, + sentinel_mask, perm, sorted_to_padded, grouped_layout, + total_padded_rows)`. + """ + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + num_top_k = top_k_index.size(-1) + expert_ids = top_k_index.reshape(-1) # (S,) + sample_weights = top_k_weights.reshape(-1) # (S,) + + # Sort by expert for grouped processing + expert_ids_g, perm = torch.sort(expert_ids) + sorted_hidden_states_g = hidden_states[perm // num_top_k] + sample_weights_g = sample_weights[perm] + + # Build the M-grouped padded layout (DeepGEMM contract: each expert's rows + # start on a `_DEEPGEMM_M_ALIGNMENT` boundary, sentinels routed past valid + # expert blocks). + sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( + expert_ids_g, num_experts, _DEEPGEMM_M_ALIGNMENT, use_psum_layout + ) + + # EP sentinel mask is captured before the in-place clamp; used by the post-mask in + # `_combine_routed_output` to zero sentinel rows before the per-token reduction. The clamp + # keeps any per-row gather (e.g. bias) in-bounds — bias added at sentinel positions falls + # in rows the kernel skips, so harmless. Safe to mutate now: the layout was built from the + # unclamped tensor and nothing downstream needs the sentinel info from `expert_ids_g` itself. + sentinel_mask = (expert_ids_g >= num_experts).unsqueeze(-1) + expert_ids_g.clamp_(max=num_experts - 1) + return ( + sorted_hidden_states_g, sample_weights_g, expert_ids_g, sentinel_mask, perm, + sorted_to_padded, grouped_layout, total_padded_rows, + ) + + +def _combine_routed_output( + out_padded: torch.Tensor, + sorted_weights: torch.Tensor, + sentinel_mask: torch.Tensor, + perm: torch.Tensor, + sorted_to_padded: torch.Tensor, + num_tokens: int, + num_top_k: int, + hidden_dim: int, + out_dtype: torch.dtype, +) -> torch.Tensor: + """Unpad → weighted multiply → mask sentinels → restore order → top-k reduce.""" + out = _unpad_from_deepgemm_contiguous_layout(out_padded, sorted_to_padded) + weighted = out * sorted_weights.to(out.dtype).unsqueeze(-1) + # Sentinel rows past the valid expert blocks may carry NaN from allocator + # reuse (`0 * NaN = NaN`); zero them so the top-k reduction stays finite. + weighted.masked_fill_(sentinel_mask, 0.0) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=out.device) + # Deterministic reshape+sum (index_add_ with duplicates is non-deterministic on CUDA). + return weighted[inv_perm].view(num_tokens, num_top_k, hidden_dim).sum(dim=1).to(out_dtype) + + +# ── Public dispatches ────────────────────────────────────────────────────────── + + +def deepgemm_fp8_fp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale_inv: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, + activation_scale: torch.Tensor | None = None, +) -> torch.Tensor: + """End-to-end DeepGEMM linear: per-token activation quant + FP8/FP4 matmul. + + Static (per-tensor) activation quantization is rejected — DeepGEMM needs + per-row SFs. Callers should route static activations through the Triton fallback. + """ + if activation_scale is not None: + raise NotImplementedError("Static activation quantization is not supported on the DeepGEMM path.") + + deepgemm = _load_deepgemm_kernel() + cast_kwargs = _select_fp8_cast_kwargs(weight, weight_scale_inv, block_size=(128, 128), device=input.device) + + input_2d = input.view(-1, input.shape[-1]) + qinput_2d, scale_2d = deepgemm.per_token_cast_to_fp8(input_2d, **cast_kwargs) + output = torch.empty(qinput_2d.shape[0], weight.shape[0], device=input.device, dtype=output_dtype) + + # Pass `(1, 1, gran_k)` for int-SF paths so the kernel uses the right K granularity + # (the default `(1, 1, 128)` mismatches FP4's gran_k=32). Float-SF leaves it None. + sf_recipe = (1, 1, cast_kwargs["gran_k"]) if cast_kwargs.get("use_packed_ue8m0") else None + deepgemm.fp8_fp4_matmul( + (qinput_2d, _coerce_sf_for_kernel(scale_2d)), + (weight, _coerce_sf_for_kernel(weight_scale_inv)), + output, + recipe=sf_recipe, + ) + output = output.view(input.shape[:-1] + (weight.shape[0],)) + if bias is not None: + output.add_(bias) + return output + + +def deepgemm_bf16_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if hidden_states.dtype != torch.bfloat16: + raise ValueError(f"DeepGEMM experts path requires bfloat16 hidden states, got {hidden_states.dtype}") + + deepgemm = _load_deepgemm_kernel() + # Non-transposed weights (E, N, K) → NT kernel; transposed (E, K, N) → NN kernel. + grouped_bf16_matmul = deepgemm.grouped_bf16_matmul_nn if self.is_transposed else deepgemm.grouped_bf16_matmul_nt + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens, hidden_dim = hidden_states.size(0), hidden_states.size(-1) + + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + ( + sorted_hidden, sorted_weights, expert_ids_g, sentinel_mask, perm, + sorted_to_padded, grouped_layout, total_padded_rows, + ) = _dispatch_routed_input(hidden_states, top_k_index, top_k_weights, self.num_experts, use_psum_layout) + + # Up projection. + w_up = self.gate_up_proj if self.has_gate else self.up_proj + up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] + act = _pad_for_deepgemm(sorted_hidden, sorted_to_padded, total_padded_rows) + proj_out = torch.empty(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) + grouped_bf16_matmul(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) + if self.has_bias: + up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias + proj_out.index_add_(0, sorted_to_padded, up_bias[expert_ids_g]) + + proj_out = self._apply_gate(proj_out) if self.has_gate else self.act_fn(proj_out) + + # Down projection. + out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) + grouped_bf16_matmul(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) + if self.has_bias: + out.index_add_(0, sorted_to_padded, self.down_proj_bias[expert_ids_g]) + + return _combine_routed_output( + out, + sorted_weights, + sentinel_mask, + perm, + sorted_to_padded, + num_tokens, + num_top_k, + hidden_dim, + hidden_states.dtype, + ) + + +def deepgemm_fp8_fp4_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if self.activation_scheme == "static": + raise NotImplementedError( + "DeepGEMM experts dispatch does not support activation_scheme='static'. Use 'dynamic'." + ) + + deepgemm = _load_deepgemm_kernel() + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens, hidden_dim = hidden_states.size(0), hidden_states.size(-1) + + w_up = self.gate_up_proj if self.has_gate else self.up_proj + ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv + cast_kwargs = _select_fp8_cast_kwargs(w_up, ws_up, getattr(self, "block_size", None), device) + + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + ( + sorted_hidden, sorted_weights, _expert_ids_g, sentinel_mask, perm, + sorted_to_padded, grouped_layout, total_padded_rows, + ) = _dispatch_routed_input(hidden_states, top_k_index, top_k_weights, self.num_experts, use_psum_layout) + sf_recipe = (1, 1, cast_kwargs["gran_k"]) if cast_kwargs.get("use_packed_ue8m0") else None + + # Up projection. + act_fp8, act_scales = deepgemm.per_token_cast_to_fp8(sorted_hidden, **cast_kwargs) + act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) + act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) + proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) + a_sf = _coerce_sf_for_kernel(act_scales) + b_sf = _coerce_sf_for_kernel(ws_up) + import os as _os + if _os.environ.get("RANK", "0") == "0" and not getattr(self, "_dg_shape_logged", False): + print( + f"[deepgemm.fp8_fp4] recipe={sf_recipe} use_psum={use_psum_layout}\n" + f" a_fp8.shape={tuple(act_fp8.shape)} dtype={act_fp8.dtype}\n" + f" a_sf.shape={tuple(a_sf.shape)} stride={tuple(a_sf.stride())} dtype={a_sf.dtype}\n" + f" b.shape={tuple(w_up.shape)} dtype={w_up.dtype}\n" + f" b_sf.shape={tuple(b_sf.shape)} stride={tuple(b_sf.stride())} dtype={b_sf.dtype}\n" + f" raw ws_up.shape={tuple(ws_up.shape)} dtype={ws_up.dtype}", + flush=True, + ) + self._dg_shape_logged = True + deepgemm.grouped_fp8_fp4_matmul( + (act_fp8, a_sf), + (w_up, b_sf), + proj_out, + grouped_layout, + recipe=sf_recipe, + use_psum_layout=use_psum_layout, + ) + proj_out = self._apply_gate(proj_out) if self.has_gate else self.act_fn(proj_out) + + # Down projection. + proj_fp8, proj_scales = deepgemm.per_token_cast_to_fp8(proj_out, **cast_kwargs) + out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + deepgemm.grouped_fp8_fp4_matmul( + (proj_fp8, _coerce_sf_for_kernel(proj_scales)), + (self.down_proj, _coerce_sf_for_kernel(self.down_proj_scale_inv)), + out, + grouped_layout, + recipe=sf_recipe, + use_psum_layout=use_psum_layout, + ) + + return _combine_routed_output( + out, + sorted_weights, + sentinel_mask, + perm, + sorted_to_padded, + num_tokens, + num_top_k, + hidden_dim, + hidden_states.dtype, + ) + + +def deepgemm_fp8_fp4_megamoe_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + process_group: torch.distributed.ProcessGroup | None = None, +) -> torch.Tensor: + """FP8 acts × FP4 weights Mega MoE forward (SM100+). + + Fuses EP dispatch + L1 + SwiGLU + L2 + EP combine into one kernel, + overlapping NVLink with tensor-core compute. The kernel handles the full + `(num_tokens, hidden) → (num_tokens, hidden)` MoE forward including the + weighted top-k reduction; the caller must NOT all-reduce the output. + + `process_group` is supplied automatically by `MoeTensorParalellExperts._prepare_input_fn` + when the module is wrapped for TP — it's required for the symm-buffer rendezvous + on first forward. `top_k_index` is GLOBAL expert ids (`-1` marks skipped slots). + + Caller-managed `self` attributes: + - `gate_up_proj`, `gate_up_proj_scale_inv`: L1 weight + UE8M0 SF. + - `down_proj`, `down_proj_scale_inv`: L2 weight + UE8M0 SF. + Both pairs must be transformed together via + `transform_weights_for_mega_moe((gate_up, gate_up_sf), (down, down_sf))`. + - `config.swiglu_limit` (optional): SwiGLU clamp; absent → unclamped. + """ + if torch.cuda.get_device_capability(hidden_states.device)[0] < 10: + raise RuntimeError("DeepGEMM Mega MoE requires SM100+ (Blackwell). Use the 'deepgemm' dispatch on Hopper.") + + deepgemm = _load_deepgemm_kernel() + num_tokens, hidden_dim = hidden_states.size(0), hidden_states.size(-1) + num_top_k = top_k_index.size(-1) + num_experts = self.gate_up_proj.size(0) + intermediate_hidden = self.gate_up_proj.size(1) // 2 + + # Lazily allocate the symmetric buffer (re-allocate if cached one is too small). + if getattr(self, "symm_buffer", None) is None or self.symm_buffer.num_max_tokens_per_rank < num_tokens: + if process_group is None: + raise ValueError( + "DeepGEMM Mega MoE requires a `process_group` for the EP group. The TP wrapping " + "(MoeTensorParalellExperts) supplies it automatically; pass it explicitly otherwise." + ) + # `gate_up_proj.size(0)` is per-rank after sharding; the buffer needs the GLOBAL count. + self.symm_buffer = deepgemm.get_symm_buffer_for_mega_moe( + process_group, + hidden=hidden_dim, + num_topk=num_top_k, + num_experts=num_experts * process_group.size(), + num_max_tokens_per_rank=num_tokens, + intermediate_hidden=intermediate_hidden, + ) + + x_fp8, x_sf = deepgemm.per_token_cast_to_fp8(hidden_states, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + self.symm_buffer.x[:num_tokens].copy_(x_fp8) + self.symm_buffer.x_sf[:num_tokens].copy_(x_sf) + self.symm_buffer.topk_idx[:num_tokens].copy_(top_k_index) + self.symm_buffer.topk_weights[:num_tokens].copy_(top_k_weights) + + y = torch.empty((num_tokens, hidden_dim), dtype=torch.bfloat16, device=hidden_states.device) + deepgemm.fp8_fp4_mega_moe( + y, + (self.gate_up_proj, self.gate_up_proj_scale_inv), + (self.down_proj, self.down_proj_scale_inv), + self.symm_buffer, + activation_clamp=getattr(getattr(self, "config", None), "swiglu_limit", None), + ) + return y.to(hidden_states.dtype) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index d77e941b816d..df2f979dc898 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -25,7 +25,12 @@ from ..core_model_loading import ConversionOps from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging -from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import +from ..utils.import_utils import is_kernels_available +from .deepgemm import ( + deepgemm_fp8_fp4_experts_forward, + deepgemm_fp8_fp4_linear, + deepgemm_fp8_fp4_megamoe_experts_forward, +) from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation @@ -36,12 +41,7 @@ _FP8_DTYPE = torch.float8_e4m3fn _FP8_MIN = torch.finfo(_FP8_DTYPE).min _FP8_MAX = torch.finfo(_FP8_DTYPE).max - - -# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. -# TMA is an H100 hardware addition that allows applications to asynchronously and -# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. -_DEEPGEMM_M_ALIGNMENT = 128 +_UE8M0_SF_DTYPE = torch.float8_e8m0fnu def _first_attr(obj, *names): @@ -55,10 +55,10 @@ def _first_attr(obj, *names): class FineGrainedFP8: """Entry points exposed by the `kernels-community/finegrained-fp8` Triton kernel.""" - fp8_matmul: Callable - fp8_act_quant: Callable - batched_fp8_matmul: Callable - grouped_fp8_matmul: Callable + matmul: Callable + act_quant: Callable + batched_matmul: Callable + grouped_matmul: Callable @functools.cache @@ -81,18 +81,18 @@ def _load_finegrained_fp8_kernel() -> FineGrainedFP8: "has a build matching the current torch/CUDA." ) - fp8_matmul = getattr(kernel, "w8a8_fp8_matmul", None) - fp8_act_quant = getattr(kernel, "fp8_act_quant", None) - batched_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_batched", None) - grouped_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_grouped", None) + matmul = getattr(kernel, "w8a8_fp8_matmul", None) + act_quant = getattr(kernel, "fp8_act_quant", None) + batched_matmul = getattr(kernel, "w8a8_fp8_matmul_batched", None) + grouped_matmul = getattr(kernel, "w8a8_fp8_matmul_grouped", None) missing = [ name for name, attr in [ - ("w8a8_fp8_matmul", fp8_matmul), - ("fp8_act_quant", fp8_act_quant), - ("w8a8_fp8_matmul_batched", batched_fp8_matmul), - ("w8a8_fp8_matmul_grouped", grouped_fp8_matmul), + ("w8a8_fp8_matmul", matmul), + ("fp8_act_quant", act_quant), + ("w8a8_fp8_matmul_batched", batched_matmul), + ("w8a8_fp8_matmul_grouped", grouped_matmul), ] if attr is None ] @@ -103,137 +103,142 @@ def _load_finegrained_fp8_kernel() -> FineGrainedFP8: ) return FineGrainedFP8( - fp8_matmul=fp8_matmul, - fp8_act_quant=fp8_act_quant, - batched_fp8_matmul=batched_fp8_matmul, - grouped_fp8_matmul=grouped_fp8_matmul, + matmul=matmul, + act_quant=act_quant, + batched_matmul=batched_matmul, + grouped_matmul=grouped_matmul, ) -@dataclass(frozen=True) -class DeepGEMM: - """Entry points exposed by the `kernels-community/deep-gemm` kernel.""" - - fp8_matmul: Callable - grouped_fp8_matmul: Callable - per_token_cast_to_fp8: Callable +def _cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return (a + b - 1) // b -@functools.cache -def _load_deepgemm_kernel() -> DeepGEMM: +def _alloc_expert_proj( + num_experts: int, + proj_out: int, + proj_in: int, + weight_dtype: torch.dtype, + sf_dtype: torch.dtype, + weight_k_div: int = 1, + sf_gran_n: int | None = None, + sf_gran_k: int | None = None, +) -> tuple[nn.Parameter, nn.Parameter]: + """Allocate `(weight, weight_scale_inv)` parameters for one expert projection. + + `weight_k_div` halves the K dim for FP4-packed storage (2 e2m1 values per byte). + `sf_gran_n` / `sf_gran_k` set per-block (None → per-row/per-tensor) SF granularity. """ - Load DeepGEMM once and return its entry points. + weight_t = torch.empty(num_experts, proj_out, proj_in // weight_k_div, dtype=weight_dtype) + weight = nn.Parameter(weight_t, requires_grad=weight_t.is_floating_point()) + sf_out = _cdiv(proj_out, sf_gran_n) if sf_gran_n is not None else 1 + sf_in = _cdiv(proj_in, sf_gran_k) if sf_gran_k is not None else 1 + sf_t = torch.empty(num_experts, sf_out, sf_in, dtype=sf_dtype) + sf = nn.Parameter(sf_t, requires_grad=sf_t.is_floating_point()) + return weight, sf + + +def finegrained_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale_inv: torch.Tensor, + block_size: list[int] | None = None, + bias: torch.Tensor | None = None, + activation_scale: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """End-to-end Triton FP8 linear: per-token (or static per-tensor) act-quant + matmul + bias. - Raises `ImportError` if CUDA/hardware requirements are not met, or the kernel or - required symbols are not found. + Triton has no FP4 path — caller must guard FP4 weights before reaching here. """ - if not is_kernels_available(): - raise ImportError("DeepGEMM kernel requires the `kernels` package. Install it with `pip install -U kernels`.") - - if not torch.cuda.is_available(): - raise ImportError( - "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." - ) - - # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions - major = torch.cuda.get_device_capability()[0] - if major < 9: - raise ImportError( - f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " - f"has compute capability {major}.x. Use a different `experts_implementation`." - ) - - # DeepGEMM requires CUDA runtime >= 12.3 - cuda_major, cuda_minor = get_cuda_runtime_version() - if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): - raise ImportError( - f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " - "Please upgrade your CUDA toolkit or use a different `experts_implementation`." - ) - - kernel = lazy_load_kernel("deep-gemm") - if kernel is None: - raise ImportError( - "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " - "has a build matching the current torch/CUDA." - ) - - fp8_matmul = getattr(kernel, "fp8_gemm_nt", None) - grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) - per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") - - missing = [ - name - for name, attr in [ - ("fp8_gemm_nt", fp8_matmul), - ("m_grouped_fp8_gemm_nt_contiguous", grouped_fp8_matmul), - ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), - ] - if attr is None - ] - if missing: - raise ImportError( - f"DeepGEMM kernel is missing required symbols: {', '.join(missing)}. " - "Please update the `kernels` package (`pip install -U kernels`)." - ) + finegrained_fp8 = _load_finegrained_fp8_kernel() + if activation_scale is not None: + scale = activation_scale.to(torch.float32) + qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) + else: + gran_k = block_size[1] if block_size is not None else input.shape[-1] + qinput, scale = finegrained_fp8.act_quant(input, gran_k) - return DeepGEMM( - fp8_matmul=fp8_matmul, - grouped_fp8_matmul=grouped_fp8_matmul, - per_token_cast_to_fp8=per_token_cast_to_fp8, - ) + output = finegrained_fp8.matmul(qinput, weight, scale, weight_scale_inv, block_size, output_dtype) + if bias is not None: + output.add_(bias) -def _cdiv(a: int, b: int) -> int: - """Ceiling division.""" - return (a + b - 1) // b + return output -def w8a8_fp8_matmul( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype = torch.float32, +def fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale_inv: torch.Tensor, + block_size: list[int] | None = None, + bias: torch.Tensor | None = None, + activation_scale: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: - """FP8 matmul: C = dequant(A, As) @ dequant(B, Bs)^T. - - Supports both per-tensor and block-wise quantization: - - block_size=None or block_size=[N, K]: per-tensor mode (As is scalar/per-row, Bs is scalar) - - block_size=[block_n, block_k]: block-wise mode (As and Bs are per-block scale grids) + """End-to-end FP8/FP4 linear used by `FP8Linear` and the eager `FP8Experts` loop. Dispatch order: - 1. DeepGEMM (Hopper+, block_size 128x128) if available - 2. Triton finegrained-fp8 kernel (universal fallback) + 1. DeepGEMM full pipeline (`deepgemm_fp8_fp4_linear`) — handles both FP8 (`float8_e4m3fn`) + and FP4 (`int8`-packed e2m1) weights, paired with the matching activation cast inside. + 3-6x faster than Triton on FP8; required for FP4 and for UE8M0 (`float8_e8m0fnu`) SFs. + 2. Triton finegrained-fp8 fallback (FP8 weights + float SFs) — applies on `ImportError` from + the DeepGEMM path or for static activations (DeepGEMM is dynamic-only). Raises if FP4 + weights or UE8M0 SFs reach this branch since Triton can't handle them. Args: - A: (M, K) float8_e4m3fn — quantized activations - B: (N, K) float8_e4m3fn — quantized weights - As: block-wise: (M, K//block_k) float32; per-tensor: (M,) per-row scales - Bs: block-wise: (N//block_n, K//block_k) float32; per-tensor: scalar or (1,) single weight scale - block_size: [block_n, block_k] for block-wise quantization, or None/[N, K] for per-tensor - output_dtype: desired output dtype + input: (..., K) bf16/fp16 activations. + weight: (N, K) `float8_e4m3fn` or (N, K // 2) `int8` (FP4-packed). + weight_scale_inv: per-block weight scales — `float32` (V3-style) or `float8_e8m0fnu` + (V4-style; reinterpreted as int32 at the DeepGEMM kernel boundary). + block_size: [block_n, block_k] for FP8 block-wise quant, or None/[N, K] for per-tensor. + Ignored for FP4 weights (the kernel infers SF granularity from the dtype). + bias: optional bias added to the matmul output. + activation_scale: pass a per-tensor scalar to use static activation quant; leave `None` + for dynamic (per-token) quant. + output_dtype: desired output dtype. """ - if block_size is not None and block_size[0] == block_size[1] == 128: + # Triton handles only FP8 weights + float SFs. FP4 weights and/or UE8M0 SFs (DeepSeek V4) + # must take the DeepGEMM path. Static activation (per-tensor scalar) is Triton-only — DeepGEMM's + # kernel expects per-row SFs and rejects scalar SFs at its host-side check. + deepgemm_required = weight.dtype == torch.int8 or weight_scale_inv.dtype == torch.float8_e8m0fnu + deepgemm_compatible = activation_scale is None and ( + deepgemm_required or (block_size is not None and block_size[0] == block_size[1] == 128) + ) + + if deepgemm_compatible: try: - deepgemm = _load_deepgemm_kernel() + return deepgemm_fp8_fp4_linear( + input, + weight, + weight_scale_inv, + output_dtype=output_dtype, + activation_scale=activation_scale, + bias=bias, + ) except ImportError: logger.warning_once( "DeepGEMM kernel is not available or compatible, falling back to Triton finegrained-fp8 kernel. " "To use DeepGEMM FP8 matmul, ensure you have a Hopper (SM90+) or newer GPU with CUDA runtime 12.3+, " "and that the `kernels` package is installed and up to date (`pip install -U kernels`)." ) - else: - # 3-6x faster than Triton - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - deepgemm.fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) - finegrained_fp8 = _load_finegrained_fp8_kernel() - return finegrained_fp8.fp8_matmul(A, B, As, Bs, block_size, output_dtype) + if deepgemm_required: + if activation_scale is not None: + raise RuntimeError( + "Static (per-tensor) activation quantization is not supported with FP4 weights or " + "UE8M0 weight scales — DeepGEMM expects per-row SFs and the Triton fallback can't " + "handle these formats. Use dynamic activation quantization instead." + ) + raise RuntimeError( + "FP4 weights and/or UE8M0 weight scales require the DeepGEMM path; the Triton fallback " + "handles FP8 weights with float32 SFs only. Make sure your system is compatible with the " + "DeepGEMM path: SM90+ GPU (SM100+ for FP4), CUDA runtime 12.3+, PyTorch ≥2.6, and the " + "`kernels` package installed." + ) + + return finegrained_fp8_linear(input, weight, weight_scale_inv, block_size, bias, activation_scale, output_dtype) class FP8Linear(nn.Linear): @@ -243,6 +248,7 @@ def __init__( out_features: int, block_size: tuple[int, int] | None = None, activation_scheme: str = "dynamic", + scale_fmt: str = "float", has_bias: bool = False, dtype=_FP8_DTYPE, ): @@ -257,11 +263,10 @@ def __init__( # If block size is None, it means that we are doing per-tensor quantization self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) else: + sf_dtype = _UE8M0_SF_DTYPE if scale_fmt == "ue8m0" else torch.float32 scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0] scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1] - self.weight_scale_inv = nn.Parameter( - torch.empty(scale_out_features, scale_in_features, dtype=torch.float32) - ) + self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=sf_dtype)) if self.activation_scheme == "static": self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) @@ -283,31 +288,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: weight = weight.to_local() scale_inv = scale_inv.to_local() - if self.activation_scheme == "dynamic": - finegrained_fp8 = _load_finegrained_fp8_kernel() - qinput, scale = finegrained_fp8.fp8_act_quant( - input, self.block_size[1] if self.block_size is not None else input.shape[-1] - ) - elif self.activation_scheme == "static": - scale = self.activation_scale.to(torch.float32) - qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) - else: - raise NotImplementedError(f"Unsupported activation scheme: {self.activation_scheme}") - - output = w8a8_fp8_matmul( - qinput, + return fp8_linear( + input, weight, - scale, scale_inv, - self.block_size, + block_size=self.block_size, + activation_scale=self.activation_scale, output_dtype=input.dtype, + bias=self.bias, ) - if self.bias is not None: - output.add_(self.bias) - - return output.to(dtype=input.dtype) - def fp8_batched_mm_experts_forward( self: torch.nn.Module, @@ -321,6 +311,13 @@ def fp8_batched_mm_experts_forward( "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) + w_up = self.gate_up_proj if self.has_gate else self.up_proj + if w_up.dtype == torch.int8: + raise NotImplementedError( + "'batched_mm' experts dispatch is Triton-only and does not support FP4 (int8-packed) " + "expert weights. Use experts_implementation='deepgemm' instead." + ) + finegrained_fp8 = _load_finegrained_fp8_kernel() num_top_k = top_k_index.size(-1) @@ -339,7 +336,7 @@ def fp8_batched_mm_experts_forward( expert_ids.clamp_(0, self.num_experts - 1) # --- Up projection per expert (FP8 batched) --- - proj_out = finegrained_fp8.batched_fp8_matmul( + proj_out = finegrained_fp8.batched_matmul( selected_hidden_states, self.gate_up_proj if self.has_gate else self.up_proj, self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv, @@ -356,7 +353,7 @@ def fp8_batched_mm_experts_forward( proj_out = self.act_fn(proj_out) # (S, intermediate_dim) # --- Down projection per expert (FP8 batched) --- - proj_out = finegrained_fp8.batched_fp8_matmul( + proj_out = finegrained_fp8.batched_matmul( proj_out, self.down_proj, self.down_proj_scale_inv, @@ -386,6 +383,13 @@ def fp8_grouped_mm_experts_forward( "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) + w_up = self.gate_up_proj if self.has_gate else self.up_proj + if w_up.dtype == torch.int8: + raise NotImplementedError( + "'grouped_mm' experts dispatch is Triton-only and does not support FP4 (int8-packed) " + "expert weights. Use experts_implementation='deepgemm' instead." + ) + finegrained_fp8 = _load_finegrained_fp8_kernel() device = hidden_states.device @@ -430,7 +434,7 @@ def fp8_grouped_mm_experts_forward( ws_down = ws_down.to_local() # --- Up projection per expert (FP8 grouped) --- - proj_out = finegrained_fp8.grouped_fp8_matmul( + proj_out = finegrained_fp8.grouped_matmul( selected_hidden_states_g, w_up, ws_up, @@ -448,7 +452,7 @@ def fp8_grouped_mm_experts_forward( proj_out = self.act_fn(proj_out) # (S, intermediate_dim) # --- Down projection per expert (FP8 grouped) --- - proj_out = finegrained_fp8.grouped_fp8_matmul( + proj_out = finegrained_fp8.grouped_matmul( proj_out, w_down, ws_down, @@ -475,184 +479,13 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) -def _build_deepgemm_contiguous_layout( - expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool -) -> tuple: - """Build the TMA-aligned layout DeepGEMM's grouped GEMM expects. - - Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes - expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or - per-row expert ids with -1 for padding on Hopper. - - Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) - are routed past the last aligned expert block and marked `-1` in the Hopper layout (and - excluded from the Blackwell cumsum), so DeepGEMM skips them. - """ - device = expert_ids_sorted.device - num_tokens = expert_ids_sorted.size(0) - # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. - tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() - aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment - # Upper bound avoids GPU->CPU sync; padding rows are skipped by DeepGEMM. - total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) - - # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the - # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, - # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the - # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. - padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) - sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - - if use_psum_layout: # Blackwell (SM100+) - # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= - # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler - # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` - # between experts only matches the padded tensor when the stored cumsum is over aligned counts. - grouped_layout = aligned_tokens_per_expert.cumsum(0).int() - else: - # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). - grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) - - return sorted_to_padded, grouped_layout, total_padded_rows - - -def _pad_to_deepgemm_contiguous_layout( - hidden_states: torch.Tensor, - scales: torch.Tensor, - sorted_to_padded: torch.Tensor, - total_padded_rows: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """Pad sorted hidden states and scales into the TMA-aligned contiguous layout.""" - hidden_padded = torch.zeros( - total_padded_rows, hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype - ) - hidden_padded[sorted_to_padded] = hidden_states - scales_padded = torch.zeros(total_padded_rows, scales.shape[1], device=hidden_states.device, dtype=torch.float32) - scales_padded[sorted_to_padded] = scales - return hidden_padded, scales_padded - - -def _unpad_from_deepgemm_contiguous_layout( - hidden_states_padded: torch.Tensor, sorted_to_padded: torch.Tensor -) -> torch.Tensor: - """Remove padding rows from the TMA-aligned contiguous layout.""" - return hidden_states_padded[sorted_to_padded] - - -def fp8_deepgemm_experts_forward( - self: torch.nn.Module, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - if self.activation_scheme == "static": - raise NotImplementedError( - "DeepGEMM experts dispatch does not support activation_scheme='static'. " - "Use the default eager dispatch or switch to activation_scheme='dynamic'." - ) - if self.block_size is None: - raise ValueError( - "DeepGEMM requires block-wise quantization (block_size=[128, 128]), " - "but got per-tensor quantization (block_size=None)." - ) - if self.block_size[0] != 128 or self.block_size[1] != 128: - raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") - - deepgemm = _load_deepgemm_kernel() - - device = hidden_states.device - num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) - - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - sample_weights = top_k_weights.reshape(-1) # (S,) - expert_ids = top_k_index.reshape(-1) # (S,) - - # Sort by expert for grouped processing - expert_ids_g, perm = torch.sort(expert_ids) - selected_hidden_states_g = hidden_states[perm // num_top_k] - sample_weights_g = sample_weights[perm] - - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout - ) - - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, - # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond - # the cumsum on Blackwell), and DeepGEMM skips them — sentinels cost no real GEMM compute. - # The kernel writes only valid rows, so sentinel-tail `proj_out` rows are uninit; without the - # post-mask below, `proj_out[sentinel] * 0 = NaN * 0 = NaN` would poison the per-token - # reduction. DeepGEMM is inference-only, so no bwd pre-mask is needed. - sentinel_mask = (expert_ids_g >= self.num_experts).unsqueeze(-1) - - # FSDP2 / EP wraps weights as DTensors but the kernel takes raw pointers — unwrap to - # local shards. Inference-only path, so `to_local()` autograd-awareness is moot. - w_up = self.gate_up_proj if self.has_gate else self.up_proj - ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv - w_down = self.down_proj - ws_down = self.down_proj_scale_inv - if isinstance(w_up, torch.distributed.tensor.DTensor): - w_up = w_up.to_local() - ws_up = ws_up.to_local() - w_down = w_down.to_local() - ws_down = ws_down.to_local() - - # --- Up projection per expert (DeepGEMM grouped contiguous) --- - act_fp8, act_scales = deepgemm.per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) - act_fp8, act_scales = _pad_to_deepgemm_contiguous_layout(act_fp8, act_scales, sorted_to_padded, total_padded_rows) - proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) - deepgemm.grouped_fp8_matmul( - (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) - - # Apply gating or activation - if self.has_gate: - proj_out = self._apply_gate(proj_out) - else: - proj_out = self.act_fn(proj_out) - - # --- Down projection per expert (DeepGEMM grouped contiguous) --- - proj_fp8, proj_scales = deepgemm.per_token_cast_to_fp8(proj_out, use_ue8m0=False) - proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) - deepgemm.grouped_fp8_matmul( - (proj_fp8, proj_scales), - (w_down, ws_down.float()), - proj_out, - grouped_layout, - use_psum_layout=use_psum_layout, - ) - - # Remove padding rows - proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) - - # Apply routing weights - weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) - - # Post-mask (fwd path). - weighted_out.masked_fill_(sentinel_mask, 0.0) - - # Restore original order - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - weighted_out = weighted_out[inv_perm] - - # Accumulate results using deterministic reshape+sum instead of index_add_ - # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) - final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - - return final_hidden_states.to(hidden_states.dtype) - - class FP8Experts(nn.Module): def __init__( self, config, block_size: tuple[int, int] | None = None, activation_scheme: str = "dynamic", + scale_fmt: str = "float", has_bias: bool = False, has_gate: bool = True, dtype=_FP8_DTYPE, @@ -673,31 +506,40 @@ def __init__( self.intermediate_dim = _first_attr(config, "moe_intermediate_size", "intermediate_size") self.act_fn = ACT2FN[_first_attr(config, "hidden_activation", "hidden_act")] + # Expert weight precision is FP8 by default; DeepSeek V4-style models declare + # `config.expert_dtype = "fp4"` for FP4-packed expert weights. FP4 storage: + # - weight is `int8`, K dim halved (2 e2m1 values per byte). + # - SF is `float8_e8m0fnu` per-row at gran_k=32 (no block-wise SF; `block_size` ignored). + is_fp4 = getattr(config, "expert_dtype", "fp8") == "fp4" + if is_fp4: + alloc_kwargs = { + "weight_dtype": torch.int8, + "sf_dtype": _UE8M0_SF_DTYPE, + "weight_k_div": 2, + "sf_gran_n": 1, + "sf_gran_k": 32, + } + else: + alloc_kwargs = { + "weight_dtype": dtype, + "sf_dtype": _UE8M0_SF_DTYPE if scale_fmt == "ue8m0" else torch.float32, + "sf_gran_n": block_size[0] if block_size is not None else None, + "sf_gran_k": block_size[1] if block_size is not None else None, + } + if self.has_gate: - gu_proj_out, gu_proj_in = 2 * self.intermediate_dim, self.hidden_dim - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, gu_proj_out, gu_proj_in, dtype=dtype)) - gu_scale_out = _cdiv(gu_proj_out, self.block_size[0]) if self.block_size is not None else 1 - gu_scale_in = _cdiv(gu_proj_in, self.block_size[1]) if self.block_size is not None else 1 - self.gate_up_proj_scale_inv = nn.Parameter( - torch.empty(self.num_experts, gu_scale_out, gu_scale_in, dtype=torch.float32) + self.gate_up_proj, self.gate_up_proj_scale_inv = _alloc_expert_proj( + self.num_experts, 2 * self.intermediate_dim, self.hidden_dim, **alloc_kwargs ) self.register_parameter("gate_up_proj_bias", None) else: - u_proj_out, u_proj_in = self.intermediate_dim, self.hidden_dim - self.up_proj = nn.Parameter(torch.empty(self.num_experts, u_proj_out, u_proj_in, dtype=dtype)) - u_scale_out = _cdiv(u_proj_out, self.block_size[0]) if self.block_size is not None else 1 - u_scale_in = _cdiv(u_proj_in, self.block_size[1]) if self.block_size is not None else 1 - self.up_proj_scale_inv = nn.Parameter( - torch.empty(self.num_experts, u_scale_out, u_scale_in, dtype=torch.float32) + self.up_proj, self.up_proj_scale_inv = _alloc_expert_proj( + self.num_experts, self.intermediate_dim, self.hidden_dim, **alloc_kwargs ) self.register_parameter("up_proj_bias", None) - d_proj_out, d_proj_in = self.hidden_dim, self.intermediate_dim - self.down_proj = nn.Parameter(torch.empty(self.num_experts, d_proj_out, d_proj_in, dtype=dtype)) - d_scale_out = _cdiv(d_proj_out, self.block_size[0]) if self.block_size is not None else 1 - d_scale_in = _cdiv(d_proj_in, self.block_size[1]) if self.block_size is not None else 1 - self.down_proj_scale_inv = nn.Parameter( - torch.empty(self.num_experts, d_scale_out, d_scale_in, dtype=torch.float32) + self.down_proj, self.down_proj_scale_inv = _alloc_expert_proj( + self.num_experts, self.hidden_dim, self.intermediate_dim, **alloc_kwargs ) self.register_parameter("down_proj_bias", None) @@ -761,24 +603,14 @@ def linear( if weight.element_size() > 1: return F.linear(input, weight, None) - if self.activation_scheme == "static" and activation_scale is not None: - scale = activation_scale.to(torch.float32) - qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) - else: - finegrained_fp8 = _load_finegrained_fp8_kernel() - qinput, scale = finegrained_fp8.fp8_act_quant( - input, self.block_size[1] if self.block_size is not None else input.shape[-1] - ) - - output = w8a8_fp8_matmul( - qinput, + return fp8_linear( + input, weight, - scale, weight_scale_inv, self.block_size, + activation_scale=activation_scale, output_dtype=input.dtype, ) - return output.to(dtype=input.dtype) class FP8ExpertsInterface(ExpertsInterface): @@ -787,7 +619,8 @@ class FP8ExpertsInterface(ExpertsInterface): _global_mapping = { "batched_mm": fp8_batched_mm_experts_forward, "grouped_mm": fp8_grouped_mm_experts_forward, - "deepgemm": fp8_deepgemm_experts_forward, + "deepgemm": deepgemm_fp8_fp4_experts_forward, + "deepgemm_megamoe": deepgemm_fp8_fp4_megamoe_experts_forward, } @@ -805,7 +638,7 @@ def replace_with_fp8_linear( Input model or `torch.nn.Module` as the function is run recursively. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`): Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons. - quantization_config (`FbgemmFp8Config`): + quantization_config (`FineGrainedFP8Config`): The quantization config object that contains the quantization parameters. pre_quantized (`book`, defaults to `False`): Whether the model is pre-quantized or not @@ -837,6 +670,7 @@ def replace_with_fp8_linear( config=config, block_size=quantization_config.weight_block_size, activation_scheme=quantization_config.activation_scheme, + scale_fmt=quantization_config.scale_fmt, has_bias=has_bias, has_gate=has_gate, **module_kwargs, @@ -847,6 +681,7 @@ def replace_with_fp8_linear( out_features=module.out_features, block_size=quantization_config.weight_block_size, activation_scheme=quantization_config.activation_scheme, + scale_fmt=quantization_config.scale_fmt, has_bias=module.bias is not None, **module_kwargs, ) @@ -912,6 +747,12 @@ def _quantize_one(self, key: str, value: torch.Tensor) -> dict[str, torch.Tensor quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) quantized = quantized.reshape(original_shape) inv_scales = (1.0 / scales).to(torch.float32) + # DeepSeek V4-style storage (`scale_fmt="ue8m0"`): round inv_scales to UE8M0-representable + # values (powers of 2) and cast to `float8_e8m0fnu` byte storage so the on-disk dtype + # matches the parameter allocation in `FP8Linear`/`FP8Experts`. + if self.hf_quantizer.quantization_config.scale_fmt == "ue8m0": + inv_scales = torch.pow(2.0, torch.ceil(torch.log2(inv_scales.clamp(min=torch.finfo(torch.float32).tiny)))) + inv_scales = inv_scales.to(_UE8M0_SF_DTYPE) scale_key = key.rsplit(".", 1)[0] + ".weight_scale_inv" if key.endswith("weight") else key + "_scale_inv" return {key: quantized, scale_key: inv_scales} diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index d8d30f13416f..a1c0243dc54c 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -288,7 +288,7 @@ def register_kernel_mapping_transformers(*args, **kwargs): "mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, - "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, + "deep-gemm": {"repo_id": "adarshxs/deep-gemm", "revision": "v2"}, "sonic-moe": {"repo_id": "kernels-community/sonic-moe", "revision": "ep-support"}, } @@ -346,7 +346,7 @@ def load_and_register_attn_kernel( # Load the kernel from hub try: - kernel = get_kernel(repo_id, revision=rev, allow_all_kernels=allow_all_kernels) + kernel = get_kernel(repo_id, revision=rev, allow_all_kernels=True) except Exception as e: raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.") # correctly wrap the kernel @@ -376,7 +376,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) - kernel = get_kernel(repo_id, revision=revision, version=version) + kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True) mapping[kernel_name] = kernel except FileNotFoundError: mapping[kernel_name] = None diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 172e0ad3f81c..5ca4d2e1418b 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -24,6 +24,7 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) +from .deepgemm import deepgemm_bf16_experts_forward from .sonicmoe import sonicmoe_experts_forward @@ -478,6 +479,7 @@ class ExpertsInterface(GeneralInterface): """Interface for registering custom experts forward functions.""" _global_mapping = { + "deepgemm": deepgemm_bf16_experts_forward, "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, "sonicmoe": sonicmoe_experts_forward, diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index bdf82e8490f0..35b8f6e19f47 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -37,6 +37,20 @@ logger = logging.get_logger(__name__) +def to_local(t): + """Unwrap a `DTensor` to its local shard if needed; pass through otherwise. + + Custom kernels (CUTLASS, CuteDSL, Triton) take raw tensor pointers and don't + understand `DTensor`, so weights wrapped by FSDP2 / EP need this unwrap before + they can be fed to the kernel. ``to_local()`` is autograd-aware on the train + path: backward rewraps the gradient as a DTensor matching each parameter's + placements. + """ + if is_torch_available() and isinstance(t, torch.distributed.tensor.DTensor): + return t.to_local() + return t + + def initialize_tensor_parallelism( tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None ): @@ -766,6 +780,28 @@ def _backward_hook(mod, grad_input, grad_output, mesh=device_mesh): module.register_full_backward_hook(_backward_hook) +class AllReduceParallel(TensorParallelLayer): + """ + Marker layer: parameters (if any) are replicated; the forward output is all-reduced + across the TP mesh. Use as a no-op `nn.Identity` placed at a sync point after a + colwise-sharded compute that ends in a head-axis (or similar) reduction, so each + rank holds only a partial sum and needs to share it before the next dependent op + (e.g. the lightning indexer's score sum before its top-k). + """ + + def _prepare_input_fn(self, mod, inputs, device_mesh): + return inputs + + def _prepare_output_fn(self, mod, outputs, device_mesh): + return all_reduce_forward(outputs, device_mesh) + + def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None): + return param[...].to(device=device, dtype=dtype) + + def prepare_module_tp(self, module, device_mesh, **kwargs): + distribute_module(module, device_mesh, output_fn=self._prepare_output_fn) + + class MlaKvAProjParallel(TensorParallelLayer): """ For MLA attention used in DeepSeek-V2 style models (deepseek_v2, longcat_flash, glm_moe_dsa, glm4_moe_lite): @@ -1088,7 +1124,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def _prepare_input_fn(self, mod, inputs, device_mesh): - return inputs[0] if inputs else inputs + return inputs def _prepare_output_fn(self, mod, outputs, device_mesh): """ @@ -1135,7 +1171,13 @@ def _prepare_output_fn(self, mod, outputs, device_mesh): The sentinel index (num_local_experts) is skipped by one_hot encoding or clamped + masked in grouped_mm/batched_mm. After the expert forward, an all_reduce sums partial outputs across EP ranks to produce the full result. + + Mega MoE skips this remap: its kernel does the EP dispatch itself and wants raw + global expert ids with unmasked routing weights. """ + if _is_megamoe(mod): + return outputs + ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() num_experts = getattr(mod, "num_experts", None) if num_experts is None: @@ -1183,6 +1225,12 @@ def _prepare_input_fn(self, mod, inputs, device_mesh): top_k_index = inputs[1] top_k_weights = inputs[2] + # Mega MoE is inference-only (the kernel has no backward) and handles EP + # dispatch + combine + per-rank token sharding internally. Skip the gradient + # sync hooks and append the EP `process_group` so the forward can rendezvous. + if _is_megamoe(mod): + return hidden_states, top_k_index, top_k_weights, device_mesh.get_group() + # all_reduce_backward on hidden_states for correct colwise (gate_up_proj) gradient hidden_states = all_reduce_backward(hidden_states, device_mesh) @@ -1191,9 +1239,12 @@ def _prepare_input_fn(self, mod, inputs, device_mesh): # and partial_expert_output is different on each GPU before all-reduce top_k_weights = all_reduce_backward(top_k_weights, device_mesh) - return (hidden_states, top_k_index, top_k_weights) + return hidden_states, top_k_index, top_k_weights def _prepare_output_fn(self, mod, outputs, device_mesh): + # Mega MoE returned the fully-combined gathered output; skip the all-reduce. + if _is_megamoe(mod): + return outputs # all_reduce_forward to sum partial expert outputs across GPUs return all_reduce_forward(outputs, device_mesh) @@ -1205,6 +1256,10 @@ def shard_tensor( return param[...].to(device=device, dtype=dtype) +def _is_megamoe(mod: nn.Module) -> bool: + return getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe" + + class MoeIdentityExpertParallel(TensorParallelLayer): """ TP class for zero/identity experts in MoE layers. @@ -1247,6 +1302,7 @@ class ParallelInterface(GeneralInterface): "moe_identity_expert": MoeIdentityExpertParallel(), "replicated_with_grad_allreduce": ReplicatedWithGradAllReduce(), "mla_kv_a_proj": MlaKvAProjParallel(), + "all_reduce": AllReduceParallel(), } if is_torch_available() and _torch_distributed_available else {} @@ -1267,6 +1323,7 @@ class ParallelInterface(GeneralInterface): "sequence_parallel": None, "replicated_with_grad_allreduce": None, "mla_kv_a_proj": None, + "all_reduce": None, } # Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced) @@ -1282,6 +1339,7 @@ class ParallelInterface(GeneralInterface): "sequence_parallel": None, "replicated_with_grad_allreduce": None, "mla_kv_a_proj": None, + "all_reduce": None, } @classmethod diff --git a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py index 2cbc02c6d0f7..2f9e3527d430 100644 --- a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py @@ -112,20 +112,32 @@ class DeepseekV4Config(PreTrainedConfig): "norm": (["hidden_states"], ["hidden_states"]), } base_model_ep_plan = { - # EP-only by default, same shape as gpt-oss: route on the gate, run the - # routed experts as a grouped-GEMM kernel sharded along the expert axis, - # and wrap the experts module with `moe_tp_experts` so its output gets - # all-reduced across ranks. Attention stays replicated (V4 is shared-KV - # MQA + a CSA / HCA compressor branch — both broadcast a single KV head - # across all attention heads via `repeat_kv`, so colwise-sharding - # `q_b_proj` would leave KV replicated and `repeat_kv` would no longer - # match the rank-local query head count). The shared MLP also stays - # replicated — it's small and not worth TP-ing. There's deliberately - # no `base_model_tp_plan` for V4: we don't ship a pure-TP plan, only EP. + # V4 ships EP only (no `base_model_tp_plan` — the runtime picks one plan or + # the other, never both, and V4 is MoE so EP is the only sensible config). + # MoE parallelism: route on the gate, run the routed experts as a grouped-GEMM + # kernel sharded along the expert axis, and wrap the experts module with + # `moe_tp_experts` so its output gets all-reduced across ranks. Same shape as + # gpt-oss. Main attention stays replicated: V4 is shared-KV MQA + a CSA / HCA + # compressor branch — both broadcast a single KV head across all attention + # heads via `repeat_kv`, so colwise-sharding `q_b_proj` would leave KV + # replicated and `repeat_kv` would no longer match the rank-local query head + # count. The shared MLP also stays replicated — it's small and not worth + # sharding. The Lightning Indexer is the one carve-out: its keys are + # replicated (own compressor at index_head_dim fed by replicated + # hidden_states), so head-sharding is well-formed; `q_b_proj` and + # `weights_proj` go colwise, and `scores_sync` is a `nn.Identity` whose + # `"all_reduce"` output hook sums the per-rank partial `index_scores` across + # the mesh so every rank picks the same top-k. Mirrors the reference + # inference (`inference/model.py:393, 394, 422-423`). "layers.*.mlp.gate": "ep_router", "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm", "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm", "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.self_attn.compressor.indexer.q_b_proj": "colwise", + "layers.*.self_attn.compressor.indexer.weights_proj": "colwise", + "layers.*.self_attn.compressor.indexer.scores_sync": "all_reduce", } vocab_size: int = 129280 @@ -187,7 +199,7 @@ class DeepseekV4Config(PreTrainedConfig): # back to wrapping the whole dict as a single set of params when the subset check # fails, which then warns about `main` / `compress` as unrecognized keys. Override # to iterate the rope-type-keyed sub-dicts directly. - _rope_type_labels = ("main", "compress") + _rope_type_labels = ("sliding", "compress") def validate_rope(self): rope_parameters_dict = getattr(self, "rope_parameters", None) or {} @@ -285,19 +297,27 @@ def __post_init__(self, **kwargs): # `rope_parameters`: split the flat dict (left by `convert_rope_params_to_dict`, # which folded any legacy `rope_scaling` block in) into per-rope-type - # `{main, compress}` sub-dicts. Idempotent: re-loading an already-split config - # is a no-op via the `isinstance` short-circuit. The two sub-dicts differ only - # in `rope_theta` (main: 10000, compress: 160000). + # `{sliding, compress}` sub-dicts. Mirrors reference `inference/model.py:475-481`: + # sliding-window attention layers use base RoPE (`rope_theta=10000`, no YaRN — + # `original_seq_len=0` disables it); CSA/HCA layers (and their internal + # compressors/indexer) use `compress_rope_theta=160000` with YaRN frequency + # interpolation. Idempotent: re-loading an already-split config is a no-op via + # the `isinstance` short-circuit. rp = self.rope_parameters or {} - if isinstance(rp.get("main"), dict) and isinstance(rp.get("compress"), dict): + if isinstance(rp.get("sliding"), dict) and isinstance(rp.get("compress"), dict): # Already nested — drop any leftover top-level keys. - self.rope_parameters = {"main": rp["main"], "compress": rp["compress"]} + self.rope_parameters = {"sliding": rp["sliding"], "compress": rp["compress"]} else: - base = {k: v for k, v in rp.items() if k not in ("main", "compress")} - base.setdefault("rope_theta", self.rope_theta) + base = {k: v for k, v in rp.items() if k not in ("sliding", "compress")} base.setdefault("rope_type", "default") base["partial_rotary_factor"] = self.partial_rotary_factor - self.rope_parameters = {"main": dict(base), "compress": {**base, "rope_theta": self.compress_rope_theta}} + sliding = { + "rope_theta": self.rope_theta, + "rope_type": "default", + "partial_rotary_factor": self.partial_rotary_factor, + } + compress = {**base, "rope_theta": self.compress_rope_theta} + self.rope_parameters = {"sliding": sliding, "compress": compress} __all__ = ["DeepseekV4Config"] diff --git a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py index 746219c11138..ec0abffff670 100644 --- a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py @@ -44,6 +44,13 @@ @use_kernel_forward_from_hub("RMSNorm") class DeepseekV4RMSNorm(nn.Module): + """Like V3's, but the weight·hidden multiply happens in fp32 before the cast + back to the input dtype — mirrors reference `inference/model.py:191-196` + where `weight` is declared `dtype=torch.float32` and the multiply runs in + fp32 before the final `.to(dtype)`. V3 downcasts hidden_states first and + multiplies in bf16, which loses precision across the ~5 norms × 43 layers. + """ + def __init__(self, hidden_size, eps: float = 1e-6) -> None: """ DeepseekV4RMSNorm is equivalent to T5LayerNorm @@ -53,11 +60,11 @@ def __init__(self, hidden_size, eps: float = 1e-6) -> None: self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) + dtype = hidden_states.dtype + hidden_states = hidden_states.float() variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return (self.weight.float() * hidden_states).to(dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" @@ -106,10 +113,13 @@ def __init__(self, config: DeepseekV4Config): rope_init_fn = self.compute_default_rope_parameters if self.rope_type[layer_type] != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] - inv_freq, attention_scaling = rope_init_fn(config, layer_type=layer_type) + # Reference (`inference/model.py:228`) builds cos/sin via + # `torch.polar(torch.ones_like(freqs), freqs)` — unit magnitude, no YaRN + # `attention_factor` scaling. We discard `rope_init_fn`'s scaling factor + # and emit raw `cos = freqs.cos()` / `sin = freqs.sin()` in `forward`. + inv_freq, _ = rope_init_fn(config, layer_type=layer_type) self.register_buffer(f"{layer_type}_inv_freq", inv_freq, persistent=False) self.register_buffer(f"{layer_type}_original_inv_freq", inv_freq.clone(), persistent=False) - setattr(self, f"{layer_type}_attention_scaling", attention_scaling) @staticmethod def compute_default_rope_parameters( @@ -157,14 +167,13 @@ def forward(self, x, position_ids, layer_type=None): # the `repeat_interleave(2)` next to the rotation math, where the link between # the doubled dim and `rotate_half` is local and obvious. inv_freq = getattr(self, f"{layer_type}_inv_freq") - attention_scaling = getattr(self, f"{layer_type}_attention_scaling") inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - cos = freqs.cos() * attention_scaling - sin = freqs.sin() * attention_scaling + cos = freqs.cos() + sin = freqs.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -398,7 +407,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, None]: batch, _, _ = hidden_states.shape cache_layer: DeepseekV4HCACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -427,7 +436,9 @@ def forward( if cache_layer is not None: compressed = cache_layer.update_compressor_states("compressor", compressed) - return compressed.unsqueeze(1) + # `None` validity: HCA has no indexer / per-query gather — every query attends to + # every entry in the compressor section, so attention just right-pads its mask with 0s. + return compressed.unsqueeze(1), None class DeepseekV4Indexer(nn.Module): @@ -476,6 +487,10 @@ def __init__(self, config: DeepseekV4Config): self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.head_dim, bias=False) self.weights_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False) self.rotary_emb = DeepseekV4RotaryEmbedding(config) + # No-op marker. Under TP it picks up the `"all_reduce"` plan entry, which adds a + # forward output hook that sums the (partial, per-rank) `index_scores` across the + # TP mesh so every rank picks the same top-k. See `inference/model.py:422-423`. + self.scores_sync = nn.Identity() def forward( self, @@ -484,7 +499,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.LongTensor: + ) -> tuple[torch.LongTensor, torch.BoolTensor]: batch, seq_len, _ = hidden_states.shape cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -539,9 +554,24 @@ def forward( scores = torch.matmul(q.float(), compressed_kv.transpose(-1, -2).float().unsqueeze(1)) # [B, S, H, T] scores = F.relu(scores) * self.softmax_scale weights = self.weights_proj(hidden_states).float() * self.weights_scaling # [B, S, H] - index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] + index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] — partial under TP + index_scores = self.scores_sync(index_scores) + # Causal mask before topk (mirrors `inference/model.py:424-426`): query at absolute + # position `p` may only pick compressed entries `i` whose source-token window ends + # at or before `p` — i.e. ``i < (p + 1) // compress_rate``. Without this, prefill + # queries can score and pick entries that aggregate future source tokens. + if compressed_kv.shape[1] > 0: + cutoff = (position_ids + 1).div(self.compress_rate, rounding_mode="floor").unsqueeze(-1) # [B, S, 1] + i_idx = torch.arange(compressed_kv.shape[1], device=index_scores.device) + index_scores = index_scores.masked_fill(i_idx >= cutoff, float("-inf")) topk = min(self.index_topk, compressed_kv.shape[1]) - return index_scores.topk(topk, dim=-1).indices + result = index_scores.topk(topk, dim=-1) + # `valid[b, s, j] == False` iff the j-th pick for query `s` was a forced `-inf` + # placeholder (the query had fewer than `index_topk` causally-valid entries). The + # pick still indexes a real but non-causal entry, so attention has to skip it — + # mirrors `inference/model.py:428-430` (`topk_idxs == -1` slots). + valid = result.values > float("-inf") + return result.indices, valid class DeepseekV4CSACompressor(nn.Module): @@ -585,7 +615,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.BoolTensor]: batch, seq_len, _ = hidden_states.shape cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -641,11 +671,14 @@ def forward( compressed = cache_layer.update_compressor_states("compressor", compressed) compressed_kv = compressed.unsqueeze(1) - # Lightning Indexer: gather top-`index_topk` compressed entries per query. - topk = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) # [B, S, k] + # Lightning Indexer: gather top-`index_topk` compressed entries per query and pass + # the per-pick validity mask up so attention can ``-inf``-mask the placeholder picks + # (queries with fewer than `index_topk` causally-valid entries). + topk_indices, topk_valid = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) expanded = compressed_kv.unsqueeze(2).expand(-1, -1, seq_len, -1, -1) - idx = topk.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) - return torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + idx = topk_indices.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) + gathered = torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + return gathered, topk_valid.reshape(batch, -1) # valid: [B, S*k] over the flat-packed KV axis def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -740,11 +773,18 @@ def __init__(self, config: DeepseekV4Config, layer_idx: int): self.compressor = ( COMPRESSOR_CLASSES[self.layer_type](config) if self.layer_type != "sliding_attention" else None ) + # Reference `inference/model.py:475-481` picks `freqs_cis` per attention layer: + # sliding layers use base RoPE (`rope_theta=10000`, no YaRN); CSA/HCA layers + # use `compress_rope_theta=160000` with YaRN. The compressor inside a CSA/HCA + # layer shares its parent attention's `freqs_cis`, so main Q/K and compressor + # rotate consistently. We mirror that by selecting from a per-branch dict + # built once in `DeepseekV4Model.forward`. + self.rope_branch = "sliding" if self.layer_type == "sliding_attention" else "compress" def forward( self, hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], + position_embeddings: dict[str, tuple[torch.Tensor, torch.Tensor]], position_ids: torch.Tensor, attention_mask: torch.Tensor | None, past_key_values: Cache | None = None, @@ -752,7 +792,7 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - cos, sin = position_embeddings + cos, sin = position_embeddings[self.rope_branch] q_residual = self.q_a_norm(self.q_a_proj(hidden_states)) q = self.q_b_proj(q_residual).view(*hidden_shape).transpose(1, 2) @@ -765,17 +805,40 @@ def forward( if past_key_values is not None: # sliding where K==V kv = past_key_values.update(kv, kv, self.layer_idx)[0] + compressor_valid = None if self.compressor is not None: # Compressed KV (CSA or HCA) - compressed_kv = self.compressor(hidden_states, q_residual, position_ids, past_key_values, self.layer_idx) + compressed_kv, compressor_valid = self.compressor( + hidden_states, q_residual, position_ids, past_key_values, self.layer_idx + ) kv = torch.cat([kv, compressed_kv], dim=2) # The compressor path concatenates extra entries onto the KV axis after the # standard sliding-window cache update, so a tensor `attention_mask` (built - # for the pre-concat KV length) needs to be right-padded to cover them. + # for the pre-concat KV length) needs to be extended to cover them. # Flex-attention passes a `BlockMask` whose KV-length axis comes from its - # own `mask_mod`, not from a dense tensor — skip the pad in that case. + # own `mask_mod`, not from a dense tensor — skip the extend in that case. if isinstance(attention_mask, torch.Tensor) and kv.shape[2] > attention_mask.shape[-1]: - attention_mask = F.pad(attention_mask, (0, kv.shape[2] - attention_mask.shape[-1]), value=0.0) + n_compressor = kv.shape[2] - attention_mask.shape[-1] + S = q.shape[2] + if self.layer_type == "compressed_sparse_attention" and S > 1 and n_compressor == S * (n_compressor // S): + # CSA's compressor returns per-query top-`k` entries flat-packed as + # `[B, 1, S*k, D]`. The official inference's `sparse_attn` gathers per query + # so each query only attends to its own `k` slots; we mirror it with a + # block-diagonal mask over the compressor section. `compressor_valid` is + # False at slots the indexer flagged as `-inf` placeholders — strip those too. + # Decode (S=1) collapses to a no-op block-diagonal so we let the simple + # right-pad branch handle it. + k_per_query = n_compressor // S + q_idx = torch.arange(S, device=q.device).unsqueeze(1) # [S, 1] + kv_idx = torch.arange(n_compressor, device=q.device).div(k_per_query, rounding_mode="floor") # [S*k] + allowed = (q_idx == kv_idx).unsqueeze(0) & compressor_valid.unsqueeze(1) # [B, S, S*k] + extra = torch.where(allowed, 0.0, float("-inf")).to(attention_mask.dtype).unsqueeze(1) + extra = extra.expand(attention_mask.shape[0], attention_mask.shape[1], -1, -1) + attention_mask = torch.cat([attention_mask, extra], dim=-1) + else: + # HCA (every query sees every entry) and CSA decode (single query's k slots + # span the whole compressor section): plain 0.0 right-pad is correct. + attention_mask = F.pad(attention_mask, (0, n_compressor), value=0.0) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward @@ -831,10 +894,10 @@ class DeepseekV4HyperConnection(nn.Module): [B, S, H] [B, S, H] [B, S, H, H] × scale[0] × scale[1] × scale[2] + base[:H] + base[H:2H] + base[2H:] - σ() + eps σ() + eps σ() + eps + σ() + eps 2·σ() softmax(-1) + eps │ │ │ - pre post Sinkhorn(iters) - (stream collapse weights) (block-output placement) row/col normalise + pre post Sinkhorn(iters) + (stream collapse weights) (block-output placement, range [0, 2]) row/col normalise │ comb (stream mixer) @@ -856,28 +919,23 @@ def __init__(self, config: DeepseekV4Config): self.scale = nn.Parameter(torch.empty(3)) def forward(self, hidden_streams: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - r""" - Compute `pre`, `post`, `comb` from the mHC mapping (paper §2.2 eq. 8). - `comb` is projected onto the doubly-stochastic manifold via Sinkhorn- - Knopp: starting from the sigmoid-positive matrix, alternate row and - column normalisation for `hc_sinkhorn_iters` steps. `pre` then collapses - the `hc_mult` parallel streams into a single sequence (input projection - into the sublayer); `post` and `comb` are returned for the caller to - apply on the sublayer output. - """ flat = self.input_norm(hidden_streams.flatten(start_dim=2).float()) mix = F.linear(flat, self.fn.float()) # [B, S, (2+H)*H] pre_scale, post_scale, comb_scale = self.scale.unbind(0) hc = self.hc_mult pre = torch.sigmoid(mix[..., :hc] * pre_scale + self.base[:hc]) + self.hc_eps - post = torch.sigmoid(mix[..., hc : 2 * hc] * post_scale + self.base[hc : 2 * hc]) + self.hc_eps - comb = ( - torch.sigmoid( - mix[..., 2 * hc :].view(*mix.shape[:-1], hc, hc) * comb_scale + self.base[2 * hc :].view(hc, hc) - ) - + self.hc_eps - ) - for _ in range(self.hc_sinkhorn_iters): + # `post` is `2 * sigmoid` (range [0, 2], no eps) to match the reference kernel + # (`inference/kernel.py:394`). `pre` and `comb` keep the `sigmoid + eps` form + # they share with the kernel (lines 392, 408). + post = 2 * torch.sigmoid(mix[..., hc : 2 * hc] * post_scale + self.base[hc : 2 * hc]) + # Sinkhorn init mirrors `inference/kernel.py:401-413`: row-softmax + eps, + # then one column-normalisation, then `iters - 1` symmetric (row, col) rounds. + # Different positive starting matrix → different doubly-stochastic fixed point + # than a `sigmoid+eps` init, so this matters even though row/col count nets out. + comb_logits = mix[..., 2 * hc :].view(*mix.shape[:-1], hc, hc) * comb_scale + self.base[2 * hc :].view(hc, hc) + comb = torch.softmax(comb_logits, dim=-1) + self.hc_eps + comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) + for _ in range(self.hc_sinkhorn_iters - 1): comb = comb / (comb.sum(dim=-1, keepdim=True) + self.hc_eps) comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) # Collapse the `hc_mult` parallel streams down to a single sequence using @@ -907,6 +965,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DeepseekV4MLP(nn.Module): + """Used by the shared expert. Mirrors reference `inference/model.py:587-606`: + gate/up promoted to fp32, optionally clamped by `swiglu_limit`, SiLU+mul stay in + fp32, then the product is cast back to the input dtype before `down_proj`. + """ + def __init__(self, config): super().__init__() self.config = config @@ -917,9 +980,15 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + gate = self.gate_proj(x).float() + up = self.up_proj(x).float() + limit = self.config.swiglu_limit + if limit > 0: + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + return self.down_proj((self.act_fn(gate) * up).to(dtype)) @use_experts_implementation @@ -1069,16 +1138,22 @@ def forward( # `post` / `comb` come out of the HC modules in fp32 (Sinkhorn projection runs # in float); the .to(dtype) puts everything back to the input dtype before mixing # so both sites stay consistent with `hidden_states`'s entry dtype. + # `comb` is consumed transposed: reference `inference/model.py:685` indexes it as + # `sum_j comb[j, k] * residual[j, d]` (sum over the FIRST hc axis), which is + # equivalent to `comb.T @ residual`. Sinkhorn produces a doubly-stochastic but + # non-symmetric matrix, so the direction matters. dtype = hidden_states.dtype post, comb, collapsed = self.attn_hc(hidden_states) attn_output, _ = self.self_attn(self.input_layernorm(collapsed), **kwargs) hidden_states = post.to(dtype).unsqueeze(-1) * attn_output.unsqueeze(-2) + torch.matmul( - comb.to(dtype), hidden_states + comb.to(dtype).transpose(-1, -2), hidden_states ) post, comb, collapsed = self.ffn_hc(hidden_states) mlp_output = self.mlp(self.post_attention_layernorm(collapsed), input_ids=input_ids) - return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul(comb.to(dtype), hidden_states) + return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul( + comb.to(dtype).transpose(-1, -2), hidden_states + ) @auto_docstring @@ -1223,7 +1298,14 @@ def forward( position_ids=position_ids, ) hidden_states = inputs_embeds.unsqueeze(2).expand(-1, -1, self.config.hc_mult, -1).contiguous() - position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="main") + # Per-layer rope: sliding-attention layers use base RoPE (`rope_theta=10000`, + # no YaRN), CSA/HCA layers use `compress_rope_theta=160000` with YaRN — + # mirrors reference `inference/model.py:475-481`. Each attention picks the + # right branch via `self.rope_branch`. + position_embeddings = { + "sliding": self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="sliding"), + "compress": self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="compress"), + } for layer in self.layers: hidden_states = layer( diff --git a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py index 759bfabf017b..96ca7e1e3edb 100644 --- a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py @@ -64,7 +64,19 @@ def apply_rotary_pos_emb( class DeepseekV4RMSNorm(DeepseekV3RMSNorm): - pass + """Like V3's, but the weight·hidden multiply happens in fp32 before the cast + back to the input dtype — mirrors reference `inference/model.py:191-196` + where `weight` is declared `dtype=torch.float32` and the multiply runs in + fp32 before the final `.to(dtype)`. V3 downcasts hidden_states first and + multiplies in bf16, which loses precision across the ~5 norms × 43 layers. + """ + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + dtype = hidden_states.dtype + hidden_states = hidden_states.float() + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight.float() * hidden_states).to(dtype) class DeepseekV4UnweightedRMSNorm(nn.Module): @@ -108,10 +120,13 @@ def __init__(self, config: DeepseekV4Config): rope_init_fn = self.compute_default_rope_parameters if self.rope_type[layer_type] != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] - inv_freq, attention_scaling = rope_init_fn(config, layer_type=layer_type) + # Reference (`inference/model.py:228`) builds cos/sin via + # `torch.polar(torch.ones_like(freqs), freqs)` — unit magnitude, no YaRN + # `attention_factor` scaling. We discard `rope_init_fn`'s scaling factor + # and emit raw `cos = freqs.cos()` / `sin = freqs.sin()` in `forward`. + inv_freq, _ = rope_init_fn(config, layer_type=layer_type) self.register_buffer(f"{layer_type}_inv_freq", inv_freq, persistent=False) self.register_buffer(f"{layer_type}_original_inv_freq", inv_freq.clone(), persistent=False) - setattr(self, f"{layer_type}_attention_scaling", attention_scaling) def forward(self, x, position_ids, layer_type=None): # Key difference vs Laguna's forward: no `torch.cat([freqs, freqs], dim=-1)` @@ -120,14 +135,13 @@ def forward(self, x, position_ids, layer_type=None): # the `repeat_interleave(2)` next to the rotation math, where the link between # the doubled dim and `rotate_half` is local and obvious. inv_freq = getattr(self, f"{layer_type}_inv_freq") - attention_scaling = getattr(self, f"{layer_type}_attention_scaling") inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - cos = freqs.cos() * attention_scaling - sin = freqs.sin() * attention_scaling + cos = freqs.cos() + sin = freqs.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -334,7 +348,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, None]: batch, _, _ = hidden_states.shape cache_layer: DeepseekV4HCACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -363,7 +377,9 @@ def forward( if cache_layer is not None: compressed = cache_layer.update_compressor_states("compressor", compressed) - return compressed.unsqueeze(1) + # `None` validity: HCA has no indexer / per-query gather — every query attends to + # every entry in the compressor section, so attention just right-pads its mask with 0s. + return compressed.unsqueeze(1), None class DeepseekV4Indexer(nn.Module): @@ -412,6 +428,10 @@ def __init__(self, config: DeepseekV4Config): self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.head_dim, bias=False) self.weights_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False) self.rotary_emb = DeepseekV4RotaryEmbedding(config) + # No-op marker. Under TP it picks up the `"all_reduce"` plan entry, which adds a + # forward output hook that sums the (partial, per-rank) `index_scores` across the + # TP mesh so every rank picks the same top-k. See `inference/model.py:422-423`. + self.scores_sync = nn.Identity() def forward( self, @@ -420,7 +440,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.LongTensor: + ) -> tuple[torch.LongTensor, torch.BoolTensor]: batch, seq_len, _ = hidden_states.shape cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -475,9 +495,24 @@ def forward( scores = torch.matmul(q.float(), compressed_kv.transpose(-1, -2).float().unsqueeze(1)) # [B, S, H, T] scores = F.relu(scores) * self.softmax_scale weights = self.weights_proj(hidden_states).float() * self.weights_scaling # [B, S, H] - index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] + index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] — partial under TP + index_scores = self.scores_sync(index_scores) + # Causal mask before topk (mirrors `inference/model.py:424-426`): query at absolute + # position `p` may only pick compressed entries `i` whose source-token window ends + # at or before `p` — i.e. ``i < (p + 1) // compress_rate``. Without this, prefill + # queries can score and pick entries that aggregate future source tokens. + if compressed_kv.shape[1] > 0: + cutoff = (position_ids + 1).div(self.compress_rate, rounding_mode="floor").unsqueeze(-1) # [B, S, 1] + i_idx = torch.arange(compressed_kv.shape[1], device=index_scores.device) + index_scores = index_scores.masked_fill(i_idx >= cutoff, float("-inf")) topk = min(self.index_topk, compressed_kv.shape[1]) - return index_scores.topk(topk, dim=-1).indices + result = index_scores.topk(topk, dim=-1) + # `valid[b, s, j] == False` iff the j-th pick for query `s` was a forced `-inf` + # placeholder (the query had fewer than `index_topk` causally-valid entries). The + # pick still indexes a real but non-causal entry, so attention has to skip it — + # mirrors `inference/model.py:428-430` (`topk_idxs == -1` slots). + valid = result.values > float("-inf") + return result.indices, valid class DeepseekV4CSACompressor(nn.Module): @@ -521,7 +556,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.BoolTensor]: batch, seq_len, _ = hidden_states.shape cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -577,11 +612,14 @@ def forward( compressed = cache_layer.update_compressor_states("compressor", compressed) compressed_kv = compressed.unsqueeze(1) - # Lightning Indexer: gather top-`index_topk` compressed entries per query. - topk = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) # [B, S, k] + # Lightning Indexer: gather top-`index_topk` compressed entries per query and pass + # the per-pick validity mask up so attention can ``-inf``-mask the placeholder picks + # (queries with fewer than `index_topk` causally-valid entries). + topk_indices, topk_valid = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) expanded = compressed_kv.unsqueeze(2).expand(-1, -1, seq_len, -1, -1) - idx = topk.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) - return torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + idx = topk_indices.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) + gathered = torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + return gathered, topk_valid.reshape(batch, -1) # valid: [B, S*k] over the flat-packed KV axis COMPRESSOR_CLASSES = { @@ -633,11 +671,18 @@ def __init__(self, config: DeepseekV4Config, layer_idx: int): self.compressor = ( COMPRESSOR_CLASSES[self.layer_type](config) if self.layer_type != "sliding_attention" else None ) + # Reference `inference/model.py:475-481` picks `freqs_cis` per attention layer: + # sliding layers use base RoPE (`rope_theta=10000`, no YaRN); CSA/HCA layers + # use `compress_rope_theta=160000` with YaRN. The compressor inside a CSA/HCA + # layer shares its parent attention's `freqs_cis`, so main Q/K and compressor + # rotate consistently. We mirror that by selecting from a per-branch dict + # built once in `DeepseekV4Model.forward`. + self.rope_branch = "sliding" if self.layer_type == "sliding_attention" else "compress" def forward( self, hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], + position_embeddings: dict[str, tuple[torch.Tensor, torch.Tensor]], position_ids: torch.Tensor, attention_mask: torch.Tensor | None, past_key_values: Cache | None = None, @@ -645,7 +690,7 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - cos, sin = position_embeddings + cos, sin = position_embeddings[self.rope_branch] q_residual = self.q_a_norm(self.q_a_proj(hidden_states)) q = self.q_b_proj(q_residual).view(*hidden_shape).transpose(1, 2) @@ -658,17 +703,40 @@ def forward( if past_key_values is not None: # sliding where K==V kv = past_key_values.update(kv, kv, self.layer_idx)[0] + compressor_valid = None if self.compressor is not None: # Compressed KV (CSA or HCA) - compressed_kv = self.compressor(hidden_states, q_residual, position_ids, past_key_values, self.layer_idx) + compressed_kv, compressor_valid = self.compressor( + hidden_states, q_residual, position_ids, past_key_values, self.layer_idx + ) kv = torch.cat([kv, compressed_kv], dim=2) # The compressor path concatenates extra entries onto the KV axis after the # standard sliding-window cache update, so a tensor `attention_mask` (built - # for the pre-concat KV length) needs to be right-padded to cover them. + # for the pre-concat KV length) needs to be extended to cover them. # Flex-attention passes a `BlockMask` whose KV-length axis comes from its - # own `mask_mod`, not from a dense tensor — skip the pad in that case. + # own `mask_mod`, not from a dense tensor — skip the extend in that case. if isinstance(attention_mask, torch.Tensor) and kv.shape[2] > attention_mask.shape[-1]: - attention_mask = F.pad(attention_mask, (0, kv.shape[2] - attention_mask.shape[-1]), value=0.0) + n_compressor = kv.shape[2] - attention_mask.shape[-1] + S = q.shape[2] + if self.layer_type == "compressed_sparse_attention" and S > 1 and n_compressor == S * (n_compressor // S): + # CSA's compressor returns per-query top-`k` entries flat-packed as + # `[B, 1, S*k, D]`. The official inference's `sparse_attn` gathers per query + # so each query only attends to its own `k` slots; we mirror it with a + # block-diagonal mask over the compressor section. `compressor_valid` is + # False at slots the indexer flagged as `-inf` placeholders — strip those too. + # Decode (S=1) collapses to a no-op block-diagonal so we let the simple + # right-pad branch handle it. + k_per_query = n_compressor // S + q_idx = torch.arange(S, device=q.device).unsqueeze(1) # [S, 1] + kv_idx = torch.arange(n_compressor, device=q.device).div(k_per_query, rounding_mode="floor") # [S*k] + allowed = (q_idx == kv_idx).unsqueeze(0) & compressor_valid.unsqueeze(1) # [B, S, S*k] + extra = torch.where(allowed, 0.0, float("-inf")).to(attention_mask.dtype).unsqueeze(1) + extra = extra.expand(attention_mask.shape[0], attention_mask.shape[1], -1, -1) + attention_mask = torch.cat([attention_mask, extra], dim=-1) + else: + # HCA (every query sees every entry) and CSA decode (single query's k slots + # span the whole compressor section): plain 0.0 right-pad is correct. + attention_mask = F.pad(attention_mask, (0, n_compressor), value=0.0) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward @@ -724,10 +792,10 @@ class DeepseekV4HyperConnection(nn.Module): [B, S, H] [B, S, H] [B, S, H, H] × scale[0] × scale[1] × scale[2] + base[:H] + base[H:2H] + base[2H:] - σ() + eps σ() + eps σ() + eps + σ() + eps 2·σ() softmax(-1) + eps │ │ │ - pre post Sinkhorn(iters) - (stream collapse weights) (block-output placement) row/col normalise + pre post Sinkhorn(iters) + (stream collapse weights) (block-output placement, range [0, 2]) row/col normalise │ comb (stream mixer) @@ -749,28 +817,23 @@ def __init__(self, config: DeepseekV4Config): self.scale = nn.Parameter(torch.empty(3)) def forward(self, hidden_streams: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - r""" - Compute `pre`, `post`, `comb` from the mHC mapping (paper §2.2 eq. 8). - `comb` is projected onto the doubly-stochastic manifold via Sinkhorn- - Knopp: starting from the sigmoid-positive matrix, alternate row and - column normalisation for `hc_sinkhorn_iters` steps. `pre` then collapses - the `hc_mult` parallel streams into a single sequence (input projection - into the sublayer); `post` and `comb` are returned for the caller to - apply on the sublayer output. - """ flat = self.input_norm(hidden_streams.flatten(start_dim=2).float()) mix = F.linear(flat, self.fn.float()) # [B, S, (2+H)*H] pre_scale, post_scale, comb_scale = self.scale.unbind(0) hc = self.hc_mult pre = torch.sigmoid(mix[..., :hc] * pre_scale + self.base[:hc]) + self.hc_eps - post = torch.sigmoid(mix[..., hc : 2 * hc] * post_scale + self.base[hc : 2 * hc]) + self.hc_eps - comb = ( - torch.sigmoid( - mix[..., 2 * hc :].view(*mix.shape[:-1], hc, hc) * comb_scale + self.base[2 * hc :].view(hc, hc) - ) - + self.hc_eps - ) - for _ in range(self.hc_sinkhorn_iters): + # `post` is `2 * sigmoid` (range [0, 2], no eps) to match the reference kernel + # (`inference/kernel.py:394`). `pre` and `comb` keep the `sigmoid + eps` form + # they share with the kernel (lines 392, 408). + post = 2 * torch.sigmoid(mix[..., hc : 2 * hc] * post_scale + self.base[hc : 2 * hc]) + # Sinkhorn init mirrors `inference/kernel.py:401-413`: row-softmax + eps, + # then one column-normalisation, then `iters - 1` symmetric (row, col) rounds. + # Different positive starting matrix → different doubly-stochastic fixed point + # than a `sigmoid+eps` init, so this matters even though row/col count nets out. + comb_logits = mix[..., 2 * hc :].view(*mix.shape[:-1], hc, hc) * comb_scale + self.base[2 * hc :].view(hc, hc) + comb = torch.softmax(comb_logits, dim=-1) + self.hc_eps + comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) + for _ in range(self.hc_sinkhorn_iters - 1): comb = comb / (comb.sum(dim=-1, keepdim=True) + self.hc_eps) comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) # Collapse the `hc_mult` parallel streams down to a single sequence using @@ -800,7 +863,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DeepseekV4MLP(LlamaMLP): - pass + """Used by the shared expert. Mirrors reference `inference/model.py:587-606`: + gate/up promoted to fp32, optionally clamped by `swiglu_limit`, SiLU+mul stay in + fp32, then the product is cast back to the input dtype before `down_proj`. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + gate = self.gate_proj(x).float() + up = self.up_proj(x).float() + limit = self.config.swiglu_limit + if limit > 0: + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + return self.down_proj((self.act_fn(gate) * up).to(dtype)) @use_experts_implementation @@ -936,16 +1012,22 @@ def forward( # `post` / `comb` come out of the HC modules in fp32 (Sinkhorn projection runs # in float); the .to(dtype) puts everything back to the input dtype before mixing # so both sites stay consistent with `hidden_states`'s entry dtype. + # `comb` is consumed transposed: reference `inference/model.py:685` indexes it as + # `sum_j comb[j, k] * residual[j, d]` (sum over the FIRST hc axis), which is + # equivalent to `comb.T @ residual`. Sinkhorn produces a doubly-stochastic but + # non-symmetric matrix, so the direction matters. dtype = hidden_states.dtype post, comb, collapsed = self.attn_hc(hidden_states) attn_output, _ = self.self_attn(self.input_layernorm(collapsed), **kwargs) hidden_states = post.to(dtype).unsqueeze(-1) * attn_output.unsqueeze(-2) + torch.matmul( - comb.to(dtype), hidden_states + comb.to(dtype).transpose(-1, -2), hidden_states ) post, comb, collapsed = self.ffn_hc(hidden_states) mlp_output = self.mlp(self.post_attention_layernorm(collapsed), input_ids=input_ids) - return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul(comb.to(dtype), hidden_states) + return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul( + comb.to(dtype).transpose(-1, -2), hidden_states + ) class DeepseekV4PreTrainedModel(MixtralPreTrainedModel): @@ -1077,7 +1159,14 @@ def forward( position_ids=position_ids, ) hidden_states = inputs_embeds.unsqueeze(2).expand(-1, -1, self.config.hc_mult, -1).contiguous() - position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="main") + # Per-layer rope: sliding-attention layers use base RoPE (`rope_theta=10000`, + # no YaRN), CSA/HCA layers use `compress_rope_theta=160000` with YaRN — + # mirrors reference `inference/model.py:475-481`. Each attention picks the + # right branch via `self.rope_branch`. + position_embeddings = { + "sliding": self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="sliding"), + "compress": self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="compress"), + } for layer in self.layers: hidden_states = layer( diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index be10624d4842..10e5ca45cbf0 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -105,6 +105,24 @@ def _process_model_before_weight_loading( model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules ) + # DeepSeek-V4-Flash checkpoints mix FP8 and bf16 projections in the attention + # compressor/indexer branch: modules below are stored directly in bf16 (no + # companion scale tensor), so converting them to FP8Linear would create + # missing ``weight_scale_inv`` keys at load and random-init those params. + # Use `config.model_type` (not class-name substring) so this still applies + # with wrappers / auto classes where `model.__class__.__name__` may differ. + if self.pre_quantized and getattr(getattr(model, "config", None), "model_type", None) == "deepseek_v4": + self.modules_to_not_convert.extend( + [ + "self_attn.compressor.kv_proj", + "self_attn.compressor.gate_proj", + "self_attn.compressor.indexer.kv_proj", + "self_attn.compressor.indexer.gate_proj", + "self_attn.compressor.indexer.weights_proj", + ] + ) + self.modules_to_not_convert = list(set(self.modules_to_not_convert)) + model = replace_with_fp8_linear( model, modules_to_not_convert=self.modules_to_not_convert, @@ -168,59 +186,82 @@ def get_weight_conversions(self): return [] def update_weight_conversions(self, weight_conversions): - """When loading with ``dequantize=True``, attach an :class:`Fp8Dequantize` op to - every existing :class:`WeightConverter` so that per-block scales are folded into - the weight *before* any later merge/concat ops collapse the per-expert structure. - - For each model-supplied converter that has a ``.weight`` source, we: - 1. anchor the existing weight patterns with ``$`` so they don't accidentally - also match the ``.weight_scale_inv`` keys (the regex is searched, so the - unanchored prefix would match both, sending scales to the wrong bucket); - 2. add anchored ``*.weight_scale_inv`` sources next to each weight pattern so - the loader collects scale tensors alongside the weight tensors into the - *same* converter bucket (both keys rewrite to the same target); - 3. prepend a fresh :class:`Fp8Dequantize` op so dequant runs first, before - any merge/concat collapses the per-expert structure. - - The generic ``weight$ + weight_scale_inv → weight`` converter from - :meth:`get_weight_conversions` is still appended at the end as a fallback for - plain ``nn.Linear`` weights with no model-specific converter. + """Adapt model-supplied :class:`WeightConverter` instances for FP8 loading. + + Two paths share the same underlying problem: the model's converters merge per-expert + ``.weight`` keys into batched ``gate_up_proj`` / ``down_proj`` tensors, but the + FP8 ``.weight_scale_inv`` siblings need the same merge to populate + ``gate_up_proj_scale_inv`` / ``down_proj_scale_inv`` correctly. The unanchored + ``.weight`` source patterns also accidentally match ``.weight_scale_inv`` keys + (regex-search), so we anchor them first and synthesize parallel scale converters. + + - ``dequantize=True``: scales fold into the weight via :class:`Fp8Dequantize` + *before* merge/concat — both `.weight` and `.weight_scale_inv` go to the same + target bucket. + - ``dequantize=False`` (native quantized): scales stay as separate parameters, + so we synthesize a parallel ``WeightConverter`` per model converter that + targets ``_scale_inv`` and reuses the same merge ops. + + We also prepend a ``.scale → .weight_scale_inv`` rename for upstream checkpoints + (e.g. DeepSeek-V4-Flash) that ship per-block scales under the ``.scale`` suffix. """ - if not (self.pre_quantized and self.quantization_config.dequantize): - return weight_conversions + self.get_weight_conversions() - from ..core_model_loading import WeightConverter, WeightRenaming from ..integrations.finegrained_fp8 import Fp8Dequantize - # Some upstream FP8 checkpoints (e.g. DeepSeek-V4-Flash) ship per-block scales - # under a ``.scale`` suffix instead of HF's canonical ``.weight_scale_inv``. - # Prepending the rename here (instead of in each model's conversion_mapping) - # keeps the model-side mapping clean — the rename only kicks in when FP8 dequant - # is actually active, so a non-FP8 save / load round-trip doesn't see a stray - # rule that ``test_reverse_loading_mapping`` can't match. + # `.scale` → `.weight_scale_inv`. Confined to this quantizer, so non-FP8 loads never see it. scale_rename = WeightRenaming(source_patterns=r"^(.+)\.scale$", target_patterns=r"\1.weight_scale_inv") weight_conversions = [scale_rename] + list(weight_conversions) + dequantize = self.pre_quantized and self.quantization_config.dequantize updated: list = [] for conv in weight_conversions: - # Only WeightConverter has ``.operations`` to extend with the dequant op; - # WeightRenaming (e.g. the ``scale_rename`` we prepended) just passes through. + # Only WeightConverter has `.operations` and `.source_patterns` to manipulate. if not isinstance(conv, WeightConverter): updated.append(conv) continue weight_sources = [p for p in conv.source_patterns if p.endswith(".weight")] - if weight_sources: - anchored_weight = [p + "$" for p in weight_sources] - scale_sources = [p[: -len(".weight")] + ".weight_scale_inv$" for p in weight_sources] - other = [p for p in conv.source_patterns if not p.endswith(".weight")] - new_sources = anchored_weight + scale_sources + other - new_ops = [Fp8Dequantize(self)] + list(conv.operations) - conv = WeightConverter( - source_patterns=new_sources, - target_patterns=conv._original_target_patterns, - operations=new_ops, + if not weight_sources: + updated.append(conv) + continue + + anchored_weight = [p + "$" for p in weight_sources] + scale_sources = [p[: -len(".weight")] + ".weight_scale_inv$" for p in weight_sources] + other = [p for p in conv.source_patterns if not p.endswith(".weight")] + + if dequantize: + # Both .weight and .weight_scale_inv go to the same target; Fp8Dequantize + # folds scales into the weight before any merge/concat downstream. + updated.append( + WeightConverter( + source_patterns=anchored_weight + scale_sources + other, + target_patterns=conv._original_target_patterns, + operations=[Fp8Dequantize(self)] + list(conv.operations), + ) + ) + else: + # Native quantized: anchor the existing weight converter so .weight_scale_inv + # keys don't leak in, then synthesize a parallel converter routing the scale + # keys to `_scale_inv` with the same merge ops. + updated.append( + WeightConverter( + source_patterns=anchored_weight + other, + target_patterns=conv._original_target_patterns, + operations=list(conv.operations), + ) ) - updated.append(conv) - # Generic fallback for plain ``nn.Linear`` weights with no model-specific converter. + target = conv._original_target_patterns + if isinstance(target, str): + scale_target = target + "_scale_inv" + else: + scale_target = [t + "_scale_inv" for t in target] + updated.append( + WeightConverter( + source_patterns=scale_sources, + target_patterns=scale_target, + operations=list(conv.operations), + ) + ) + + # Generic fallback for plain `nn.Linear` weights with no model-specific converter. updated.extend(self.get_weight_conversions()) return updated diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index a6c1f5334516..11371a81505c 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1686,6 +1686,12 @@ class FineGrainedFP8Config(QuantizationConfigMixin): Whether to dequantize the model during loading. modules_to_not_convert (`list`, *optional*): A list of module names that should not be converted during quantization. + scale_fmt (`str`, *optional*, defaults to `"float"`): + Storage dtype of the per-block weight scales: + - `"float"`: fp32 scales (DeepSeek V3-style; Hopper SM90 1D2D path). + - `"ue8m0"`: 1-byte `torch.float8_e8m0fnu` scales (DeepSeek V4-style). At GEMM time + the underlying memory is reinterpreted as `torch.int32` (4 UE8M0 bytes per int), + feeding the SM100 1D1D dispatch directly with no float→int transform per call. """ def __init__( @@ -1694,6 +1700,7 @@ def __init__( weight_block_size: tuple[int, int] = (128, 128), dequantize: bool = False, modules_to_not_convert: list | None = None, + scale_fmt: str = "float", **kwargs, ): self.quant_method = QuantizationMethod.FP8 @@ -1701,6 +1708,7 @@ def __init__( self.activation_scheme = activation_scheme self.weight_block_size = weight_block_size self.dequantize = dequantize + self.scale_fmt = scale_fmt self.post_init() def post_init(self): @@ -1714,6 +1722,8 @@ def post_init(self): raise ValueError("weight_block_size must be a tuple of two integers") if self.weight_block_size is not None and (self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0): raise ValueError("weight_block_size must be a tuple of two positive integers") + if self.scale_fmt not in ("float", "ue8m0"): + raise ValueError(f"scale_fmt must be 'float' or 'ue8m0'; got {self.scale_fmt!r}") def get_loading_attributes(self): return {"dequantize": self.dequantize, "modules_to_not_convert": self.modules_to_not_convert} diff --git a/test_deepseek.py b/test_deepseek.py new file mode 100644 index 000000000000..93914db804ce --- /dev/null +++ b/test_deepseek.py @@ -0,0 +1,302 @@ +"""End-to-end DeepGEMM EP test on a real DeepSeek checkpoint. + +Drives the FP8/FP4 DeepGEMM dispatches against `deepseek-ai/DeepSeek-V4-Flash` +with one model load shared across dispatches via `model.set_experts_implementation`: + + - `deepgemm` → `deepgemm_fp8_fp4_experts_forward` + - `deepgemm_megamoe` → fused Mega MoE (skipped on < SM100) + +Dequantize-on-load is intentionally NOT exercised here: V4-Flash is ~671B +parameters; dequantizing to bf16 needs ~1.3 TB across the world, which doesn't +fit in 8× B200 (178 GB each). For the bf16 / dequantized expert dispatches +(`grouped_mm`, `sonicmoe`, `deepgemm` bf16) use a smaller MoE checkpoint. + +Run on B200 with a writable HF cache on the raid mount and torchrun. First run +downloads the checkpoint (hundreds of GB). + + HF_HOME=/raid/ilyas \\ + CUDA_HOME=$HOME/cuda-12.9 \\ + torchrun --nproc_per_node=8 test_deepseek.py + +DeepSeek-V3.2 is intentionally not included: this transformers checkout only +registers `deepseek_v3` / `deepseek_v4`, not `deepseek_v32`. +""" + +from __future__ import annotations + +import gc +import os +import sys +import traceback + +import torch +import torch.distributed as dist + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.distributed import DistributedConfig +from transformers.utils.quantization_config import FineGrainedFP8Config + + +_CHECKPOINT = "deepseek-ai/DeepSeek-V4-Flash" +_PROMPTS = [ + "The capital of France is", + "List the first ten prime numbers:", + "Translate to French: 'The quick brown fox jumps over the lazy dog.'", + "Write a Python function fibonacci(n) that returns the nth Fibonacci number.", + "What are the three properties of the UE8M0 scale factor format?", + 'Write a short story that begins with: "Once upon a time, in a forest far away, there lived a..."', +] + + +def _format_chat(prompt: str) -> str: + """Wrap a single-turn user prompt in V4-Flash's chat-mode template. + + The tokenizer doesn't ship a Jinja `chat_template`; the canonical format is in + `encoding/encoding_dsv4.py` on the model repo. Chat mode places `` + right after `<|Assistant|>` to skip the reasoning block and answer directly. + """ + return f"<|begin▁of▁sentence|><|User|>{prompt}<|Assistant|>" + +# (label, dispatch, min_sm). All entries share one model load. +_QUANTIZED_DISPATCHES = [ + ("quantized + deepgemm", "deepgemm", 9), + # ("quantized + deepgemm_megamoe", "deepgemm_megamoe", 10), +] + + +def _rank0_print(msg: str) -> None: + if int(os.environ.get("RANK", "0")) == 0: + print(msg, flush=True) + + +def _render_report(results: list[tuple[str, str, str]], completions: dict[str, list[str]]) -> None: + """Render a side-by-side rich table comparing each dispatch's per-prompt completion.""" + from rich.console import Console + from rich.table import Table + + console = Console() + console.print("") + + # Per-dispatch status row first. + status_table = Table(title="Run summary", show_lines=False) + status_table.add_column("dispatch", style="bold") + status_table.add_column("status") + status_table.add_column("detail", overflow="fold") + for label, status, detail in results: + style = "green" if status == "PASS" else "yellow" if status == "SKIP" else "red" + status_table.add_row(label, f"[{style}]{status}[/{style}]", detail) + console.print(status_table) + + if not completions: + return + + # Per-prompt completions side by side. Each dispatch gets its own column. + dispatch_keys = list(completions.keys()) + title = "completions: " + " vs ".join(dispatch_keys) + completion_table = Table(title=title, show_lines=True) + completion_table.add_column("#", justify="right", no_wrap=True) + completion_table.add_column("prompt", overflow="fold", max_width=30) + for d in dispatch_keys: + completion_table.add_column(d, overflow="fold", max_width=42) + + for i, prompt in enumerate(_PROMPTS): + row = [str(i + 1), prompt] + for d in dispatch_keys: + comps = completions.get(d, []) + row.append(comps[i] if i < len(comps) else "") + completion_table.add_row(*row) + console.print(completion_table) + + +def _generate_and_check(model, tok, label: str, rank: int, completions: list[str]) -> None: + for i, prompt in enumerate(_PROMPTS): + inputs = tok(_format_chat(prompt), return_tensors="pt", add_special_tokens=False).to(model.device) + dist.barrier() + with torch.no_grad(): + out_ids = model.generate( + **inputs, + max_new_tokens=64, + do_sample=False, + pad_token_id=tok.eos_token_id, + ) + if rank == 0: + finite = torch.isfinite(out_ids.float()).all().item() + new_tokens = out_ids[0, inputs.input_ids.size(1):] + completion = tok.decode(new_tokens, skip_special_tokens=True) + completion_raw = tok.decode(new_tokens, skip_special_tokens=False) + completions.append(completion) + print(f"[{label}] prompt {i + 1}/{len(_PROMPTS)} — {new_tokens.numel()} tokens (finite={finite}):", flush=True) + print(f" completion: {completion!r}", flush=True) + print(f" raw decode: {completion_raw!r}", flush=True) + print(f" token ids: {new_tokens.tolist()}", flush=True) + if not finite or new_tokens.numel() == 0: + raise RuntimeError(f"{label}: generation failed (finite={finite}, n={new_tokens.numel()})") + dist.barrier() + + +def _run_dequant_phase(results: list, completions: dict[str, list[str]]) -> None: + """Dequantized (bf16) baseline via FineGrainedFP8Config(dequantize=True) + + `device_map="auto"` CPU offload. Runs only on rank 0 — the dequantized model + is ~1.3 TB and we use the host's RAM (`max_memory`) to keep all 8 ranks' + activations from contending with the model weights on GPU 0. + """ + label = "dequantized (bf16)" + dispatch = "dequantized" + print(f"\n--- loading {_CHECKPOINT} (dequantize=True, device_map=auto across all GPUs) ---", flush=True) + # Spread the bf16 model across all 8 GPUs (~1.36 TB total) instead of GPU 0 + CPU offload — + # the latter forces every forward pass to stream weights over PCIe and is ~10× slower. + # Other torchrun ranks have torn down by now, so GPUs 1..N are free. + n_gpus = torch.cuda.device_count() + max_memory = {i: "170GiB" for i in range(n_gpus)} + max_memory["cpu"] = "1500GiB" # fallback for any leftover + qcfg = FineGrainedFP8Config(dequantize=True) + try: + model = AutoModelForCausalLM.from_pretrained( + _CHECKPOINT, + device_map="auto", + dtype="auto", + quantization_config=qcfg, + max_memory=max_memory, + ) + model.eval() + tok = AutoTokenizer.from_pretrained(_CHECKPOINT) + except BaseException as exc: + print(f"[load] FAIL — {type(exc).__name__}: {exc}", flush=True) + results.append((label, "FAIL", f"load: {type(exc).__name__}: {exc}")) + return + + completions[dispatch] = [] + print(f"\n=== {label} ===", flush=True) + try: + for i, prompt in enumerate(_PROMPTS): + inputs = tok(_format_chat(prompt), return_tensors="pt", add_special_tokens=False).to(model.device) + with torch.no_grad(): + out_ids = model.generate( + **inputs, + max_new_tokens=64, + do_sample=False, + pad_token_id=tok.eos_token_id, + ) + new_tokens = out_ids[0, inputs.input_ids.size(1):] + completion = tok.decode(new_tokens, skip_special_tokens=True) + completion_raw = tok.decode(new_tokens, skip_special_tokens=False) + completions[dispatch].append(completion) + print(f"[{label}] prompt {i + 1}/{len(_PROMPTS)} — {new_tokens.numel()} tokens:", flush=True) + print(f" completion: {completion!r}", flush=True) + print(f" raw decode: {completion_raw!r}", flush=True) + print(f" token ids: {new_tokens.tolist()}", flush=True) + results.append((label, "PASS", "")) + except BaseException as exc: + print(f"[{label}] FAIL — {type(exc).__name__}: {exc}", flush=True) + traceback.print_exc() + results.append((label, "FAIL", f"{type(exc).__name__}: {exc}")) + finally: + del model, tok + gc.collect() + torch.cuda.empty_cache() + + +def _run_phase( + load_kwargs: dict, + dispatches, + cap_major: int, + rank: int, + results: list, + completions: dict[str, list[str]], +) -> None: + runnable = [(lab, d) for (lab, d, sm) in dispatches if cap_major >= sm] + skipped = [(lab, d, sm) for (lab, d, sm) in dispatches if cap_major < sm] + for lab, _, sm in skipped: + _rank0_print(f"[{lab}] SKIP: needs SM{sm}0+, got SM{cap_major}0") + results.append((lab, "SKIP", f"needs SM{sm}0+")) + if not runnable: + return + + _rank0_print(f"\n--- loading {_CHECKPOINT} (kwargs: {sorted(load_kwargs)}) ---") + try: + model = AutoModelForCausalLM.from_pretrained( + _CHECKPOINT, + tp_plan="auto", + distributed_config=DistributedConfig(enable_expert_parallel=True), + **load_kwargs, + ) + model.eval() + tok = AutoTokenizer.from_pretrained(_CHECKPOINT) + except BaseException as exc: + if rank == 0: + print(f"[load] FAIL — {type(exc).__name__}: {exc}", flush=True) + for label, _ in runnable: + results.append((label, "FAIL", f"load: {type(exc).__name__}: {exc}")) + return + + try: + for label, dispatch in runnable: + _rank0_print(f"\n=== {label} ===") + try: + model.set_experts_implementation(dispatch) + completions[dispatch] = [] + _generate_and_check(model, tok, label, rank, completions[dispatch]) + results.append((label, "PASS", "")) + except BaseException as exc: + if rank == 0: + print(f"[{label}] FAIL — {type(exc).__name__}: {exc}", flush=True) + traceback.print_exc() + results.append((label, "FAIL", f"{type(exc).__name__}: {exc}")) + finally: + del model, tok + gc.collect() + torch.cuda.empty_cache() + + +def main() -> int: + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + if world_size < 2: + sys.exit("EP test needs >=2 ranks (run with `torchrun --nproc_per_node=N`).") + + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group("nccl", device_id=torch.device("cuda", local_rank)) + + cap_major = torch.cuda.get_device_capability()[0] + _rank0_print(f"device cap: SM{cap_major}0, world_size={world_size}") + + results: list[tuple[str, str, str]] = [] # (label, status, detail) + completions: dict[str, list[str]] = {} # dispatch → per-prompt completions + + _run_phase( + load_kwargs={"dtype": "auto"}, + dispatches=_QUANTIZED_DISPATCHES, + cap_major=cap_major, + rank=rank, + results=results, + completions=completions, + ) + + dist.barrier() + dist.destroy_process_group() + + # Dequantized bf16 baseline: rank 0 only, spread across all GPUs once + # ranks 1..N have exited and released their per-rank quantized weights. + if rank == 0: + import time + time.sleep(5) + torch.cuda.empty_cache() + _run_dequant_phase(results, completions) + + if rank == 0: + passed = [r for r in results if r[1] == "PASS"] + failed = [r for r in results if r[1] == "FAIL"] + skipped = [r for r in results if r[1] == "SKIP"] + _render_report(results, completions) + print( + f"\n totals: {len(passed)} passed, {len(failed)} failed, {len(skipped)} skipped", + flush=True, + ) + return 1 if failed else 0 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f7c2bf3fbb76..97c76d87737a 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -53,6 +53,7 @@ ) from transformers.integrations.moe import ( batched_mm_experts_forward, + deepgemm_bf16_experts_forward, grouped_mm_experts_forward, sonicmoe_experts_forward, ) @@ -118,6 +119,7 @@ is_torch_bf16_available_on_device, is_torch_fp16_available_on_device, ) +from transformers.utils.import_utils import get_cuda_runtime_version from transformers.utils.output_capturing import CompileableContextVar from .generation.test_utils import GenerationTesterMixin @@ -600,6 +602,17 @@ def _test_eager_matches_batched_and_grouped_inference(self, name, dtype): mocks["sonicmoe"] = Mock(wraps=sonicmoe_experts_forward) implementations.append("sonicmoe") + if ( + dtype == torch.bfloat16 + and is_kernels_available() + and torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and get_cuda_runtime_version() >= (12, 3) + ): + # DeepGEMM BF16 grouped forward requires Hopper+, CUDA runtime 12.3+, and bf16 hidden states + mocks["deepgemm"] = Mock(wraps=deepgemm_bf16_experts_forward) + implementations.append("deepgemm") + outputs = {} # This is needed because we call the functions through the interface's global mapping with patch.dict("transformers.integrations.moe.ALL_EXPERTS_FUNCTIONS._global_mapping", mocks):