Skip to content

Commit 299e858

Browse files
authored
feat: add PyTorch C++ backend for Add and Gemm (#51)
* refactor: extract `TorchDeviceName` into `src/torch/device_.h` Move the `TorchDeviceName<kDev>` template specializations from `pybind11_utils.h` into a standalone header so they can be reused by torch operator implementations without pulling in pybind11. * feat: add slotted `ActiveImplementationsImpl` for composable registries Add a third template parameter `N` (slot index, default 0) to `ActiveImplementationsImpl`. Slot 0 holds the base/device-native indices, and higher slots let add-on backends register extra indices without conflicting with existing specializations. `ActiveImplementations` flattens slots 0-3 via `Flatten`. * feat: add `WITH_TORCH` CMake option and wrapper generator support Add `WITH_TORCH` option that finds PyTorch via pip, links against libtorch, and compiles sources under `src/torch/`. Pass `--with-torch` to `generate_wrappers.py` so it scans `src/torch/` for operator specializations. * feat: add torch tensor conversion utilities Add `ToAtenDtype()` and `ToAtenTensor<kDev>()` in `src/torch/tensor_.h` for zero-copy conversion from `infini::ops::Tensor` to `at::Tensor` via `at::from_blob()`. * feat: add torch `Add` operator (implementation index 1) Register torch `Add` via `ActiveImplementationsImpl` slot 1 for all devices. The implementation uses `at::add_out()` through ATen's device-generic dispatch. * feat: add torch `Gemm` operator (implementation index 2) Register torch `Gemm` via `ActiveImplementationsImpl` slot 1 for all devices. The implementation uses `at::addmm_out()` / `at::baddbmm_out()` through ATen's device-generic dispatch. * feat: add `AUTO_DETECT_BACKENDS` CMake option Auto-detect PyTorch by attempting `import torch`. When found, `WITH_TORCH` is enabled automatically. * fix: replace `find_package(Torch)` with direct path queries `find_package(Torch)` pulls in Caffe2's cmake config, which calls `enable_language(CUDA)` and breaks on platforms with non-standard CUDA toolchains (e.g. Iluvatar). Query include and library paths directly via `torch.utils.cpp_extension` instead. * fix(iluvatar): set `CMAKE_CUDA_ARCHITECTURES` before `enable_language(CUDA)` CMake 4.3+ requires `CMAKE_CUDA_ARCHITECTURES` to be set before `enable_language(CUDA)` when using non-standard CUDA compilers like Iluvatar's `clang++`. Without it, CMake fails to detect a default architecture. * fix: auto-detect `pybind11_DIR` from Python when not set When `pybind11` is installed via pip but not in a standard CMake search path, `find_package(pybind11 CONFIG)` fails. Query `python -m pybind11 --cmakedir` as a fallback to locate the package. * fix: cross-platform torch compilation and version compatibility - Iluvatar: set `CMAKE_CUDA_ARCHITECTURES` to `OFF` instead of `ivcore20` (CMake 4.3 rejects non-integer architecture names; the architecture is already passed via `CMAKE_CUDA_FLAGS`). - MetaX/Moore: split torch operator headers into declaration-only `.h` files and `.cc` implementation files with explicit instantiations. Compile the `.cc` files with the system `g++` instead of the vendor compiler (`mxcc`/`mcc`), which cannot parse vendor-forked torch headers in C++ extension mode. - Cambricon: guard `UInt16`/`UInt32`/`UInt64` scalar types in `ToAtenDtype()` with a `TORCH_VERSION` check (these types require PyTorch 2.4+; Cambricon ships torch 2.1). - Wrapper generator: scan only `.h` files to avoid including `.cc` explicit-instantiation files in the generated `ops.cc`. * fix: resolve build failures on iluvatar, metax, and cambricon - Iluvatar: move `-x ivcore` from CMAKE_CUDA_FLAGS to compile-only options so it doesn't get passed during linking (which caused clang++ to re-parse .o files as source code) - MetaX: add `-DUSE_MACA=1` to g++ flags for torch source compilation (MetaX torch fork headers require this define) - Cambricon: query `torch.compiled_with_cxx11_abi()` and set `_GLIBCXX_USE_CXX11_ABI` globally to match torch's ABI setting (fixes undefined reference to `c10::Device::Device(std::string)`) * style: fix comments to comply with `CONTRIBUTING.md` Use backtick-fenced Markdown syntax for identifiers in comments and error messages, and ensure comments are complete sentences. * style: apply `clang-format` * refactor: auto-detect operator implementations via SFINAE Replace the manual `ActiveImplementationsImpl` slot system with `std::is_base_of`-based compile-time detection. A real `Operator` specialization inherits from `Key` (e.g., `Gemm`), while the primary template inherits only from `OperatorBase` — SFINAE distinguishes the two automatically, eliminating the need for `registry.h` files. * refactor: use `std::index_sequence` for implementation auto-detection Replace the hand-unrolled `Flatten<..., 0>::type, ..., 3>::type>` with `std::index_sequence<0..kMaxImplementations>` expansion. Increase `kMaxImplementations` from 4 to 16. * refactor: replace preprocessor conditionals with `constexpr` in `tensor_.h` Use `constexpr int kTorchVersion` and `if constexpr` instead of `#if` macros for PyTorch version checks. Extract unsigned dtype handling into `detail::ToAtenUnsignedDataType`. * docs: add TODO comments for dynamic implementation index parametrization * fix: use dependent type alias to support `if constexpr` on PyTorch < 2.4 `c10::ScalarType::UInt16` is a non-dependent name resolved at template definition time. Introduce `DependentScalarType<kVersion>::type` so the enum member access becomes dependent and is properly discarded by `if constexpr` on older PyTorch versions. * fix(tests): skip torch Gemm on CPU half-precision ATen `addmm`/`baddbmm` does not support `float16`/`bfloat16` on CPU. * refactor: rename `ToAtenDtype` to `ToAtenDataType`
1 parent e5571b4 commit 299e858

File tree

17 files changed

+593
-91
lines changed

17 files changed

+593
-91
lines changed

CMakeLists.txt

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ option(WITH_CAMBRICON "Enable Cambricon backend" OFF)
1616
option(WITH_MOORE "Enable Moore backend" OFF)
1717
option(WITH_ASCEND "Enable Ascend backend" OFF)
1818

19+
option(WITH_TORCH "Enable PyTorch C++ backend" OFF)
20+
1921
option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
22+
option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF)
2023
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)
2124

