Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
94a2786
init
IlyasMoutawwakil Apr 23, 2026
8fc43d7
Merge branch 'main' into deepgemm
IlyasMoutawwakil Apr 23, 2026
357a035
style
IlyasMoutawwakil Apr 23, 2026
741b5eb
full support
IlyasMoutawwakil Apr 23, 2026
9fc3662
support EP better using offsets !
IlyasMoutawwakil Apr 24, 2026
84552ae
comments
IlyasMoutawwakil Apr 24, 2026
1d9f319
get rid of neutralize_ep_sentinels
IlyasMoutawwakil Apr 24, 2026
9b86043
remove deepgemm stuff
IlyasMoutawwakil Apr 24, 2026
8a3f88b
Merge branch 'main' into deepgemm
IlyasMoutawwakil Apr 24, 2026
996d67d
fix
IlyasMoutawwakil Apr 24, 2026
d033a83
prefix
IlyasMoutawwakil Apr 24, 2026
e15cfe6
move
IlyasMoutawwakil Apr 24, 2026
10b6d90
fix
IlyasMoutawwakil Apr 24, 2026
d4a6b30
remove comment
IlyasMoutawwakil Apr 24, 2026
1d6054f
fix unintilized outputs leaking
IlyasMoutawwakil Apr 24, 2026
137393c
revert unnecessary changes
IlyasMoutawwakil Apr 24, 2026
774f901
more unnecessary changes
IlyasMoutawwakil Apr 24, 2026
81230fe
revert downcast
IlyasMoutawwakil Apr 24, 2026
9f2ff08
keep it simple
IlyasMoutawwakil Apr 24, 2026
c55b7b7
guard deepgemm cuda version
IlyasMoutawwakil Apr 24, 2026
20858db
fix style
IlyasMoutawwakil Apr 24, 2026
bfea94f
update
IlyasMoutawwakil Apr 24, 2026
eada47e
add deepgemm testing
IlyasMoutawwakil Apr 24, 2026
89d2f0b
moe sentinel support
IlyasMoutawwakil Apr 26, 2026
ef72b0e
Merge branch 'main' into deepgemm
IlyasMoutawwakil Apr 26, 2026
60db1ca
fix
IlyasMoutawwakil Apr 26, 2026
e732af1
Merge branch 'main' into deepgemm
IlyasMoutawwakil Apr 27, 2026
68b7b0f
compilable sonicmoe
IlyasMoutawwakil Apr 27, 2026
faaa7aa
mega moe kernel support attempt
IlyasMoutawwakil Apr 29, 2026
b502cd6
use package for now
IlyasMoutawwakil Apr 29, 2026
cc753ca
Merge branch 'deepgemm' into deepgemm-isolation
IlyasMoutawwakil Apr 29, 2026
1c17452
skip ep router and experts pre/post processing
IlyasMoutawwakil Apr 29, 2026
5a8ceae
simpler
IlyasMoutawwakil Apr 29, 2026
a05aa39
fix
IlyasMoutawwakil Apr 29, 2026
854f4fb
Merge branch 'main' into deepgemm-isolation
IlyasMoutawwakil Apr 29, 2026
053c9df
fix
IlyasMoutawwakil Apr 29, 2026
80a6fe5
fix
IlyasMoutawwakil May 1, 2026
ad8226c
dtensor support
IlyasMoutawwakil May 1, 2026
a663f4d
more dtensor
IlyasMoutawwakil May 1, 2026
74c3f2e
simpler
IlyasMoutawwakil May 1, 2026
d3cae33
remove comment
IlyasMoutawwakil May 1, 2026
c1bfa0f
Merge branch 'deepgemm' into deepgemm-isolation
IlyasMoutawwakil May 2, 2026
6c43611
Merge branch 'main' into deepgemm-isolation
IlyasMoutawwakil May 4, 2026
0528e0e
revert
IlyasMoutawwakil May 4, 2026
2bfd029
bc order
IlyasMoutawwakil May 4, 2026
368db00
revert extra indent
IlyasMoutawwakil May 4, 2026
28fbb84
revert unnecessary change
IlyasMoutawwakil May 4, 2026
6ef27ab
update
IlyasMoutawwakil May 4, 2026
e436b77
less defensive
IlyasMoutawwakil May 4, 2026
8e8f0ee
allow all kernels
IlyasMoutawwakil May 5, 2026
c494e35
alow all kernels
IlyasMoutawwakil May 5, 2026
03b0442
hub only
IlyasMoutawwakil May 5, 2026
527d8b4
Merge branch 'main' into deepgemm-isolation
IlyasMoutawwakil May 5, 2026
2e51b3c
fix
IlyasMoutawwakil May 5, 2026
82c5fb5
fix
IlyasMoutawwakil May 5, 2026
99fdf71
test
IlyasMoutawwakil May 5, 2026
21398a8
test
IlyasMoutawwakil May 5, 2026
f21db7c
sync
IlyasMoutawwakil May 6, 2026
1fe3768
check nvcc
IlyasMoutawwakil May 6, 2026
bbfda35
probe
IlyasMoutawwakil May 6, 2026
68a373c
fix
IlyasMoutawwakil May 6, 2026
825f6d4
test psum
IlyasMoutawwakil May 6, 2026
028a39f
test
IlyasMoutawwakil May 6, 2026
ee173a5
test
IlyasMoutawwakil May 6, 2026
900e984
probe
IlyasMoutawwakil May 6, 2026
09a711f
fix
IlyasMoutawwakil May 6, 2026
a0d4940
test
IlyasMoutawwakil May 6, 2026
653b7b3
nan issue
IlyasMoutawwakil May 6, 2026
ed8af6b
repro
IlyasMoutawwakil May 6, 2026
fb8d338
repro
IlyasMoutawwakil May 6, 2026
562eb51
fix
IlyasMoutawwakil May 6, 2026
481e5e6
simplifications
IlyasMoutawwakil May 6, 2026
d3dbd32
fix
IlyasMoutawwakil May 6, 2026
9f168ff
fix
IlyasMoutawwakil May 6, 2026
7274c22
fix
IlyasMoutawwakil May 6, 2026
a800b8c
fix
IlyasMoutawwakil May 6, 2026
804c988
fix
IlyasMoutawwakil May 6, 2026
7f36562
fix
IlyasMoutawwakil May 6, 2026
6df9f47
fix
IlyasMoutawwakil May 6, 2026
c3432c9
empty
IlyasMoutawwakil May 6, 2026
77f09fe
simplify
IlyasMoutawwakil May 6, 2026
7dfbedd
test deepseek
IlyasMoutawwakil May 6, 2026
689cc29
dsv4 only
IlyasMoutawwakil May 6, 2026
8de089c
download dsv4
IlyasMoutawwakil May 6, 2026
cb4d6f9
fix test
IlyasMoutawwakil May 6, 2026
389eee8
Merge branch 'main' into deepgemm-isolation
IlyasMoutawwakil May 6, 2026
8f29ed6
push
IlyasMoutawwakil May 6, 2026
946e200
test
IlyasMoutawwakil May 6, 2026
6619647
fix
IlyasMoutawwakil May 6, 2026
913339a
fix ep plan
IlyasMoutawwakil May 6, 2026
7cfc2b2
fix attempt
IlyasMoutawwakil May 6, 2026
4c12f6e
debug
IlyasMoutawwakil May 6, 2026
92e6337
attempt
IlyasMoutawwakil May 6, 2026
a9997f5
debug
IlyasMoutawwakil May 6, 2026
197a8f6
fixes in modeling
IlyasMoutawwakil May 11, 2026
30d0780
more modeling changes
IlyasMoutawwakil May 11, 2026
3b1c470
more modeling changes
IlyasMoutawwakil May 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
566 changes: 566 additions & 0 deletions src/transformers/integrations/deepgemm.py

Large diffs are not rendered by default.

551 changes: 196 additions & 355 deletions src/transformers/integrations/finegrained_fp8.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def register_kernel_mapping_transformers(*args, **kwargs):
"mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1},
"falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1},
"finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1},
"deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1},
"deep-gemm": {"repo_id": "adarshxs/deep-gemm", "revision": "v2"},
"sonic-moe": {"repo_id": "kernels-community/sonic-moe", "revision": "ep-support"},
}

Expand Down Expand Up @@ -346,7 +346,7 @@ def load_and_register_attn_kernel(

# Load the kernel from hub
try:
kernel = get_kernel(repo_id, revision=rev, allow_all_kernels=allow_all_kernels)
kernel = get_kernel(repo_id, revision=rev, allow_all_kernels=True)
except Exception as e:
raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.")
# correctly wrap the kernel
Expand Down Expand Up @@ -376,7 +376,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None)
kernel = get_kernel(repo_id, revision=revision, version=version)
kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True)
mapping[kernel_name] = kernel
except FileNotFoundError:
mapping[kernel_name] = None
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
is_torch_less_or_equal,
is_torchdynamo_compiling,
)
from .deepgemm import deepgemm_bf16_experts_forward
from .sonicmoe import sonicmoe_experts_forward


Expand Down Expand Up @@ -478,6 +479,7 @@ class ExpertsInterface(GeneralInterface):
"""Interface for registering custom experts forward functions."""

