Skip to content

Commit 0c7a31a

Browse files
author
zhangyue
committed
feat(ascend): add Ascend operator kernels for all operators
Add ACLNN-based implementations for: Add, Cast, Cat, CausalSoftmax, FlashAttention, Linear, Matmul, Mul, RmsNorm, RotaryEmbedding, ReshapeAndCache (+ v2), Swiglu, SiluAndMul. All kernels use AclTensorCache for descriptor reuse and WorkspacePool for device memory management. Executor instances are cached with aclSetAclOpExecutorRepeatable for repeat dispatch.
1 parent 5e00c34 commit 0c7a31a

15 files changed

Lines changed: 1932 additions & 0 deletions

File tree

src/ascend/add/kernel.h

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#ifndef INFINI_OPS_ASCEND_ADD_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_ADD_KERNEL_H_
3+
4+
#include "acl/acl.h"
5+
#include "aclnn/aclnn_base.h"
6+
#include "aclnn_add.h"
7+
#include "ascend/common.h"
8+
#include "ascend/workspace_pool_.h"
9+
#include "base/add.h"
10+
#include "data_type.h"
11+
#include "operator.h"
12+
13+
namespace infini::ops {
14+
15+
template <>
16+
class Operator<Add, Device::Type::kAscend> : public Add {
17+
public:
18+
Operator(const Tensor input, const Tensor other, Tensor out)
19+
: Add(input, other, out),
20+
in_cache_(input),
21+
oth_cache_(other),
22+
out_cache_(out) {
23+
// aclCreateScalar stores the pointer rather than copying the value, so
24+
// alpha_storage_* must remain alive for the lifetime of alpha_.
25+
// The alpha scalar type must match the tensor dtype: use int64 for integer
26+
// dtypes and float for floating-point dtypes.
27+
if (ascend::isIntegerDtype(input.dtype())) {
28+
alpha_ = aclCreateScalar(&alpha_int_storage_, ACL_INT64);
29+
} else {
30+
alpha_ = aclCreateScalar(&alpha_float_storage_, ACL_FLOAT);
31+
}
32+
}
33+
34+
~Operator() {
35+
if (executor_) aclDestroyAclOpExecutor(executor_);
36+
aclDestroyScalar(alpha_);
37+
}
38+
39+
void operator()(const Tensor input, const Tensor other,
40+
Tensor out) const override {
41+
auto stream = static_cast<aclrtStream>(stream_);
42+
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
43+
auto t_oth = oth_cache_.get(const_cast<void*>(other.data()));
44+
auto t_out = out_cache_.get(out.data());
45+
46+
if (!executor_) {
47+
aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_size_,
48+
&executor_);
49+
aclSetAclOpExecutorRepeatable(executor_);
50+
} else {
51+
aclSetInputTensorAddr(executor_, 0, t_in,
52+
const_cast<void*>(input.data()));
53+
aclSetInputTensorAddr(executor_, 1, t_oth,
54+
const_cast<void*>(other.data()));
55+
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
56+
}
57+
58+
auto& arena = ascend::workspacePool().ensure(stream, ws_size_);
59+
aclnnAdd(arena.buf, ws_size_, executor_, stream);
60+
}
61+
62+
private:
63+
mutable ascend::AclTensorCache in_cache_;
64+
65+
mutable ascend::AclTensorCache oth_cache_;
66+
67+
mutable ascend::AclTensorCache out_cache_;
68+
69+
mutable aclOpExecutor* executor_ = nullptr;
70+
71+
mutable uint64_t ws_size_ = 0;
72+
73+
float alpha_float_storage_ =
74+
1.0f; // stable address for aclCreateScalar (float)
75+
int64_t alpha_int_storage_ = 1; // stable address for aclCreateScalar (int)
76+
aclScalar* alpha_ = nullptr;
77+
};
78+
79+
} // namespace infini::ops
80+
81+
#endif

