diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 4275dc5cfd..cbaa297a9c 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -955,10 +955,7 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel): elif self.config.attention == "vllm_rpa": return group_sizes else: - return tokamax.RaggedDotGroupSizes( - group_sizes, - max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), - ) + return tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)) def get_quantization_dtypes(): lhs_quantize_dtype, rhs_quantize_dtype = None, None diff --git a/src/maxtext/models/deepseek_batchsplit_fp8.py b/src/maxtext/models/deepseek_batchsplit_fp8.py index 02fb383d86..7a34241d26 100644 --- a/src/maxtext/models/deepseek_batchsplit_fp8.py +++ b/src/maxtext/models/deepseek_batchsplit_fp8.py @@ -27,7 +27,6 @@ from maxtext.layers import attention_op from maxtext.layers import moe as moe_lib from maxtext.layers import quantizations -from maxtext.utils import max_utils import qwix.pallas as qpl import tokamax @@ -970,10 +969,7 @@ def gmm( output = tokamax.ragged_dot( lhs=inputs, rhs=kernel, - group_sizes=tokamax.RaggedDotGroupSizes( - group_sizes, - max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), - ), + group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), precision=jax.lax.Precision.DEFAULT, preferred_element_type=preferred_element_type, implementation="mosaic", diff --git a/src/maxtext/utils/max_utils.py b/src/maxtext/utils/max_utils.py index 292096017b..07a7c044d3 100644 --- a/src/maxtext/utils/max_utils.py +++ b/src/maxtext/utils/max_utils.py @@ -1137,16 +1137,6 @@ def transformer_engine_context(): yield -def generate_representative_group_sizes(target_m: int, g: int) -> tuple[int, ...]: - """Generate group sizes for a given target m.""" - np.random.seed(0) - repr_val = np.random.uniform(size=(g,)) - repr_val = np.random.binomial(1, 0.9, (g,)) * repr_val - repr_val = np.int32((repr_val / np.sum(repr_val)) * target_m) - repr_val[0] += target_m - np.sum(repr_val) - return tuple(map(int, repr_val)) - - def maybe_pad(inputs, tile_size): """Pads the inputs leading dimension to be divisible by tile_size.""" inputs_dim = inputs.shape[0]