Skip to content

Commit acd00f9

Browse files
Re-enable JAX 0.10.0 with MoE layout-constraint fix (#2683)
Signed-off-by: Qiliang Cui <derrhein@gmail.com> Co-authored-by: Qiliang Cui <derrhein@gmail.com>
1 parent 9828b94 commit acd00f9

7 files changed

Lines changed: 43 additions & 46 deletions

File tree

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ pytest-mock
55
absl-py
66
numpy
77
google-cloud-storage
8-
jax==0.9.2
9-
jaxlib==0.9.2
10-
libtpu==0.0.39
8+
jax==0.10.0
9+
jaxlib==0.10.0
10+
libtpu==0.0.40
1111
jaxtyping
1212
flax==0.12.4
1313
torchax==0.0.11

tests/kernels/mla_v2_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
import os
1616

17-
import pytest
18-
1917
os.environ["LIBTPU_INIT_ARGS"] = (os.environ.get("LIBTPU_INIT_ARGS", "") +
2018
" --xla_tpu_scoped_vmem_limit_kib=65536")
2119

@@ -408,9 +406,6 @@ def _test_mla_ragged_paged_attention(
408406
self.assertAllClose(expected_out, kernel_out, atol=0.1, rtol=0.2)
409407

410408

411-
# The test is slow on v6e, causing timeouts in presubmit. See b/513860288.
412-
@pytest.mark.skipif(not jtu.is_device_tpu_at_least(version=7),
413-
reason="Expect TPUv7+")
414409
@jtu.with_config(jax_numpy_dtype_promotion="standard")
415410
class MlaRaggedPagedAttentionKernelV2Test(MlaRaggedPagedAttentionTestBase):
416411

tests/kernels/ragged_gather_reduce_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,6 @@ class ScatterTest(jtu.JaxTestCase):
8383
@parameterized.parameters(*_test_cases)
8484
def test_sc_ragged_gather_reduce(self, out_size, hidden_size, start_end,
8585
dtype, reduce_group_size):
86-
# The test is slow on v6e, causing timeouts in presubmit. See b/513860288.
87-
if not jtu.is_device_tpu_at_least(version=7):
88-
self.skipTest("Expect TPUv7+")
8986
start, end = start_end
9087
start = min(start, out_size)
9188
end = min(end, out_size)

tests/layers/vllm/test_fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,9 @@ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
431431

432432
a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
433433
w1 = torch.randn(
434-
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
434+
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 100
435435
w2 = torch.randn(
436-
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
436+
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 100
437437
score = torch.randn((num_tokens, num_experts), dtype=dtype)
438438

439439
engine_args = EngineArgs(model=MODELS[0],

tpu_inference/kernels/sparse_core/gather_reduce.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def __getitem__(self, shape):
6363

6464
def is_supported_by_sc_gather_reduce(x_shape: int,
6565
sc_kernel_threshold: int) -> bool:
66+
# TODO: Skip until numeric issue is fixed.
67+
return False
6668
if x_shape > sc_kernel_threshold and pltpu.get_tpu_info().generation == 7:
6769
return True
6870
return False

tpu_inference/kernels/sparse_core/ragged_gather_reduce.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,10 @@ def main_kernel(
9999
# Inputs.
100100
num_rows_per_row_partition_ref: jax.Ref,
101101
in_hbm_ref: jax.Ref,
102-
src_indices_hbm_ref: jax.Ref,
102+
indices_hbm_ref: jax.Ref,
103103
dst_indices_hbm_ref: jax.Ref,
104104
topk_weights_hbm_ref: jax.Ref,
105+
sorted_by_validity_hbm_ref: jax.Ref,
105106
# Outputs.
106107
out_hbm_ref: jax.Ref,
107108
# Scratch.
@@ -111,6 +112,7 @@ def main_kernel(
111112
src_indices_vmem_ref: jax.Ref,
112113
dst_indices_vmem_ref: jax.Ref,
113114
topk_weights_vmem_ref: jax.Ref,
115+
sorted_by_validity_vmem_ref: jax.Ref,
114116
sem_ref: jax.Ref,
115117
*,
116118
core_axis_name: str,
@@ -176,9 +178,9 @@ def row_loop(row_block_id):
176178
dma_list = []
177179
dma_list.append(
178180
pltpu.make_async_copy(
179-
src_indices_hbm_ref.at[pl.ds(row_tile_start,
180-
num_simd_lanes)],
181-
src_indices_vmem_ref,
181+
sorted_by_validity_hbm_ref.at[pl.ds(
182+
row_tile_start, num_simd_lanes)],
183+
sorted_by_validity_vmem_ref,
182184
recv_sem,
183185
))
184186
dma_list.append(
@@ -188,13 +190,22 @@ def row_loop(row_block_id):
188190
dst_indices_vmem_ref,
189191
recv_sem,
190192
))
193+
jax.tree.map(lambda x: x.start(), dma_list)
194+
jax.tree.map(lambda x: x.wait(), dma_list)
195+
196+
dma_list = []
191197
dma_list.append(
192198
pltpu.make_async_copy(
193-
topk_weights_hbm_ref.at[pl.ds(row_tile_start,
194-
num_simd_lanes)],
199+
topk_weights_hbm_ref.at[sorted_by_validity_vmem_ref],
195200
topk_weights_vmem_ref,
196201
recv_sem,
197202
))
203+
dma_list.append(
204+
pltpu.make_async_copy(
205+
indices_hbm_ref.at[sorted_by_validity_vmem_ref],
206+
src_indices_vmem_ref,
207+
recv_sem,
208+
))
198209
jax.tree.map(lambda x: x.start(), dma_list)
199210
jax.tree.map(lambda x: x.wait(), dma_list)
200211

@@ -227,9 +238,12 @@ def row_loop(row_block_id):
227238

228239
# VMEM to HBM transfer.
229240
# Use dynamic loop to minimize register spills.
241+
@pl.loop(0,
242+
col_size,
243+
step=num_lanes,
244+
init_carry=(prev_dst_row_hbm, ))
230245
@jax.named_scope("dma_write_loop")
231-
def dma_write_loop(i, carry):
232-
col_vmem_start = i * num_lanes
246+
def dma_write_loop(col_vmem_start, carry):
233247
col_hbm_start = col_start + col_vmem_start
234248

235249
for _ in range(num_simd_lanes):
@@ -359,12 +373,6 @@ def dma_write_loop(i, carry):
359373

360374
return carry
361375

362-
jax.lax.fori_loop(
363-
0,
364-
pl.cdiv(col_size, num_lanes),
365-
dma_write_loop,
366-
init_val=(prev_dst_row_hbm, ),
367-
)
368376
# Wait for dma write to finish.
369377
for _ in range(0, col_size, num_lanes):
370378
for _ in range(num_simd_lanes):
@@ -380,12 +388,11 @@ def dma_write_loop(i, carry):
380388
# TODO(gxd): investigate if we can make the preprocessing more efficient.
381389
def _preprocess(
382390
indices: jax.Array,
383-
topk_weights: jax.Array,
384391
valid_rows_mask: jax.Array,
385392
reduce_group_size: int,
386393
num_row_partitions: int,
387394
num_simd_lanes: int,
388-
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
395+
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
389396
"""Preprocesses indices for ragged gather reduce."""
390397
assert indices.ndim == 1, "Ragged scatter only supports 1d indices."
391398

@@ -403,12 +410,10 @@ def _preprocess(
403410
) * row_partition_size)
404411
sorted_by_validity = sorted_by_validity.reshape(-1)
405412

406-
src_indices = indices[sorted_by_validity]
407413
# `reduce_group_size` source rows are mapped (and reduced) to the same output
408414
# row.
409415
dst_indices = sorted_by_validity // reduce_group_size
410-
topk_weights = topk_weights[sorted_by_validity]
411-
topk_weights = topk_weights.astype(jnp.float32)
416+
sorted_by_validity = sorted_by_validity.astype(jnp.int32)
412417

413418
num_src_rows_per_row_partition = jnp.sum(valid_rows_mask, axis=-1)
414419
assert num_row_partitions <= num_simd_lanes
@@ -421,9 +426,8 @@ def _preprocess(
421426
mask = jnp.any(valid_rows_mask.reshape(-1, reduce_group_size), axis=-1)
422427

423428
return (
424-
src_indices,
425429
dst_indices,
426-
topk_weights,
430+
sorted_by_validity,
427431
num_src_rows_per_row_partition,
428432
mask,
429433
)
@@ -521,14 +525,12 @@ def ragged_gather_reduce(
521525
col_size = x.shape[-1] // num_column_partitions
522526

523527
(
524-
src_indices,
525528
dst_indices,
526-
topk_weights,
529+
sorted_by_validity,
527530
num_src_rows_per_row_partition,
528531
mask,
529532
) = _preprocess(
530533
indices,
531-
topk_weights,
532534
valid_rows_mask,
533535
reduce_group_size,
534536
num_row_partitions,
@@ -566,12 +568,19 @@ def ragged_gather_reduce(
566568
pltpu.VMEM((num_simd_lanes, ), jnp.int32),
567569
pltpu.VMEM((num_simd_lanes, ), jnp.int32),
568570
pltpu.VMEM((num_simd_lanes, ), jnp.float32),
571+
pltpu.VMEM((num_simd_lanes, ), jnp.int32),
569572
pltpu.SemaphoreType.DMA((2, )),
570573
],
571574
mesh=vector_mesh,
572575
name="sc_ragged_gather_reduce",
573-
)(num_src_rows_per_row_partition, x, src_indices, dst_indices,
574-
topk_weights)
576+
)(
577+
num_src_rows_per_row_partition,
578+
x,
579+
indices,
580+
dst_indices,
581+
topk_weights.astype(jnp.float32),
582+
sorted_by_validity,
583+
)
575584

576585
# If there is no valid source row in a reduce group, set that group's output
577586
# to zero.

tpu_inference/layers/common/process_weights/moe_weights.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import jax
1818
import jax.numpy as jnp
19-
from jax.experimental.layout import Layout, with_layout_constraint
19+
from jax.experimental.layout import Layout
2020
from jax.sharding import Mesh, NamedSharding, PartitionSpec
2121
from torchax.tensor import Tensor
2222
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
@@ -323,10 +323,6 @@ def process_moe_weights(
323323
w13_weight = jnp.swapaxes(w13_weight, 1, 2)
324324
w2_weight = jnp.swapaxes(w2_weight, 1, 2)
325325

326-
# Workaround for JAX error "must have valid byte strides"
327-
w13_weight = with_layout_constraint(w13_weight, Layout((0, 1, 2)))
328-
w2_weight = with_layout_constraint(w2_weight, Layout((0, 1, 2)))
329-
330326
if w13_weight_scale is not None:
331327
# For block scales (experts, out_blocks, in_blocks), we need to maintain
332328
# the block dims
@@ -374,8 +370,6 @@ def process_moe_weights(
374370
intermediate_size,
375371
)
376372
w13_weight = jnp.swapaxes(w13_weight, 1, 2)
377-
w13_weight = with_layout_constraint(w13_weight, Layout(
378-
(0, 1, 2, 3)))
379373

380374
# Fused moe kernel expects dims to be multiple of 256.
381375
pad_width_intermediate_size = (align_to(intermediate_size, 256) -

0 commit comments

Comments
 (0)