Skip to content

Commit 61a14b9

Browse files
committed
Add support for ragged buffer factor
1 parent 64a7746 commit 61a14b9

5 files changed

Lines changed: 113 additions & 10 deletions

File tree

docs/reference/core_concepts/moe_configuration.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ Dropping:
8888
- Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped.
8989
- Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline.
9090

91+
`ragged_buffer_factor`: A scalar multiplier for the size of the ragged buffer (effectively expert capacity). Effective only when `sparse_matmul` is True.
92+
93+
- Value > 0: Uses an explicit buffer size which may drop tokens when this size is exceeded
94+
- Value = -1: Uses a worst case calculated buffer size which is guaranteed to not drop any tokens.
95+
9196
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. Recommended to replace the inefficient scatter-add generated by the `jax.numpy.take` in the backward pass.
9297

9398
`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture.

src/maxtext/configs/base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ num_experts_per_tok: 1
195195
megablox: true
196196
sparse_matmul: true
197197
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
198+
ragged_buffer_factor: -1.0 # a factor to determine the size of the ragged buffer for routed MoE activations.
199+
# By default (-1), the routed buffer is worst case size to ensure no dropping.
200+
# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates
201+
# a size larger than this then tokens are dropped.
202+
# In general if ragged_buffer_factor > 0, the ragged_buffer_size is balanced_size * ragged_buffer_factor.
198203
load_balance_loss_weight: 0.0 # weight for the load balance loss
199204
use_random_routing: false # whether to use random routing for debug/test purpose
200205
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul

src/maxtext/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ class MoEGeneral(BaseModel):
646646
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
647647
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
648648
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
649+
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
649650
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
650651
use_custom_sort_vjp: bool = Field(
651652
True,
@@ -2085,6 +2086,12 @@ def load_model_specific_defaults(cls, values: dict[str, Any]) -> dict[str, Any]:
20852086
"""This method is a no-op because `pyconfig` handles model-specific config loading."""
20862087
return values
20872088

2089+
def validate_ragged_buffer_factor(self):
2090+
if self.ragged_buffer_factor <= 0:
2091+
return # Nothing to validate if not using ragged buffer factor
2092+
if self.use_ring_of_experts:
2093+
raise ValueError("Currently we only support ragged buffer factor with ragged a2a approach.")
2094+
20882095
@model_validator(mode="after")
20892096
def set_derived_and_validate_values(self) -> "MaxTextConfig":
20902097
"""
@@ -2566,6 +2573,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25662573
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
25672574
if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK:
25682575
raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.")
2576+
self.validate_ragged_buffer_factor()
25692577
if self.use_multimodal:
25702578
valid_mm_models = (
25712579
"gemma3-4b",

src/maxtext/layers/moe.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,42 @@ def transform_bias(self, experts_index, *biases):
891891
"""Selects bias values for a variable number of bias tensors based on chosen experts."""
892892
return tuple(bias[experts_index] for bias in biases)
893893

894+
@staticmethod
895+
def get_ragged_buffer_size(local_batch, ep_degree, global_experts, top_k, ragged_buffer_factor):
896+
"""Calculates the token batch size of the ragged buffer.
897+
When explicitly setting ragged_buffer_factor>0, this is balanced_size * ragged_buffer_factor, which can drop tokens.
898+
Otherwise this will be worst case size to ensure no dropping.
899+
900+
Inputs:
901+
local_batch: local token batch (batch*seq blown up by top_k) shard on this device (e.g. inside shard_map)
902+
ep_degree: degree of expert parallelism, generally equal to ici_expert_parallelism
903+
global_experts: unsharded expert count, e.g. 256 for deepseek
904+
top_k: aka num_experts_per_tok, 8 for deepseek.
905+
ragged_buffer_factor: When set > 0, the buffer is balanced_size * ragged_buffer_factor.
906+
The value 1.0 will be dropless only in the perfectly balanced case, else tokens will be dropped.
907+
Outputs:
908+
The ragged buffer's token batch size.
909+
"""
910+
balanced_size = local_batch
911+
if ragged_buffer_factor > 0.0:
912+
# This will drop tokens if the true distribution exceeds this buffer.
913+
return int(balanced_size * ragged_buffer_factor)
914+
else:
915+
# Worst case
916+
# Either determined by degree of EP, or can be less when num_local_exp is smaller than top_k:
917+
# Example: If we have 4 EP shards, top_k=8, and experts=256 (deepseek), then worst case is
918+
# all tokens in our EP replica get routed to a single shard, e.g. rank 0 - thus is |EP|=4x larger than perfectly
919+
# balanced. However if we use EP=128, then there are only 256/128 = 2 local experts, and thus at most in an EP
920+
# replica group only the 2 experts of top_k=8 can be chosen, so at most 1/4 of all tokens goes to the most
921+
# popular shard. Thus the imbalance factor goes like |EP|/(top_k/local_exp) = 128/4 = 32.
922+
# In general for local_experts < top_k (e.g. |EP|>32), the balance will go as
923+
# EP * local_experts / top_k = EP * (global_exp/EP) / top_k = global_exp / top_k.
924+
# This is constant as a function of the model - e.g. for deepseek the imbalance is never worse than
925+
# 256 exp / 8 top_k = 32. In practice the imbalance should be much less and potentially can use
926+
# ragged_buffer_factor set to >1 e.g. 3.0, and likely have no dropping (not guaranteed)
927+
worst_case_factor = min(ep_degree, global_experts / top_k)
928+
return int(balanced_size * worst_case_factor)
929+
894930
def sparse_matmul(
895931
self,
896932
inputs,
@@ -1155,16 +1191,13 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11551191
num_expert_parallelism,
11561192
)
11571193

1158-
# TODO(ranran): For better performance, we could update output buffer to a smaller
1159-
# size to replace self.get_expert_parallelism_size() for efficiency,
1160-
# Or we could apply capacity_factor for excessive experts.
1161-
# Note: Reducing buffer increase the risk of token dropping under unbalanced distribution.
1162-
1163-
# In the worst case, all of the global input data is assigned to each expert in the current shard.
1164-
# This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
1165-
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
1166-
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
1167-
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
1194+
buffer_size = self.get_ragged_buffer_size(
1195+
jnp.shape(x)[0],
1196+
num_expert_parallelism,
1197+
self.config.num_experts,
1198+
self.config.num_experts_per_tok,
1199+
self.config.ragged_buffer_factor,
1200+
)
11681201
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)
11691202

11701203
x = jax.lax.ragged_all_to_all(

tests/unit/moe_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,58 @@ def test_get_all_to_all_params_unsharded_batch(self):
10631063
jnp.array_equal(recv_sz, exp_recv_sz), f"Unsharded Batch: Receive sizes mismatch for shard {expert_shard_id}"
10641064
)
10651065

1066+
def test_ragged_buffer_balanced(self):
1067+
ragged_buffer_factor = 1.0
1068+
local_batch = 32768
1069+
ep_degree = 4 # unused for ragged_factor>0
1070+
num_experts_per_tok = 8 # unused for ragged_factor>0
1071+
global_experts = 256 # unused for ragged_factor>0
1072+
1073+
expected_ragged_buffer = 32768
1074+
actual_ragged_buffer = moe.RoutedMoE.get_ragged_buffer_size(
1075+
local_batch, ep_degree, global_experts, num_experts_per_tok, ragged_buffer_factor
1076+
)
1077+
self.assertEqual(expected_ragged_buffer, actual_ragged_buffer)
1078+
1079+
def test_ragged_buffer_larger(self):
1080+
ragged_buffer_factor = 2.0
1081+
local_batch = 32768
1082+
ep_degree = 4 # unused for ragged_factor>0
1083+
num_experts_per_tok = 8 # unused for ragged_factor>0
1084+
global_experts = 256 # unused for ragged_factor>0
1085+
1086+
expected_ragged_buffer = 65536
1087+
actual_ragged_buffer = moe.RoutedMoE.get_ragged_buffer_size(
1088+
local_batch, ep_degree, global_experts, num_experts_per_tok, ragged_buffer_factor
1089+
)
1090+
self.assertEqual(expected_ragged_buffer, actual_ragged_buffer)
1091+
1092+
def test_small_ep_worst_case(self):
1093+
ragged_buffer_factor = -1.0 # Not using ragged_buffer_factor
1094+
local_batch = 32768
1095+
num_experts_per_tok = 8
1096+
global_experts = 256
1097+
ep_degree = 4
1098+
1099+
expected_ragged_buffer = 131072 # local_batch * ep_degree
1100+
actual_ragged_buffer = moe.RoutedMoE.get_ragged_buffer_size(
1101+
local_batch, ep_degree, global_experts, num_experts_per_tok, ragged_buffer_factor
1102+
)
1103+
self.assertEqual(expected_ragged_buffer, actual_ragged_buffer)
1104+
1105+
def test_large_ep_worst_case(self):
1106+
ragged_buffer_factor = -1.0 # Not using ragged_buffer_factor
1107+
local_batch = 32768
1108+
num_experts_per_tok = 8
1109+
global_experts = 256
1110+
ep_degree = 128
1111+
1112+
expected_ragged_buffer = 1048576 # (32768) * (global_exp / top_k)
1113+
actual_ragged_buffer = moe.RoutedMoE.get_ragged_buffer_size(
1114+
local_batch, ep_degree, global_experts, num_experts_per_tok, ragged_buffer_factor
1115+
)
1116+
self.assertEqual(expected_ragged_buffer, actual_ragged_buffer)
1117+
10661118

10671119
if __name__ == "__main__":
10681120
unittest.main()

0 commit comments

Comments
 (0)