diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 25fb55e367..e3537c11cf 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1043,7 +1043,6 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a if inputs.shape[0] != expert_assignments.shape[0]: raise ValueError("The number of input tokens must match the number of expert assignments!") - tokamax_group_sizes = get_tokamax_group_sizes(group_sizes, inputs, kernel) orig_inputs_shape = inputs.shape # save shape of inputs before potentially padding. inputs, padding_amount = max_utils.maybe_pad(inputs, self.config.wi_tile_fwd_batch_seq) inputs = inputs.astype(self.dtype) @@ -1067,6 +1066,9 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a weight_gather_axes=weight_gather_axes, ) else: # tokamax (unquantized) + tokamax_group_sizes = get_tokamax_group_sizes( + group_sizes, inputs, kernel + ) output = tokamax.ragged_dot( lhs=inputs, rhs=kernel,