Skip to content

Commit 8381dad

Browse files
niyatic21Google-ML-Automation
authored andcommitted
Use Tokamax's representative group sizes.
PiperOrigin-RevId: 885246948
1 parent 093ab89 commit 8381dad

3 files changed

Lines changed: 2 additions & 19 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -900,10 +900,7 @@ def gmm(
900900
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
901901
tokamax_group_sizes = group_sizes
902902
else:
903-
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
904-
group_sizes,
905-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
906-
)
903+
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(group_sizes, len(inputs))
907904
pad_length = self.config.wi_tile_fwd_batch_seq
908905
hs_shape = inputs.shape
909906
# pad length is the 1st dimension of tiling size in gmm call

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from maxtext.layers import attention_op
2828
from maxtext.layers import moe as moe_lib
2929
from maxtext.layers import quantizations
30-
from maxtext.utils import max_utils
3130
import qwix.pallas as qpl
3231
import tokamax
3332

@@ -970,10 +969,7 @@ def gmm(
970969
output = tokamax.ragged_dot(
971970
lhs=inputs,
972971
rhs=kernel,
973-
group_sizes=tokamax.RaggedDotGroupSizes(
974-
group_sizes,
975-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
976-
),
972+
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
977973
precision=jax.lax.Precision.DEFAULT,
978974
preferred_element_type=preferred_element_type,
979975
implementation="mosaic",

src/maxtext/utils/max_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,13 +1078,3 @@ def transformer_engine_context():
10781078
yield
10791079
except (ImportError, AttributeError):
10801080
yield
1081-
1082-
1083-
def generate_representative_group_sizes(target_m: int, g: int) -> tuple[int, ...]:
1084-
"""Generate group sizes for a given target m."""
1085-
np.random.seed(0)
1086-
repr_val = np.random.uniform(size=(g,))
1087-
repr_val = np.random.binomial(1, 0.9, (g,)) * repr_val
1088-
repr_val = np.int32((repr_val / np.sum(repr_val)) * target_m)
1089-
repr_val[0] += target_m - np.sum(repr_val)
1090-
return tuple(map(int, repr_val))

0 commit comments

Comments
 (0)