|
| 1 | +#include "torch/gemm/gemm.h" |
| 2 | + |
| 3 | +#include "torch/tensor_.h" |
| 4 | + |
| 5 | +namespace infini::ops { |
| 6 | + |
| 7 | +template <Device::Type kDev> |
| 8 | +Operator<Gemm, kDev, 2>::Operator(const Tensor a, const Tensor b, |
| 9 | + std::optional<float> alpha, |
| 10 | + std::optional<float> beta, |
| 11 | + std::optional<int> trans_a, |
| 12 | + std::optional<int> trans_b, Tensor c) |
| 13 | + : Gemm{a, b, alpha, beta, trans_a, trans_b, c}, |
| 14 | + a_shape_{a.shape()}, |
| 15 | + b_shape_{b.shape()}, |
| 16 | + c_shape_{c.shape()}, |
| 17 | + device_index_{c.device().index()} {} |
| 18 | + |
| 19 | +template <Device::Type kDev> |
| 20 | +void Operator<Gemm, kDev, 2>::operator()( |
| 21 | + const Tensor a, const Tensor b, std::optional<float> alpha, |
| 22 | + std::optional<float> beta, std::optional<int> trans_a, |
| 23 | + std::optional<int> trans_b, Tensor c) const { |
| 24 | + auto at_a = ToAtenTensor<kDev>(const_cast<void*>(a.data()), a_shape_, |
| 25 | + a_strides_, a_type_, device_index_); |
| 26 | + auto at_b = ToAtenTensor<kDev>(const_cast<void*>(b.data()), b_shape_, |
| 27 | + b_strides_, b_type_, device_index_); |
| 28 | + auto at_c = ToAtenTensor<kDev>(c.data(), c_shape_, c_strides_, c_type_, |
| 29 | + device_index_); |
| 30 | + |
| 31 | + auto alpha_val = alpha.value_or(alpha_); |
| 32 | + auto beta_val = beta.value_or(beta_); |
| 33 | + |
| 34 | + if (trans_a.value_or(trans_a_)) { |
| 35 | + at_a = at_a.transpose(-2, -1); |
| 36 | + } |
| 37 | + |
| 38 | + if (trans_b.value_or(trans_b_)) { |
| 39 | + at_b = at_b.transpose(-2, -1); |
| 40 | + } |
| 41 | + |
| 42 | + if (at_a.dim() == 2) { |
| 43 | + at::addmm_out(at_c, at_c, at_a, at_b, beta_val, alpha_val); |
| 44 | + } else { |
| 45 | + at::baddbmm_out(at_c, at_c, at_a, at_b, beta_val, alpha_val); |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +template class Operator<Gemm, Device::Type::kCpu, 2>; |
| 50 | +template class Operator<Gemm, Device::Type::kNvidia, 2>; |
| 51 | +template class Operator<Gemm, Device::Type::kCambricon, 2>; |
| 52 | +template class Operator<Gemm, Device::Type::kAscend, 2>; |
| 53 | +template class Operator<Gemm, Device::Type::kMetax, 2>; |
| 54 | +template class Operator<Gemm, Device::Type::kMoore, 2>; |
| 55 | +template class Operator<Gemm, Device::Type::kIluvatar, 2>; |
| 56 | +template class Operator<Gemm, Device::Type::kKunlun, 2>; |
| 57 | +template class Operator<Gemm, Device::Type::kHygon, 2>; |
| 58 | +template class Operator<Gemm, Device::Type::kQy, 2>; |
| 59 | + |
| 60 | +} // namespace infini::ops |
0 commit comments