-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathkernel.h
More file actions
71 lines (56 loc) · 2.1 KB
/
kernel.h
File metadata and controls
71 lines (56 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#ifndef INFINI_OPS_NVIDIA_LINEAR_KERNEL_H_
#define INFINI_OPS_NVIDIA_LINEAR_KERNEL_H_
#include <optional>
#include "base/linear.h"
#include "nvidia/add/kernel.h"
#include "nvidia/gemm/cublas.h"
namespace infini::ops {
template <>
class Operator<Linear, Device::Type::kNvidia> : public Linear {
public:
Operator(const Tensor a, const Tensor b, std::optional<Tensor> bias,
bool trans_a, bool trans_b, Tensor out)
: Linear(a, b, bias, trans_a, trans_b, out),
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, out) {
if (has_bias_) {
add_.emplace(out, BroadcastBias(*bias, out), out);
}
}
void operator()(const Tensor a, const Tensor b, std::optional<Tensor> bias,
bool trans_a, bool trans_b, Tensor out) const override {
assert(has_bias_ == bias.has_value());
ConfigureSubOperator(gemm_);
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, out);
if (has_bias_) {
auto bias_view = BroadcastBias(*bias, out);
ConfigureSubOperator(*add_);
(*add_)(out, bias_view, out);
}
}
private:
static Tensor BroadcastBias(const Tensor& bias, const Tensor& out) {
assert(bias.ndim() == 1);
assert(bias.size(0) == out.size(-1));
auto shape = out.shape();
Tensor::Strides strides(out.ndim(), 0);
strides.back() = bias.stride(0);
return Tensor{const_cast<void*>(bias.data()), shape, bias.dtype(),
bias.device(), strides};
}
template <typename Op>
void ConfigureSubOperator(Op& op) const {
op.set_handle(handle_);
op.set_config(config_);
op.set_stream(stream_);
op.set_workspace(workspace_);
op.set_workspace_size_in_bytes(workspace_size_in_bytes_);
}
mutable Operator<Gemm, Device::Type::kNvidia> gemm_;
mutable std::optional<Operator<Add, Device::Type::kNvidia>> add_;
};
} // namespace infini::ops
#endif