@@ -1296,26 +1296,26 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12961296 expert_assignments = selected_experts ,
12971297 )
12981298 wi_tile_size = (
1299- self .config .wi_tile_fwd_batch_seq ,
1300- self .config .wi_tile_fwd_embed_dim ,
1301- self .config .wi_tile_fwd_mlp_dim ,
1302- self .config .wi_tile_dlhs_batch_seq ,
1303- self .config .wi_tile_dlhs_embed_dim ,
1304- self .config .wi_tile_dlhs_mlp_dim ,
1305- self .config .wi_tile_drhs_batch_seq ,
1306- self .config .wi_tile_drhs_embed_dim ,
1307- self .config .wi_tile_drhs_mlp_dim ,
1299+ self .config .wi_tile_fwd_batch_seq , # m (LHS batch)
1300+ self .config .wi_tile_fwd_embed_dim , # k (contracting)
1301+ self .config .wi_tile_fwd_mlp_dim , # n (RHS batch)
1302+ self .config .wi_tile_dlhs_batch_seq , # m (LHS batch)
1303+ self .config .wi_tile_dlhs_mlp_dim , # k (contracting)
1304+ self .config .wi_tile_dlhs_embed_dim , # n (RHS batch)
1305+ self .config .wi_tile_drhs_batch_seq , # Called m in megablox, but this is contracting
1306+ self .config .wi_tile_drhs_embed_dim , # Called k in megablox, but this is LHS batch dim
1307+ self .config .wi_tile_drhs_mlp_dim , # Called n in megablox, and indeed is RHS batch dim
13081308 )
13091309 wo_tile_size = (
1310- self .config .wo_tile_fwd_batch_seq ,
1311- self .config .wo_tile_fwd_embed_dim ,
1312- self .config .wo_tile_fwd_mlp_dim ,
1313- self .config .wo_tile_dlhs_batch_seq ,
1314- self .config .wo_tile_dlhs_embed_dim ,
1315- self .config .wo_tile_dlhs_mlp_dim ,
1316- self .config .wo_tile_drhs_batch_seq ,
1317- self .config .wo_tile_drhs_embed_dim ,
1318- self .config .wo_tile_drhs_mlp_dim ,
1310+ self .config .wo_tile_fwd_batch_seq , # m (LHS batch)
1311+ self .config .wo_tile_fwd_mlp_dim , # k (contracting)
1312+ self .config .wo_tile_fwd_embed_dim , # n (RHS batch)
1313+ self .config .wo_tile_dlhs_batch_seq , # m (LHS batch)
1314+ self .config .wo_tile_dlhs_embed_dim , # k (contracting)
1315+ self .config .wo_tile_dlhs_mlp_dim , # n (RHS)
1316+ self .config .wo_tile_drhs_batch_seq , # Called m in megablox, but this is contracting
1317+ self .config .wo_tile_drhs_mlp_dim , # Called k in megablox, but this is LHS batch dim
1318+ self .config .wo_tile_drhs_embed_dim , # Called n in megablox, and indeed is the RHS batch dim
13191319 )
13201320
13211321 layer_w0 = gmm_fn (
0 commit comments