2225
if(AUTO_DETECT_DEVICES)
@@ -79,6 +82,72 @@ if(AUTO_DETECT_DEVICES)
7982
endif()
8083
endif()
8184

85+
if(AUTO_DETECT_BACKENDS)
86+
message(STATUS "Auto-detecting available backends...")
87+
88+
find_package(Python COMPONENTS Interpreter QUIET)
89+
90+
if(Python_FOUND)
91+
execute_process(
92+
COMMAND ${Python_EXECUTABLE} -c "import torch"
93+
RESULT_VARIABLE _torch_import_result
94+
OUTPUT_QUIET
95+
ERROR_QUIET
96+
)
97+
98+
if(_torch_import_result EQUAL 0)
99+
set(WITH_TORCH ON)
100+
message(STATUS "Auto-detected PyTorch.")
101+
endif()
102+
endif()
103+
endif()
104+
105+
if(WITH_TORCH)
106+
find_package(Python COMPONENTS Interpreter REQUIRED)
107+
108+
# Query `torch` paths directly instead of using `find_package(Torch)`,
109+
# which pulls in Caffe2's CMake config and may fail on platforms with
110+
# non-standard CUDA toolchains.
111+
execute_process(
112+
COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))"
113+
OUTPUT_VARIABLE TORCH_INCLUDE_DIRS
114+
OUTPUT_STRIP_TRAILING_WHITESPACE
115+
RESULT_VARIABLE _torch_result
116+
)
117+
118+
if(NOT _torch_result EQUAL 0)
119+
message(FATAL_ERROR "`WITH_TORCH` is `ON` but `torch` is not installed.")
120+
endif()
121+
122+
execute_process(
123+
COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import library_paths; print(';'.join(library_paths()))"
124+
OUTPUT_VARIABLE _torch_lib_dirs
125+
OUTPUT_STRIP_TRAILING_WHITESPACE
126+
)
127+
128+
find_library(TORCH_LIB torch HINTS ${_torch_lib_dirs} REQUIRED)
129+
find_library(TORCH_CPU_LIB torch_cpu HINTS ${_torch_lib_dirs} REQUIRED)
130+
find_library(C10_LIB c10 HINTS ${_torch_lib_dirs} REQUIRED)
131+
set(TORCH_LIBRARIES ${TORCH_LIB} ${TORCH_CPU_LIB} ${C10_LIB})
132+
133+
# Query the `CXX11` ABI setting that `torch` was compiled with.
134+
# A mismatch causes linker errors (e.g. undefined reference to
135+
# `c10::Device::Device(std::string const&)`).
136+
execute_process(
137+
COMMAND ${Python_EXECUTABLE} -c "import torch; print(int(torch.compiled_with_cxx11_abi()))"
138+
OUTPUT_VARIABLE TORCH_CXX11_ABI
139+
OUTPUT_STRIP_TRAILING_WHITESPACE
140+
RESULT_VARIABLE _torch_abi_result
141+
)
142+
143+
if(_torch_abi_result EQUAL 0)
144+
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI})
145+
message(STATUS "PyTorch CXX11 ABI: ${TORCH_CXX11_ABI}")
146+
endif()
147+
148+
message(STATUS "Found PyTorch: ${TORCH_INCLUDE_DIRS}")
149+
endif()
150+
82151
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
83152

