Skip to content

Commit 10b2d08

Browse files
niyatic21Google-ML-Automation
authored andcommitted
Use Tokamax's representative group sizes.
PiperOrigin-RevId: 885246948
1 parent 51c7f2b commit 10b2d08

3 files changed

Lines changed: 2 additions & 19 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -955,10 +955,7 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
955955
elif self.config.attention == "vllm_rpa":
956956
return group_sizes
957957
else:
958-
return tokamax.RaggedDotGroupSizes(
959-
group_sizes,
960-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
961-
)
958+
return tokamax.RaggedDotGroupSizes(group_sizes, len(inputs))
962959

963960
def get_quantization_dtypes():
964961
lhs_quantize_dtype, rhs_quantize_dtype = None, None

src/maxtext/models/deepseek_batchsplit_fp8.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -970,10 +970,7 @@ def gmm(
970970
output = tokamax.ragged_dot(
971971
lhs=inputs,
972972
rhs=kernel,
973-
group_sizes=tokamax.RaggedDotGroupSizes(
974-
group_sizes,
975-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
976-
),
973+
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
977974
precision=jax.lax.Precision.DEFAULT,
978975
preferred_element_type=preferred_element_type,
979976
implementation="mosaic",

src/maxtext/utils/max_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,17 +1136,6 @@ def transformer_engine_context():
11361136
except (ImportError, AttributeError):
11371137
yield
11381138

1139-
1140-
def generate_representative_group_sizes(target_m: int, g: int) -> tuple[int, ...]:
1141-
"""Generate group sizes for a given target m."""
1142-
np.random.seed(0)
1143-
repr_val = np.random.uniform(size=(g,))
1144-
repr_val = np.random.binomial(1, 0.9, (g,)) * repr_val
1145-
repr_val = np.int32((repr_val / np.sum(repr_val)) * target_m)
1146-
repr_val[0] += target_m - np.sum(repr_val)
1147-
return tuple(map(int, repr_val))
1148-
1149-
11501139
def maybe_pad(inputs, tile_size):
11511140
"""Pads the inputs leading dimension to be divisible by tile_size."""
11521141
inputs_dim = inputs.shape[0]

0 commit comments

Comments
 (0)