Skip to content

Commit 764a10a

Browse files
nitingGoogle-ML-Automation
authored andcommitted
Move tokamax_group_sizes calculation inside tokamax branch.
PiperOrigin-RevId: 903583531
1 parent 25f7ba0 commit 764a10a

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

src/maxtext/layers/moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,6 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
10431043
if inputs.shape[0] != expert_assignments.shape[0]:
10441044
raise ValueError("The number of input tokens must match the number of expert assignments!")
10451045

1046-
tokamax_group_sizes = get_tokamax_group_sizes(group_sizes, inputs, kernel)
10471046
orig_inputs_shape = inputs.shape # save shape of inputs before potentially padding.
10481047
inputs, padding_amount = max_utils.maybe_pad(inputs, self.config.wi_tile_fwd_batch_seq)
10491048
inputs = inputs.astype(self.dtype)
@@ -1067,6 +1066,9 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
10671066
weight_gather_axes=weight_gather_axes,
10681067
)
10691068
else: # tokamax (unquantized)
1069+
tokamax_group_sizes = get_tokamax_group_sizes(
1070+
group_sizes, inputs, kernel
1071+
)
10701072
output = tokamax.ragged_dot(
10711073
lhs=inputs,
10721074
rhs=kernel,

0 commit comments

Comments
 (0)