You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat(rotary_embedding): make query_out / key_out optional (inplace-default)
Align with vLLM's `RotaryEmbedding.forward(positions, query, key)`
signature by letting callers omit the output buffers — the kernel then
writes back in place on `query` / `key`. This removes a signature
mismatch that forced vllm-infini to allocate and pass explicit out
tensors it doesn't need.
Base class signature:
`query_out` / `key_out` → `std::optional<Tensor>` with `std::nullopt`
default. Shape / stride members fall back to `query` / `key` when the
optional is empty.
All three Ascend impls resolve the optional to a concrete `Tensor` at
the top of `operator()` via `value_or(query)`:
- impl=0 (aclnn V2): skips the D2D memcpy in the inplace case
since `query.data() == q_out.data()`
- impl=1 (ATB RopeParam): same short-circuit on the D2D copy
- impl=2 (aclnnRopeWithSinCosCache): descriptors reuse `q_out` /
`k_out` pointers, so the kernel writes to
whichever tensor is resolved
Adds `test_rotary_embedding_inplace` covering both fp16 / bf16 on
impl=0 and impl=1. Tolerance is atol=5e-3 — matches the V2 ~4 ULP
fp16 accumulator error documented in `kernel.h`.
0 commit comments