Skip to content

Commit bac2c05

Browse files
NuojChengShuwen-Fang
authored andcommitted
update tokamax group sizes for pipeline
1 parent cbb74d2 commit bac2c05

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -894,14 +894,16 @@ def sparse_matmul(
894894
vma_axes = tuple(axis for axis in self.config.mesh_axes if self.mesh.shape[axis] > 1)
895895
use_vma = not self.config.use_tokamax_gmm
896896

897-
vma_axes = tuple(axis for axis in self.config.mesh_axes if self.mesh.shape[axis] > 1)
898-
use_vma = not self.config.use_tokamax_gmm
899-
900897
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes):
901-
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
902-
group_sizes,
903-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
904-
)
898+
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
899+
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
900+
tokamax_group_sizes = group_sizes
901+
else:
902+
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
903+
group_sizes,
904+
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
905+
)
906+
905907
pad_length = self.config.wi_tile_fwd_batch_seq
906908
hs_shape = inputs.shape
907909
# pad length is the 1st dimension of tiling size in gmm call

0 commit comments

Comments
 (0)