Skip to content

Commit 19c63a6

Browse files
Merge pull request #3869 from AI-Hypercomputer:zxhe/tokamax_ragged_dot
PiperOrigin-RevId: 915115659
2 parents 1ea03b1 + c201e2a commit 19c63a6

1 file changed

Lines changed: 68 additions & 38 deletions

File tree

  • src/maxtext/kernels/megablox

src/maxtext/kernels/megablox/ops.py

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,27 @@ def _gmm_fwd(
172172
if transpose_rhs:
173173
rhs = rhs.swapaxes(1, 2)
174174

175-
out = tokamax.ragged_dot(
176-
lhs=lhs,
177-
rhs=rhs,
178-
group_sizes=group_sizes,
179-
precision=jax.lax.Precision.DEFAULT,
180-
preferred_element_type=preferred_element_type,
181-
group_offset=group_offset,
182-
implementation="mosaic",
183-
manual_axis_type=jax.sharding.ManualAxisType(
184-
varying=frozenset(["data", "fsdp", "expert"])
185-
) if use_manual_quantization else None,
186-
)
175+
if use_manual_quantization:
176+
out = tokamax.ragged_dot(
177+
lhs=lhs,
178+
rhs=rhs,
179+
group_sizes=group_sizes,
180+
precision=jax.lax.Precision.DEFAULT,
181+
preferred_element_type=preferred_element_type,
182+
group_offset=group_offset,
183+
implementation="mosaic",
184+
manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])),
185+
)
186+
else:
187+
out = tokamax.ragged_dot(
188+
lhs=lhs,
189+
rhs=rhs,
190+
group_sizes=group_sizes,
191+
precision=jax.lax.Precision.DEFAULT,
192+
preferred_element_type=preferred_element_type,
193+
group_offset=group_offset,
194+
implementation="mosaic",
195+
)
187196
else:
188197
out = backend.gmm(
189198
lhs,
@@ -264,32 +273,53 @@ def _gmm_bwd(
264273
if not transpose_rhs:
265274
dlhs_rhs = dlhs_rhs.swapaxes(1, 2)
266275

267-
dlhs = tokamax.ragged_dot(
268-
lhs=dlhs_dout,
269-
rhs=dlhs_rhs,
270-
group_sizes=group_sizes,
271-
precision=jax.lax.Precision.DEFAULT,
272-
preferred_element_type=lhs_dtype,
273-
group_offset=group_offset,
274-
implementation="mosaic",
275-
manual_axis_type=jax.sharding.ManualAxisType(
276-
varying=frozenset(["data", "fsdp", "expert"])
277-
) if use_manual_quantization else None,
278-
)
279-
drhs = tokamax.ragged_dot_general(
280-
lhs=lhs,
281-
rhs=drhs_dout,
282-
group_sizes=group_sizes,
283-
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
284-
precision=jax.lax.Precision.DEFAULT,
285-
preferred_element_type=rhs_dtype,
286-
group_offset=group_offset,
287-
implementation="mosaic",
288-
manual_axis_type=jax.sharding.ManualAxisType(
289-
varying=frozenset(["expert"]),
290-
unreduced=frozenset(["data", "fsdp"])
291-
) if use_manual_quantization else None,
292-
)
276+
if use_manual_quantization:
277+
dlhs = tokamax.ragged_dot(
278+
lhs=dlhs_dout,
279+
rhs=dlhs_rhs,
280+
group_sizes=group_sizes,
281+
precision=jax.lax.Precision.DEFAULT,
282+
preferred_element_type=lhs_dtype,
283+
group_offset=group_offset,
284+
implementation="mosaic",
285+
manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])),
286+
)
287+
else:
288+
dlhs = tokamax.ragged_dot(
289+
lhs=dlhs_dout,
290+
rhs=dlhs_rhs,
291+
group_sizes=group_sizes,
292+
precision=jax.lax.Precision.DEFAULT,
293+
preferred_element_type=lhs_dtype,
294+
group_offset=group_offset,
295+
implementation="mosaic",
296+
)
297+
if use_manual_quantization:
298+
drhs = tokamax.ragged_dot_general(
299+
lhs=lhs,
300+
rhs=drhs_dout,
301+
group_sizes=group_sizes,
302+
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
303+
precision=jax.lax.Precision.DEFAULT,
304+
preferred_element_type=rhs_dtype,
305+
group_offset=group_offset,
306+
implementation="mosaic",
307+
manual_axis_type=jax.sharding.ManualAxisType(
308+
varying=frozenset(["expert"]),
309+
unreduced=frozenset(["data", "fsdp"]),
310+
),
311+
)
312+
else:
313+
drhs = tokamax.ragged_dot_general(
314+
lhs=lhs,
315+
rhs=drhs_dout,
316+
group_sizes=group_sizes,
317+
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
318+
precision=jax.lax.Precision.DEFAULT,
319+
preferred_element_type=rhs_dtype,
320+
group_offset=group_offset,
321+
implementation="mosaic",
322+
)
293323
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
294324
# Scatter back in reverse order of gather
295325
for axis_name, axis_idx in reversed(weight_gather_axes):

0 commit comments

Comments
 (0)