5555 DLHS_RAGGED_DOT_DIM_NUMS ,
5656 DRHS_RAGGED_DOT_DIM_NUMS ,
5757)
58+ from tokamax ._src .ops .ragged_dot import base
5859from tokamax ._src .ops import op
60+ import dataclasses
61+
62+
63+ @jax .tree_util .register_dataclass
64+ @dataclasses .dataclass (frozen = True , kw_only = True , slots = True )
65+ class PallasMosaicTpuRaggedDotCustom (PallasMosaicTpuRaggedDot ):
66+ config : Config | None = None
67+ fwd_tile : tuple [int , int , int ] = (128 , 128 , 128 )
68+ dlhs_tile : tuple [int , int , int ] = (128 , 128 , 128 )
69+ drhs_tile : tuple [int , int , int ] = (128 , 128 , 128 )
70+
71+ def __post_init__ (self ):
72+ from tokamax ._src .ops .ragged_dot import base
73+ qdtype = self .qdtype if self .qdtype is None else jnp .dtype (self .qdtype ).name
74+ if self .vjp is None :
75+ fn = lambda * args , ** kw : PallasMosaicTpuRaggedDotCustom (
76+ qdtype = qdtype ,
77+ interpret = self .interpret ,
78+ fwd_tile = self .fwd_tile ,
79+ dlhs_tile = self .dlhs_tile ,
80+ drhs_tile = self .drhs_tile ,
81+ )(* args , ** kw )
82+ object .__setattr__ (
83+ self ,
84+ "vjp" ,
85+ functools .partial (base .vjp , dlhs_ragged_dot = fn , drhs_ragged_dot = fn ),
86+ )
87+
88+ def _get_heuristics_config (self , ba ) -> Config :
89+ dims = ba .arguments .get ("ragged_dot_dimension_numbers" , DEFAULT_RAGGED_DOT_DIM_NUMS )
90+ if dims == DEFAULT_RAGGED_DOT_DIM_NUMS :
91+ return Config (tile_m = self .fwd_tile [0 ], tile_k = self .fwd_tile [1 ], tile_n = self .fwd_tile [2 ])
92+ elif dims == DLHS_RAGGED_DOT_DIM_NUMS :
93+ return Config (tile_m = self .dlhs_tile [0 ], tile_k = self .dlhs_tile [1 ], tile_n = self .dlhs_tile [2 ])
94+ elif dims == DRHS_RAGGED_DOT_DIM_NUMS :
95+ return Config (tile_m = self .drhs_tile [0 ], tile_k = self .drhs_tile [1 ], tile_n = self .drhs_tile [2 ])
96+ return Config ()
97+
5998
6099set_xla_metadata = xla_metadata .set_xla_metadata
61100
@@ -557,8 +596,7 @@ def __init__(
557596 ):
558597 self .wo .value = self .wo .value * self .per_expert_scale .value [:, None , None ]
559598
560- # Monkey-patch Tokamax heuristics globally once
561- _monkey_patch_tokamax_heuristics (self .config )
599+
562600
563601 def _maybe_shard_with_logical (self , inputs , logical_name ):
564602 return maybe_shard_with_logical (
@@ -1095,6 +1133,18 @@ def sparse_matmul(
10951133 wo_bias ,
10961134 ):
10971135 """Perform sparse matrix multiplication of inputs and Experts."""
1136+ config = self .config
1137+
1138+ gmm_impl_wi = PallasMosaicTpuRaggedDotCustom (
1139+ fwd_tile = (config .wi_tile_fwd_batch_seq , config .wi_tile_fwd_embed_dim , config .wi_tile_fwd_mlp_dim ),
1140+ dlhs_tile = (config .wi_tile_dlhs_batch_seq , config .wi_tile_dlhs_mlp_dim , config .wi_tile_dlhs_embed_dim ),
1141+ drhs_tile = (config .wi_tile_drhs_batch_seq , config .wi_tile_drhs_embed_dim , config .wi_tile_drhs_mlp_dim ),
1142+ )
1143+ gmm_impl_wo = PallasMosaicTpuRaggedDotCustom (
1144+ fwd_tile = (config .wo_tile_fwd_batch_seq , config .wo_tile_fwd_mlp_dim , config .wo_tile_fwd_embed_dim ),
1145+ dlhs_tile = (config .wo_tile_dlhs_batch_seq , config .wo_tile_dlhs_embed_dim , config .wo_tile_dlhs_mlp_dim ),
1146+ drhs_tile = (config .wo_tile_drhs_batch_seq , config .wo_tile_drhs_mlp_dim , config .wo_tile_drhs_embed_dim ),
1147+ )
10981148
10991149 def jax_ragged_dot_gmm (inputs , kernel , tiling , group_sizes , expert_assignments , padding_amount ):
11001150 """Execute jax.lax.ragged_dot, with potential quantization"""
@@ -1139,6 +1189,15 @@ def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments,
11391189 output *= scales
11401190 return output
11411191
1192+ def get_gmm_group_sizes (inputs , kernel , ep ):
1193+ # Calculates perfectly balanced group sizes where each local expert receives an equal
1194+ # share of local tokens, adjusted for expert parallelism.
1195+ #
1196+ # Note: This function assumes the inputs are ragged and padded to the worst-case size
1197+ # (which is generally a factor of EP larger than perfectly balanced). This is why we must
1198+ # divide by EP.
1199+ return (inputs .shape [0 ] // kernel .shape [0 ] // ep ,) * kernel .shape [0 ]
1200+
11421201 def get_tokamax_group_sizes (group_sizes , inputs , kernel ):
11431202 # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
11441203 if self .config .use_qwix_quantization or (
@@ -1151,7 +1210,7 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
11511210 ep = self .get_expert_parallelism_size ()
11521211 return tokamax .RaggedDotGroupSizes (
11531212 group_sizes ,
1154- (inputs . shape [ 0 ] // kernel . shape [ 0 ] // ep ,) * kernel . shape [ 0 ] ,
1213+ get_gmm_group_sizes (inputs , kernel , ep ) ,
11551214 )
11561215
11571216 def get_quantization_dtypes ():
@@ -1162,7 +1221,7 @@ def get_quantization_dtypes():
11621221 rhs_quantize_dtype = quant_dg .fwd .dg_quantizer .rhs .numerics .get_dtype ()
11631222 return lhs_quantize_dtype , rhs_quantize_dtype
11641223
1165- def gmm (inputs , kernel , tiling , group_sizes , expert_assignments , weight_gather_axes ):
1224+ def gmm (inputs , kernel , tiling , group_sizes , expert_assignments , weight_gather_axes , gmm_impl = None ):
11661225 if inputs .shape [0 ] != expert_assignments .shape [0 ]:
11671226 raise ValueError ("The number of input tokens must match the number of expert assignments!" )
11681227
@@ -1196,7 +1255,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
11961255 group_sizes = tokamax_group_sizes ,
11971256 precision = jax .lax .Precision .DEFAULT ,
11981257 preferred_element_type = self .dtype ,
1199- implementation = "mosaic" ,
1258+ implementation = "mosaic" if gmm_impl is None else [ gmm_impl ] ,
12001259 )
12011260 elif self .config .megablox : # Older forked megablox
12021261 output = mblx .gmm (
@@ -1485,6 +1544,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14851544 w0 ,
14861545 tiling = wi_tile_size ,
14871546 weight_gather_axes = wi_gather_axes ,
1547+ gmm_impl = gmm_impl_wi ,
14881548 )
14891549 if self .get_tensor_transpose_parallelism_size () > 1 :
14901550 layer_w0 = jax .lax .psum (layer_w0 , "tensor_transpose" )
@@ -1497,6 +1557,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14971557 w1 ,
14981558 tiling = wi_tile_size ,
14991559 weight_gather_axes = wi_gather_axes ,
1560+ gmm_impl = gmm_impl_wi ,
15001561 )
15011562 if self .get_tensor_transpose_parallelism_size () > 1 :
15021563 layer_w1 = jax .lax .psum (layer_w1 , "tensor_transpose" )
@@ -1510,6 +1571,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15101571 wo ,
15111572 tiling = wo_tile_size ,
15121573 weight_gather_axes = wo_gather_axes ,
1574+ gmm_impl = gmm_impl_wo ,
15131575 )
15141576 if self .get_tensor_parallelism_size () > 1 :
15151577 intermediate_output = jax .lax .psum_scatter (
@@ -2553,74 +2615,3 @@ def get_routed_and_shared_moe(
25532615 abstract_init = False ,
25542616 )
25552617 return module
2556-
2557-
2558- _heuristics_patched = False
2559-
2560-
2561- def _monkey_patch_tokamax_heuristics (config , force = False ):
2562- """Globally monkey-patches Tokamax GMM heuristics with manual tiling overrides."""
2563- global _heuristics_patched
2564- if _heuristics_patched and not force :
2565- return
2566-
2567- def custom_heuristics (self , ba : op .BoundArguments ) -> Config :
2568- lhs , rhs = ba .arguments ["lhs" ], ba .arguments ["rhs" ]
2569- dims = ba .arguments .get ("ragged_dot_dimension_numbers" , DEFAULT_RAGGED_DOT_DIM_NUMS )
2570-
2571- is_wo = False
2572- if dims == DEFAULT_RAGGED_DOT_DIM_NUMS :
2573- is_wo = rhs .shape [1 ] == config .base_mlp_dim
2574- elif dims == DLHS_RAGGED_DOT_DIM_NUMS :
2575- is_wo = rhs .shape [2 ] == config .base_emb_dim
2576- elif dims == DRHS_RAGGED_DOT_DIM_NUMS :
2577- is_wo = lhs .shape [1 ] == config .base_mlp_dim
2578-
2579- if is_wo :
2580- # Return wo tile sizes
2581- if dims == DEFAULT_RAGGED_DOT_DIM_NUMS :
2582- return Config (
2583- tile_m = config .wo_tile_fwd_batch_seq ,
2584- tile_k = config .wo_tile_fwd_mlp_dim ,
2585- tile_n = config .wo_tile_fwd_embed_dim ,
2586- )
2587- elif dims == DLHS_RAGGED_DOT_DIM_NUMS :
2588- return Config (
2589- tile_m = config .wo_tile_dlhs_batch_seq ,
2590- tile_k = config .wo_tile_dlhs_embed_dim ,
2591- tile_n = config .wo_tile_dlhs_mlp_dim ,
2592- )
2593- elif dims == DRHS_RAGGED_DOT_DIM_NUMS :
2594- return Config (
2595- tile_m = config .wo_tile_drhs_batch_seq ,
2596- tile_k = config .wo_tile_drhs_mlp_dim ,
2597- tile_n = config .wo_tile_drhs_embed_dim ,
2598- )
2599- else :
2600- # Return wi tile sizes
2601- if dims == DEFAULT_RAGGED_DOT_DIM_NUMS :
2602- return Config (
2603- tile_m = config .wi_tile_fwd_batch_seq ,
2604- tile_k = config .wi_tile_fwd_embed_dim ,
2605- tile_n = config .wi_tile_fwd_mlp_dim ,
2606- )
2607- elif dims == DLHS_RAGGED_DOT_DIM_NUMS :
2608- return Config (
2609- tile_m = config .wi_tile_dlhs_batch_seq ,
2610- tile_k = config .wi_tile_dlhs_mlp_dim ,
2611- tile_n = config .wi_tile_dlhs_embed_dim ,
2612- )
2613- elif dims == DRHS_RAGGED_DOT_DIM_NUMS :
2614- return Config (
2615- tile_m = config .wi_tile_drhs_batch_seq ,
2616- tile_k = config .wi_tile_drhs_embed_dim ,
2617- tile_n = config .wi_tile_drhs_mlp_dim ,
2618- )
2619-
2620- return Config ()
2621-
2622- # Apply class-level monkey patch!
2623- # pylint: disable=protected-access
2624- PallasMosaicTpuRaggedDot ._get_heuristics_config = custom_heuristics
2625- _heuristics_patched = True
2626- print ("[TOKAMAX_PATCH] Successfully monkey-patched Tokamax GMM heuristics globally!" )
0 commit comments