Skip to content
Draft

draft #3703

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ topk_routing_group: -1 # number of top groups to route inputs. For EP,
# all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers.
use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits.
batch_split_factor: 1 # the factor by which to split the batch. Only used if use_batch_split_schedule is True.
use_fp8_for_batch_split: False # a flag if to use fp8 for batch split. Only used if use_batch_split_schedule is True.

# For complex architectures like llama4 there are repeated sets of
# inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope]
Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,10 @@ class DeepSeekMoE(BaseModel):
1,
description="Factor by which to split the batch into micro-batches. Only used if use_batch_split_schedule is True.",
)
use_fp8_for_batch_split: bool = Field(
False,
description="Whether to use fp8 for batch split. Only used if use_batch_split_schedule is True.",
)


class Qwen3Next(BaseModel):
Expand Down
6 changes: 4 additions & 2 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,8 @@ def __call__(
# as detected by immutable params, use deepseek_batchsplit custom
# scan with initialized parameters.
if cfg.use_batch_split_schedule and not self.is_mutable_collection("params"):
if cfg.use_qwix_quantization:
if not cfg.use_fp8_for_batch_split:
max_logging.log("Using deepseek_batchsplit_fp8")
y = deepseek_batchsplit_fp8.scan_batch_split_layers(
y,
self.variables["params"]["moe_layers"],
Expand All @@ -931,7 +932,8 @@ def __call__(
policy=policy,
)
else:
# bf16 code path
# bf16 and fp8
max_logging.log("Using deepseek_batchsplit")
y = deepseek_batchsplit.scan_batch_split_layers(
y,
self.variables["params"]["moe_layers"],
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def extract_fn(x):
weights = deepseek_batchsplit.fetch_weights(
nnx.to_pure_dict(nnx.state(self, nnx.Param), extract_fn), self.config.dtype
)
weights = deepseek_batchsplit.gather_weights(weights, self.mesh)
weights = deepseek_batchsplit.gather_weights(weights, self.mesh, use_fp8=self.config.use_fp8_for_batch_split)
outputs, _ = deepseek_batchsplit.batch_split_schedule(
inputs,
weights,
Expand Down
Loading
Loading