Skip to content

Commit 585e3ce

Browse files
committed
refactor: auto-detect operator implementations via SFINAE
Replace the manual `ActiveImplementationsImpl` slot system with `std::is_base_of`-based compile-time detection. A real `Operator` specialization inherits from `Key` (e.g., `Gemm`), while the primary template inherits only from `OperatorBase` — SFINAE distinguishes the two automatically, eliminating the need for `registry.h` files.
1 parent b2b1e6a commit 585e3ce

File tree

8 files changed

+30
-66
lines changed

8 files changed

+30
-66
lines changed

src/nvidia/gemm/cublas.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include "cuda/gemm/blas.h"
55
#include "nvidia/blas.h"
6-
#include "nvidia/gemm/registry.h"
76

87
namespace infini::ops {
98

src/nvidia/gemm/cublaslt.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#include "base/gemm.h"
1212
#include "nvidia/blas_utils.h"
13-
#include "nvidia/gemm/registry.h"
1413
#include "nvidia/runtime_.h"
1514

1615
namespace infini::ops {

src/nvidia/gemm/registry.h

Lines changed: 0 additions & 15 deletions
This file was deleted.

src/operator.h

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,9 @@ struct std::equal_to<infini::ops::detail::CacheKey> {
8484

8585
namespace infini::ops {
8686

87-
template <typename Key, Device::Type kDev, std::size_t N = 0>
88-
struct ActiveImplementationsImpl {
89-
using type = List<>;
90-
};
91-
87+
// Forward declaration — defined after `Operator` using SFINAE auto-detection.
9288
template <typename Key, Device::Type kDev>
93-
struct ActiveImplementationsImpl<Key, kDev, 0> {
94-
using type = List<0>;
95-
};
96-
97-
template <typename Key, Device::Type kDev>
98-
using ActiveImplementations = typename Flatten<
99-
typename ActiveImplementationsImpl<Key, kDev, 0>::type,
100-
typename ActiveImplementationsImpl<Key, kDev, 1>::type,
101-
typename ActiveImplementationsImpl<Key, kDev, 2>::type,
102-
typename ActiveImplementationsImpl<Key, kDev, 3>::type>::type;
89+
struct ActiveImplementations;
10390

10491
class OperatorBase {
10592
public:
@@ -161,7 +148,7 @@ class Operator : public OperatorBase {
161148
}
162149
},
163150
"Operator::make(implementation_index)",
164-
ActiveImplementations<Key, kDev>{});
151+
typename ActiveImplementations<Key, kDev>::type{});
165152
},
166153
"Operator::make");
167154

@@ -208,7 +195,8 @@ class Operator : public OperatorBase {
208195
dev_type,
209196
[&](auto device_tag) {
210197
constexpr Device::Type kDev = decltype(device_tag)::value;
211-
result = detail::ListToVector(ActiveImplementations<Key, kDev>{});
198+
result = detail::ListToVector(
199+
typename ActiveImplementations<Key, kDev>::type{});
212200
},
213201
"Operator::active_implementation_indices");
214202
return result;
@@ -235,6 +223,31 @@ class Operator : public OperatorBase {
235223
static constexpr std::size_t implementation_index_{implementation_index};
236224
};
237225

226+
// SFINAE-based implementation detection. A partial specialization
227+
// `Operator<Key, kDev, N>` inherits from `Key` (the operator base class),
228+
// while the unspecialized primary template inherits only from `OperatorBase`.
229+
// `std::is_base_of` distinguishes the two at compile time, eliminating the
230+
// need for manual `registry.h` files.
231+
template <typename Key, Device::Type kDev, std::size_t N,
232+
bool = std::is_base_of_v<Key, Operator<Key, kDev, N>>>
233+
struct ActiveImplementationsImpl {
234+
using type = List<>;
235+
};
236+
237+
template <typename Key, Device::Type kDev, std::size_t N>
238+
struct ActiveImplementationsImpl<Key, kDev, N, true> {
239+
using type = List<N>;
240+
};
241+
242+
template <typename Key, Device::Type kDev>
243+
struct ActiveImplementations {
244+
using type = typename Flatten<
245+
typename ActiveImplementationsImpl<Key, kDev, 0>::type,
246+
typename ActiveImplementationsImpl<Key, kDev, 1>::type,
247+
typename ActiveImplementationsImpl<Key, kDev, 2>::type,
248+
typename ActiveImplementationsImpl<Key, kDev, 3>::type>::type;
249+
};
250+
238251
} // namespace infini::ops
239252

240253
#endif

src/torch/add/add.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define INFINI_OPS_TORCH_ADD_H_
33

44
#include "base/add.h"
5-
#include "torch/add/registry.h"
65

76
namespace infini::ops {
87

src/torch/add/registry.h

Lines changed: 0 additions & 15 deletions
This file was deleted.

src/torch/gemm/gemm.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define INFINI_OPS_TORCH_GEMM_H_
33

44
#include "base/gemm.h"
5-
#include "torch/gemm/registry.h"
65

76
namespace infini::ops {
87

src/torch/gemm/registry.h

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)