Skip to content

Commit 7649042

Browse files
author
zhangyue
committed
feat(ascend): op-simple group — Add, Mul, Cast, Cat, Matmul, Gemm, Linear
Seven foundational Ascend operators: | op | impl | |---|---| | Add | aclnnAdd | | Mul | aclnnMul | | Cast | aclnnCast | | Cat | aclnnCat | | Matmul | aclnnMatmul | | Gemm | aclnnMm (also carries the cached-executor / workspace-pool rework) | | Linear | aclnnMatmul + optional bias | Also ships: - `src/base/<op>.h` for the 5 new ops (cast/cat/linear/matmul/mul); `add.h` and `gemm.h` existed on master and are updated in-place - `src/cpu/<op>/<op>.h` reference impls for cast/cat/linear/mul (add/gemm/matmul had CPU refs on master already) - `tests/test_<op>.py` for each operator (add and gemm have MODIFY diffs; others are new)
1 parent a05713b commit 7649042

22 files changed

Lines changed: 1508 additions & 56 deletions

File tree

src/ascend/add/kernel.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#ifndef INFINI_OPS_ASCEND_ADD_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_ADD_KERNEL_H_
3+
4+
#include "acl/acl.h"
5+
#include "aclnn/aclnn_base.h"
6+
#include "aclnn_add.h"
7+
#include "ascend/common.h"
8+
#include "ascend/workspace_pool_.h"
9+
#include "base/add.h"
10+
#include "data_type.h"
11+
#include "operator.h"
12+
13+
namespace infini::ops {
14+
15+
template <>
16+
class Operator<Add, Device::Type::kAscend> : public Add {
17+
public:
18+
Operator(const Tensor input, const Tensor other, Tensor out)
19+
: Add(input, other, out),
20+
in_cache_(input),
21+
oth_cache_(other),
22+
out_cache_(out) {
23+
// `aclCreateScalar` stores the pointer rather than copying the value, so
24+
// `alpha_storage_*` must remain alive for the lifetime of `alpha_`.
25+
// The alpha scalar type must match the tensor dtype: use int64 for integer
26+
// dtypes and float for floating-point dtypes.
27+
if (ascend::IsIntegerDtype(input.dtype())) {
28+
alpha_ = aclCreateScalar(&alpha_int_storage_, ACL_INT64);
29+
} else {
30+
alpha_ = aclCreateScalar(&alpha_float_storage_, ACL_FLOAT);
31+
}
32+
}
33+
34+
~Operator() {
35+
if (!ascend::IsAclRuntimeAlive()) return;
36+
37+
// Destroy cached tensors and the executor, then the scalar.
38+
// Historical note: this active-destroy pattern works for `Add` at
39+
// process exit but crashed for most other operators — see `64c367c`
40+
// and the rest of `src/ascend/*/kernel.h` which use `release()` only.
41+
in_cache_.destroy();
42+
oth_cache_.destroy();
43+
out_cache_.destroy();
44+
45+
if (executor_) aclDestroyAclOpExecutor(executor_);
46+
if (alpha_) aclDestroyScalar(alpha_);
47+
}
48+
49+
void operator()(const Tensor input, const Tensor other,
50+
Tensor out) const override {
51+
auto stream = static_cast<aclrtStream>(stream_);
52+
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
53+
auto t_oth = oth_cache_.get(const_cast<void*>(other.data()));
54+
auto t_out = out_cache_.get(out.data());
55+
56+
if (!executor_) {
57+
aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_size_,
58+
&executor_);
59+
aclSetAclOpExecutorRepeatable(executor_);
60+
} else {
61+
aclSetInputTensorAddr(executor_, 0, t_in,
62+
const_cast<void*>(input.data()));
63+
aclSetInputTensorAddr(executor_, 1, t_oth,
64+
const_cast<void*>(other.data()));
65+
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
66+
}
67+
68+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
69+
aclnnAdd(arena.buf, ws_size_, executor_, stream);
70+
}
71+
72+
private:
73+
mutable ascend::AclTensorCache in_cache_;
74+
75+
mutable ascend::AclTensorCache oth_cache_;
76+
77+
mutable ascend::AclTensorCache out_cache_;
78+
79+
mutable aclOpExecutor* executor_ = nullptr;
80+
81+
mutable uint64_t ws_size_ = 0;
82+
83+
float alpha_float_storage_ =
84+
1.0f; // Stable address for `aclCreateScalar` (float).
85+
int64_t alpha_int_storage_ =
86+
1; // Stable address for `aclCreateScalar` (int).
87+
aclScalar* alpha_ = nullptr;
88+
};
89+
90+
} // namespace infini::ops
91+
92+
#endif

