|
| 1 | +#ifndef INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_SINCOS_CACHE_H_ |
| 2 | +#define INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_SINCOS_CACHE_H_ |
| 3 | + |
| 4 | +#include <cassert> |
| 5 | +#include <cstdint> |
| 6 | + |
| 7 | +#include "acl/acl.h" |
| 8 | +#include "aclnn/aclnn_base.h" |
| 9 | +#include "aclnnop/aclnn_rope_with_sin_cos_cache.h" |
| 10 | +#include "ascend/common.h" |
| 11 | +#include "ascend/workspace_pool_.h" |
| 12 | +#include "base/rotary_embedding.h" |
| 13 | +#include "operator.h" |
| 14 | + |
| 15 | +namespace infini::ops { |
| 16 | + |
| 17 | +// Rotary position embedding via `aclnnRopeWithSinCosCache` (implementation |
| 18 | +// index 2). This is the only Ascend fused rotary API that supports partial |
| 19 | +// rotary (`rotary_dim < head_size`); it also natively supports both |
| 20 | +// GPT-NeoX (`is_neox_style=true`) and GPT-J (`is_neox_style=false`) styles |
| 21 | +// from the same interface. |
| 22 | +// |
| 23 | +// Input format: 2D contiguous `[num_tokens, num_heads * head_size]`. The |
| 24 | +// aclnn wrapper reads strides from the tensor descriptor — we pass a 2D |
| 25 | +// descriptor even when the caller holds a 3D view `[T, N, D]`, since the |
| 26 | +// memory layout is identical for contiguous tensors. The 2D descriptor is |
| 27 | +// what the aclnn sample in the CANN 8.5 docs uses. |
| 28 | +// |
| 29 | +// `cos_sin_cache` layout: `[max_seq_len, rotary_dim]` where the first |
| 30 | +// `rotary_dim / 2` columns are cos and the next `rotary_dim / 2` are sin. |
| 31 | +// The aclnn API splits internally via `cosSin.chunk(2, dim=-1)`. |
| 32 | +// |
| 33 | +// cf. `aclnn_rope_with_sin_cos_cache_hidden_attrs` memory: the public |
| 34 | +// header hides four `REG_OP` attrs (`numQHeads`, `numKHeads`, `qStride`, |
| 35 | +// `kStride`). For 2D contiguous inputs the aclnn wrapper infers them |
| 36 | +// correctly from the tensor descriptor; for 3D descriptors a previous |
| 37 | +// attempt produced garbage output. |
| 38 | +template <> |
| 39 | +class Operator<RotaryEmbedding, Device::Type::kAscend, 2> |
| 40 | + : public RotaryEmbedding { |
| 41 | + public: |
| 42 | + Operator(const Tensor positions, const Tensor query, const Tensor key, |
| 43 | + const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim, |
| 44 | + bool is_neox_style, Tensor query_out, Tensor key_out) |
| 45 | + : RotaryEmbedding(positions, query, key, cos_sin_cache, head_size, |
| 46 | + rotary_dim, is_neox_style, query_out, key_out), |
| 47 | + max_seq_len_{cos_sin_cache.size(0)} { |
| 48 | + const int64_t T = num_tokens_; |
| 49 | + const int64_t Nq = num_heads_; |
| 50 | + const int64_t Nkv = num_kv_heads_; |
| 51 | + const int64_t D = head_size_; |
| 52 | + aclDataType acl_dt = ascend::toAclDtype(query.dtype()); |
| 53 | + |
| 54 | + positions_cache_ = ascend::AclTensorCache( |
| 55 | + {T}, ACL_INT64, const_cast<void*>(positions.data())); |
| 56 | + q_in_cache_ = ascend::AclTensorCache( |
| 57 | + {T, Nq * D}, acl_dt, const_cast<void*>(query.data())); |
| 58 | + k_in_cache_ = ascend::AclTensorCache( |
| 59 | + {T, Nkv * D}, acl_dt, const_cast<void*>(key.data())); |
| 60 | + cos_sin_cache_cache_ = ascend::AclTensorCache( |
| 61 | + {max_seq_len_, rotary_dim_}, acl_dt, |
| 62 | + const_cast<void*>(cos_sin_cache.data())); |
| 63 | + q_out_cache_ = |
| 64 | + ascend::AclTensorCache({T, Nq * D}, acl_dt, query_out.data()); |
| 65 | + k_out_cache_ = |
| 66 | + ascend::AclTensorCache({T, Nkv * D}, acl_dt, key_out.data()); |
| 67 | + } |
| 68 | + |
| 69 | + ~Operator() { |
| 70 | + if (!ascend::isAclRuntimeAlive()) return; |
| 71 | + |
| 72 | + positions_cache_.release(); |
| 73 | + q_in_cache_.release(); |
| 74 | + k_in_cache_.release(); |
| 75 | + cos_sin_cache_cache_.release(); |
| 76 | + q_out_cache_.release(); |
| 77 | + k_out_cache_.release(); |
| 78 | + } |
| 79 | + |
| 80 | + Operator(const Operator&) = delete; |
| 81 | + |
| 82 | + Operator& operator=(const Operator&) = delete; |
| 83 | + |
| 84 | + void operator()(const Tensor positions, const Tensor query, const Tensor key, |
| 85 | + const Tensor cos_sin_cache, int64_t head_size, |
| 86 | + int64_t rotary_dim, bool is_neox_style, Tensor query_out, |
| 87 | + Tensor key_out) const override { |
| 88 | + auto stream = static_cast<aclrtStream>(stream_); |
| 89 | + |
| 90 | + // Refresh cached descriptors with the current-call data pointers — |
| 91 | + // `Operator::call()` cache matches on shape/stride/dtype, so one |
| 92 | + // instance may serve multiple calls with different underlying buffers. |
| 93 | + auto t_pos = positions_cache_.get(const_cast<void*>(positions.data())); |
| 94 | + auto t_q = q_in_cache_.get(const_cast<void*>(query.data())); |
| 95 | + auto t_k = k_in_cache_.get(const_cast<void*>(key.data())); |
| 96 | + auto t_cache = |
| 97 | + cos_sin_cache_cache_.get(const_cast<void*>(cos_sin_cache.data())); |
| 98 | + auto t_q_out = q_out_cache_.get(query_out.data()); |
| 99 | + auto t_k_out = k_out_cache_.get(key_out.data()); |
| 100 | + |
| 101 | + uint64_t ws_size = 0; |
| 102 | + aclOpExecutor* executor = nullptr; |
| 103 | + |
| 104 | + auto ret = aclnnRopeWithSinCosCacheGetWorkspaceSize( |
| 105 | + t_pos, t_q, t_k, t_cache, /*mropeSection=*/nullptr, head_size, |
| 106 | + is_neox_style, t_q_out, t_k_out, &ws_size, &executor); |
| 107 | + assert(ret == 0 && "aclnnRopeWithSinCosCacheGetWorkspaceSize failed"); |
| 108 | + |
| 109 | + void* ws_buf = nullptr; |
| 110 | + |
| 111 | + if (ws_size > 0) { |
| 112 | + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size); |
| 113 | + ws_buf = arena.buf; |
| 114 | + } |
| 115 | + |
| 116 | + ret = aclnnRopeWithSinCosCache(ws_buf, ws_size, executor, stream); |
| 117 | + assert(ret == 0 && "aclnnRopeWithSinCosCache failed"); |
| 118 | + } |
| 119 | + |
| 120 | + private: |
| 121 | + int64_t max_seq_len_; |
| 122 | + |
| 123 | + mutable ascend::AclTensorCache positions_cache_; |
| 124 | + |
| 125 | + mutable ascend::AclTensorCache q_in_cache_; |
| 126 | + |
| 127 | + mutable ascend::AclTensorCache k_in_cache_; |
| 128 | + |
| 129 | + mutable ascend::AclTensorCache cos_sin_cache_cache_; |
| 130 | + |
| 131 | + mutable ascend::AclTensorCache q_out_cache_; |
| 132 | + |
| 133 | + mutable ascend::AclTensorCache k_out_cache_; |
| 134 | +}; |
| 135 | + |
| 136 | +} // namespace infini::ops |
| 137 | + |
| 138 | +#endif |
0 commit comments