_global_mapping = {
"deepgemm": deepgemm_bf16_experts_forward,
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
"sonicmoe": sonicmoe_experts_forward,
Expand Down
62 changes: 60 additions & 2 deletions src/transformers/integrations/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@
logger = logging.get_logger(__name__)


def to_local(t):
"""Unwrap a `DTensor` to its local shard if needed; pass through otherwise.

Custom kernels (CUTLASS, CuteDSL, Triton) take raw tensor pointers and don't
understand `DTensor`, so weights wrapped by FSDP2 / EP need this unwrap before
they can be fed to the kernel. ``to_local()`` is autograd-aware on the train
path: backward rewraps the gradient as a DTensor matching each parameter's
placements.
"""
if is_torch_available() and isinstance(t, torch.distributed.tensor.DTensor):
return t.to_local()
return t


def initialize_tensor_parallelism(
tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
):
Expand Down Expand Up @@ -766,6 +780,28 @@ def _backward_hook(mod, grad_input, grad_output, mesh=device_mesh):
module.register_full_backward_hook(_backward_hook)


class AllReduceParallel(TensorParallelLayer):
"""
Marker layer: parameters (if any) are replicated; the forward output is all-reduced
across the TP mesh. Use as a no-op `nn.Identity` placed at a sync point after a
colwise-sharded compute that ends in a head-axis (or similar) reduction, so each
rank holds only a partial sum and needs to share it before the next dependent op
(e.g. the lightning indexer's score sum before its top-k).
"""

def _prepare_input_fn(self, mod, inputs, device_mesh):
return inputs

def _prepare_output_fn(self, mod, outputs, device_mesh):
return all_reduce_forward(outputs, device_mesh)

def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
return param[...].to(device=device, dtype=dtype)

def prepare_module_tp(self, module, device_mesh, **kwargs):
distribute_module(module, device_mesh, output_fn=self._prepare_output_fn)


class MlaKvAProjParallel(TensorParallelLayer):
"""
For MLA attention used in DeepSeek-V2 style models (deepseek_v2, longcat_flash, glm_moe_dsa, glm4_moe_lite):
Expand Down Expand Up @@ -1088,7 +1124,7 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)

def _prepare_input_fn(self, mod, inputs, device_mesh):
return inputs[0] if inputs else inputs
return inputs

def _prepare_output_fn(self, mod, outputs, device_mesh):
"""
Expand Down Expand Up @@ -1135,7 +1171,13 @@ def _prepare_output_fn(self, mod, outputs, device_mesh):
The sentinel index (num_local_experts) is skipped by one_hot encoding or clamped
+ masked in grouped_mm/batched_mm. After the expert forward, an all_reduce sums
partial outputs across EP ranks to produce the full result.

Mega MoE skips this remap: its kernel does the EP dispatch itself and wants raw
global expert ids with unmasked routing weights.
"""
if _is_megamoe(mod):
return outputs

ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
num_experts = getattr(mod, "num_experts", None)
if num_experts is None:
Expand Down Expand Up @@ -1183,6 +1225,12 @@ def _prepare_input_fn(self, mod, inputs, device_mesh):
top_k_index = inputs[1]
top_k_weights = inputs[2]

# Mega MoE is inference-only (the kernel has no backward) and handles EP
# dispatch + combine + per-rank token sharding internally. Skip the gradient
# sync hooks and append the EP `process_group` so the forward can rendezvous.
if _is_megamoe(mod):
return hidden_states, top_k_index, top_k_weights, device_mesh.get_group()

# all_reduce_backward on hidden_states for correct colwise (gate_up_proj) gradient
hidden_states = all_reduce_backward(hidden_states, device_mesh)

Expand All @@ -1191,9 +1239,12 @@ def _prepare_input_fn(self, mod, inputs, device_mesh):
# and partial_expert_output is different on each GPU before all-reduce
top_k_weights = all_reduce_backward(top_k_weights, device_mesh)

return (hidden_states, top_k_index, top_k_weights)
return hidden_states, top_k_index, top_k_weights

def _prepare_output_fn(self, mod, outputs, device_mesh):
# Mega MoE returned the fully-combined gathered output; skip the all-reduce.
if _is_megamoe(mod):
return outputs
# all_reduce_forward to sum partial expert outputs across GPUs
return all_reduce_forward(outputs, device_mesh)

Expand All @@ -1205,6 +1256,10 @@ def shard_tensor(
return param[...].to(device=device, dtype=dtype)


def _is_megamoe(mod: nn.Module) -> bool:
return getattr(getattr(mod, "config", None), "_experts_implementation", None) == "deepgemm_megamoe"


class MoeIdentityExpertParallel(TensorParallelLayer):
"""
TP class for zero/identity experts in MoE layers.
Expand Down Expand Up @@ -1247,6 +1302,7 @@ class ParallelInterface(GeneralInterface):
"moe_identity_expert": MoeIdentityExpertParallel(),
"replicated_with_grad_allreduce": ReplicatedWithGradAllReduce(),
"mla_kv_a_proj": MlaKvAProjParallel(),
"all_reduce": AllReduceParallel(),
}
if is_torch_available() and _torch_distributed_available
else {}
Expand All @@ -1267,6 +1323,7 @@ class ParallelInterface(GeneralInterface):
"sequence_parallel": None,
"replicated_with_grad_allreduce": None,
"mla_kv_a_proj": None,
"all_reduce": None,
}

# Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced)
Expand All @@ -1282,6 +1339,7 @@ class ParallelInterface(GeneralInterface):
"sequence_parallel": None,
"replicated_with_grad_allreduce": None,
"mla_kv_a_proj": None,
"all_reduce": None,
}

@classmethod
Expand Down
58 changes: 39 additions & 19 deletions src/transformers/models/deepseek_v4/configuration_deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,32 @@ class DeepseekV4Config(PreTrainedConfig):
"norm": (["hidden_states"], ["hidden_states"]),
}
base_model_ep_plan = {
# EP-only by default, same shape as gpt-oss: route on the gate, run the
# routed experts as a grouped-GEMM kernel sharded along the expert axis,
# and wrap the experts module with `moe_tp_experts` so its output gets
# all-reduced across ranks. Attention stays replicated (V4 is shared-KV
# MQA + a CSA / HCA compressor branch — both broadcast a single KV head
# across all attention heads via `repeat_kv`, so colwise-sharding
# `q_b_proj` would leave KV replicated and `repeat_kv` would no longer
# match the rank-local query head count). The shared MLP also stays
# replicated — it's small and not worth TP-ing. There's deliberately
# no `base_model_tp_plan` for V4: we don't ship a pure-TP plan, only EP.
# V4 ships EP only (no `base_model_tp_plan` — the runtime picks one plan or
# the other, never both, and V4 is MoE so EP is the only sensible config).
# MoE parallelism: route on the gate, run the routed experts as a grouped-GEMM
# kernel sharded along the expert axis, and wrap the experts module with
# `moe_tp_experts` so its output gets all-reduced across ranks. Same shape as
# gpt-oss. Main attention stays replicated: V4 is shared-KV MQA + a CSA / HCA
# compressor branch — both broadcast a single KV head across all attention
# heads via `repeat_kv`, so colwise-sharding `q_b_proj` would leave KV
# replicated and `repeat_kv` would no longer match the rank-local query head
# count. The shared MLP also stays replicated — it's small and not worth
# sharding. The Lightning Indexer is the one carve-out: its keys are
# replicated (own compressor at index_head_dim fed by replicated
# hidden_states), so head-sharding is well-formed; `q_b_proj` and
# `weights_proj` go colwise, and `scores_sync` is a `nn.Identity` whose
# `"all_reduce"` output hook sums the per-rank partial `index_scores` across
# the mesh so every rank picks the same top-k. Mirrors the reference
# inference (`inference/model.py:393, 394, 422-423`).
"layers.*.mlp.gate": "ep_router",
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
"layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm",
"layers.*.mlp.experts.down_proj": "grouped_gemm",
"layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm",
"layers.*.mlp.experts": "moe_tp_experts",
"layers.*.self_attn.compressor.indexer.q_b_proj": "colwise",
"layers.*.self_attn.compressor.indexer.weights_proj": "colwise",
"layers.*.self_attn.compressor.indexer.scores_sync": "all_reduce",
}

