Skip to content

Commit b43d692

Browse files
Kevin WangGoogle-ML-Automation
authored andcommitted
Make pure-JAX batch-split scheduled version of DeepSeekMoELayer.
PiperOrigin-RevId: 864963827
1 parent a894502 commit b43d692

7 files changed

Lines changed: 988 additions & 407 deletions

File tree

src/MaxText/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']
6060
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
6161
logical_axis_rules: [
6262
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6364
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6465
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
6566
['activation_norm_length', ['context']],
@@ -68,10 +69,12 @@ logical_axis_rules: [
6869
['embed_no_exp', ['fsdp']],
6970
['q_lora', ['fsdp']],
7071
['kv_lora', ['fsdp']],
71-
['q_lora_up_proj', ['fsdp_transpose']],
72-
['kv_lora_up_proj', ['fsdp_transpose']],
73-
['q_heads', ['fsdp_transpose']],
74-
['kv_heads', ['fsdp_transpose']],
75-
['heads', ['fsdp_transpose']],
76-
['mlp', ['fsdp_transpose']],
72+
['q_lora_up_proj', ['fsdp_transpose', 'expert']],
73+
['kv_lora_up_proj', ['fsdp_transpose', 'expert']],
74+
['q_heads', ['fsdp_transpose', 'expert']],
75+
['kv_heads', ['fsdp_transpose', 'expert']],
76+
['heads', ['fsdp_transpose', 'expert']],
77+
['mlp', ['fsdp_transpose', 'expert']],
78+
['mlp_only_fsdp_transpose', ['fsdp_transpose']],
79+
['mlp_only_tensor', ['expert']],
7780
]

src/MaxText/configs/models/deepseek3-tiny.yml

Lines changed: 0 additions & 50 deletions
This file was deleted.
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Token sorting for MoE layers."""
16+
17+
import functools
18+
19+
import jax
20+
import jax.numpy as jnp
21+
22+
23+
@functools.partial(jax.custom_vjp, nondiff_argnums=(2,))
24+
def route(
25+
tokens: jax.Array,
26+
selected_experts: jax.Array,
27+
use_custom_mosaic_kernel: bool,
28+
) -> jax.Array:
29+
"""Route tokens to selected experts."""
30+
return _route_fwd(tokens, selected_experts, use_custom_mosaic_kernel)[0]
31+
32+
33+
def _route_fwd(
34+
tokens: jax.Array,
35+
selected_experts: jax.Array,
36+
use_custom_mosaic_kernel: bool,
37+
) -> tuple[jax.Array, jax.Array]:
38+
return (
39+
_route_impl(tokens, selected_experts, use_custom_mosaic_kernel),
40+
selected_experts,
41+
)
42+
43+
44+
def _route_bwd(
45+
use_custom_mosaic_kernel: bool,
46+
residuals: jax.Array,
47+
grads: jax.Array,
48+
) -> tuple[jax.Array, None]:
49+
selected_experts = residuals
50+
return _unroute_impl(grads, selected_experts, use_custom_mosaic_kernel), None
51+
52+
53+
route.defvjp(_route_fwd, _route_bwd)
54+
55+
56+
@functools.partial(jax.custom_vjp, nondiff_argnums=(2,))
57+
def unroute(
58+
tokens: jax.Array,
59+
selected_experts: jax.Array,
60+
use_custom_mosaic_kernel: bool,
61+
) -> jax.Array:
62+
return _unroute_fwd(tokens, selected_experts, use_custom_mosaic_kernel)[0]
63+
64+
65+
def _unroute_fwd(
66+
tokens: jax.Array,
67+
selected_experts: jax.Array,
68+
use_custom_mosaic_kernel: bool,
69+
) -> tuple[jax.Array, jax.Array]:
70+
return (
71+
_unroute_impl(tokens, selected_experts, use_custom_mosaic_kernel),
72+
selected_experts,
73+
)
74+
75+
76+
def _unroute_bwd(
77+
use_custom_mosaic_kernel: bool, residuals: jax.Array, grads: jax.Array
78+
) -> tuple[jax.Array, None]:
79+
selected_experts = residuals
80+
return _route_impl(grads, selected_experts, use_custom_mosaic_kernel), None
81+
82+
83+
unroute.defvjp(_unroute_fwd, _unroute_bwd)
84+
85+
86+
def _route_impl(
87+
tokens: jax.Array,
88+
selected_experts: jax.Array,
89+
use_custom_mosaic_kernel: bool,
90+
) -> jax.Array:
91+
"""Gather `tokens` according to `selected_experts`."""
92+
assert (
93+
tokens.shape[0] == selected_experts.shape[0]
94+
and selected_experts.ndim == 2
95+
), f"{tokens.shape=}, {selected_experts.shape=}"
96+
if use_custom_mosaic_kernel:
97+
raise NotImplementedError("Custom Mosaic kernel not implemented.")
98+
inds = jnp.argsort(jnp.ravel(selected_experts)) // selected_experts.shape[1]
99+
return _sort_impl(tokens, inds, use_custom_mosaic_kernel)
100+
101+
102+
def _unroute_impl(
103+
tokens: jax.Array,
104+
selected_experts: jax.Array,
105+
use_custom_mosaic_kernel: bool,
106+
) -> jax.Array:
107+
assert (
108+
tokens.shape[0] == selected_experts.shape[0] * selected_experts.shape[1]
109+
and selected_experts.ndim == 2
110+
)
111+
inds = jnp.argsort(jnp.argsort(jnp.ravel(selected_experts)))
112+
return jnp.sum(
113+
jnp.reshape(
114+
_sort_impl(tokens, inds, use_custom_mosaic_kernel),
115+
(-1, selected_experts.shape[1]) + tokens.shape[1:],
116+
),
117+
axis=1,
118+
)
119+
120+
121+
def _sort_impl(
122+
tokens: jax.Array, inds: jax.Array, use_custom_mosaic_kernel: bool
123+
) -> jax.Array:
124+
if use_custom_mosaic_kernel:
125+
raise NotImplementedError("Custom Mosaic kernel not implemented.")
126+
else:
127+
return tokens[inds, ...]

src/MaxText/layers/decoders.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from MaxText.layers.quantizations import AqtQuantization as Quant
4444
from MaxText.layers import (
4545
deepseek,
46-
deepseek_batchsplit,
4746
gemma,
4847
gemma2,
4948
gemma3,
@@ -405,10 +404,10 @@ def get_decoder_layers(self):
405404
case DecoderBlockType.MIXTRAL:
406405
return [mixtral.MixtralDecoderLayerToLinen]
407406
case DecoderBlockType.DEEPSEEK:
408-
if self.config.use_batch_split_schedule:
409-
return [deepseek_batchsplit.DeepSeekDenseLayerToLinen, deepseek_batchsplit.DeepSeekMoELayerToLinen]
410-
else:
411-
return [deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen]
407+
return [
408+
deepseek.DeepSeekDenseLayerToLinen,
409+
deepseek.DeepSeekMoELayerToLinen,
410+
]
412411
case DecoderBlockType.GEMMA:
413412
return [gemma.GemmaDecoderLayerToLinen]
414413
case DecoderBlockType.GEMMA2:

src/MaxText/layers/deepseek.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,27 @@
1818

1919
from typing import Optional
2020

21+
from flax import nnx
2122
from jax.ad_checkpoint import checkpoint_name
22-
from jax.sharding import Mesh
2323
import jax.numpy as jnp
24-
25-
from flax import nnx
26-
24+
from jax.sharding import Mesh
2725
from MaxText.common_types import Config
2826
from MaxText.common_types import MODEL_MODE_PREFILL
27+
from maxtext.inference import page_manager
2928
from MaxText.layers import attention_mla
29+
from MaxText.layers import deepseek_batchsplit
3030
from MaxText.layers import initializers
3131
from MaxText.layers import linears
3232
from MaxText.layers import moe
3333
from MaxText.layers import nnx_wrappers
3434
from MaxText.layers import quantizations
3535
from MaxText.layers.linears import Dropout
3636
from MaxText.layers.normalizations import RMSNorm
37-
from MaxText.sharding import maybe_shard_with_logical, create_sharding
38-
from maxtext.inference import page_manager
37+
from MaxText.sharding import create_sharding
38+
from MaxText.sharding import maybe_shard_with_logical
3939
from maxtext.utils import max_utils
4040

41+
4142
# -----------------------------------------
4243
# The Decoder Layer for DeepSeek v3
4344
# -----------------------------------------
@@ -366,6 +367,21 @@ def __call__(
366367
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
367368
if isinstance(inputs, tuple):
368369
inputs = inputs[0]
370+
371+
# If using batch split schedule, call the batch split version of the layer.
372+
if self.config.use_batch_split_schedule:
373+
outputs = deepseek_batchsplit.batch_split_schedule(
374+
inputs,
375+
nnx.to_pure_dict(nnx.state(self, nnx.Param)),
376+
decoder_positions,
377+
decoder_segment_ids,
378+
model_mode=model_mode,
379+
mesh=self.mesh,
380+
quant=self.quant,
381+
cfg=self.config,
382+
)
383+
return outputs, None
384+
369385
x = self.with_logical_constraint(inputs)
370386
x = checkpoint_name(x, "decoder_layer_input")
371387

0 commit comments

Comments
 (0)