src/ascend/cast/kernel.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#ifndef INFINI_OPS_ASCEND_CAST_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_CAST_KERNEL_H_
3+
4+
#include "acl/acl.h"
5+
#include "aclnn/aclnn_base.h"
6+
#include "aclnnop/aclnn_cast.h"
7+
#include "ascend/common.h"
8+
#include "ascend/workspace_pool_.h"
9+
#include "base/cast.h"
10+
#include "operator.h"
11+
12+
namespace infini::ops {
13+
14+
template <>
15+
class Operator<Cast, Device::Type::kAscend> : public Cast {
16+
public:
17+
Operator(const Tensor input, Tensor out)
18+
: Cast(input, out),
19+
in_cache_(input),
20+
out_cache_(out),
21+
acl_out_dtype_(ascend::ToAclDtype(out.dtype())) {}
22+
23+
~Operator() {
24+
if (!ascend::IsAclRuntimeAlive()) return;
25+
26+
// Null cached descriptors — see `AclTensorCache::release()`.
27+
in_cache_.release();
28+
out_cache_.release();
29+
}
30+
31+
void operator()(const Tensor input, Tensor out) const override {
32+
auto stream = static_cast<aclrtStream>(stream_);
33+
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
34+
auto t_out = out_cache_.get(out.data());
35+
36+
if (!executor_) {
37+
aclnnCastGetWorkspaceSize(t_in, acl_out_dtype_, t_out, &ws_size_,
38+
&executor_);
39+
aclSetAclOpExecutorRepeatable(executor_);
40+
} else {
41+
aclSetInputTensorAddr(executor_, 0, t_in,
42+
const_cast<void*>(input.data()));
43+
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
44+
}
45+
46+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
47+
aclnnCast(arena.buf, ws_size_, executor_, stream);
48+
}
49+
50+
private:
51+
mutable ascend::AclTensorCache in_cache_;
52+
53+
mutable ascend::AclTensorCache out_cache_;
54+
55+
aclDataType acl_out_dtype_;
56+
57+
mutable aclOpExecutor* executor_ = nullptr;
58+
59+
mutable uint64_t ws_size_ = 0;
60+
};
61+
62+
} // namespace infini::ops
63+
64+
#endif

