@@ -172,18 +172,27 @@ def _gmm_fwd(
172172 if transpose_rhs :
173173 rhs = rhs .swapaxes (1 , 2 )
174174
175- out = tokamax .ragged_dot (
176- lhs = lhs ,
177- rhs = rhs ,
178- group_sizes = group_sizes ,
179- precision = jax .lax .Precision .DEFAULT ,
180- preferred_element_type = preferred_element_type ,
181- group_offset = group_offset ,
182- implementation = "mosaic" ,
183- manual_axis_type = jax .sharding .ManualAxisType (
184- varying = frozenset (["data" , "fsdp" , "expert" ])
185- ) if use_manual_quantization else None ,
186- )
175+ if use_manual_quantization :
176+ out = tokamax .ragged_dot (
177+ lhs = lhs ,
178+ rhs = rhs ,
179+ group_sizes = group_sizes ,
180+ precision = jax .lax .Precision .DEFAULT ,
181+ preferred_element_type = preferred_element_type ,
182+ group_offset = group_offset ,
183+ implementation = "mosaic" ,
184+ manual_axis_type = jax .sharding .ManualAxisType (varying = frozenset (["data" , "fsdp" , "expert" ])),
185+ )
186+ else :
187+ out = tokamax .ragged_dot (
188+ lhs = lhs ,
189+ rhs = rhs ,
190+ group_sizes = group_sizes ,
191+ precision = jax .lax .Precision .DEFAULT ,
192+ preferred_element_type = preferred_element_type ,
193+ group_offset = group_offset ,
194+ implementation = "mosaic" ,
195+ )
187196 else :
188197 out = backend .gmm (
189198 lhs ,
@@ -264,32 +273,53 @@ def _gmm_bwd(
264273 if not transpose_rhs :
265274 dlhs_rhs = dlhs_rhs .swapaxes (1 , 2 )
266275
267- dlhs = tokamax .ragged_dot (
268- lhs = dlhs_dout ,
269- rhs = dlhs_rhs ,
270- group_sizes = group_sizes ,
271- precision = jax .lax .Precision .DEFAULT ,
272- preferred_element_type = lhs_dtype ,
273- group_offset = group_offset ,
274- implementation = "mosaic" ,
275- manual_axis_type = jax .sharding .ManualAxisType (
276- varying = frozenset (["data" , "fsdp" , "expert" ])
277- ) if use_manual_quantization else None ,
278- )
279- drhs = tokamax .ragged_dot_general (
280- lhs = lhs ,
281- rhs = drhs_dout ,
282- group_sizes = group_sizes ,
283- ragged_dot_dimension_numbers = DRHS_RAGGED_DOT_DIM_NUMS ,
284- precision = jax .lax .Precision .DEFAULT ,
285- preferred_element_type = rhs_dtype ,
286- group_offset = group_offset ,
287- implementation = "mosaic" ,
288- manual_axis_type = jax .sharding .ManualAxisType (
289- varying = frozenset (["expert" ]),
290- unreduced = frozenset (["data" , "fsdp" ])
291- ) if use_manual_quantization else None ,
292- )
276+ if use_manual_quantization :
277+ dlhs = tokamax .ragged_dot (
278+ lhs = dlhs_dout ,
279+ rhs = dlhs_rhs ,
280+ group_sizes = group_sizes ,
281+ precision = jax .lax .Precision .DEFAULT ,
282+ preferred_element_type = lhs_dtype ,
283+ group_offset = group_offset ,
284+ implementation = "mosaic" ,
285+ manual_axis_type = jax .sharding .ManualAxisType (varying = frozenset (["data" , "fsdp" , "expert" ])),
286+ )
287+ else :
288+ dlhs = tokamax .ragged_dot (
289+ lhs = dlhs_dout ,
290+ rhs = dlhs_rhs ,
291+ group_sizes = group_sizes ,
292+ precision = jax .lax .Precision .DEFAULT ,
293+ preferred_element_type = lhs_dtype ,
294+ group_offset = group_offset ,
295+ implementation = "mosaic" ,
296+ )
297+ if use_manual_quantization :
298+ drhs = tokamax .ragged_dot_general (
299+ lhs = lhs ,
300+ rhs = drhs_dout ,
301+ group_sizes = group_sizes ,
302+ ragged_dot_dimension_numbers = DRHS_RAGGED_DOT_DIM_NUMS ,
303+ precision = jax .lax .Precision .DEFAULT ,
304+ preferred_element_type = rhs_dtype ,
305+ group_offset = group_offset ,
306+ implementation = "mosaic" ,
307+ manual_axis_type = jax .sharding .ManualAxisType (
308+ varying = frozenset (["expert" ]),
309+ unreduced = frozenset (["data" , "fsdp" ]),
310+ ),
311+ )
312+ else :
313+ drhs = tokamax .ragged_dot_general (
314+ lhs = lhs ,
315+ rhs = drhs_dout ,
316+ group_sizes = group_sizes ,
317+ ragged_dot_dimension_numbers = DRHS_RAGGED_DOT_DIM_NUMS ,
318+ precision = jax .lax .Precision .DEFAULT ,
319+ preferred_element_type = rhs_dtype ,
320+ group_offset = group_offset ,
321+ implementation = "mosaic" ,
322+ )
293323 if quantization_rule and quantization_rule .bwd_qtype and weight_gather_axes :
294324 # Scatter back in reverse order of gather
295325 for axis_name , axis_idx in reversed (weight_gather_axes ):
0 commit comments