84153
# Only one CUDA-like GPU backend can be enabled at a time.
@@ -110,11 +179,23 @@ if(WITH_ILUVATAR)
110179
else()
111180
set(CMAKE_CUDA_COMPILER "clang++" CACHE STRING "Iluvatar CUDA compiler (clang++)")
112181
endif()
113-
set(CMAKE_CUDA_FLAGS "-x ivcore -std=c++17 --cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags")
182+
# `-x ivcore` must not be in `CMAKE_CUDA_FLAGS` — CMake passes those flags
183+
# to both compile and link steps. During linking, `-x ivcore` causes
184+
# `clang++` to re-parse `.o` files as source code.
185+
set(CMAKE_CUDA_FLAGS "--cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags")
114186
set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Iluvatar")
187+
set(CMAKE_CUDA_ARCHITECTURES OFF CACHE STRING "Iluvatar CUDA architectures (passed via CMAKE_CUDA_FLAGS)")
188+
# Iluvatar does not ship `libcudadevrt`, which CMake's compiler test
189+
# tries to link. Disable automatic CUDA runtime linking and link
190+
# manually via `find_package(CUDAToolkit)` instead.
191+
set(CMAKE_CUDA_RUNTIME_LIBRARY NONE)
115192
message(STATUS "Iluvatar: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${ILUVATAR_ARCH}")
116193
enable_language(CUDA)
117194
find_package(CUDAToolkit REQUIRED)
195+
# Add `-x ivcore` as a compile-only flag so it is not passed during
196+
# linking, where it would cause `clang++` to re-parse `.o` files as
197+
# source.
198+
add_compile_options($<$<COMPILE_LANGUAGE:CUDA>:-x$<SEMICOLON>ivcore>)
118199
endif()
119200

120201
if(WITH_METAX)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ install-dir = "infini"
1414

