Skip to content

Commit 873d79b

Browse files
committed
feat(gpu): add linear kernels for metax iluvatar and moore
1 parent d50872a commit 873d79b

3 files changed

Lines changed: 213 additions & 0 deletions

File tree

src/iluvatar/linear/kernel.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#ifndef INFINI_OPS_ILUVATAR_LINEAR_KERNEL_H_
2+
#define INFINI_OPS_ILUVATAR_LINEAR_KERNEL_H_
3+
4+
#include <optional>
5+
6+
#include "base/linear.h"
7+
#include "iluvatar/add/kernel.h"
8+
#include "iluvatar/gemm/cublas.h"
9+
10+
namespace infini::ops {
11+
12+
template <>
13+
class Operator<Linear, Device::Type::kIluvatar> : public Linear {
14+
public:
15+
Operator(const Tensor a, const Tensor b, std::optional<Tensor> bias,
16+
bool trans_a, bool trans_b, Tensor out)
17+
: Linear(a, b, bias, trans_a, trans_b, out),
18+
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
19+
std::optional<int>{static_cast<int>(trans_a)},
20+
std::optional<int>{static_cast<int>(trans_b)}, out) {
21+
if (has_bias_) {
22+
add_.emplace(out, BroadcastBias(*bias, out), out);
23+
}
24+
}
25+
26+
void operator()(const Tensor a, const Tensor b, std::optional<Tensor> bias,
27+
bool trans_a, bool trans_b, Tensor out) const override {
28+
assert(has_bias_ == bias.has_value());
29+
30+
ConfigureSubOperator(gemm_);
31+
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
32+
std::optional<int>{static_cast<int>(trans_a)},
33+
std::optional<int>{static_cast<int>(trans_b)}, out);
34+
35+
if (has_bias_) {
36+
auto bias_view = BroadcastBias(*bias, out);
37+
ConfigureSubOperator(*add_);
38+
(*add_)(out, bias_view, out);
39+
}
40+
}
41+
42+
private:
43+
static Tensor BroadcastBias(const Tensor& bias, const Tensor& out) {
44+
assert(bias.ndim() == 1);
45+
assert(bias.size(0) == out.size(-1));
46+
47+
auto shape = out.shape();
48+
Tensor::Strides strides(out.ndim(), 0);
49+
strides.back() = bias.stride(0);
50+
51+
return Tensor{const_cast<void*>(bias.data()), shape, bias.dtype(),
52+
bias.device(), strides};
53+
}
54+
55+
template <typename Op>
56+
void ConfigureSubOperator(Op& op) const {
57+
op.set_handle(handle_);
58+
op.set_config(config_);
59+
op.set_stream(stream_);
60+
op.set_workspace(workspace_);
61+
op.set_workspace_size_in_bytes(workspace_size_in_bytes_);
62+
}
63+
64+
mutable Operator<Gemm, Device::Type::kIluvatar> gemm_;
65+
66+
mutable std::optional<Operator<Add, Device::Type::kIluvatar>> add_;
67+
};
68+
69+
} // namespace infini::ops
70+
71+
#endif