src/ascend/cat/kernel.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#ifndef INFINI_OPS_ASCEND_CAT_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_CAT_KERNEL_H_
3+
4+
#include <vector>
5+
6+
#include "acl/acl.h"
7+
#include "aclnn/acl_meta.h"
8+
#include "aclnn/aclnn_base.h"
9+
#include "aclnnop/aclnn_cat.h"
10+
#include "ascend/common.h"
11+
#include "ascend/workspace_pool_.h"
12+
#include "base/cat.h"
13+
#include "operator.h"
14+
15+
namespace infini::ops {
16+
17+
template <>
18+
class Operator<Cat, Device::Type::kAscend> : public Cat {
19+
public:
20+
Operator(const Tensor first_input, std::vector<Tensor> rest_inputs,
21+
int64_t dim, Tensor out)
22+
: Cat(first_input, rest_inputs, dim, out), out_cache_(out) {
23+
// Build `AclTensorCache` for each input tensor.
24+
in_caches_.reserve(input_count_);
25+
in_caches_.emplace_back(first_input);
26+
for (const auto& t : rest_inputs) {
27+
in_caches_.emplace_back(t);
28+
}
29+
}
30+
31+
~Operator() {
32+
if (!ascend::IsAclRuntimeAlive()) return;
33+
34+
// Null cached descriptors — see `AclTensorCache::release()`.
35+
out_cache_.release();
36+
37+
if (tensor_list_) aclDestroyTensorList(tensor_list_);
38+
}
39+
40+
void operator()(const Tensor first_input, std::vector<Tensor> rest_inputs,
41+
int64_t /*dim*/, Tensor out) const override {
42+
auto stream = static_cast<aclrtStream>(stream_);
43+
44+
// Collect all input tensors in order.
45+
std::vector<const Tensor*> inputs;
46+
inputs.reserve(input_count_);
47+
inputs.push_back(&first_input);
48+
for (const auto& t : rest_inputs) {
49+
inputs.push_back(&t);
50+
}
51+
52+
auto t_out = out_cache_.get(out.data());
53+
54+
if (!executor_) {
55+
// First call: create descriptors, tensor list, and executor.
56+
std::vector<aclTensor*> acl_tensors(input_count_);
57+
for (size_t i = 0; i < input_count_; ++i) {
58+
acl_tensors[i] =
59+
in_caches_[i].get(const_cast<void*>(inputs[i]->data()));
60+
}
61+
62+
tensor_list_ =
63+
aclCreateTensorList(const_cast<const aclTensor**>(acl_tensors.data()),
64+
static_cast<uint64_t>(input_count_));
65+
66+
aclnnCatGetWorkspaceSize(tensor_list_, dim_, t_out, &ws_size_,
67+
&executor_);
68+
aclSetAclOpExecutorRepeatable(executor_);
69+
} else {
70+
// Subsequent calls: update data pointers on cached descriptors via
71+
// `aclSetRawTensorAddr`. The executor holds references to the same
72+
// `aclTensor*` objects inside `tensor_list_`, so updating their data
73+
// pointers is sufficient — no `aclSetInputTensorAddr` needed.
74+
for (size_t i = 0; i < input_count_; ++i) {
75+
in_caches_[i].get(const_cast<void*>(inputs[i]->data()));
76+
}
77+
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
78+
}
79+
80+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
81+
aclnnCat(arena.buf, ws_size_, executor_, stream);
82+
}
83+
84+
private:
85+
mutable std::vector<ascend::AclTensorCache> in_caches_;
86+
87+
mutable ascend::AclTensorCache out_cache_;
88+
89+
mutable aclTensorList* tensor_list_ = nullptr;
90+
91+
mutable aclOpExecutor* executor_ = nullptr;
92+
93+
mutable uint64_t ws_size_ = 0;
94+
};
95+
96+
} // namespace infini::ops
97+
98+
#endif