1515
[tool.scikit-build.cmake.define]
1616
AUTO_DETECT_DEVICES = "ON"
17+
AUTO_DETECT_BACKENDS = "ON"
1718
GENERATE_PYTHON_BINDINGS = "ON"
1819

1920
[tool.pytest.ini_options]

scripts/generate_wrappers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,12 @@ def _snake_to_pascal(snake_str):
413413
return "".join(word.capitalize() for word in snake_str.split("_"))
414414

415415

416-
def _get_all_ops(devices):
416+
def _get_all_ops(devices, with_torch=False):
417+
scan_dirs = set(devices)
418+
419+
if with_torch:
420+
scan_dirs.add("torch")
421+
417422
ops = {}
418423

419424
for file_path in _BASE_DIR.iterdir():
@@ -424,8 +429,8 @@ def _get_all_ops(devices):
424429

425430
ops[op_name] = []
426431

427-
for file_path in _SRC_DIR.rglob("*"):
428-
if not file_path.is_file() or file_path.parent.parent.name not in devices:
432+
for file_path in _SRC_DIR.rglob("*.h"):
433+
if file_path.parent.parent.name not in scan_dirs:
429434
continue
430435

431436
if f"class Operator<{_snake_to_pascal(op_name)}" in file_path.read_text():
@@ -445,6 +450,12 @@ def _get_all_ops(devices):
445450
help="Devices to use. Please pick from `cpu`, `nvidia`, `cambricon`, `ascend`, `metax`, `moore`, `iluvatar`, `kunlun`, `hygon`, and `qy`. (default: `cpu`)",
446451
)
447452

453+
parser.add_argument(
454+
"--with-torch",
455+
action="store_true",
456+
help="Include PyTorch C++ backend implementations.",
457+
)
458+
448459
args = parser.parse_args()
449460

450461
_BINDINGS_DIR.mkdir(parents=True, exist_ok=True)
@@ -456,7 +467,7 @@ def _get_all_ops(devices):
456467
if ops_json.exists():
457468
ops = json.loads(ops_json.read_text())
458469
else:
459-
ops = _get_all_ops(args.devices)
470+
ops = _get_all_ops(args.devices, with_torch=args.with_torch)
460471

461472
header_paths = []
462473
bind_func_names = []

src/CMakeLists.txt

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ if(WITH_CPU)
1616

1717
target_compile_definitions(infiniops PUBLIC WITH_CPU=1)
1818

19-
find_package(OpenMP REQUIRED)
19+
find_package(OpenMP REQUIRED COMPONENTS CXX)
2020
target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX)
2121

2222
list(APPEND DEVICE_LIST "cpu")
@@ -218,6 +218,69 @@ if(WITH_ASCEND)
218218
list(APPEND DEVICE_LIST "ascend")
219219
endif()
220220

