Skip to content

Commit e06745f

Browse files
Merge pull request #3678 from AI-Hypercomputer:mattdavidow-ragged-buffer-a1
PiperOrigin-RevId: 901000348
2 parents 60282d1 + 61a14b9 commit e06745f

5 files changed

Lines changed: 113 additions & 10 deletions

File tree

docs/reference/core_concepts/moe_configuration.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ Dropping:
8888
- Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped.
8989
- Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline.
9090

91+
`ragged_buffer_factor`: A scalar multiplier for the size of the ragged buffer (effectively expert capacity). Effective only when `sparse_matmul` is True.
92+
93+
- Value > 0: Uses an explicit buffer size which may drop tokens when this size is exceeded
94+
- Value = -1: Uses a worst case calculated buffer size which is guaranteed to not drop any tokens.
95+
9196
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. Recommended to replace the inefficient scatter-add generated by the `jax.numpy.take` in the backward pass.
9297

9398
`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture.

src/maxtext/configs/base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ num_experts_per_tok: 1
195195
megablox: true
196196
sparse_matmul: true
197197
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
198+
ragged_buffer_factor: -1.0 # a factor to determine the size of the ragged buffer for routed MoE activations.
199+
# By default (-1), the routed buffer is worst case size to ensure no dropping.
200+
# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates
201+
# a size larger than this then tokens are dropped.
202+
# In general if ragged_buffer_factor > 0, the ragged_buffer_size is balanced_size * ragged_buffer_factor.
198203
load_balance_loss_weight: 0.0 # weight for the load balance loss
199204
use_random_routing: false # whether to use random routing for debug/test purpose
200205
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul

src/maxtext/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ class MoEGeneral(BaseModel):
646646
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
647647
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
648648
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
649+
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
649650
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
650651
use_custom_sort_vjp: bool = Field(
651652
True,
@@ -2082,6 +2083,12 @@ def load_model_specific_defaults(cls, values: dict[str, Any]) -> dict[str, Any]:
20822083
"""This method is a no-op because `pyconfig` handles model-specific config loading."""
20832084
return values
20842085

2086+
def validate_ragged_buffer_factor(self):
2087+
if self.ragged_buffer_factor <= 0:
2088+
return # Nothing to validate if not using ragged buffer factor
2089+
if self.use_ring_of_experts:
2090+
raise ValueError("Currently we only support ragged buffer factor with ragged a2a approach.")
2091+
20852092
@model_validator(mode="after")
20862093
def set_derived_and_validate_values(self) -> "MaxTextConfig":
20872094
"""
@@ -2570,6 +2577,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25702577
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
25712578
if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK:
25722579
raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.")
2580+
self.validate_ragged_buffer_factor()
25732581
if self.use_multimodal:
25742582
valid_mm_models = (
25752583
"gemma3-4b",

src/maxtext/layers/moe.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,42 @@ def transform_bias(self, experts_index, *biases):
901901
"""Selects bias values for a variable number of bias tensors based on chosen experts."""
902902
return tuple(bias[experts_index] for bias in biases)
903903

904+
@staticmethod
905+
def get_ragged_buffer_size(local_batch, ep_degree, global_experts, top_k, ragged_buffer_factor):
906+
"""Calculates the token batch size of the ragged buffer.
907+
When explicitly setting ragged_buffer_factor>0, this is balanced_size * ragged_buffer_factor, which can drop tokens.
908+
Otherwise this will be worst case size to ensure no dropping.
909+
910+
Inputs:
911+
local_batch: local token batch (batch*seq blown up by top_k) shard on this device (e.g. inside shard_map)
912+
ep_degree: degree of expert parallelism, generally equal to ici_expert_parallelism
913+
global_experts: unsharded expert count, e.g. 256 for deepseek
914+
top_k: aka num_experts_per_tok, 8 for deepseek.
915+
ragged_buffer_factor: When set > 0, the buffer is balanced_size * ragged_buffer_factor.
916+
The value 1.0 will be dropless only in the perfectly balanced case, else tokens will be dropped.
917+
Outputs:
918+
The ragged buffer's token batch size.
919+
"""
920+
balanced_size = local_batch
921+
if ragged_buffer_factor > 0.0:
922+
# This will drop tokens if the true distribution exceeds this buffer.
923+
return int(balanced_size * ragged_buffer_factor)
924+
else:
925+
# Worst case
926+
# Either determined by degree of EP, or can be less when num_local_exp is smaller than top_k:
927+
# Example: If we have 4 EP shards, top_k=8, and experts=256 (deepseek), then worst case is
928+
# all tokens in our EP replica get routed to a single shard, e.g. rank 0 - thus is |EP|=4x larger than perfectly
929+
# balanced. However if we use EP=128, then there are only 256/128 = 2 local experts, and thus at most in an EP
930+
# replica group only the 2 experts of top_k=8 can be chosen, so at most 1/4 of all tokens goes to the most
931+
# popular shard. Thus the imbalance factor goes like |EP|/(top_k/local_exp) = 128/4 = 32.
932+
# In general for local_experts < top_k (e.g. |EP|>32), the balance will go as
933+
# EP * local_experts / top_k = EP * (global_exp/EP) / top_k = global_exp / top_k.
934+
# This is constant as a function of the model - e.g. for deepseek the imbalance is never worse than
935+
# 256 exp / 8 top_k = 32. In practice the imbalance should be much less and potentially can use
936+
# ragged_buffer_factor set to >1 e.g. 3.0, and likely have no dropping (not guaranteed)
937+
worst_case_factor = min(ep_degree, global_experts / top_k)
938+
return int(balanced_size * worst_case_factor)
939+
904940
def sparse_matmul(
905941
self,
906942
inputs,
@@ -1165,16 +1201,13 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11651201
num_expert_parallelism,
11661202
)
11671203

1168-
# TODO(ranran): For better performance, we could update output buffer to a smaller
1169-
# size to replace self.get_expert_parallelism_size() for efficiency,
1170-
# Or we could apply capacity_factor for excessive experts.
1171-
# Note: Reducing buffer increase the risk of token dropping under unbalanced distribution.
1172-
1173-
# In the worst case, all of the global input data is assigned to each expert in the current shard.
1174-
# This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
1175-
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
1176-
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
1177-
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
1204+
buffer_size = self.get_ragged_buffer_size(
1205+
jnp.shape(x)[0],
1206+
num_expert_parallelism,
1207+
self.config.num_experts,
1208+
self.config.num_experts_per_tok,
1209+
self.config.ragged_buffer_factor,
1210+
)
11781211
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)
11791212

11801213
x = jax.lax.ragged_all_to_all(

tests/unit/moe_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,58 @@ def test_get_all_to_all_params_unsharded_batch(self):
10631063
jnp.array_equal(recv_sz, exp_recv_sz), f"Unsharded Batch: Receive sizes mismatch for shard {expert_shard_id}"
10641064
)
10651065

1066+
def test_ragged_buffer_balanced(self):
1067+
ragged_buffer_factor = 1.0
1068+
local_batch = 32768
1069+
ep_degree = 4 # unused for ragged_factor>0
1070+
num_experts_per_tok = 8 # unused for ragged_factor>0
1071+
global_experts = 256 # unused for ragged_factor>0
1072+
1073+
expected_ragged_buffer = 32768
1074+
actual_ragged_buffer = moe.RoutedMoE.get_ragged_buffer_size(
1075+
local_batch, ep_degree, global_experts, num_experts_per_tok, ragged_buffer_factor
1076+
)
1077+
self.assertEqual(expected_ragged_buffer, actual_ragged_buffer)
1078+
1079+
def test_ragged_buffer_larger(self):
1080+
ragged_buffer_factor = 2.0
1081+
local_batch = 32768
1082+
ep_degree = 4 # unused for ragged_factor>0
1083+
num_experts_per_tok = 8 # unused for ragged_factor>0
1084+
global_experts = 256 # unused for ragged_factor>0
1085+
1086+
expected_ragged_buffer = 65536
1087+
actual_ragged_buffer = moe.RoutedMoE.get_ragged_buffer_size(
1088+
local_batch, ep_degree, global_experts, num_experts_per_tok, ragged_buffer_factor
1089+
)
1090+
self.assertEqual(expected_ragged_buffer, actual_ragged_buffer)
1091+
1092+
def test_small_ep_worst_case(self):
1093+
ragged_buffer_factor = -1.0 # Not using ragged_buffer_factor
1094+
local_batch = 32768
1095+
num_experts_per_tok = 8
1096+
global_experts = 256
1097+
ep_degree = 4
1098+
1099+
expected_ragged_buffer = 131072 # local_batch * ep_degree
1100+
actual_ragged_buffer = moe.RoutedMoE.get_ragged_buffer_size(
1101+
local_batch, ep_degree, global_experts, num_experts_per_tok, ragged_buffer_factor
1102+
)
1103+
self.assertEqual(expected_ragged_buffer, actual_ragged_buffer)
1104+
1105+
def test_large_ep_worst_case(self):
1106+
ragged_buffer_factor = -1.0 # Not using ragged_buffer_factor
1107+
local_batch = 32768
1108+
num_experts_per_tok = 8
1109+
global_experts = 256
1110+
ep_degree = 128
1111+
1112+
expected_ragged_buffer = 1048576 # (32768) * (global_exp / top_k)
1113+
actual_ragged_buffer = moe.RoutedMoE.get_ragged_buffer_size(
1114+
local_batch, ep_degree, global_experts, num_experts_per_tok, ragged_buffer_factor
1115+
)
1116+
self.assertEqual(expected_ragged_buffer, actual_ragged_buffer)
1117+
10661118

10671119
if __name__ == "__main__":
10681120
unittest.main()

0 commit comments

Comments
 (0)