@@ -64,7 +64,7 @@ def gmm(
6464 rhs_vma_axes : tuple = tuple (),
6565 # TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
6666 qwix_rule : qwix .QtRule | None = None ,
67- use_manual_quantization : bool = False ,
67+ use_manual_quantization : bool = False , # used in batchsplit
6868):
6969 """Grouped matrix multiplication operation."""
7070 quantization_rule = None
@@ -163,10 +163,10 @@ def _gmm_fwd(
163163 else :
164164 rhs = quantizations .manual_quantize (
165165 rhs ,
166- quantization_rule .weight_calibration_method ,
167166 quantization_rule .weight_qtype ,
167+ calibration_method = quantization_rule .weight_calibration_method ,
168168 )
169- # QAG is only supported for following conditions
169+ # QAG is only supported for following conditions
170170 if use_tokamax_backend :
171171 if quantization_rule and quantization_rule .bwd_qtype :
172172 if quantization_rule .weight_calibration_method .startswith ("fixed" ) and isinstance (rhs , qpl .QArray ):
@@ -178,27 +178,23 @@ def _gmm_fwd(
178178 if transpose_rhs :
179179 rhs = rhs .swapaxes (1 , 2 )
180180
181+ # manual_axis_type is for gmm with shard_map check_vma=True, needs tokamax > 0.0.12
182+ out_kwargs = {}
181183 if use_manual_quantization :
182- out = tokamax .ragged_dot (
183- lhs = lhs ,
184- rhs = rhs ,
185- group_sizes = group_sizes ,
186- precision = jax .lax .Precision .DEFAULT ,
187- preferred_element_type = preferred_element_type ,
188- group_offset = group_offset ,
189- implementation = "mosaic" ,
190- manual_axis_type = jax .sharding .ManualAxisType (varying = frozenset (["data" , "fsdp" , "expert" ])),
191- )
192- else :
193- out = tokamax .ragged_dot (
194- lhs = lhs ,
195- rhs = rhs ,
196- group_sizes = group_sizes ,
197- precision = jax .lax .Precision .DEFAULT ,
198- preferred_element_type = preferred_element_type ,
199- group_offset = group_offset ,
200- implementation = "mosaic" ,
201- )
184+ # used in batchsplit
185+ out_kwargs ["manual_axis_type" ] = jax .sharding .ManualAxisType (varying = frozenset (["data" , "fsdp" , "expert" ]))
186+
187+ out = tokamax .ragged_dot (
188+ lhs = lhs ,
189+ rhs = rhs ,
190+ group_sizes = group_sizes ,
191+ precision = jax .lax .Precision .DEFAULT ,
192+ preferred_element_type = preferred_element_type ,
193+ # `group_offset` is not yet supported
194+ group_offset = None ,
195+ implementation = "mosaic" ,
196+ ** out_kwargs ,
197+ )
202198 else :
203199 out = backend .gmm (
204200 lhs ,
@@ -284,53 +280,39 @@ def _gmm_bwd(
284280 if not transpose_rhs :
285281 dlhs_rhs = dlhs_rhs .swapaxes (1 , 2 )
286282
283+ # manual_axis_type is for gmm with shard_map check_vma=True, needs tokamax > 0.0.12
284+ dlhs_kwargs = {}
285+ drhs_kwargs = {}
287286 if use_manual_quantization :
288- dlhs = tokamax .ragged_dot (
289- lhs = dlhs_dout ,
290- rhs = dlhs_rhs ,
291- group_sizes = group_sizes ,
292- precision = jax .lax .Precision .DEFAULT ,
293- preferred_element_type = lhs_dtype ,
294- group_offset = group_offset ,
295- implementation = "mosaic" ,
296- manual_axis_type = jax .sharding .ManualAxisType (varying = frozenset (["data" , "fsdp" , "expert" ])),
297- )
298- else :
299- dlhs = tokamax .ragged_dot (
300- lhs = dlhs_dout ,
301- rhs = dlhs_rhs ,
302- group_sizes = group_sizes ,
303- precision = jax .lax .Precision .DEFAULT ,
304- preferred_element_type = lhs_dtype ,
305- group_offset = group_offset ,
306- implementation = "mosaic" ,
307- )
308- if use_manual_quantization :
309- drhs = tokamax .ragged_dot_general (
310- lhs = lhs ,
311- rhs = drhs_dout ,
312- group_sizes = group_sizes ,
313- ragged_dot_dimension_numbers = DRHS_RAGGED_DOT_DIM_NUMS ,
314- precision = jax .lax .Precision .DEFAULT ,
315- preferred_element_type = rhs_dtype ,
316- group_offset = group_offset ,
317- implementation = "mosaic" ,
318- manual_axis_type = jax .sharding .ManualAxisType (
319- varying = frozenset (["expert" ]),
320- unreduced = frozenset (["data" , "fsdp" ]),
321- ),
322- )
323- else :
324- drhs = tokamax .ragged_dot_general (
325- lhs = lhs ,
326- rhs = drhs_dout ,
327- group_sizes = group_sizes ,
328- ragged_dot_dimension_numbers = DRHS_RAGGED_DOT_DIM_NUMS ,
329- precision = jax .lax .Precision .DEFAULT ,
330- preferred_element_type = rhs_dtype ,
331- group_offset = group_offset ,
332- implementation = "mosaic" ,
287+ # used in batchsplit
288+ dlhs_kwargs ["manual_axis_type" ] = jax .sharding .ManualAxisType (varying = frozenset (["data" , "fsdp" , "expert" ]))
289+ drhs_kwargs ["manual_axis_type" ] = jax .sharding .ManualAxisType (
290+ varying = frozenset (["expert" ]), unreduced = frozenset (["data" , "fsdp" ])
333291 )
292+
293+ dlhs = tokamax .ragged_dot (
294+ lhs = dlhs_dout ,
295+ rhs = dlhs_rhs ,
296+ group_sizes = group_sizes ,
297+ precision = jax .lax .Precision .DEFAULT ,
298+ preferred_element_type = lhs_dtype ,
299+ # `group_offset` is not yet supported
300+ group_offset = None ,
301+ implementation = "mosaic" ,
302+ ** dlhs_kwargs ,
303+ )
304+ drhs = tokamax .ragged_dot_general (
305+ lhs = lhs ,
306+ rhs = drhs_dout ,
307+ group_sizes = group_sizes ,
308+ ragged_dot_dimension_numbers = DRHS_RAGGED_DOT_DIM_NUMS ,
309+ precision = jax .lax .Precision .DEFAULT ,
310+ preferred_element_type = rhs_dtype ,
311+ # `group_offset` is not yet supported
312+ group_offset = None ,
313+ implementation = "mosaic" ,
314+ ** drhs_kwargs ,
315+ )
334316 if quantization_rule and quantization_rule .bwd_qtype and weight_gather_axes :
335317 # Scatter back in reverse order of gather
336318 for axis_name , axis_idx in reversed (weight_gather_axes ):
0 commit comments