221+
if(WITH_TORCH)
222+
file(GLOB_RECURSE TORCH_SOURCES CONFIGURE_DEPENDS "torch/*.cc" "torch/*.cpp")
223+
224+
target_compile_definitions(infiniops PUBLIC WITH_TORCH=1)
225+
target_link_libraries(infiniops PUBLIC ${TORCH_LIBRARIES})
226+
target_include_directories(infiniops PUBLIC ${TORCH_INCLUDE_DIRS})
227+
228+
if(WITH_METAX OR WITH_MOORE)
229+
# Vendor compilers (`mxcc`/`mcc`) cannot compile vendor-forked `torch`
230+
# headers. Compile `torch` sources with the system C++ compiler instead.
231+
find_program(SYSTEM_CXX NAMES g++ c++)
232+
233+
if(NOT SYSTEM_CXX)
234+
message(FATAL_ERROR "Could not find system `g++` for `torch` compilation.")
235+
endif()
236+
237+
set(_torch_include_flags "")
238+
foreach(_dir ${TORCH_INCLUDE_DIRS})
239+
list(APPEND _torch_include_flags "-isystem" "${_dir}")
240+
endforeach()
241+
242+
# Vendor-specific defines required by forked `torch` headers.
243+
set(_torch_extra_flags "")
244+
if(WITH_METAX)
245+
list(APPEND _torch_extra_flags "-DUSE_MACA=1")
246+
endif()
247+
if(DEFINED TORCH_CXX11_ABI)
248+
list(APPEND _torch_extra_flags "-D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
249+
endif()
250+
251+
set(TORCH_OBJECT_DIR "${CMAKE_CURRENT_BINARY_DIR}/torch_objs")
252+
file(MAKE_DIRECTORY "${TORCH_OBJECT_DIR}")
253+
254+
set(TORCH_OBJECT_FILES)
255+
foreach(_src ${TORCH_SOURCES})
256+
file(RELATIVE_PATH _rel ${CMAKE_CURRENT_SOURCE_DIR} ${_src})
257+
string(REPLACE "/" "_" _obj_name "${_rel}")
258+
string(REPLACE ".cc" ".o" _obj_name "${_obj_name}")
259+
string(REPLACE ".cpp" ".o" _obj_name "${_obj_name}")
260+
set(_obj "${TORCH_OBJECT_DIR}/${_obj_name}")
261+
262+
add_custom_command(
263+
OUTPUT "${_obj}"
264+
COMMAND ${SYSTEM_CXX}
265+
-std=c++17 -fPIC -O2
266+
"-I${CMAKE_CURRENT_SOURCE_DIR}"
267+
${_torch_include_flags}
268+
${_torch_extra_flags}
269+
-c "${_src}" -o "${_obj}"
270+
DEPENDS "${_src}"
271+
COMMENT "Compiling ${_rel} with system C++ compiler"
272+
)
273+
list(APPEND TORCH_OBJECT_FILES "${_obj}")
274+
endforeach()
275+
276+
set_source_files_properties(${TORCH_OBJECT_FILES}
277+
PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
278+
target_sources(infiniops PRIVATE ${TORCH_OBJECT_FILES})
279+
else()
280+
target_sources(infiniops PRIVATE ${TORCH_SOURCES})
281+
endif()
282+
endif()
283+
221284
target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
222285

223286
if(GENERATE_PYTHON_BINDINGS)
@@ -226,8 +289,14 @@ if(GENERATE_PYTHON_BINDINGS)
226289
# active device list. Stale generated files (e.g., committed for one
227290
# platform) would omit specializations for other enabled backends,
228291
# causing link-time or runtime failures.
292+
293+
set(GENERATOR_ARGS --devices ${DEVICE_LIST})
294+
if(WITH_TORCH)
295+
list(APPEND GENERATOR_ARGS --with-torch)
296+
endif()
297+
229298
execute_process(
230-
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST}
299+
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS}
231300
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
232301
RESULT_VARIABLE script_result
233302
)
@@ -246,6 +315,20 @@ if(GENERATE_PYTHON_BINDINGS)
246315
endif()
247316

248317
find_package(Python COMPONENTS Interpreter Development)
318+
319+
if(NOT pybind11_DIR)
320+
execute_process(
321+
COMMAND ${Python_EXECUTABLE} -m pybind11 --cmakedir
322+
OUTPUT_VARIABLE _pybind11_cmake_dir
323+
OUTPUT_STRIP_TRAILING_WHITESPACE
324+
RESULT_VARIABLE _pybind11_result
325+
)
326+
327+
if(_pybind11_result EQUAL 0)
328+
set(pybind11_DIR "${_pybind11_cmake_dir}" CACHE PATH "pybind11 CMake directory")
329+
endif()
330+
endif()
331+
249332
find_package(pybind11 CONFIG)
250333

251334
if(PYBIND11_ENABLE_EXTRAS)

src/nvidia/gemm/cublas.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include "cuda/gemm/blas.h"
55
#include "nvidia/blas.h"
6-
#include "nvidia/gemm/registry.h"
76

87
namespace infini::ops {
98

src/nvidia/gemm/cublaslt.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#include "base/gemm.h"
1212
#include "nvidia/blas_utils.h"
13-
#include "nvidia/gemm/registry.h"
1413
#include "nvidia/runtime_.h"
1514

1615
namespace infini::ops {

src/nvidia/gemm/registry.h

Lines changed: 0 additions & 15 deletions
This file was deleted.

src/operator.h

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,9 @@ struct std::equal_to<infini::ops::detail::CacheKey> {
8484

8585
namespace infini::ops {
8686

87+
// Forward declaration — defined after `Operator` using SFINAE auto-detection.
8788
template <typename Key, Device::Type kDev>
88-
struct ActiveImplementationsImpl {
89-
using type = List<0>;
90-
};
91-
92-
template <typename Key, Device::Type kDev>
93-
using ActiveImplementations =
94-
typename ActiveImplementationsImpl<Key, kDev>::type;
89+
struct ActiveImplementations;
9590

9691
class OperatorBase {
9792
public:
@@ -153,7 +148,7 @@ class Operator : public OperatorBase {
153148
}
154149
},
155150
"Operator::make(implementation_index)",
156-
ActiveImplementations<Key, kDev>{});
151+
typename ActiveImplementations<Key, kDev>::type{});
157152
},
158153
"Operator::make");
159154

@@ -200,7 +195,8 @@ class Operator : public OperatorBase {
200195
dev_type,
201196
[&](auto device_tag) {
202197
constexpr Device::Type kDev = decltype(device_tag)::value;
203-
result = detail::ListToVector(ActiveImplementations<Key, kDev>{});
198+
result = detail::ListToVector(
199+
typename ActiveImplementations<Key, kDev>::type{});
204200
},
205201
"Operator::active_implementation_indices");
206202
return result;
@@ -227,6 +223,45 @@ class Operator : public OperatorBase {
227223
static constexpr std::size_t implementation_index_{implementation_index};
228224
};
229225

226+
// Maximum number of implementation slots per (operator, device) pair.
227+
// Increase this value when adding operators with more implementations.
228+
constexpr std::size_t kMaxImplementations = 16;
229+
230+
// SFINAE-based implementation detection. A partial specialization
231+
// `Operator<Key, kDev, N>` inherits from `Key` (the operator base class),
232+
// while the unspecialized primary template inherits only from `OperatorBase`.
233+
// `std::is_base_of` distinguishes the two at compile time, eliminating the
234+
// need for manual `registry.h` files.
235+
template <typename Key, Device::Type kDev, std::size_t N,
236+
bool = std::is_base_of_v<Key, Operator<Key, kDev, N>>>
237+
struct ActiveImplementationsImpl {
238+
using type = List<>;
239+
};
240+
241+
template <typename Key, Device::Type kDev, std::size_t N>
242+
struct ActiveImplementationsImpl<Key, kDev, N, true> {
243+
using type = List<N>;
244+
};
245+
246+
namespace detail {
247+
248+
template <typename Key, Device::Type kDev, typename Seq>
249+
struct ActiveImplementationsHelper;
250+
251+
template <typename Key, Device::Type kDev, std::size_t... ns>
252+
struct ActiveImplementationsHelper<Key, kDev, std::index_sequence<ns...>> {
253+
using type = typename Flatten<
254+
typename ActiveImplementationsImpl<Key, kDev, ns>::type...>::type;
255+
};
256+
257+
} // namespace detail
258+
259+
template <typename Key, Device::Type kDev>
260+
struct ActiveImplementations {
261+
using type = typename detail::ActiveImplementationsHelper<
262+
Key, kDev, std::make_index_sequence<kMaxImplementations>>::type;
263+
};
264+
230265
} // namespace infini::ops
231266

232267
#endif

0 commit comments

Comments
 (0)