Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/iluvatar/linear/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef INFINI_OPS_ILUVATAR_LINEAR_KERNEL_H_
#define INFINI_OPS_ILUVATAR_LINEAR_KERNEL_H_

#include <optional>

#include "base/linear.h"
#include "iluvatar/add/kernel.h"
#include "iluvatar/gemm/cublas.h"

namespace infini::ops {

template <>
class Operator<Linear, Device::Type::kIluvatar> : public Linear {
public:
Operator(const Tensor a, const Tensor b, std::optional<Tensor> bias,
bool trans_a, bool trans_b, Tensor out)
: Linear(a, b, bias, trans_a, trans_b, out),
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, out) {
if (has_bias_) {
add_.emplace(out, BroadcastBias(*bias, out), out);
}
}

void operator()(const Tensor a, const Tensor b, std::optional<Tensor> bias,
bool trans_a, bool trans_b, Tensor out) const override {
assert(has_bias_ == bias.has_value());

ConfigureSubOperator(gemm_);
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, out);

if (has_bias_) {
auto bias_view = BroadcastBias(*bias, out);
ConfigureSubOperator(*add_);
(*add_)(out, bias_view, out);
}
}

private:
static Tensor BroadcastBias(const Tensor& bias, const Tensor& out) {
assert(bias.ndim() == 1);
assert(bias.size(0) == out.size(-1));

auto shape = out.shape();
Tensor::Strides strides(out.ndim(), 0);
strides.back() = bias.stride(0);

return Tensor{const_cast<void*>(bias.data()), shape, bias.dtype(),
bias.device(), strides};
}

template <typename Op>
void ConfigureSubOperator(Op& op) const {
op.set_handle(handle_);
op.set_config(config_);
op.set_stream(stream_);
op.set_workspace(workspace_);
op.set_workspace_size_in_bytes(workspace_size_in_bytes_);
}

mutable Operator<Gemm, Device::Type::kIluvatar> gemm_;

mutable std::optional<Operator<Add, Device::Type::kIluvatar>> add_;
};

} // namespace infini::ops

#endif
71 changes: 71 additions & 0 deletions src/metax/linear/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef INFINI_OPS_METAX_LINEAR_KERNEL_H_
#define INFINI_OPS_METAX_LINEAR_KERNEL_H_

#include <optional>

#include "base/linear.h"
#include "metax/add/kernel.h"
#include "metax/gemm/mcblas.h"

namespace infini::ops {

template <>
class Operator<Linear, Device::Type::kMetax> : public Linear {
public:
Operator(const Tensor a, const Tensor b, std::optional<Tensor> bias,
bool trans_a, bool trans_b, Tensor out)
: Linear(a, b, bias, trans_a, trans_b, out),
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, out) {
if (has_bias_) {
add_.emplace(out, BroadcastBias(*bias, out), out);
}
}

void operator()(const Tensor a, const Tensor b, std::optional<Tensor> bias,
bool trans_a, bool trans_b, Tensor out) const override {
assert(has_bias_ == bias.has_value());

ConfigureSubOperator(gemm_);
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, out);

if (has_bias_) {
auto bias_view = BroadcastBias(*bias, out);
ConfigureSubOperator(*add_);
(*add_)(out, bias_view, out);
}
}

private:
static Tensor BroadcastBias(const Tensor& bias, const Tensor& out) {
assert(bias.ndim() == 1);
assert(bias.size(0) == out.size(-1));

auto shape = out.shape();
Tensor::Strides strides(out.ndim(), 0);
strides.back() = bias.stride(0);

return Tensor{const_cast<void*>(bias.data()), shape, bias.dtype(),
bias.device(), strides};
}

template <typename Op>
void ConfigureSubOperator(Op& op) const {
op.set_handle(handle_);
op.set_config(config_);
op.set_stream(stream_);
op.set_workspace(workspace_);
op.set_workspace_size_in_bytes(workspace_size_in_bytes_);
}

mutable Operator<Gemm, Device::Type::kMetax> gemm_;

mutable std::optional<Operator<Add, Device::Type::kMetax>> add_;
};

} // namespace infini::ops