src/ascend/atb_common_.h

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#ifndef INFINI_OPS_ASCEND_ATB_COMMON__H_
2+
#define INFINI_OPS_ASCEND_ATB_COMMON__H_
3+
4+
#ifdef INFINI_HAS_ATB
5+
6+
#include <cassert>
7+
#include <cstdint>
8+
#include <mutex>
9+
#include <vector>
10+
11+
#include "acl/acl.h"
12+
#include "atb/context.h"
13+
#include "atb/operation.h"
14+
#include "atb/types.h"
15+
#include "ascend/data_type_.h"
16+
#include "tensor.h"
17+
18+
namespace infini::ops::ascend {
19+
20+
// Thread-local ATB context.
21+
//
22+
// ATB requires a `Context` for Setup/Execute. Creating one per call is
23+
// expensive (internal tiling buffer allocation), so we cache one per thread.
24+
// `SetExecuteStream` is called before every `Execute` to match the caller's
25+
// stream.
26+
inline atb::Context*& threadLocalAtbContext() {
27+
thread_local atb::Context* ctx = nullptr;
28+
29+
return ctx;
30+
}
31+
32+
inline atb::Context* getAtbContext(aclrtStream stream) {
33+
auto*& ctx = threadLocalAtbContext();
34+
35+
if (!ctx) {
36+
atb::Status s = atb::CreateContext(&ctx);
37+
assert(s == atb::NO_ERROR && "atb::CreateContext failed");
38+
}
39+
40+
atb::Status s = ctx->SetExecuteStream(stream);
41+
assert(s == atb::NO_ERROR && "atb::Context::SetExecuteStream failed");
42+
43+
return ctx;
44+
}
45+
46+
// Build an `atb::Tensor` from an InfiniOps Tensor.
47+
//
48+
// Sets dtype, ND format, shape dimensions, and the device data pointer.
49+
// The caller must keep the InfiniOps Tensor alive for the duration of the
50+
// ATB operation.
51+
inline atb::Tensor toAtbTensor(const Tensor& t) {
52+
atb::Tensor out;
53+
out.desc.dtype = toAclDtype(t.dtype());
54+
out.desc.format = ACL_FORMAT_ND;
55+
out.desc.shape.dimNum = t.ndim();
56+
assert(t.ndim() <= atb::MAX_DIM);
57+
58+
for (uint64_t i = 0; i < t.ndim(); ++i) {
59+
out.desc.shape.dims[i] = static_cast<int64_t>(t.size(i));
60+
}
61+
62+
out.deviceData = const_cast<void*>(t.data());
63+
out.dataSize = static_cast<uint64_t>(t.numel()) * t.element_size();
64+
65+
return out;
66+
}
67+
68+
// Build an `atb::Tensor` from explicit shape, dtype, and data pointer.
69+
//
70+
// Useful for sub-views of a larger buffer (e.g. K-cache and V-cache halves
71+
// of a fused KV cache tensor).
72+
inline atb::Tensor toAtbTensor(const std::vector<int64_t>& shape,
73+
aclDataType dtype, void* data,
74+
uint64_t data_size) {
75+
atb::Tensor out;
76+
out.desc.dtype = dtype;
77+
out.desc.format = ACL_FORMAT_ND;
78+
out.desc.shape.dimNum = shape.size();
79+
assert(shape.size() <= atb::MAX_DIM);
80+
81+
for (size_t i = 0; i < shape.size(); ++i) {
82+
out.desc.shape.dims[i] = shape[i];
83+
}
84+
85+
out.deviceData = data;
86+
out.dataSize = data_size;
87+
88+
return out;
89+
}
90+
91+
} // namespace infini::ops::ascend
92+
93+
#endif // INFINI_HAS_ATB
94+
95+
#endif // INFINI_OPS_ASCEND_ATB_COMMON__H_

src/ascend/cast/kernel.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#ifndef INFINI_OPS_ASCEND_CAST_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_CAST_KERNEL_H_
3+
4+
#include "acl/acl.h"
5+
#include "aclnn/aclnn_base.h"
6+
#include "aclnnop/aclnn_cast.h"
7+
#include "ascend/common.h"
8+
#include "ascend/workspace_pool_.h"
9+
#include "base/cast.h"
10+
#include "operator.h"
11+
12+
namespace infini::ops {
13+
14+
template <>
15+
class Operator<Cast, Device::Type::kAscend> : public Cast {
16+
public:
17+
Operator(const Tensor input, Tensor out)
18+
: Cast(input, out),
19+
in_cache_(input),
20+
out_cache_(out),
21+
acl_out_dtype_(ascend::toAclDtype(out.dtype())) {}
22+
23+
~Operator() {
24+
if (executor_) aclDestroyAclOpExecutor(executor_);
25+
}
26+
27+
void operator()(const Tensor input, Tensor out) const override {
28+
auto stream = static_cast<aclrtStream>(stream_);
29+
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
30+
auto t_out = out_cache_.get(out.data());
31+
32+
if (!executor_) {
33+
aclnnCastGetWorkspaceSize(t_in, acl_out_dtype_, t_out, &ws_size_,
34+
&executor_);
35+
aclSetAclOpExecutorRepeatable(executor_);
36+
} else {
37+
aclSetInputTensorAddr(executor_, 0, t_in,
38+
const_cast<void*>(input.data()));
39+
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
40+
}
41+
42+
auto& arena = ascend::workspacePool().ensure(stream, ws_size_);
43+
aclnnCast(arena.buf, ws_size_, executor_, stream);
44+
}
45+
46+
private:
47+
mutable ascend::AclTensorCache in_cache_;
48+
49+
mutable ascend::AclTensorCache out_cache_;
50+
51+
aclDataType acl_out_dtype_;
52+
53+
mutable aclOpExecutor* executor_ = nullptr;
54+
55+
mutable uint64_t ws_size_ = 0;
56+
};
57+
58+
} // namespace infini::ops
59+
60+
#endif

