Skip to content

Commit 592b493

Browse files
author
zhangyue
committed
feat(rotary_embedding): make query_out / key_out optional (inplace-default)
Align with vLLM's `RotaryEmbedding.forward(positions, query, key)` signature by letting callers omit the output buffers — the kernel then writes back in place on `query` / `key`. This removes a signature mismatch that forced vllm-infini to allocate and pass explicit out tensors it doesn't need. Base class signature: `query_out` / `key_out` → `std::optional<Tensor>` with `std::nullopt` default. Shape / stride members fall back to `query` / `key` when the optional is empty. All three Ascend impls resolve the optional to a concrete `Tensor` at the top of `operator()` via `value_or(query)`: - impl=0 (aclnn V2): skips the D2D memcpy in the inplace case since `query.data() == q_out.data()` - impl=1 (ATB RopeParam): same short-circuit on the D2D copy - impl=2 (aclnnRopeWithSinCosCache): descriptors reuse `q_out` / `k_out` pointers, so the kernel writes to whichever tensor is resolved Adds `test_rotary_embedding_inplace` covering both fp16 / bf16 on impl=0 and impl=1. Tolerance is atol=5e-3 — matches the V2 ~4 ULP fp16 accumulator error documented in `kernel.h`.
1 parent f757ed6 commit 592b493

5 files changed

Lines changed: 157 additions & 37 deletions

File tree

src/ascend/rotary_embedding/kernel.h

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cassert>
55
#include <cstddef>
66
#include <cstring>
7+
#include <optional>
78
#include <vector>
89

910
#include "acl/acl.h"
@@ -38,11 +39,17 @@ class Operator<RotaryEmbedding, Device::Type::kAscend>
3839
public:
3940
Operator(const Tensor positions, const Tensor query, const Tensor key,
4041
const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim,
41-
bool is_neox_style, Tensor query_out, Tensor key_out)
42+
bool is_neox_style,
43+
std::optional<Tensor> query_out = std::nullopt,
44+
std::optional<Tensor> key_out = std::nullopt)
4245
: RotaryEmbedding(positions, query, key, cos_sin_cache, head_size,
4346
rotary_dim, is_neox_style, query_out, key_out),
4447
max_seq_len_{cos_sin_cache.size(0)},
4548
elem_sz_{cos_sin_cache.element_size()} {
49+
// Resolve optional out buffers; when omitted, RoPE writes back in place
50+
// on `query` / `key` — vLLM-style inplace semantics.
51+
Tensor q_out = query_out.value_or(query);
52+
Tensor k_out = key_out.value_or(key);
4653
assert(rotary_dim == head_size &&
4754
"Ascend `RotaryEmbedding` requires rotary_dim == head_size "
4855
"(partial rotation not implemented in this wrapper)");
@@ -85,9 +92,9 @@ class Operator<RotaryEmbedding, Device::Type::kAscend>
8592
cos_v2_cache_ = ascend::AclTensorCache({T, 1, D}, acl_dt, cos_dev_);
8693
sin_v2_cache_ = ascend::AclTensorCache({T, 1, D}, acl_dt, sin_dev_);
8794
q_cache_ = ascend::AclTensorCache({T, Nq, D}, acl_dt,
88-
const_cast<void*>(query_out.data()));
95+
const_cast<void*>(q_out.data()));
8996
k_cache_ = ascend::AclTensorCache({T, Nkv, D}, acl_dt,
90-
const_cast<void*>(key_out.data()));
97+
const_cast<void*>(k_out.data()));
9198
}
9299