src/metax/linear/kernel.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#ifndef INFINI_OPS_METAX_LINEAR_KERNEL_H_
2+
#define INFINI_OPS_METAX_LINEAR_KERNEL_H_
3+
4+
#include <optional>
5+
6+
#include "base/linear.h"
7+
#include "metax/add/kernel.h"
8+
#include "metax/gemm/mcblas.h"
9+
10+
namespace infini::ops {
11+
12+
template <>
13+
class Operator<Linear, Device::Type::kMetax> : public Linear {
14+
public:
15+
Operator(const Tensor a, const Tensor b, std::optional<Tensor> bias,
16+
bool trans_a, bool trans_b, Tensor out)
17+
: Linear(a, b, bias, trans_a, trans_b, out),
18+
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
19+
std::optional<int>{static_cast<int>(trans_a)},
20+
std::optional<int>{static_cast<int>(trans_b)}, out) {
21+
if (has_bias_) {
22+
add_.emplace(out, BroadcastBias(*bias, out), out);
23+
}
24+
}
25+
26+
void operator()(const Tensor a, const Tensor b, std::optional<Tensor> bias,
27+
bool trans_a, bool trans_b, Tensor out) const override {
28+
assert(has_bias_ == bias.has_value());
29+
30+
ConfigureSubOperator(gemm_);
31+
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
32+
std::optional<int>{static_cast<int>(trans_a)},
33+
std::optional<int>{static_cast<int>(trans_b)}, out);
34+
35+
if (has_bias_) {
36+
auto bias_view = BroadcastBias(*bias, out);
37+
ConfigureSubOperator(*add_);
38+
(*add_)(out, bias_view, out);
39+
}
40+
}
41+
42+
private:
43+
static Tensor BroadcastBias(const Tensor& bias, const Tensor& out) {
44+
assert(bias.ndim() == 1);
45+
assert(bias.size(0) == out.size(-1));
46+
47+
auto shape = out.shape();
48+
Tensor::Strides strides(out.ndim(), 0);
49+
strides.back() = bias.stride(0);
50+
51+
return Tensor{const_cast<void*>(bias.data()), shape, bias.dtype(),
52+
bias.device(), strides};
53+
}
54+
55+
template <typename Op>
56+
void ConfigureSubOperator(Op& op) const {
57+
op.set_handle(handle_);
58+
op.set_config(config_);
59+
op.set_stream(stream_);
60+
op.set_workspace(workspace_);
61+
op.set_workspace_size_in_bytes(workspace_size_in_bytes_);
62+
}
63+
64+
mutable Operator<Gemm, Device::Type::kMetax> gemm_;
65+
66+
mutable std::optional<Operator<Add, Device::Type::kMetax>> add_;
67+
};
68+
69+
} // namespace infini::ops
70+
71+
#endif

src/moore/linear/kernel.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#ifndef INFINI_OPS_MOORE_LINEAR_KERNEL_H_
2+
#define INFINI_OPS_MOORE_LINEAR_KERNEL_H_
3+
4+
#include <optional>
5+
6+
#include "base/linear.h"
7+
#include "moore/add/kernel.h"
8+
#include "moore/gemm/mublas.h"
9+
10+
namespace infini::ops {
11+
12+
template <>
13+
class Operator<Linear, Device::Type::kMoore> : public Linear {
14+
public:
15+
Operator(const Tensor a, const Tensor b, std::optional<Tensor> bias,
16+
bool trans_a, bool trans_b, Tensor out)
17+
: Linear(a, b, bias, trans_a, trans_b, out),
18+
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
19+
std::optional<int>{static_cast<int>(trans_a)},
20+
std::optional<int>{static_cast<int>(trans_b)}, out) {
21+
if (has_bias_) {
22+
add_.emplace(out, BroadcastBias(*bias, out), out);
23+
}
24+
}
25+
26+
void operator()(const Tensor a, const Tensor b, std::optional<Tensor> bias,
27+
bool trans_a, bool trans_b, Tensor out) const override {
28+
assert(has_bias_ == bias.has_value());
29+
30+
ConfigureSubOperator(gemm_);
31+
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
32+
std::optional<int>{static_cast<int>(trans_a)},
33+
std::optional<int>{static_cast<int>(trans_b)}, out);
34+
35+
if (has_bias_) {
36+
auto bias_view = BroadcastBias(*bias, out);
37+
ConfigureSubOperator(*add_);
38+
(*add_)(out, bias_view, out);
39+
}
40+
}
41+
42+
private:
43+
static Tensor BroadcastBias(const Tensor& bias, const Tensor& out) {
44+
assert(bias.ndim() == 1);
45+
assert(bias.size(0) == out.size(-1));
46+
47+
auto shape = out.shape();
48+
Tensor::Strides strides(out.ndim(), 0);
49+
strides.back() = bias.stride(0);
50+
51+
return Tensor{const_cast<void*>(bias.data()), shape, bias.dtype(),
52+
bias.device(), strides};
53+
}
54+
55+
template <typename Op>
56+
void ConfigureSubOperator(Op& op) const {
57+
op.set_handle(handle_);
58+
op.set_config(config_);
59+
op.set_stream(stream_);
60+
op.set_workspace(workspace_);
61+
op.set_workspace_size_in_bytes(workspace_size_in_bytes_);
62+
}
63+
64+
mutable Operator<Gemm, Device::Type::kMoore> gemm_;
65+
66+
mutable std::optional<Operator<Add, Device::Type::kMoore>> add_;
67+
};
68+
69+
} // namespace infini::ops
70+
71+
#endif

0 commit comments

Comments
 (0)