Skip to content

Commit 694506b

Browse files
author
zhangyue
committed
test(rotary_embedding): add pre_gathered=True coverage
Fold the deleted `test_apply_rotary_pos_emb` / `_3d` cases into a single `test_rotary_embedding_pre_gathered` that exercises the `pre_gathered` fast path directly on the `rotary_embedding` overload (no shim). Parametrize over 2D / 3D query-key layouts, impls 0 and 1 (impl 2 asserts `!pre_gathered_`), neox / GPT-J styles, fp16 / bf16. The new `_build_pre_gathered_cache` helper constructs the `[2*T, head_size]` wire format that `src/ascend/rotary_embedding/kernel.h` expects — cos rows 0..T-1, sin rows T..2T-1, both neox-expanded per token. Coverage: 12 new cases pass (4 skip for `impl=0 + not-neox`, same as the `test_rotary_embedding_full` skip — V2 only supports `rotaryMode="half"`). Full rotary suite: 88 passed, 8 skipped (was 80 passed, 4 skipped before this test was added).
1 parent fdeb779 commit 694506b

1 file changed

Lines changed: 105 additions & 0 deletions

File tree

tests/test_rotary_embedding.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,3 +619,108 @@ def test_rotary_embedding_inplace(implementation_index, dtype, rtol, atol, devic
619619
_assert_close(query, ref_q, rtol, atol)
620620
_assert_close(key, ref_k, rtol, atol)
621621

622+
623+
def _build_pre_gathered_cache(cos_sin_cache, positions, head_size, is_neox_style):
624+
"""Build the `[2 * T, head_size]` pre-gathered cache the kernel expects.
625+
626+
Layout (see `src/ascend/rotary_embedding/kernel.h` pre-gathered branch):
627+
- rows `0..T-1`: neox-expanded cos for each token (row `t` holds the
628+
cos values for `positions[t]`, broadcast to full `head_size`).
629+
- rows `T..2T-1`: neox-expanded sin, same indexing.
630+
"""
631+
half = head_size // 2
632+
cos_half = cos_sin_cache[:, :half].index_select(0, positions)
633+
sin_half = cos_sin_cache[:, half:].index_select(0, positions)
634+
635+
if is_neox_style:
636+
cos_full = torch.cat([cos_half, cos_half], dim=-1)
637+
sin_full = torch.cat([sin_half, sin_half], dim=-1)
638+
else:
639+
# GPT-J interleave: pair-wise expansion `(x[0],x[0],x[1],x[1],…)`.
640+
cos_full = cos_half.repeat_interleave(2, dim=-1)
641+
sin_full = sin_half.repeat_interleave(2, dim=-1)
642+
643+
return torch.cat([cos_full, sin_full], dim=0)
644+
645+
646+
# Hardcoded `(0, 1)` — impl 2 (`aclnnRopeWithSinCosCache`) asserts
647+
# `!pre_gathered_` at construction. Cannot use conftest auto-injection.
648+
@pytest.mark.parametrize("implementation_index", (0, 1))
649+
@pytest.mark.parametrize("layout", ("2d", "3d"))
650+
@pytest.mark.parametrize("is_neox_style", (True, False))
651+
@pytest.mark.parametrize(
652+
("dtype", "rtol", "atol"),
653+
(
654+
(torch.float16, 1e-2, 5e-3),
655+
(torch.bfloat16, 1e-2, 5e-3),
656+
),
657+
)
658+
@pytest.mark.parametrize("device", ("npu",))
659+
def test_rotary_embedding_pre_gathered(
660+
implementation_index, layout, is_neox_style, dtype, rtol, atol, device
661+
):
662+
"""`pre_gathered=True` fast path: caller hands in `[2*T, head_size]` with
663+
cos/sin already gathered and neox-expanded per token. Exercises both 2D
664+
`[T, N*D]` and 3D `[T, N, D]` query/key layouts."""
665+
if not (hasattr(torch, "npu") and torch.npu.is_available()):
666+
pytest.skip("NPU not available")
667+
668+
if not is_neox_style and implementation_index == 0:
669+
pytest.skip(
670+
'Ascend `aclnnApplyRotaryPosEmbV2` only supports `rotaryMode="half"`'
671+
)
672+
673+
num_tokens = 8
674+
num_heads = 16
675+
num_kv_heads = 4
676+
head_size = 128
677+
rotary_dim = head_size
678+
max_seq_len = 64
679+
680+
positions = randint_strided(
681+
0, max_seq_len, (num_tokens,), None, dtype=torch.int64, device=device
682+
)
683+
cos_sin_cache = randn_strided(
684+
(max_seq_len, rotary_dim), None, dtype=dtype, device=device
685+
)
686+
687+
if layout == "3d":
688+
q_shape = (num_tokens, num_heads, head_size)
689+
k_shape = (num_tokens, num_kv_heads, head_size)
690+
else:
691+
q_shape = (num_tokens, num_heads * head_size)
692+
k_shape = (num_tokens, num_kv_heads * head_size)
693+
694+
query = randn_strided(q_shape, None, dtype=dtype, device=device)
695+
key = randn_strided(k_shape, None, dtype=dtype, device=device)
696+
query_out = torch.empty_like(query)
697+
key_out = torch.empty_like(key)
698+
699+
pre_gathered_cache = _build_pre_gathered_cache(
700+
cos_sin_cache, positions, head_size, is_neox_style
701+
)
702+
# Kernel reads `positions` as `0..T-1` in the pre-gathered path (the
703+
# gather has already happened); the actual values are not indexed.
704+
arange_positions = torch.arange(num_tokens, dtype=torch.int64, device=device)
705+
706+
infini.ops.rotary_embedding(
707+
arange_positions,
708+
query,
709+
key,
710+
head_size,
711+
pre_gathered_cache,
712+
is_neox_style,
713+
rotary_dim,
714+
query_out,
715+
key_out,
716+
True,
717+
implementation_index=implementation_index,
718+
stream=get_stream(query.device),
719+
)
720+
721+
ref_q, ref_k = _ref_rotary_embedding(
722+
positions, query, key, cos_sin_cache, head_size, rotary_dim, is_neox_style
723+
)
724+
725+
_assert_close(query_out, ref_q, rtol, atol)
726+
_assert_close(key_out, ref_k, rtol, atol)

0 commit comments

Comments
 (0)