Skip to content

Commit 1f4c15e

Browse files
author
zhangyue
committed
feat(ascend/rotary_embedding): add impl=2 via aclnnRopeWithSinCosCache
Partial rotary (`rotary_dim < head_size`) is not expressible in the V2 (`aclnnApplyRotaryPosEmbV2`, impl=0) or ATB `RopeParam` (impl=1) APIs — both require `cos.D == sin.D == x.D`. `aclnnRopeWithSinCosCache` is the only Ascend fused API that accepts partial rotary natively; it also supports both neox and interleave styles via `isNeoxStyle` bool. `test_rotary_embedding_partial` now routes through impl=2, resolving the 4 G-case skips.
1 parent ccc7b5d commit 1f4c15e

2 files changed

Lines changed: 154 additions & 7 deletions

File tree

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#ifndef INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_SINCOS_CACHE_H_
2+
#define INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_SINCOS_CACHE_H_
3+
4+
#include <cassert>
5+
#include <cstdint>
6+
7+
#include "acl/acl.h"
8+
#include "aclnn/aclnn_base.h"
9+
#include "aclnnop/aclnn_rope_with_sin_cos_cache.h"
10+
#include "ascend/common.h"
11+
#include "ascend/workspace_pool_.h"
12+
#include "base/rotary_embedding.h"
13+
#include "operator.h"
14+
15+
namespace infini::ops {
16+
17+
// Rotary position embedding via `aclnnRopeWithSinCosCache` (implementation
18+
// index 2). This is the only Ascend fused rotary API that supports partial
19+
// rotary (`rotary_dim < head_size`); it also natively supports both
20+
// GPT-NeoX (`is_neox_style=true`) and GPT-J (`is_neox_style=false`) styles
21+
// from the same interface.
22+
//
23+
// Input format: 2D contiguous `[num_tokens, num_heads * head_size]`. The
24+
// aclnn wrapper reads strides from the tensor descriptor — we pass a 2D
25+
// descriptor even when the caller holds a 3D view `[T, N, D]`, since the
26+
// memory layout is identical for contiguous tensors. The 2D descriptor is
27+
// what the aclnn sample in the CANN 8.5 docs uses.
28+
//
29+
// `cos_sin_cache` layout: `[max_seq_len, rotary_dim]` where the first
30+
// `rotary_dim / 2` columns are cos and the next `rotary_dim / 2` are sin.
31+
// The aclnn API splits internally via `cosSin.chunk(2, dim=-1)`.
32+
//
33+
// cf. `aclnn_rope_with_sin_cos_cache_hidden_attrs` memory: the public
34+
// header hides four `REG_OP` attrs (`numQHeads`, `numKHeads`, `qStride`,
35+
// `kStride`). For 2D contiguous inputs the aclnn wrapper infers them
36+
// correctly from the tensor descriptor; for 3D descriptors a previous
37+
// attempt produced garbage output.
38+
template <>
39+
class Operator<RotaryEmbedding, Device::Type::kAscend, 2>
40+
: public RotaryEmbedding {
41+
public:
42+
Operator(const Tensor positions, const Tensor query, const Tensor key,
43+
const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim,
44+
bool is_neox_style, Tensor query_out, Tensor key_out)
45+
: RotaryEmbedding(positions, query, key, cos_sin_cache, head_size,
46+
rotary_dim, is_neox_style, query_out, key_out),
47+
max_seq_len_{cos_sin_cache.size(0)} {
48+
const int64_t T = num_tokens_;
49+
const int64_t Nq = num_heads_;
50+
const int64_t Nkv = num_kv_heads_;
51+
const int64_t D = head_size_;
52+
aclDataType acl_dt = ascend::toAclDtype(query.dtype());
53+
54+
positions_cache_ = ascend::AclTensorCache(
55+
{T}, ACL_INT64, const_cast<void*>(positions.data()));
56+
q_in_cache_ = ascend::AclTensorCache(
57+
{T, Nq * D}, acl_dt, const_cast<void*>(query.data()));
58+
k_in_cache_ = ascend::AclTensorCache(
59+
{T, Nkv * D}, acl_dt, const_cast<void*>(key.data()));
60+
cos_sin_cache_cache_ = ascend::AclTensorCache(
61+
{max_seq_len_, rotary_dim_}, acl_dt,
62+
const_cast<void*>(cos_sin_cache.data()));
63+
q_out_cache_ =
64+
ascend::AclTensorCache({T, Nq * D}, acl_dt, query_out.data());
65+
k_out_cache_ =
66+
ascend::AclTensorCache({T, Nkv * D}, acl_dt, key_out.data());
67+
}
68+
69+
~Operator() {
70+
if (!ascend::isAclRuntimeAlive()) return;
71+
72+
positions_cache_.release();
73+
q_in_cache_.release();
74+
k_in_cache_.release();
75+
cos_sin_cache_cache_.release();
76+
q_out_cache_.release();
77+
k_out_cache_.release();
78+
}
79+
80+
Operator(const Operator&) = delete;
81+
82+
Operator& operator=(const Operator&) = delete;
83+
84+
void operator()(const Tensor positions, const Tensor query, const Tensor key,
85+
const Tensor cos_sin_cache, int64_t head_size,
86+
int64_t rotary_dim, bool is_neox_style, Tensor query_out,
87+
Tensor key_out) const override {
88+
auto stream = static_cast<aclrtStream>(stream_);
89+
90+
// Refresh cached descriptors with the current-call data pointers —
91+
// `Operator::call()` cache matches on shape/stride/dtype, so one
92+
// instance may serve multiple calls with different underlying buffers.
93+
auto t_pos = positions_cache_.get(const_cast<void*>(positions.data()));
94+
auto t_q = q_in_cache_.get(const_cast<void*>(query.data()));
95+
auto t_k = k_in_cache_.get(const_cast<void*>(key.data()));
96+
auto t_cache =
97+
cos_sin_cache_cache_.get(const_cast<void*>(cos_sin_cache.data()));
98+
auto t_q_out = q_out_cache_.get(query_out.data());
99+
auto t_k_out = k_out_cache_.get(key_out.data());
100+
101+
uint64_t ws_size = 0;
102+
aclOpExecutor* executor = nullptr;
103+
104+
auto ret = aclnnRopeWithSinCosCacheGetWorkspaceSize(
105+
t_pos, t_q, t_k, t_cache, /*mropeSection=*/nullptr, head_size,
106+
is_neox_style, t_q_out, t_k_out, &ws_size, &executor);
107+
assert(ret == 0 && "aclnnRopeWithSinCosCacheGetWorkspaceSize failed");
108+
109+
void* ws_buf = nullptr;
110+
111+
if (ws_size > 0) {
112+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size);
113+
ws_buf = arena.buf;
114+
}
115+
116+
ret = aclnnRopeWithSinCosCache(ws_buf, ws_size, executor, stream);
117+
assert(ret == 0 && "aclnnRopeWithSinCosCache failed");
118+
}
119+
120+
private:
121+
int64_t max_seq_len_;
122+
123+
mutable ascend::AclTensorCache positions_cache_;
124+
125+
mutable ascend::AclTensorCache q_in_cache_;
126+
127+
mutable ascend::AclTensorCache k_in_cache_;
128+
129+
mutable ascend::AclTensorCache cos_sin_cache_cache_;
130+
131+
mutable ascend::AclTensorCache q_out_cache_;
132+
133+
mutable ascend::AclTensorCache k_out_cache_;
134+
};
135+
136+
} // namespace infini::ops
137+
138+
#endif

