Skip to content

Commit 6ce6ef6

Browse files
committed
Fix megablox tile sizes
1 parent f67d8b1 commit 6ce6ef6

2 files changed

Lines changed: 37 additions & 36 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,26 +1296,26 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12961296
expert_assignments=selected_experts,
12971297
)
12981298
wi_tile_size = (
1299-
self.config.wi_tile_fwd_batch_seq,
1300-
self.config.wi_tile_fwd_embed_dim,
1301-
self.config.wi_tile_fwd_mlp_dim,
1302-
self.config.wi_tile_dlhs_batch_seq,
1303-
self.config.wi_tile_dlhs_embed_dim,
1304-
self.config.wi_tile_dlhs_mlp_dim,
1305-
self.config.wi_tile_drhs_batch_seq,
1306-
self.config.wi_tile_drhs_embed_dim,
1307-
self.config.wi_tile_drhs_mlp_dim,
1299+
self.config.wi_tile_fwd_batch_seq, # m (LHS batch)
1300+
self.config.wi_tile_fwd_embed_dim, # k (contracting)
1301+
self.config.wi_tile_fwd_mlp_dim, # n (RHS batch)
1302+
self.config.wi_tile_dlhs_batch_seq, # m (LHS batch)
1303+
self.config.wi_tile_dlhs_mlp_dim, # k (contracting)
1304+
self.config.wi_tile_dlhs_embed_dim, # n (RHS batch)
1305+
self.config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
1306+
self.config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim
1307+
self.config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is RHS batch dim
13081308
)
13091309
wo_tile_size = (
1310-
self.config.wo_tile_fwd_batch_seq,
1311-
self.config.wo_tile_fwd_embed_dim,
1312-
self.config.wo_tile_fwd_mlp_dim,
1313-
self.config.wo_tile_dlhs_batch_seq,
1314-
self.config.wo_tile_dlhs_embed_dim,
1315-
self.config.wo_tile_dlhs_mlp_dim,
1316-
self.config.wo_tile_drhs_batch_seq,
1317-
self.config.wo_tile_drhs_embed_dim,
1318-
self.config.wo_tile_drhs_mlp_dim,
1310+
self.config.wo_tile_fwd_batch_seq, # m (LHS batch)
1311+
self.config.wo_tile_fwd_mlp_dim, # k (contracting)
1312+
self.config.wo_tile_fwd_embed_dim, # n (RHS batch)
1313+
self.config.wo_tile_dlhs_batch_seq, # m (LHS batch)
1314+
self.config.wo_tile_dlhs_embed_dim, # k (contracting)
1315+
self.config.wo_tile_dlhs_mlp_dim, # n (RHS)
1316+
self.config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
1317+
self.config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim
1318+
self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim
13191319
)
13201320

13211321
layer_w0 = gmm_fn(

src/maxtext/models/deepseek_batchsplit_fp8.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -977,26 +977,27 @@ def gmm(
977977
wo_gather_axes = []
978978

979979
wi_tile_size = (
980-
config.wi_tile_fwd_batch_seq,
981-
config.wi_tile_fwd_embed_dim,
982-
config.wi_tile_fwd_mlp_dim,
983-
config.wi_tile_dlhs_batch_seq,
984-
config.wi_tile_dlhs_embed_dim,
985-
config.wi_tile_dlhs_mlp_dim,
986-
config.wi_tile_drhs_batch_seq,
987-
config.wi_tile_drhs_embed_dim,
988-
config.wi_tile_drhs_mlp_dim,
980+
config.wi_tile_fwd_batch_seq, # m (LHS batch)
981+
config.wi_tile_fwd_embed_dim, # k (contracting)
982+
config.wi_tile_fwd_mlp_dim, # n (RHS batch)
983+
config.wi_tile_dlhs_batch_seq, # m (LHS batch)
984+
config.wi_tile_dlhs_mlp_dim, # k (contracting)
985+
config.wi_tile_dlhs_embed_dim, # n (RHS batch)
986+
config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
987+
config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim
988+
config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is the RHS batch dim
989989
)
990+
990991
wo_tile_size = (
991-
config.wo_tile_fwd_batch_seq,
992-
config.wo_tile_fwd_embed_dim,
993-
config.wo_tile_fwd_mlp_dim,
994-
config.wo_tile_dlhs_batch_seq,
995-
config.wo_tile_dlhs_embed_dim,
996-
config.wo_tile_dlhs_mlp_dim,
997-
config.wo_tile_drhs_batch_seq,
998-
config.wo_tile_drhs_embed_dim,
999-
config.wo_tile_drhs_mlp_dim,
992+
config.wo_tile_fwd_batch_seq, # m (LHS batch)
993+
config.wo_tile_fwd_mlp_dim, # k (contracting)
994+
config.wo_tile_fwd_embed_dim, # n (RHS batch)
995+
config.wo_tile_dlhs_batch_seq, # m (LHS batch)
996+
config.wo_tile_dlhs_embed_dim, # k (contracting)
997+
config.wo_tile_dlhs_mlp_dim, # n (RHS)
998+
config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
999+
config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim
1000+
config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim
10001001
)
10011002

10021003
if config.use_qwix_quantization:

0 commit comments

Comments
 (0)