diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 25fb55e367..7abe0782f3 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1296,26 +1296,26 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): expert_assignments=selected_experts, ) wi_tile_size = ( - self.config.wi_tile_fwd_batch_seq, - self.config.wi_tile_fwd_embed_dim, - self.config.wi_tile_fwd_mlp_dim, - self.config.wi_tile_dlhs_batch_seq, - self.config.wi_tile_dlhs_embed_dim, - self.config.wi_tile_dlhs_mlp_dim, - self.config.wi_tile_drhs_batch_seq, - self.config.wi_tile_drhs_embed_dim, - self.config.wi_tile_drhs_mlp_dim, + self.config.wi_tile_fwd_batch_seq, # m (LHS batch) + self.config.wi_tile_fwd_embed_dim, # k (contracting) + self.config.wi_tile_fwd_mlp_dim, # n (RHS batch) + self.config.wi_tile_dlhs_batch_seq, # m (LHS batch) + self.config.wi_tile_dlhs_mlp_dim, # k (contracting) + self.config.wi_tile_dlhs_embed_dim, # n (RHS batch) + self.config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting + self.config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim + self.config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is RHS batch dim ) wo_tile_size = ( - self.config.wo_tile_fwd_batch_seq, - self.config.wo_tile_fwd_embed_dim, - self.config.wo_tile_fwd_mlp_dim, - self.config.wo_tile_dlhs_batch_seq, - self.config.wo_tile_dlhs_embed_dim, - self.config.wo_tile_dlhs_mlp_dim, - self.config.wo_tile_drhs_batch_seq, - self.config.wo_tile_drhs_embed_dim, - self.config.wo_tile_drhs_mlp_dim, + self.config.wo_tile_fwd_batch_seq, # m (LHS batch) + self.config.wo_tile_fwd_mlp_dim, # k (contracting) + self.config.wo_tile_fwd_embed_dim, # n (RHS batch) + self.config.wo_tile_dlhs_batch_seq, # m (LHS batch) + self.config.wo_tile_dlhs_embed_dim, # k (contracting) + self.config.wo_tile_dlhs_mlp_dim, # n (RHS) + self.config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting + self.config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim + self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim ) layer_w0 = gmm_fn( diff --git a/src/maxtext/models/deepseek_batchsplit_fp8.py b/src/maxtext/models/deepseek_batchsplit_fp8.py index db830406b3..3f7508687a 100644 --- a/src/maxtext/models/deepseek_batchsplit_fp8.py +++ b/src/maxtext/models/deepseek_batchsplit_fp8.py @@ -977,26 +977,27 @@ def gmm( wo_gather_axes = [] wi_tile_size = ( - config.wi_tile_fwd_batch_seq, - config.wi_tile_fwd_embed_dim, - config.wi_tile_fwd_mlp_dim, - config.wi_tile_dlhs_batch_seq, - config.wi_tile_dlhs_embed_dim, - config.wi_tile_dlhs_mlp_dim, - config.wi_tile_drhs_batch_seq, - config.wi_tile_drhs_embed_dim, - config.wi_tile_drhs_mlp_dim, + config.wi_tile_fwd_batch_seq, # m (LHS batch) + config.wi_tile_fwd_embed_dim, # k (contracting) + config.wi_tile_fwd_mlp_dim, # n (RHS batch) + config.wi_tile_dlhs_batch_seq, # m (LHS batch) + config.wi_tile_dlhs_mlp_dim, # k (contracting) + config.wi_tile_dlhs_embed_dim, # n (RHS batch) + config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting + config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim + config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is the RHS batch dim ) + wo_tile_size = ( - config.wo_tile_fwd_batch_seq, - config.wo_tile_fwd_embed_dim, - config.wo_tile_fwd_mlp_dim, - config.wo_tile_dlhs_batch_seq, - config.wo_tile_dlhs_embed_dim, - config.wo_tile_dlhs_mlp_dim, - config.wo_tile_drhs_batch_seq, - config.wo_tile_drhs_embed_dim, - config.wo_tile_drhs_mlp_dim, + config.wo_tile_fwd_batch_seq, # m (LHS batch) + config.wo_tile_fwd_mlp_dim, # k (contracting) + config.wo_tile_fwd_embed_dim, # n (RHS batch) + config.wo_tile_dlhs_batch_seq, # m (LHS batch) + config.wo_tile_dlhs_embed_dim, # k (contracting) + config.wo_tile_dlhs_mlp_dim, # n (RHS) + config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting + config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim + config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim ) if config.use_qwix_quantization: