Skip to content

Commit 9b9dda2

Browse files
committed
feat(hygon-gemm): add Hygon backend support for Gemm
- add a Hygon `Gemm` backend on top of the shared CUDA BLAS path - use DTK-friendly compute and algo settings for fp32/fp16 gemm - fall back to `cublasGemmEx` for single-batch Hygon gemm to avoid DTK crashes - release Hygon cublas handles after each call and re-enable the `gemm` example - verified with `pip install -e .[dev]`, `pytest tests/test_gemm.py -k cuda`, and `pytest tests/test_gemm.py`
1 parent e10671d commit 9b9dda2

File tree

4 files changed

+136
-10
lines changed

4 files changed

+136
-10
lines changed

examples/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@ file(GLOB_RECURSE EXAMPLE_SOURCES CONFIGURE_DEPENDS "*.cc")
22

33
# Iterate through each file and create an executable.
44
foreach(source_file ${EXAMPLE_SOURCES})
5-
if(WITH_HYGON AND source_file MATCHES "/gemm\\.cc$")
6-
continue()
7-
endif()
8-
95
get_filename_component(example_name ${source_file} NAME_WE)
106

117
add_executable(${example_name} ${source_file})

examples/gemm/gemm.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#if WITH_ILUVATAR
1212
#include "iluvatar/gemm/cublas.h"
1313
#endif
14+
#if WITH_HYGON
15+
#include "hygon/gemm/cublas.h"
16+
#endif
1417
#if WITH_METAX
1518
#include "metax/gemm/mcblas.h"
1619
#endif

src/cuda/gemm/blas.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ class Blas : public Gemm {
2121
// TODO: Check constraints.
2222
}
2323

24-
~Blas() { Backend::blasDestroy(handle_); }
24+
~Blas() {
25+
if (handle_ != nullptr) {
26+
Backend::blasDestroy(handle_);
27+
}
28+
}
2529

2630
Blas(const Tensor a, const Tensor b, std::optional<float> alpha,
2731
std::optional<float> beta, Tensor c)
@@ -69,7 +73,6 @@ class Blas : public Gemm {
6973
return &beta;
7074
}
7175

72-
private:
7376
auto GetOpA(int trans_a, int trans_b) const {
7477
if (swap_a_and_b_) {
7578
return (b_is_col_major_ == trans_b) ? Backend::BLAS_OP_T
@@ -88,13 +91,14 @@ class Blas : public Gemm {
8891
: Backend::BLAS_OP_N;
8992
}
9093

91-
bool a_is_col_major_{false};
94+
bool swap_a_and_b_{false};
9295

93-
bool b_is_col_major_{false};
96+
mutable typename Backend::blasHandle_t handle_{};
9497

95-
bool swap_a_and_b_{false};
98+
private:
99+
bool a_is_col_major_{false};
96100

97-
typename Backend::blasHandle_t handle_;
101+
bool b_is_col_major_{false};
98102
};
99103

100104
} // namespace infini::ops