src/ascend/gemm/kernel.h

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,50 +21,63 @@ class Operator<Gemm, Device::Type::kAscend> : public Gemm {
2121
: Gemm(a, b, alpha, beta, trans_a, trans_b, c),
2222
batched_{batch_count_ > 1},
2323
alpha_val_{alpha.value_or(1.0f)},
24-
beta_val_{beta.value_or(1.0f)} {
24+
beta_val_{beta.value_or(1.0f)},
25+
self_cache_(c),
26+
a_cache_(a, trans_a_),
27+
b_cache_(b, trans_b_),
28+
out_cache_(c) {
2529
alpha_scalar_ = aclCreateScalar(&alpha_val_, ACL_FLOAT);
2630
beta_scalar_ = aclCreateScalar(&beta_val_, ACL_FLOAT);
2731
}
2832

2933
~Operator() {
30-
aclDestroyScalar(alpha_scalar_);
31-
aclDestroyScalar(beta_scalar_);
34+
if (!ascend::IsAclRuntimeAlive()) return;
35+
36+
// Null cached descriptors — see `AclTensorCache::release()`.
37+
self_cache_.release();
38+
a_cache_.release();
39+
b_cache_.release();
40+
out_cache_.release();
41+
42+
if (alpha_scalar_) aclDestroyScalar(alpha_scalar_);
43+
if (beta_scalar_) aclDestroyScalar(beta_scalar_);
3244
}
3345

3446
void operator()(const Tensor a, const Tensor b, std::optional<float> alpha,
3547
std::optional<float> beta, std::optional<int> trans_a,
3648
std::optional<int> trans_b, Tensor c) const override {
3749
auto stream = static_cast<aclrtStream>(stream_);
3850

39-
auto t_self = ascend::BuildAclTensor(c);
40-
auto t_a = ascend::BuildAclTensor(a, trans_a_);
41-
auto t_b = ascend::BuildAclTensor(b, trans_b_);
42-
auto t_out = ascend::BuildAclTensor(c);
43-
44-
uint64_t ws_needed = 0;
45-
aclOpExecutor* executor = nullptr;
46-
47-
if (batched_) {
48-
aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_,
49-
alpha_scalar_, t_out, 0, &ws_needed,
50-
&executor);
51+
auto t_self = self_cache_.get(c.data());
52+
auto t_a = a_cache_.get(const_cast<void*>(a.data()));
53+
auto t_b = b_cache_.get(const_cast<void*>(b.data()));
54+
auto t_out = out_cache_.get(c.data());
55+
56+
if (!executor_) {
57+
if (batched_) {
58+
aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_,
59+
alpha_scalar_, t_out, 0, &ws_size_,
60+
&executor_);
61+
} else {
62+
aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_,
63+
alpha_scalar_, t_out, 0, &ws_size_,
64+
&executor_);
65+
}
66+
aclSetAclOpExecutorRepeatable(executor_);
5167
} else {
52-
aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, alpha_scalar_,
53-
t_out, 0, &ws_needed, &executor);
68+
aclSetInputTensorAddr(executor_, 0, t_self, c.data());
69+
aclSetInputTensorAddr(executor_, 1, t_a, const_cast<void*>(a.data()));
70+
aclSetInputTensorAddr(executor_, 2, t_b, const_cast<void*>(b.data()));
71+
aclSetOutputTensorAddr(executor_, 0, t_out, c.data());
5472
}
5573

56-
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed);
74+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
5775

5876
if (batched_) {
59-
aclnnBaddbmm(arena.buf, ws_needed, executor, stream);
77+
aclnnBaddbmm(arena.buf, ws_size_, executor_, stream);
6078
} else {
61-
aclnnAddmm(arena.buf, ws_needed, executor, stream);
79+
aclnnAddmm(arena.buf, ws_size_, executor_, stream);
6280
}
63-
64-
aclDestroyTensor(t_self);
65-
aclDestroyTensor(t_a);
66-
aclDestroyTensor(t_b);
67-
aclDestroyTensor(t_out);
6881
}
6982

7083
private:
@@ -77,6 +90,18 @@ class Operator<Gemm, Device::Type::kAscend> : public Gemm {
7790
aclScalar* alpha_scalar_ = nullptr;
7891

7992
aclScalar* beta_scalar_ = nullptr;
93+
94+
mutable ascend::AclTensorCache self_cache_;
95+
96+
mutable ascend::AclTensorCache a_cache_;
97+
98+
mutable ascend::AclTensorCache b_cache_;
99+
100+
mutable ascend::AclTensorCache out_cache_;
101+
102+
mutable aclOpExecutor* executor_ = nullptr;
103+
104+
mutable uint64_t ws_size_ = 0;
80105
};
81106

82107
} // namespace infini::ops

0 commit comments

Comments
 (0)