1616
1717# pylint: disable=too-many-positional-arguments
1818
19- import functools
2019import dataclasses
21- from typing import Literal , List , Tuple
20+ import functools
21+ from typing import List , Literal , Tuple
2222import jax
2323import jax .numpy as jnp
2424from maxtext .kernels .megablox import backend
25- from tokamax ._src .ops .ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
2625import qwix
2726import qwix .pallas as qpl
27+ import tokamax
28+
29+
30+ DRHS_RAGGED_DOT_DIM_NUMS = jax .lax .RaggedDotDimensionNumbers (
31+ dot_dimension_numbers = (([0 ], [0 ]), ([], [])),
32+ lhs_ragged_dimensions = [0 ],
33+ rhs_group_dimensions = [],
34+ )
2835
2936
3037def gmm (
3138 lhs : jnp .ndarray ,
3239 rhs : jnp .ndarray ,
3340 group_sizes : jnp .ndarray ,
3441 preferred_element_type : jnp .dtype = jnp .float32 ,
35- tiling : tuple [int , int , int , int , int , int , int , int , int ] = (128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ),
42+ tiling : tuple [int , int , int , int , int , int , int , int , int ] = (
43+ 128 ,
44+ 128 ,
45+ 128 ,
46+ 128 ,
47+ 128 ,
48+ 128 ,
49+ 128 ,
50+ 128 ,
51+ 128 ,
52+ ),
3653 group_offset : jnp .ndarray | None = None ,
3754 existing_out : jnp .ndarray | None = None ,
3855 transpose_rhs : bool = False ,
@@ -42,8 +59,6 @@ def gmm(
4259 use_qwix_quantization : bool = False ,
4360 use_tokamax_backend : bool = False ,
4461 weight_gather_axes : List [Tuple [str , int ]] | None = None ,
45- input_buffer_count : tuple [int , int , int ] = (2 , 2 , 2 ),
46- combine_scopes : bool = False ,
4762 # TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
4863 qwix_rule : qwix .QtRule | None = None ,
4964):
@@ -65,16 +80,14 @@ def gmm(
6580 )
6681
6782 gmm_fwd_bwd = lambda * args : _gmm_fwd (* args )[0 ] # pylint: disable=C3001
68- gmm_fwd_bwd = jax .custom_vjp (gmm_fwd_bwd , nondiff_argnums = (3 , 4 , 5 , 6 , 9 , 10 , 11 , 12 , 13 ))
83+ gmm_fwd_bwd = jax .custom_vjp (gmm_fwd_bwd , nondiff_argnums = (3 , 4 , 7 , 8 , 9 , 10 , 11 ))
6984 gmm_fwd_bwd .defvjp (_gmm_fwd , functools .partial (_gmm_bwd , lhs .dtype , rhs .dtype ))
7085 return gmm_fwd_bwd (
7186 lhs ,
7287 rhs ,
7388 group_sizes ,
7489 preferred_element_type ,
7590 tiling ,
76- input_buffer_count ,
77- combine_scopes ,
7891 group_offset ,
7992 existing_out ,
8093 transpose_rhs ,
@@ -90,9 +103,17 @@ def _gmm_fwd(
90103 rhs : jnp .ndarray ,
91104 group_sizes : jnp .ndarray ,
92105 preferred_element_type : jnp .dtype = jnp .float32 ,
93- tiling : tuple [int , int , int , int , int , int , int , int , int ] = (128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ),
94- input_buffer_count : tuple [int , int , int ] = (2 , 2 , 2 ),
95- combine_scopes : bool = False ,
106+ tiling : tuple [int , int , int , int , int , int , int , int , int ] = (
107+ 128 ,
108+ 128 ,
109+ 128 ,
110+ 128 ,
111+ 128 ,
112+ 128 ,
113+ 128 ,
114+ 128 ,
115+ 128 ,
116+ ),
96117 group_offset : jnp .ndarray | None = None ,
97118 existing_out : jnp .ndarray | None = None ,
98119 transpose_rhs : bool = False ,
@@ -136,17 +157,18 @@ def _gmm_fwd(
136157 for axis_name , axis_idx in weight_gather_axes :
137158 rhs_qvalue = jax .lax .all_gather (rhs .qvalue , axis_name , axis = axis_idx , tiled = True )
138159 rhs = dataclasses .replace (rhs , qvalue = rhs_qvalue )
139- out = tokamax_backend .gmm (
160+ # Handle transpose_rhs manually as ragged_dot assumes (G, K, N)
161+ if transpose_rhs :
162+ rhs = rhs .swapaxes (1 , 2 )
163+
164+ out = tokamax .ragged_dot (
140165 lhs = lhs ,
141166 rhs = rhs ,
142167 group_sizes = group_sizes ,
143168 precision = jax .lax .Precision .DEFAULT ,
144- out_dtype = preferred_element_type ,
145- tiling = tiling [:3 ],
169+ preferred_element_type = preferred_element_type ,
146170 group_offset = group_offset ,
147- transpose_rhs = transpose_rhs ,
148- interpret = interpret ,
149- input_buffer_count = input_buffer_count [0 ],
171+ implementation = "mosaic" ,
150172 )
151173 else :
152174 out = backend .gmm (
@@ -168,8 +190,6 @@ def _gmm_bwd(
168190 rhs_dtype : jax .typing .DTypeLike ,
169191 preferred_element_type : jnp .dtype ,
170192 tiling : tuple [int , int , int , int , int , int , int , int , int ],
171- input_buffer_count : tuple [int , int , int ],
172- combine_scopes : bool ,
173193 transpose_rhs : bool ,
174194 interpret : bool ,
175195 quantization_rule : qwix .QtRule | None ,
@@ -224,30 +244,29 @@ def _gmm_bwd(
224244 calibration_method = quantization_rule .bwd_calibration_method ,
225245 )
226246 if use_tokamax_backend :
227- dlhs = tokamax_backend .gmm (
247+ # Handle transpose_rhs manually
248+ dlhs_rhs = rhs
249+ if not transpose_rhs :
250+ dlhs_rhs = dlhs_rhs .swapaxes (1 , 2 )
251+
252+ dlhs = tokamax .ragged_dot (
228253 lhs = dlhs_dout ,
229- rhs = rhs ,
254+ rhs = dlhs_rhs ,
230255 group_sizes = group_sizes ,
231256 precision = jax .lax .Precision .DEFAULT ,
232- out_dtype = lhs_dtype ,
233- tiling = tiling [3 :6 ],
257+ preferred_element_type = lhs_dtype ,
234258 group_offset = group_offset ,
235- transpose_rhs = not transpose_rhs ,
236- interpret = interpret ,
237- input_buffer_count = input_buffer_count [1 ],
259+ implementation = "mosaic" ,
238260 )
239- drhs = tokamax_backend . tgmm (
240- lhs = lhs . swapaxes ( 0 , 1 ) ,
261+ drhs = tokamax . ragged_dot_general (
262+ lhs = lhs ,
241263 rhs = drhs_dout ,
242264 group_sizes = group_sizes ,
265+ ragged_dot_dimension_numbers = DRHS_RAGGED_DOT_DIM_NUMS ,
243266 precision = jax .lax .Precision .DEFAULT ,
244- out_dtype = rhs_dtype ,
245- tiling = tiling [- 3 :],
267+ preferred_element_type = rhs_dtype ,
246268 group_offset = group_offset ,
247- num_actual_groups = num_actual_groups ,
248- interpret = interpret ,
249- input_buffer_count = input_buffer_count [2 ],
250- combine_scopes = combine_scopes ,
269+ implementation = "mosaic" ,
251270 )
252271 if quantization_rule and quantization_rule .bwd_qtype and weight_gather_axes :
253272 # Scatter back in reverse order of gather
0 commit comments