Skip to content

Commit 85d8771

Browse files
niyatic21Google-ML-Automation
authored andcommitted
Use Tokamax's representative group sizes.
PiperOrigin-RevId: 885246948
1 parent f7c1855 commit 85d8771

3 files changed

Lines changed: 2 additions & 19 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -914,10 +914,7 @@ def gmm(
914914
elif self.config.attention == "vllm_rpa":
915915
tokamax_group_sizes = group_sizes
916916
else:
917-
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
918-
group_sizes,
919-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
920-
)
917+
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(group_sizes, len(inputs))
921918
pad_length = self.config.wi_tile_fwd_batch_seq
922919
hs_shape = inputs.shape
923920
# pad length is the 1st dimension of tiling size in gmm call

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from maxtext.layers import attention_op
2828
from maxtext.layers import moe as moe_lib
2929
from maxtext.layers import quantizations
30-
from maxtext.utils import max_utils
3130
import qwix.pallas as qpl
3231
import tokamax
3332

@@ -970,10 +969,7 @@ def gmm(
970969
output = tokamax.ragged_dot(
971970
lhs=inputs,
972971
rhs=kernel,
973-
group_sizes=tokamax.RaggedDotGroupSizes(
974-
group_sizes,
975-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
976-
),
972+
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
977973
precision=jax.lax.Precision.DEFAULT,
978974
preferred_element_type=preferred_element_type,
979975
implementation="mosaic",

src/maxtext/utils/max_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,13 +1118,3 @@ def transformer_engine_context():
11181118
yield
11191119
except (ImportError, AttributeError):
11201120
yield
1121-
1122-
1123-
def generate_representative_group_sizes(target_m: int, g: int) -> tuple[int, ...]:
1124-
"""Generate group sizes for a given target m."""
1125-
np.random.seed(0)
1126-
repr_val = np.random.uniform(size=(g,))
1127-
repr_val = np.random.binomial(1, 0.9, (g,)) * repr_val
1128-
repr_val = np.int32((repr_val / np.sum(repr_val)) * target_m)
1129-
repr_val[0] += target_m - np.sum(repr_val)
1130-
return tuple(map(int, repr_val))

0 commit comments

Comments
 (0)