5555 DLHS_RAGGED_DOT_DIM_NUMS ,
5656 DRHS_RAGGED_DOT_DIM_NUMS ,
5757)
58+ from tokamax ._src .ops .ragged_dot import base
5859from tokamax ._src .ops import op
5960
6061set_xla_metadata = xla_metadata .set_xla_metadata
@@ -557,8 +558,7 @@ def __init__(
557558 ):
558559 self .wo .value = self .wo .value * self .per_expert_scale .value [:, None , None ]
559560
560- # Monkey-patch Tokamax heuristics globally once
561- _monkey_patch_tokamax_heuristics (self .config )
561+
562562
563563 def _maybe_shard_with_logical (self , inputs , logical_name ):
564564 return maybe_shard_with_logical (
@@ -1095,6 +1095,86 @@ def sparse_matmul(
10951095 wo_bias ,
10961096 ):
10971097 """Perform sparse matrix multiplication of inputs and Experts."""
1098+ config = self .config
1099+
1100+ class PallasMosaicTpuRaggedDotWI (PallasMosaicTpuRaggedDot ):
1101+
1102+ def __post_init__ (self ):
1103+ from tokamax ._src .ops .ragged_dot import base
1104+ qdtype = self .qdtype if self .qdtype is None else jnp .dtype (self .qdtype ).name
1105+ if self .vjp is None :
1106+ fn = lambda * args , ** kw : PallasMosaicTpuRaggedDotWI (
1107+ qdtype = qdtype ,
1108+ interpret = self .interpret ,
1109+ )(* args , ** kw )
1110+ object .__setattr__ (
1111+ self ,
1112+ "vjp" ,
1113+ functools .partial (base .vjp , dlhs_ragged_dot = fn , drhs_ragged_dot = fn ),
1114+ )
1115+
1116+ def _get_heuristics_config (self , ba ) -> Config :
1117+ dims = ba .arguments .get ("ragged_dot_dimension_numbers" , DEFAULT_RAGGED_DOT_DIM_NUMS )
1118+ if dims == DEFAULT_RAGGED_DOT_DIM_NUMS :
1119+ return Config (
1120+ tile_m = config .wi_tile_fwd_batch_seq ,
1121+ tile_k = config .wi_tile_fwd_embed_dim ,
1122+ tile_n = config .wi_tile_fwd_mlp_dim ,
1123+ )
1124+ elif dims == DLHS_RAGGED_DOT_DIM_NUMS :
1125+ return Config (
1126+ tile_m = config .wi_tile_dlhs_batch_seq ,
1127+ tile_k = config .wi_tile_dlhs_mlp_dim ,
1128+ tile_n = config .wi_tile_dlhs_embed_dim ,
1129+ )
1130+ elif dims == DRHS_RAGGED_DOT_DIM_NUMS :
1131+ return Config (
1132+ tile_m = config .wi_tile_drhs_batch_seq ,
1133+ tile_k = config .wi_tile_drhs_embed_dim ,
1134+ tile_n = config .wi_tile_drhs_mlp_dim ,
1135+ )
1136+ return Config ()
1137+
1138+ class PallasMosaicTpuRaggedDotWO (PallasMosaicTpuRaggedDot ):
1139+
1140+ def __post_init__ (self ):
1141+ from tokamax ._src .ops .ragged_dot import base
1142+ qdtype = self .qdtype if self .qdtype is None else jnp .dtype (self .qdtype ).name
1143+ if self .vjp is None :
1144+ fn = lambda * args , ** kw : PallasMosaicTpuRaggedDotWO (
1145+ qdtype = qdtype ,
1146+ interpret = self .interpret ,
1147+ )(* args , ** kw )
1148+ object .__setattr__ (
1149+ self ,
1150+ "vjp" ,
1151+ functools .partial (base .vjp , dlhs_ragged_dot = fn , drhs_ragged_dot = fn ),
1152+ )
1153+
1154+ def _get_heuristics_config (self , ba ) -> Config :
1155+ dims = ba .arguments .get ("ragged_dot_dimension_numbers" , DEFAULT_RAGGED_DOT_DIM_NUMS )
1156+ if dims == DEFAULT_RAGGED_DOT_DIM_NUMS :
1157+ return Config (
1158+ tile_m = config .wo_tile_fwd_batch_seq ,
1159+ tile_k = config .wo_tile_fwd_mlp_dim ,
1160+ tile_n = config .wo_tile_fwd_embed_dim ,
1161+ )
1162+ elif dims == DLHS_RAGGED_DOT_DIM_NUMS :
1163+ return Config (
1164+ tile_m = config .wo_tile_dlhs_batch_seq ,
1165+ tile_k = config .wo_tile_dlhs_embed_dim ,
1166+ tile_n = config .wo_tile_dlhs_mlp_dim ,
1167+ )
1168+ elif dims == DRHS_RAGGED_DOT_DIM_NUMS :
1169+ return Config (
1170+ tile_m = config .wo_tile_drhs_batch_seq ,
1171+ tile_k = config .wo_tile_drhs_mlp_dim ,
1172+ tile_n = config .wo_tile_drhs_embed_dim ,
1173+ )
1174+ return Config ()
1175+
1176+ gmm_impl_wi = PallasMosaicTpuRaggedDotWI (qdtype = None , interpret = False )
1177+ gmm_impl_wo = PallasMosaicTpuRaggedDotWO (qdtype = None , interpret = False )
10981178
10991179 def jax_ragged_dot_gmm (inputs , kernel , tiling , group_sizes , expert_assignments , padding_amount ):
11001180 """Execute jax.lax.ragged_dot, with potential quantization"""
@@ -1139,6 +1219,11 @@ def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments,
11391219 output *= scales
11401220 return output
11411221
1222+ def get_gmm_group_sizes (inputs , kernel , ep ):
1223+ # Calculates perfectly balanced group sizes where each local expert receives an equal
1224+ # share of local tokens, adjusted for expert parallelism.
1225+ return (inputs .shape [0 ] // kernel .shape [0 ] // ep ,) * kernel .shape [0 ]
1226+
11421227 def get_tokamax_group_sizes (group_sizes , inputs , kernel ):
11431228 # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
11441229 if self .config .use_qwix_quantization or (
@@ -1151,7 +1236,7 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
11511236 ep = self .get_expert_parallelism_size ()
11521237 return tokamax .RaggedDotGroupSizes (
11531238 group_sizes ,
1154- (inputs . shape [ 0 ] // kernel . shape [ 0 ] // ep ,) * kernel . shape [ 0 ] ,
1239+ get_gmm_group_sizes (inputs , kernel , ep ) ,
11551240 )
11561241
11571242 def get_quantization_dtypes ():
@@ -1162,7 +1247,7 @@ def get_quantization_dtypes():
11621247 rhs_quantize_dtype = quant_dg .fwd .dg_quantizer .rhs .numerics .get_dtype ()
11631248 return lhs_quantize_dtype , rhs_quantize_dtype
11641249
1165- def gmm (inputs , kernel , tiling , group_sizes , expert_assignments , weight_gather_axes ):
1250+ def gmm (inputs , kernel , tiling , group_sizes , expert_assignments , weight_gather_axes , gmm_impl = None ):
11661251 if inputs .shape [0 ] != expert_assignments .shape [0 ]:
11671252 raise ValueError ("The number of input tokens must match the number of expert assignments!" )
11681253
@@ -1196,7 +1281,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
11961281 group_sizes = tokamax_group_sizes ,
11971282 precision = jax .lax .Precision .DEFAULT ,
11981283 preferred_element_type = self .dtype ,
1199- implementation = "mosaic" ,
1284+ implementation = "mosaic" if gmm_impl is None else [ gmm_impl ] ,
12001285 )
12011286 elif self .config .megablox : # Older forked megablox
12021287 output = mblx .gmm (
@@ -1485,6 +1570,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14851570 w0 ,
14861571 tiling = wi_tile_size ,
14871572 weight_gather_axes = wi_gather_axes ,
1573+ gmm_impl = gmm_impl_wi ,
14881574 )
14891575 if self .get_tensor_transpose_parallelism_size () > 1 :
14901576 layer_w0 = jax .lax .psum (layer_w0 , "tensor_transpose" )
@@ -1497,6 +1583,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14971583 w1 ,
14981584 tiling = wi_tile_size ,
14991585 weight_gather_axes = wi_gather_axes ,
1586+ gmm_impl = gmm_impl_wi ,
15001587 )
15011588 if self .get_tensor_transpose_parallelism_size () > 1 :
15021589 layer_w1 = jax .lax .psum (layer_w1 , "tensor_transpose" )
@@ -1510,6 +1597,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15101597 wo ,
15111598 tiling = wo_tile_size ,
15121599 weight_gather_axes = wo_gather_axes ,
1600+ gmm_impl = gmm_impl_wo ,
15131601 )
15141602 if self .get_tensor_parallelism_size () > 1 :
15151603 intermediate_output = jax .lax .psum_scatter (
@@ -2553,74 +2641,3 @@ def get_routed_and_shared_moe(
25532641 abstract_init = False ,
25542642 )
25552643 return module
2556-
2557-
2558- _heuristics_patched = False
2559-
2560-
2561- def _monkey_patch_tokamax_heuristics (config , force = False ):
2562- """Globally monkey-patches Tokamax GMM heuristics with manual tiling overrides."""
2563- global _heuristics_patched
2564- if _heuristics_patched and not force :
2565- return
2566-
2567- def custom_heuristics (self , ba : op .BoundArguments ) -> Config :
2568- lhs , rhs = ba .arguments ["lhs" ], ba .arguments ["rhs" ]
2569- dims = ba .arguments .get ("ragged_dot_dimension_numbers" , DEFAULT_RAGGED_DOT_DIM_NUMS )
2570-
2571- is_wo = False
2572- if dims == DEFAULT_RAGGED_DOT_DIM_NUMS :
2573- is_wo = rhs .shape [1 ] == config .base_mlp_dim
2574- elif dims == DLHS_RAGGED_DOT_DIM_NUMS :
2575- is_wo = rhs .shape [2 ] == config .base_emb_dim
2576- elif dims == DRHS_RAGGED_DOT_DIM_NUMS :
2577- is_wo = lhs .shape [1 ] == config .base_mlp_dim
2578-
2579- if is_wo :
2580- # Return wo tile sizes
2581- if dims == DEFAULT_RAGGED_DOT_DIM_NUMS :
2582- return Config (
2583- tile_m = config .wo_tile_fwd_batch_seq ,
2584- tile_k = config .wo_tile_fwd_mlp_dim ,
2585- tile_n = config .wo_tile_fwd_embed_dim ,
2586- )
2587- elif dims == DLHS_RAGGED_DOT_DIM_NUMS :
2588- return Config (
2589- tile_m = config .wo_tile_dlhs_batch_seq ,
2590- tile_k = config .wo_tile_dlhs_embed_dim ,
2591- tile_n = config .wo_tile_dlhs_mlp_dim ,
2592- )
2593- elif dims == DRHS_RAGGED_DOT_DIM_NUMS :
2594- return Config (
2595- tile_m = config .wo_tile_drhs_batch_seq ,
2596- tile_k = config .wo_tile_drhs_mlp_dim ,
2597- tile_n = config .wo_tile_drhs_embed_dim ,
2598- )
2599- else :
2600- # Return wi tile sizes
2601- if dims == DEFAULT_RAGGED_DOT_DIM_NUMS :
2602- return Config (
2603- tile_m = config .wi_tile_fwd_batch_seq ,
2604- tile_k = config .wi_tile_fwd_embed_dim ,
2605- tile_n = config .wi_tile_fwd_mlp_dim ,
2606- )
2607- elif dims == DLHS_RAGGED_DOT_DIM_NUMS :
2608- return Config (
2609- tile_m = config .wi_tile_dlhs_batch_seq ,
2610- tile_k = config .wi_tile_dlhs_mlp_dim ,
2611- tile_n = config .wi_tile_dlhs_embed_dim ,
2612- )
2613- elif dims == DRHS_RAGGED_DOT_DIM_NUMS :
2614- return Config (
2615- tile_m = config .wi_tile_drhs_batch_seq ,
2616- tile_k = config .wi_tile_drhs_embed_dim ,
2617- tile_n = config .wi_tile_drhs_mlp_dim ,
2618- )
2619-
2620- return Config ()
2621-
2622- # Apply class-level monkey patch!
2623- # pylint: disable=protected-access
2624- PallasMosaicTpuRaggedDot ._get_heuristics_config = custom_heuristics
2625- _heuristics_patched = True
2626- print ("[TOKAMAX_PATCH] Successfully monkey-patched Tokamax GMM heuristics globally!" )
0 commit comments