Skip to content

Commit 9232585

Browse files
niyatic21Google-ML-Automation
authored andcommitted
Use Tokamax's representative group sizes.
PiperOrigin-RevId: 885246948
1 parent 51c7f2b commit 9232585

3 files changed

Lines changed: 2 additions & 19 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -955,10 +955,7 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
955955
elif self.config.attention == "vllm_rpa":
956956
return group_sizes
957957
else:
958-
return tokamax.RaggedDotGroupSizes(
959-
group_sizes,
960-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
961-
)
958+
return tokamax.RaggedDotGroupSizes(group_sizes, len(inputs))
962959

963960
def get_quantization_dtypes():
964961
lhs_quantize_dtype, rhs_quantize_dtype = None, None

src/maxtext/models/deepseek_batchsplit_fp8.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from maxtext.layers import attention_op
2828
from maxtext.layers import moe as moe_lib
2929
from maxtext.layers import quantizations
30-
from maxtext.utils import max_utils
3130
import qwix.pallas as qpl
3231
import tokamax
3332

@@ -970,10 +969,7 @@ def gmm(
970969
output = tokamax.ragged_dot(
971970
lhs=inputs,
972971
rhs=kernel,
973-
group_sizes=tokamax.RaggedDotGroupSizes(
974-
group_sizes,
975-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
976-
),
972+
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
977973
precision=jax.lax.Precision.DEFAULT,
978974
preferred_element_type=preferred_element_type,
979975
implementation="mosaic",

src/maxtext/utils/max_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,16 +1137,6 @@ def transformer_engine_context():
11371137
yield
11381138

11391139

1140-
def generate_representative_group_sizes(target_m: int, g: int) -> tuple[int, ...]:
1141-
"""Generate group sizes for a given target m."""
1142-
np.random.seed(0)
1143-
repr_val = np.random.uniform(size=(g,))
1144-
repr_val = np.random.binomial(1, 0.9, (g,)) * repr_val
1145-
repr_val = np.int32((repr_val / np.sum(repr_val)) * target_m)
1146-
repr_val[0] += target_m - np.sum(repr_val)
1147-
return tuple(map(int, repr_val))
1148-
1149-
11501140
def maybe_pad(inputs, tile_size):
11511141
"""Pads the inputs leading dimension to be divisible by tile_size."""
11521142
inputs_dim = inputs.shape[0]

0 commit comments

Comments
 (0)