From 764a10a63706a6d66f974c995b0dc7b3eec8afc8 Mon Sep 17 00:00:00 2001 From: Nitin Gangahar Date: Tue, 21 Apr 2026 20:59:26 -0700 Subject: [PATCH] Move tokamax_group_sizes calculation inside tokamax branch. PiperOrigin-RevId: 903583531 --- src/maxtext/layers/moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 25fb55e367..e3537c11cf 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1043,7 +1043,6 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a if inputs.shape[0] != expert_assignments.shape[0]: raise ValueError("The number of input tokens must match the number of expert assignments!") - tokamax_group_sizes = get_tokamax_group_sizes(group_sizes, inputs, kernel) orig_inputs_shape = inputs.shape # save shape of inputs before potentially padding. inputs, padding_amount = max_utils.maybe_pad(inputs, self.config.wi_tile_fwd_batch_seq) inputs = inputs.astype(self.dtype) @@ -1067,6 +1066,9 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a weight_gather_axes=weight_gather_axes, ) else: # tokamax (unquantized) + tokamax_group_sizes = get_tokamax_group_sizes( + group_sizes, inputs, kernel + ) output = tokamax.ragged_dot( lhs=inputs, rhs=kernel,