src/ascend/cat/kernel.h

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#ifndef INFINI_OPS_ASCEND_CAT_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_CAT_KERNEL_H_
3+
4+
#include <vector>
5+
6+
#include "acl/acl.h"
7+
#include "aclnn/aclnn_base.h"
8+
#include "aclnn/acl_meta.h"
9+
#include "aclnnop/aclnn_cat.h"
10+
#include "ascend/common.h"
11+
#include "ascend/workspace_pool_.h"
12+
#include "base/cat.h"
13+
#include "operator.h"
14+
15+
namespace infini::ops {
16+
17+
template <>
18+
class Operator<Cat, Device::Type::kAscend> : public Cat {
19+
public:
20+
Operator(const Tensor first_input, std::vector<Tensor> rest_inputs,
21+
int64_t dim, Tensor out)
22+
: Cat(first_input, rest_inputs, dim, out), out_cache_(out) {
23+
// Build AclTensorCache for each input tensor.
24+
in_caches_.reserve(input_count_);
25+
in_caches_.emplace_back(first_input);
26+
for (const auto& t : rest_inputs) {
27+
in_caches_.emplace_back(t);
28+
}
29+
}
30+
31+
~Operator() {
32+
if (executor_) aclDestroyAclOpExecutor(executor_);
33+
if (tensor_list_) aclDestroyTensorList(tensor_list_);
34+
}
35+
36+
void operator()(const Tensor first_input, std::vector<Tensor> rest_inputs,
37+
int64_t /*dim*/, Tensor out) const override {
38+
auto stream = static_cast<aclrtStream>(stream_);
39+
40+
// Collect all input tensors in order.
41+
std::vector<const Tensor*> inputs;
42+
inputs.reserve(input_count_);
43+
inputs.push_back(&first_input);
44+
for (const auto& t : rest_inputs) {
45+
inputs.push_back(&t);
46+
}
47+
48+
auto t_out = out_cache_.get(out.data());
49+
50+
if (!executor_) {
51+
// First call: create descriptors, tensor list, and executor.
52+
std::vector<aclTensor*> acl_tensors(input_count_);
53+
for (size_t i = 0; i < input_count_; ++i) {
54+
acl_tensors[i] =
55+
in_caches_[i].get(const_cast<void*>(inputs[i]->data()));
56+
}
57+
58+
tensor_list_ = aclCreateTensorList(
59+
const_cast<const aclTensor**>(acl_tensors.data()),
60+
static_cast<uint64_t>(input_count_));
61+
62+
aclnnCatGetWorkspaceSize(tensor_list_, dim_, t_out, &ws_size_,
63+
&executor_);
64+
aclSetAclOpExecutorRepeatable(executor_);
65+
} else {
66+
// Subsequent calls: update data pointers on cached descriptors via
67+
// `aclSetRawTensorAddr`. The executor holds references to the same
68+
// `aclTensor*` objects inside `tensor_list_`, so updating their data
69+
// pointers is sufficient — no `aclSetInputTensorAddr` needed.
70+
for (size_t i = 0; i < input_count_; ++i) {
71+
in_caches_[i].get(const_cast<void*>(inputs[i]->data()));
72+
}
73+
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
74+
}
75+
76+
auto& arena = ascend::workspacePool().ensure(stream, ws_size_);
77+
aclnnCat(arena.buf, ws_size_, executor_, stream);
78+
}
79+
80+
private:
81+
mutable std::vector<ascend::AclTensorCache> in_caches_;
82+
83+
mutable ascend::AclTensorCache out_cache_;
84+
85+
mutable aclTensorList* tensor_list_ = nullptr;
86+
87+
mutable aclOpExecutor* executor_ = nullptr;
88+
89+
mutable uint64_t ws_size_ = 0;
90+
};
91+
92+
} // namespace infini::ops
93+
94+
#endif

0 commit comments

Comments
 (0)