Skip to content

Commit ccc7b5d

Browse files
author
zhangyue
committed
feat(ascend): support non-neox rotaryMode via ATB RopeParam rotaryCoeff
RotaryEmbedding impl=1 (ATB Rope) now plumbs both rotary styles: - is_neox_style=true -> rotaryCoeff=2 (half split + cat) - is_neox_style=false -> rotaryCoeff=head_size (interleave) The cos/sin expand path also branches: neox layout duplicates the half values front/back, while interleave layout repeats each value pair-wise. Test skip is narrowed to impl=0 only, which still uses aclnnApplyRotaryPosEmbV2 (declares "interleave" but only implements "half"). G (partial rotary) skip message updated to reflect that neither aclnn nor ATB fused APIs support rotary_dim < head_size.
1 parent 222ea13 commit ccc7b5d

2 files changed

Lines changed: 68 additions & 25 deletions

File tree

src/ascend/rotary_embedding/kernel_atb.h

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,14 @@ namespace infini::ops {
4444
// gathered `[T, D]` tensors to ATB Rope. The `seqlen` input is a single
4545
// int32 element equal to T (all tokens treated as one batch).
4646
//
47-
// Restrictions (implementation choices, not ATB API limits):
47+
// Restrictions:
4848
// - `rotary_dim` must equal `head_size` (full rotation only). ATB
4949
// RopeParam supports `rotaryCoeff=2/4/head_size/head_size_2` per the
50-
// CANN 8.5 ATB docs; this wrapper plumbs only `rotaryCoeff=2`.
51-
// - `is_neox_style` must be true.
50+
// CANN 8.5 ATB docs. This wrapper plumbs:
51+
// * `rotaryCoeff=2` when `is_neox_style=true` (half split + cat)
52+
// * `rotaryCoeff=head_size` when `is_neox_style=false` (interleave)
53+
// Partial rotary (`rotary_dim < head_size`) is not supported by either
54+
// the aclnn or ATB fused APIs; callers must pad to `head_size` upstream.
5255
template <>
5356
class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
5457
: public RotaryEmbedding {
@@ -57,11 +60,10 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
5760
const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim,
5861
bool is_neox_style, Tensor query_out, Tensor key_out)
5962
: RotaryEmbedding(positions, query, key, cos_sin_cache, head_size,
60-
rotary_dim, is_neox_style, query_out, key_out) {
63+
rotary_dim, is_neox_style, query_out, key_out),
64+
is_neox_style_{is_neox_style} {
6165
assert(rotary_dim == head_size &&
6266
"ATB `RotaryEmbedding` requires rotary_dim == head_size");
63-
assert(is_neox_style &&
64-
"ATB `RotaryEmbedding` requires neox style (rotaryCoeff=2)");
6567

6668
const int64_t D = head_size_;
6769
const size_t elem_sz = cos_sin_cache.element_size();
@@ -110,10 +112,14 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
110112
cos_out_cache_ = ascend::AclTensorCache({T, D}, acl_dt_, cos_dev_);
111113
sin_out_cache_ = ascend::AclTensorCache({T, D}, acl_dt_, sin_dev_);
112114

113-
// Create the ATB Rope operation.
115+
// Create the ATB Rope operation. `rotaryCoeff` selects the rotation
116+
// pattern: 2 for neox (split-then-rotate halves), `head_size` for
117+
// interleave (pair-wise rotate adjacent elements).
114118
atb::infer::RopeParam param;
115-
param.rotaryCoeff = 2; // Neox half-rotation.
116-
param.cosFormat = 0; // Inference mode.
119+
param.rotaryCoeff = is_neox_style
120+
? 2
121+
: static_cast<int32_t>(D);
122+
param.cosFormat = 0; // Inference mode.
117123
atb::Status s = atb::CreateOperation(param, &op_);
118124

119125
assert(s == atb::NO_ERROR && "atb::CreateOperation(Rope) failed");
@@ -254,8 +260,16 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
254260
}
255261

256262
private:
257-
// D2H copy cos_sin_cache, split into cos/sin, neox-expand, and upload to
258-
// device. Called once at construction.
263+
// D2H copy cos_sin_cache, split into cos/sin, expand to `[max_seq_len, D]`
264+
// in the layout that ATB Rope expects for the chosen `rotaryCoeff`, and
265+
// upload to device. Called once at construction.
266+
//
267+
// For `rotaryCoeff=2` (neox): cos tensor holds the same `half_D` values
268+
// duplicated front/back — `[c0 .. c_{half-1}, c0 .. c_{half-1}]`.
269+
//
270+
// For `rotaryCoeff=head_size` (interleave): cos tensor holds each of the
271+
// `half_D` values repeated pair-wise —
272+
// `[c0, c0, c1, c1, .., c_{half-1}, c_{half-1}]`.
259273
void uploadCosSinCache(const Tensor cos_sin_cache) const {
260274
const int64_t D = head_size_;
261275
const int64_t half_D = D / 2;
@@ -277,16 +291,35 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
277291
const auto* s_src = cache_host.data() +
278292
static_cast<size_t>(p * D + half_D + j) * elem_sz;
279293

280-
std::memcpy(cos_host.data() + static_cast<size_t>(p * D + j) * elem_sz,
281-
c_src, elem_sz);
282-
std::memcpy(
283-
cos_host.data() + static_cast<size_t>(p * D + half_D + j) * elem_sz,
284-
c_src, elem_sz);
285-
std::memcpy(sin_host.data() + static_cast<size_t>(p * D + j) * elem_sz,
286-
s_src, elem_sz);
287-
std::memcpy(
288-
sin_host.data() + static_cast<size_t>(p * D + half_D + j) * elem_sz,
289-
s_src, elem_sz);
294+
if (is_neox_style_) {
295+
// Neox layout: [c_j ... , c_j ...] front/back duplication.
296+
std::memcpy(
297+
cos_host.data() + static_cast<size_t>(p * D + j) * elem_sz,
298+
c_src, elem_sz);
299+
std::memcpy(
300+
cos_host.data() + static_cast<size_t>(p * D + half_D + j) * elem_sz,
301+
c_src, elem_sz);
302+
std::memcpy(
303+
sin_host.data() + static_cast<size_t>(p * D + j) * elem_sz,
304+
s_src, elem_sz);
305+
std::memcpy(
306+
sin_host.data() + static_cast<size_t>(p * D + half_D + j) * elem_sz,
307+
s_src, elem_sz);
308+
} else {
309+
// Interleave layout: each value repeated pair-wise.
310+
std::memcpy(
311+
cos_host.data() + static_cast<size_t>(p * D + 2 * j) * elem_sz,
312+
c_src, elem_sz);
313+
std::memcpy(
314+
cos_host.data() + static_cast<size_t>(p * D + 2 * j + 1) * elem_sz,
315+
c_src, elem_sz);
316+
std::memcpy(
317+
sin_host.data() + static_cast<size_t>(p * D + 2 * j) * elem_sz,
318+
s_src, elem_sz);
319+
std::memcpy(
320+
sin_host.data() + static_cast<size_t>(p * D + 2 * j + 1) * elem_sz,
321+
s_src, elem_sz);
322+
}
290323
}
291324
}
292325