93100
~Operator() {
@@ -112,10 +119,16 @@ class Operator<RotaryEmbedding, Device::Type::kAscend>
112119

113120
void operator()(const Tensor positions, const Tensor query, const Tensor key,
114121
const Tensor cos_sin_cache, int64_t head_size,
115-
int64_t rotary_dim, bool is_neox_style, Tensor query_out,
116-
Tensor key_out) const override {
122+
int64_t rotary_dim, bool is_neox_style,
123+
std::optional<Tensor> query_out,
124+
std::optional<Tensor> key_out) const override {
117125
auto stream = static_cast<aclrtStream>(stream_);
118126

127+
// Resolve optional out buffers (inplace on `query` / `key` when omitted).
128+
// Non-const so `.data()` returns a writable `void*`.
129+
Tensor q_out = query_out.value_or(query);
130+
Tensor k_out = key_out.value_or(key);
131+
119132
const int64_t T = query.size(0);
120133
const int64_t Nq = num_heads_;
121134
const int64_t Nkv = num_kv_heads_;
@@ -162,15 +175,15 @@ class Operator<RotaryEmbedding, Device::Type::kAscend>
162175
// Step 2: Copy q→q_out, k→k_out if not inplace (V2 operates inplace).
163176
size_t elem_sz = query.element_size();
164177

165-
if (query.data() != query_out.data()) {
166-
aclrtMemcpyAsync(query_out.data(),
178+
if (query.data() != q_out.data()) {
179+
aclrtMemcpyAsync(q_out.data(),
167180
static_cast<size_t>(T * Nq * D) * elem_sz, query.data(),
168181
static_cast<size_t>(T * Nq * D) * elem_sz,
169182
ACL_MEMCPY_DEVICE_TO_DEVICE, stream);
170183
}
171184

172-
if (key.data() != key_out.data()) {
173-
aclrtMemcpyAsync(key_out.data(),
185+
if (key.data() != k_out.data()) {
186+
aclrtMemcpyAsync(k_out.data(),
174187
static_cast<size_t>(T * Nkv * D) * elem_sz, key.data(),
175188
static_cast<size_t>(T * Nkv * D) * elem_sz,
176189
ACL_MEMCPY_DEVICE_TO_DEVICE, stream);
@@ -179,17 +192,17 @@ class Operator<RotaryEmbedding, Device::Type::kAscend>
179192
// Step 3: Apply V2 RoPE inplace on q_out and k_out.
180193
auto t_cos = cos_v2_cache_.get(cos_dev_);
181194
auto t_sin = sin_v2_cache_.get(sin_dev_);
182-
auto t_q = q_cache_.get(query_out.data());
183-
auto t_k = k_cache_.get(key_out.data());
195+
auto t_q = q_cache_.get(q_out.data());
196+
auto t_k = k_cache_.get(k_out.data());
184197

185198
if (!v2_exec_) {
186199
aclnnApplyRotaryPosEmbV2GetWorkspaceSize(
187200
t_q, t_k, t_cos, t_sin, /*layout=*/4, const_cast<char*>("half"),
188201
&v2_ws_, &v2_exec_);
189202
aclSetAclOpExecutorRepeatable(v2_exec_);
190203
} else {
191-
aclSetInputTensorAddr(v2_exec_, 0, t_q, query_out.data());
192-
aclSetInputTensorAddr(v2_exec_, 1, t_k, key_out.data());
204+
aclSetInputTensorAddr(v2_exec_, 0, t_q, q_out.data());
205+
aclSetInputTensorAddr(v2_exec_, 1, t_k, k_out.data());
193206
}
194207

195208
auto& arena = ascend::GetWorkspacePool().Ensure(stream, v2_ws_);

src/ascend/rotary_embedding/kernel_atb.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cstddef>
88
#include <cstdint>
99
#include <cstring>
10+
#include <optional>
1011
#include <vector>
1112

1213
#include "acl/acl.h"
@@ -58,7 +59,9 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
5859
public:
5960
Operator(const Tensor positions, const Tensor query, const Tensor key,
6061
const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim,
61-
bool is_neox_style, Tensor query_out, Tensor key_out)
62+
bool is_neox_style,
63+
std::optional<Tensor> query_out = std::nullopt,
64+
std::optional<Tensor> key_out = std::nullopt)
6265
: RotaryEmbedding(positions, query, key, cos_sin_cache, head_size,
6366
rotary_dim, is_neox_style, query_out, key_out),
6467
is_neox_style_{is_neox_style} {
@@ -149,10 +152,16 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
149152

150153
void operator()(const Tensor positions, const Tensor query, const Tensor key,
151154
const Tensor cos_sin_cache, int64_t head_size,
152-
int64_t rotary_dim, bool is_neox_style, Tensor query_out,
153-
Tensor key_out) const override {
155+
int64_t rotary_dim, bool is_neox_style,
156+
std::optional<Tensor> query_out,
157+
std::optional<Tensor> key_out) const override {
154158
auto stream = static_cast<aclrtStream>(stream_);
155159

160+
// Resolve optional out buffers (inplace on `query` / `key` when omitted).
161+
// Non-const so `.data()` returns a writable `void*`.
162+
Tensor q_out = query_out.value_or(query);
163+
Tensor k_out = key_out.value_or(key);
164+
156165
int64_t T = query.size(0);
157166
int64_t D = head_size;
158167

@@ -202,15 +211,15 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
202211
// Step 2: Copy q->q_out, k->k_out if not in-place.
203212
size_t elem_sz = query.element_size();
204213

205-
if (query.data() != query_out.data()) {
206-
aclrtMemcpyAsync(query_out.data(),
214+
if (query.data() != q_out.data()) {
215+
aclrtMemcpyAsync(q_out.data(),
207216
static_cast<size_t>(T * hiddenQ) * elem_sz, query.data(),
208217
static_cast<size_t>(T * hiddenQ) * elem_sz,
209218
ACL_MEMCPY_DEVICE_TO_DEVICE, stream);
210219
}
211220

212-
if (key.data() != key_out.data()) {
213-
aclrtMemcpyAsync(key_out.data(),
221+
if (key.data() != k_out.data()) {
222+
aclrtMemcpyAsync(k_out.data(),
214223
static_cast<size_t>(T * hiddenK) * elem_sz, key.data(),
215224
static_cast<size_t>(T * hiddenK) * elem_sz,
216225
ACL_MEMCPY_DEVICE_TO_DEVICE, stream);
@@ -227,9 +236,9 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
227236
uint64_t gathered_bytes = static_cast<uint64_t>(T * D) * elem_size_;
228237

229238
atb::Tensor t_q =
230-
ascend::toAtbTensor(q_2d_shape_, acl_dt_, query_out.data(), q_bytes);
239+
ascend::toAtbTensor(q_2d_shape_, acl_dt_, q_out.data(), q_bytes);
231240
atb::Tensor t_k =
232-
ascend::toAtbTensor(k_2d_shape_, acl_dt_, key_out.data(), k_bytes);
241+
ascend::toAtbTensor(k_2d_shape_, acl_dt_, k_out.data(), k_bytes);
233242
atb::Tensor t_cos = ascend::toAtbTensor(cos_sin_gathered_shape_, acl_dt_,
234243
cos_dev_, gathered_bytes);
235244
atb::Tensor t_sin = ascend::toAtbTensor(cos_sin_gathered_shape_, acl_dt_,

src/ascend/rotary_embedding/kernel_sincos_cache.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <cassert>
55
#include <cstdint>
6+
#include <optional>
67

78
#include "acl/acl.h"
89
#include "aclnn/aclnn_base.h"
@@ -41,10 +42,17 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 2>
4142
public:
4243
Operator(const Tensor positions, const Tensor query, const Tensor key,
4344
const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim,
44-
bool is_neox_style, Tensor query_out, Tensor key_out)
45+
bool is_neox_style,
46+
std::optional<Tensor> query_out = std::nullopt,
47+
std::optional<Tensor> key_out = std::nullopt)
4548
: RotaryEmbedding(positions, query, key, cos_sin_cache, head_size,
4649
rotary_dim, is_neox_style, query_out, key_out),
4750
max_seq_len_{cos_sin_cache.size(0)} {
51+
// Resolve optional out buffers (inplace on `query` / `key` when omitted).
52+
// Non-const so `.data()` returns a writable `void*`.
53+
Tensor q_out = query_out.value_or(query);
54+
Tensor k_out = key_out.value_or(key);
55+
4856
const int64_t T = num_tokens_;
4957
const int64_t Nq = num_heads_;
5058
const int64_t Nkv = num_kv_heads_;
@@ -61,9 +69,9 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 2>
6169
{max_seq_len_, rotary_dim_}, acl_dt,
6270
const_cast<void*>(cos_sin_cache.data()));
6371
q_out_cache_ =
64-
ascend::AclTensorCache({T, Nq * D}, acl_dt, query_out.data());
72+
ascend::AclTensorCache({T, Nq * D}, acl_dt, q_out.data());
6573
k_out_cache_ =
66-
ascend::AclTensorCache({T, Nkv * D}, acl_dt, key_out.data());
74+
ascend::AclTensorCache({T, Nkv * D}, acl_dt, k_out.data());
6775
}
6876

6977
~Operator() {
@@ -83,10 +91,15 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 2>
8391

8492
void operator()(const Tensor positions, const Tensor query, const Tensor key,
8593
const Tensor cos_sin_cache, int64_t head_size,
86-
int64_t rotary_dim, bool is_neox_style, Tensor query_out,
87-
Tensor key_out) const override {
94+
int64_t rotary_dim, bool is_neox_style,
95+
std::optional<Tensor> query_out,
96+
std::optional<Tensor> key_out) const override {
8897
auto stream = static_cast<aclrtStream>(stream_);
8998

99+
// Resolve optional out buffers (inplace on `query` / `key` when omitted).
100+
Tensor q_out = query_out.value_or(query);
101+
Tensor k_out = key_out.value_or(key);
102+
90103
// Refresh cached descriptors with the current-call data pointers —
91104
// `Operator::call()` cache matches on shape/stride/dtype, so one
92105
// instance may serve multiple calls with different underlying buffers.
@@ -95,8 +108,8 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 2>
95108
auto t_k = k_in_cache_.get(const_cast<void*>(key.data()));
96109
auto t_cache =
97110
cos_sin_cache_cache_.get(const_cast<void*>(cos_sin_cache.data()));
98-
auto t_q_out = q_out_cache_.get(query_out.data());
99-
auto t_k_out = k_out_cache_.get(key_out.data());
111+
auto t_q_out = q_out_cache_.get(const_cast<void*>(q_out.data()));
112+
auto t_k_out = k_out_cache_.get(const_cast<void*>(k_out.data()));
100113

101114
uint64_t ws_size = 0;
102115
aclOpExecutor* executor = nullptr;

src/base/rotary_embedding.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define INFINI_OPS_BASE_ROTARY_EMBEDDING_H_
33

44
#include <cstddef>
5+
#include <optional>
56
#include <vector>
67

78
#include "operator.h"
@@ -13,10 +14,17 @@ class RotaryEmbedding : public Operator<RotaryEmbedding> {
1314
// Accepts 2D `[T, N*D]` (vLLM convention) or 3D `[T, N, D]`.
1415
// `num_heads_` and `num_kv_heads_` are derived from `numel / (T *
1516
// head_size)`.
17+
//
18+
// `query_out` / `key_out` are optional. When omitted, the kernel writes
19+
// back into `query` / `key` — matching vLLM's inplace
20+
// `RotaryEmbedding.forward(positions, query, key)` signature. Pass
21+
// explicit out buffers only when the caller needs a separate
22+
// destination.
1623
RotaryEmbedding(const Tensor positions, const Tensor query, const Tensor key,
1724
const Tensor cos_sin_cache, int64_t head_size,
18-
int64_t rotary_dim, bool is_neox_style, Tensor query_out,
19-
Tensor key_out)
25+
int64_t rotary_dim, bool is_neox_style,
26+
std::optional<Tensor> query_out = std::nullopt,
27+
std::optional<Tensor> key_out = std::nullopt)
2028
: num_tokens_{query.size(0)},
2129
num_heads_{static_cast<int64_t>(query.numel()) /
2230
(static_cast<int64_t>(query.size(0)) * head_size)},
@@ -28,12 +36,12 @@ class RotaryEmbedding : public Operator<RotaryEmbedding> {
2836
query_shape_{query.shape()},
2937
key_shape_{key.shape()},
3038
cos_sin_cache_shape_{cos_sin_cache.shape()},
31-
query_out_shape_{query_out.shape()},
32-
key_out_shape_{key_out.shape()},
39+
query_out_shape_{query_out.value_or(query).shape()},
40+
key_out_shape_{key_out.value_or(key).shape()},
3341
query_strides_{query.strides()},
3442
key_strides_{key.strides()},
35-
query_out_strides_{query_out.strides()},
36-
key_out_strides_{key_out.strides()} {
43+
query_out_strides_{query_out.value_or(query).strides()},
44+
key_out_strides_{key_out.value_or(key).strides()} {
3745
assert(
3846
(query.ndim() == 2 || query.ndim() == 3) &&
3947
"`RotaryEmbedding` requires query to be 2D [T, N*D] or 3D [T, N, D]");
@@ -47,8 +55,10 @@ class RotaryEmbedding : public Operator<RotaryEmbedding> {
4755
virtual void operator()(const Tensor positions, const Tensor query,
4856
const Tensor key, const Tensor cos_sin_cache,
4957
int64_t head_size, int64_t rotary_dim,
50-
bool is_neox_style, Tensor query_out,
51-
Tensor key_out) const = 0;
58+
bool is_neox_style,
59+
std::optional<Tensor> query_out = std::nullopt,
60+
std::optional<Tensor> key_out = std::nullopt)
61+
const = 0;
5262

5363
protected:
5464
Tensor::Size num_tokens_{0};

tests/test_rotary_embedding.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,3 +562,78 @@ def test_rotary_embedding_partial(
562562

563563
_assert_close(q_out, ref_q, rtol, atol)
564564
_assert_close(k_out, ref_k, rtol, atol)
565+
566+
567+
@pytest.mark.parametrize("implementation_index", (0, 1))
568+
@pytest.mark.parametrize(
569+
("dtype", "rtol", "atol"),
570+
(
571+
# V2 accumulates ~4 ULP error in fp16 (kernel.h doc: max diff ~0.008);
572+
# ATB `RopeParam` is similar. Use atol=5e-3 for honest headroom.
573+
(torch.float16, 1e-2, 5e-3),
574+
(torch.bfloat16, 1e-2, 5e-3),
575+
),
576+
)
577+
@pytest.mark.parametrize("device", ("npu",))
578+
def test_rotary_embedding_inplace(implementation_index, dtype, rtol, atol, device):
579+
"""Verify the inplace path (`query_out` / `key_out` omitted).
580+
581+
Matches vLLM's `RotaryEmbedding.forward(positions, query, key)`
582+
convention where the op mutates `query` / `key` directly.
583+
"""
584+
if not (hasattr(torch, "npu") and torch.npu.is_available()):
585+
pytest.skip("NPU not available")
586+
587+
active_indices = infini.ops.RotaryEmbedding.active_implementation_indices(device)
588+
589+
if implementation_index not in active_indices:
590+
pytest.skip(
591+
f"Implementation index={implementation_index} not active on this build"
592+
)
593+
594+
num_tokens = 4
595+
num_heads = 8
596+
num_kv_heads = 8
597+
head_size = 64
598+
rotary_dim = head_size
599+
max_seq_len = 32
600+
601+
positions = randint_strided(
602+
0, max_seq_len, (num_tokens,), None, dtype=torch.int64, device=device
603+
)
604+
query = randn_strided(
605+
(num_tokens, num_heads, head_size), None, dtype=dtype, device=device
606+
)
607+
key = randn_strided(
608+
(num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device
609+
)
610+
cos_sin_cache = randn_strided(
611+
(max_seq_len, rotary_dim), None, dtype=dtype, device=device
612+
)
613+
614+
# Reference: apply RoPE to clones of the original inputs.
615+
ref_q, ref_k = _ref_rotary_embedding(
616+
positions,
617+
query.clone(),
618+
key.clone(),
619+
cos_sin_cache,
620+
head_size,
621+
rotary_dim,
622+
is_neox_style=True,
623+
)
624+
625+
# Inplace call — no `query_out` / `key_out` supplied.
626+
infini.ops.rotary_embedding(
627+
positions,
628+
query,
629+
key,
630+
cos_sin_cache,
631+
head_size,
632+
rotary_dim,
633+
True,
634+
implementation_index=implementation_index,
635+
stream=get_npu_stream(query),
636+
)
637+
638+
_assert_close(query, ref_q, rtol, atol)
639+
_assert_close(key, ref_k, rtol, atol)

0 commit comments

Comments
 (0)