#endif
71 changes: 71 additions & 0 deletions src/moore/linear/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef INFINI_OPS_MOORE_LINEAR_KERNEL_H_
#define INFINI_OPS_MOORE_LINEAR_KERNEL_H_

#include <optional>

#include "base/linear.h"
#include "moore/add/kernel.h"
#include "moore/gemm/mublas.h"

namespace infini::ops {

template <>
class Operator<Linear, Device::Type::kMoore> : public Linear {
public:
Operator(const Tensor a, const Tensor b, std::optional<Tensor> bias,
bool trans_a, bool trans_b, Tensor out)
: Linear(a, b, bias, trans_a, trans_b, out),
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, out) {
if (has_bias_) {
add_.emplace(out, BroadcastBias(*bias, out), out);
}
}

void operator()(const Tensor a, const Tensor b, std::optional<Tensor> bias,
bool trans_a, bool trans_b, Tensor out) const override {
assert(has_bias_ == bias.has_value());

ConfigureSubOperator(gemm_);
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, out);

if (has_bias_) {
auto bias_view = BroadcastBias(*bias, out);
ConfigureSubOperator(*add_);
(*add_)(out, bias_view, out);
}
}

private:
static Tensor BroadcastBias(const Tensor& bias, const Tensor& out) {
assert(bias.ndim() == 1);
assert(bias.size(0) == out.size(-1));

auto shape = out.shape();
Tensor::Strides strides(out.ndim(), 0);
strides.back() = bias.stride(0);

return Tensor{const_cast<void*>(bias.data()), shape, bias.dtype(),
bias.device(), strides};
}

template <typename Op>
void ConfigureSubOperator(Op& op) const {
op.set_handle(handle_);
op.set_config(config_);
op.set_stream(stream_);
op.set_workspace(workspace_);
op.set_workspace_size_in_bytes(workspace_size_in_bytes_);
}

mutable Operator<Gemm, Device::Type::kMoore> gemm_;

mutable std::optional<Operator<Add, Device::Type::kMoore>> add_;
};

} // namespace infini::ops

#endif
71 changes: 71 additions & 0 deletions src/nvidia/linear/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef INFINI_OPS_NVIDIA_LINEAR_KERNEL_H_
#define INFINI_OPS_NVIDIA_LINEAR_KERNEL_H_

#include <optional>

#include "base/linear.h"
#include "nvidia/add/kernel.h"
#include "nvidia/gemm/cublas.h"

namespace infini::ops {

template <>
class Operator<Linear, Device::Type::kNvidia> : public Linear {
public:
Operator(const Tensor a, const Tensor b, std::optional<Tensor> bias,
bool trans_a, bool trans_b, Tensor out)
: Linear(a, b, bias, trans_a, trans_b, out),
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, out) {
if (has_bias_) {
add_.emplace(out, BroadcastBias(*bias, out), out);
}
}

void operator()(const Tensor a, const Tensor b, std::optional<Tensor> bias,
bool trans_a, bool trans_b, Tensor out) const override {
assert(has_bias_ == bias.has_value());

ConfigureSubOperator(gemm_);
gemm_(a, b, std::optional<float>{1.0f}, std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, out);

if (has_bias_) {
auto bias_view = BroadcastBias(*bias, out);
ConfigureSubOperator(*add_);
(*add_)(out, bias_view, out);
}
}

private:
static Tensor BroadcastBias(const Tensor& bias, const Tensor& out) {
assert(bias.ndim() == 1);
assert(bias.size(0) == out.size(-1));

auto shape = out.shape();
Tensor::Strides strides(out.ndim(), 0);
strides.back() = bias.stride(0);

return Tensor{const_cast<void*>(bias.data()), shape, bias.dtype(),
bias.device(), strides};
}

template <typename Op>
void ConfigureSubOperator(Op& op) const {
op.set_handle(handle_);
op.set_config(config_);
op.set_stream(stream_);
op.set_workspace(workspace_);
op.set_workspace_size_in_bytes(workspace_size_in_bytes_);
}

mutable Operator<Gemm, Device::Type::kNvidia> gemm_;

mutable std::optional<Operator<Add, Device::Type::kNvidia>> add_;
};

} // namespace infini::ops

#endif
Loading