@@ -296,6 +329,8 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
296329
ACL_MEMCPY_HOST_TO_DEVICE);
297330
}
298331

332+
bool is_neox_style_;
333+
299334
atb::Operation* op_ = nullptr;
300335

301336
// Neox-expanded cos/sin tables on device: [max_seq_len, D].

tests/test_rotary_embedding.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,12 @@ def test_rotary_embedding_full(
164164
f"Implementation index={implementation_index} not active on this build"
165165
)
166166

167-
if device == "npu" and not is_neox_style:
167+
# Only implementation 0 (`aclnnApplyRotaryPosEmbV2`) is still limited to
168+
# `rotaryMode="half"`; implementation 1 (ATB `RopeParam`) plumbs
169+
# `rotaryCoeff=head_size` for the non-neox (interleave) case.
170+
if device == "npu" and not is_neox_style and implementation_index == 0:
168171
pytest.skip(
169-
'Ascend `RotaryEmbedding` wrappers only plumb `rotaryMode="half"` '
170-
"through the underlying V2/ATB APIs."
172+
'Ascend `aclnnApplyRotaryPosEmbV2` only supports `rotaryMode="half"`'
171173
)
172174

173175
# `aclnnApplyRotaryPosEmbV2` accumulates with ~4 ULP error for float16.
@@ -486,7 +488,13 @@ def test_rotary_embedding_partial(
486488
pytest.skip("NPU not available")
487489

488490
if device == "npu":
489-
pytest.skip("Ascend aclnnApplyRotaryPosEmbV2 requires rotary_dim == head_size")
491+
pytest.skip(
492+
"Partial rotary (`rotary_dim < head_size`) is not supported by "
493+
"any Ascend fused API: `aclnnApplyRotaryPosEmbV2`, "
494+
"`aclnnRotaryPositionEmbedding`, and ATB `RopeParam` all require "
495+
"`cos.D == sin.D == x.D`. A decomposed implementation is "
496+
"forbidden by project policy."
497+
)
490498

491499
num_tokens = 16
492500
max_seq_len = 64

0 commit comments

Comments
 (0)