src/hygon/gemm/cublas.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#ifndef INFINI_OPS_HYGON_GEMM_CUBLAS_H_
2+
#define INFINI_OPS_HYGON_GEMM_CUBLAS_H_
3+
4+
#include <utility>
5+
6+
// clang-format off
7+
#include "cublas_v2.h"
8+
// clang-format on
9+
10+
#include "cuda/gemm/blas.h"
11+
12+
namespace infini::ops {
13+
14+
namespace gemm {
15+
16+
struct HygonBackend {
17+
using blasHandle_t = cublasHandle_t;
18+
19+
using stream_t = cudaStream_t;
20+
21+
static constexpr auto BLAS_OP_N = CUBLAS_OP_N;
22+
23+
static constexpr auto BLAS_OP_T = CUBLAS_OP_T;
24+
25+
static constexpr auto R_16F = CUDA_R_16F;
26+
27+
static constexpr auto R_16BF = CUDA_R_16BF;
28+
29+
static constexpr auto R_32F = CUDA_R_32F;
30+
31+
static constexpr auto BLAS_COMPUTE_32F = CUBLAS_COMPUTE_32F;
32+
33+
// DTK exposes the TF32 enum for compatibility, but BW/GFX9-class Hygon
34+
// devices do not provide a working TF32 GEMM fast path.
35+
static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUBLAS_COMPUTE_32F;
36+
37+
static constexpr auto BLAS_GEMM_DEFAULT = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
38+
39+
static constexpr auto blasCreate = cublasCreate;
40+
41+
static constexpr auto blasSetStream = cublasSetStream;
42+
43+
static constexpr auto blasDestroy = cublasDestroy;
44+
45+
static constexpr auto blasGemmEx = [](auto&&... args) {
46+
return cublasGemmEx(std::forward<decltype(args)>(args)...);
47+
};
48+
49+
static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) {
50+
return cublasGemmStridedBatchedEx(std::forward<decltype(args)>(args)...);
51+
};
52+
53+
static auto GetDataType(DataType dtype) {
54+
if (dtype == DataType::kFloat16) return R_16F;
55+
if (dtype == DataType::kBFloat16) return R_16BF;
56+
return R_32F;
57+
}
58+
59+
static auto GetComputeType(DataType dtype) {
60+
if (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16)
61+
return BLAS_COMPUTE_32F;
62+
return BLAS_COMPUTE_32F_FAST_TF32;
63+
}
64+
};
65+
66+
} // namespace gemm
67+
68+
template <>
69+
class Operator<Gemm, Device::Type::kHygon> : public Blas<gemm::HygonBackend> {
70+
public:
71+
using Blas<gemm::HygonBackend>::Blas;
72+
73+
void operator()(const Tensor a, const Tensor b, std::optional<float> alpha,
74+
std::optional<float> beta, std::optional<int> trans_a,
75+
std::optional<int> trans_b, Tensor c) const override {
76+
if (this->handle_ == nullptr) {
77+
gemm::HygonBackend::blasCreate(&this->handle_);
78+
}
79+
80+
gemm::HygonBackend::blasSetStream(
81+
this->handle_, static_cast<gemm::HygonBackend::stream_t>(this->stream_));
82+
83+
const auto& alpha_value{alpha.value_or(this->alpha_)};
84+
const auto& beta_value{beta.value_or(this->beta_)};
85+
86+
const auto& trans_a_value{trans_a.value_or(this->trans_a_)};
87+
const auto& trans_b_value{trans_b.value_or(this->trans_b_)};
88+
auto op_a{this->GetOpA(trans_a_value, trans_b_value)};
89+
auto op_b{this->GetOpB(trans_a_value, trans_b_value)};
90+
const void* alpha_ptr{this->GetAlphaPtr(alpha_value, c.dtype())};
91+
const void* beta_ptr{this->GetBetaPtr(beta_value, c.dtype())};
92+
93+
if (this->batch_count_ == 1) {
94+
gemm::HygonBackend::blasGemmEx(
95+
this->handle_, op_a, op_b,
96+
this->swap_a_and_b_ ? this->n_ : this->m_,
97+
this->swap_a_and_b_ ? this->m_ : this->n_, this->k_, alpha_ptr,
98+
this->swap_a_and_b_ ? b.data() : a.data(),
99+
gemm::HygonBackend::GetDataType(this->swap_a_and_b_ ? b.dtype()
100+
: a.dtype()),
101+
this->swap_a_and_b_ ? this->ldb_ : this->lda_,
102+
this->swap_a_and_b_ ? a.data() : b.data(),
103+
gemm::HygonBackend::GetDataType(this->swap_a_and_b_ ? a.dtype()
104+
: b.dtype()),
105+
this->swap_a_and_b_ ? this->lda_ : this->ldb_, beta_ptr, c.data(),
106+
gemm::HygonBackend::GetDataType(c.dtype()), this->ldc_,
107+
gemm::HygonBackend::GetComputeType(c.dtype()),
108+
gemm::HygonBackend::BLAS_GEMM_DEFAULT);
109+
gemm::HygonBackend::blasDestroy(this->handle_);
110+
this->handle_ = nullptr;
111+
return;
112+
}
113+
114+
Blas<gemm::HygonBackend>::operator()(a, b, alpha, beta, trans_a, trans_b,
115+
c);
116+
gemm::HygonBackend::blasDestroy(this->handle_);
117+
this->handle_ = nullptr;
118+
}
119+
};
120+
121+
} // namespace infini::ops
122+
123+
#endif

0 commit comments

Comments
 (0)