tests/test_rotary_embedding.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -483,19 +483,27 @@ def test_rotary_embedding_partial(
483483
atol,
484484
device,
485485
):
486-
"""Partial rotary: ``rotary_dim < head_size``."""
486+
"""Partial rotary: ``rotary_dim < head_size`` via implementation_index=2.
487+
488+
Only `aclnnRopeWithSinCosCache` (impl=2) supports partial rotary among
489+
the Ascend fused APIs — V2 (impl=0) and ATB `RopeParam` (impl=1) both
490+
require `cos.D == sin.D == x.D`.
491+
"""
487492
if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()):
488493
pytest.skip("NPU not available")
489494

490495
if device == "npu":
491-
pytest.skip(
492-
"Partial rotary (`rotary_dim < head_size`) is not supported by "
493-
"any Ascend fused API: `aclnnApplyRotaryPosEmbV2`, "
494-
"`aclnnRotaryPositionEmbedding`, and ATB `RopeParam` all require "
495-
"`cos.D == sin.D == x.D`. A decomposed implementation is "
496-
"forbidden by project policy."
496+
active_indices = infini.ops.RotaryEmbedding.active_implementation_indices(
497+
device
497498
)
498499

500+
if 2 not in active_indices:
501+
pytest.skip(
502+
"`aclnnRopeWithSinCosCache` (implementation_index=2) not "
503+
"active on this build; it is the only Ascend fused API "
504+
"that supports partial rotary (`rotary_dim < head_size`)."
505+
)
506+
499507
num_tokens = 16
500508
max_seq_len = 64
501509

@@ -539,6 +547,7 @@ def test_rotary_embedding_partial(
539547
query_out,
540548
key_out,
541549
device,
550+
implementation_index=2,
542551
)
543552

544553
ref_q, ref_k = _ref_rotary_embedding(

0 commit comments

Comments
 (0)