vocab_size: int = 129280
Expand Down Expand Up @@ -187,7 +199,7 @@ class DeepseekV4Config(PreTrainedConfig):
# back to wrapping the whole dict as a single set of params when the subset check
# fails, which then warns about `main` / `compress` as unrecognized keys. Override
# to iterate the rope-type-keyed sub-dicts directly.
_rope_type_labels = ("main", "compress")
_rope_type_labels = ("sliding", "compress")

def validate_rope(self):
rope_parameters_dict = getattr(self, "rope_parameters", None) or {}
Expand Down Expand Up @@ -285,19 +297,27 @@ def __post_init__(self, **kwargs):

# `rope_parameters`: split the flat dict (left by `convert_rope_params_to_dict`,
# which folded any legacy `rope_scaling` block in) into per-rope-type
# `{main, compress}` sub-dicts. Idempotent: re-loading an already-split config
# is a no-op via the `isinstance` short-circuit. The two sub-dicts differ only
# in `rope_theta` (main: 10000, compress: 160000).
# `{sliding, compress}` sub-dicts. Mirrors reference `inference/model.py:475-481`:
# sliding-window attention layers use base RoPE (`rope_theta=10000`, no YaRN —
# `original_seq_len=0` disables it); CSA/HCA layers (and their internal
# compressors/indexer) use `compress_rope_theta=160000` with YaRN frequency
# interpolation. Idempotent: re-loading an already-split config is a no-op via
# the `isinstance` short-circuit.
rp = self.rope_parameters or {}
if isinstance(rp.get("main"), dict) and isinstance(rp.get("compress"), dict):
if isinstance(rp.get("sliding"), dict) and isinstance(rp.get("compress"), dict):
# Already nested — drop any leftover top-level keys.
self.rope_parameters = {"main": rp["main"], "compress": rp["compress"]}
self.rope_parameters = {"sliding": rp["sliding"], "compress": rp["compress"]}
else:
base = {k: v for k, v in rp.items() if k not in ("main", "compress")}
base.setdefault("rope_theta", self.rope_theta)
base = {k: v for k, v in rp.items() if k not in ("sliding", "compress")}
base.setdefault("rope_type", "default")
base["partial_rotary_factor"] = self.partial_rotary_factor
self.rope_parameters = {"main": dict(base), "compress": {**base, "rope_theta": self.compress_rope_theta}}
sliding = {
"rope_theta": self.rope_theta,
"rope_type": "default",
"partial_rotary_factor": self.partial_rotary_factor,
}
compress = {**base, "rope_theta": self.compress_rope_theta}
self.rope_parameters = {"sliding": sliding, "compress": compress}


__all__ = ["DeepseekV4Config"]
Loading
Loading