Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
443ae3b
refactor: extract `TorchDeviceName` into `src/torch/device_.h`
voltjia Apr 13, 2026
a9ffa02
feat: add slotted `ActiveImplementationsImpl` for composable registries
voltjia Apr 13, 2026
40e435f
feat: add `WITH_TORCH` CMake option and wrapper generator support
voltjia Apr 13, 2026
05c0b26
feat: add torch tensor conversion utilities
voltjia Apr 13, 2026
2c9a3d9
feat: add torch `Add` operator (implementation index 1)
voltjia Apr 13, 2026
93d7145
feat: add torch `Gemm` operator (implementation index 2)
voltjia Apr 13, 2026
43d4cc5
feat: add `AUTO_DETECT_BACKENDS` CMake option
voltjia Apr 14, 2026
34024d8
fix: replace `find_package(Torch)` with direct path queries
voltjia Apr 14, 2026
8a0acb6
fix(iluvatar): set `CMAKE_CUDA_ARCHITECTURES` before `enable_language…
voltjia Apr 14, 2026
55703d6
fix: auto-detect `pybind11_DIR` from Python when not set
voltjia Apr 14, 2026
55bef7d
fix: cross-platform torch compilation and version compatibility
voltjia Apr 14, 2026
32e16e8
fix: resolve build failures on iluvatar, metax, and cambricon
voltjia Apr 14, 2026
9c04549
style: fix comments to comply with `CONTRIBUTING.md`
voltjia Apr 14, 2026
b2b1e6a
style: apply `clang-format`
voltjia Apr 15, 2026
585e3ce
refactor: auto-detect operator implementations via SFINAE
voltjia Apr 16, 2026
76db40f
refactor: use `std::index_sequence` for implementation auto-detection
voltjia Apr 16, 2026
5527c3b
refactor: replace preprocessor conditionals with `constexpr` in `tens…
voltjia Apr 16, 2026
ac4d781
docs: add TODO comments for dynamic implementation index parametrization
voltjia Apr 16, 2026
c13b378
fix: use dependent type alias to support `if constexpr` on PyTorch < 2.4
voltjia Apr 16, 2026
24572c8
fix(tests): skip torch Gemm on CPU half-precision
voltjia Apr 16, 2026
cc42af0
refactor: rename `ToAtenDtype` to `ToAtenDataType`
voltjia Apr 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ option(WITH_CAMBRICON "Enable Cambricon backend" OFF)
option(WITH_MOORE "Enable Moore backend" OFF)
option(WITH_ASCEND "Enable Ascend backend" OFF)

option(WITH_TORCH "Enable PyTorch C++ backend" OFF)

option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF)
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)

if(AUTO_DETECT_DEVICES)
Expand Down Expand Up @@ -79,6 +82,72 @@ if(AUTO_DETECT_DEVICES)
endif()
endif()

if(AUTO_DETECT_BACKENDS)
message(STATUS "Auto-detecting available backends...")

find_package(Python COMPONENTS Interpreter QUIET)

if(Python_FOUND)
execute_process(
COMMAND ${Python_EXECUTABLE} -c "import torch"
RESULT_VARIABLE _torch_import_result
OUTPUT_QUIET
ERROR_QUIET
)

if(_torch_import_result EQUAL 0)
set(WITH_TORCH ON)
message(STATUS "Auto-detected PyTorch.")
endif()
endif()
endif()

if(WITH_TORCH)
find_package(Python COMPONENTS Interpreter REQUIRED)

# Query `torch` paths directly instead of using `find_package(Torch)`,
# which pulls in Caffe2's CMake config and may fail on platforms with
# non-standard CUDA toolchains.
execute_process(
COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))"
OUTPUT_VARIABLE TORCH_INCLUDE_DIRS
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE _torch_result
)

if(NOT _torch_result EQUAL 0)
message(FATAL_ERROR "`WITH_TORCH` is `ON` but `torch` is not installed.")
endif()

execute_process(
COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import library_paths; print(';'.join(library_paths()))"
OUTPUT_VARIABLE _torch_lib_dirs
OUTPUT_STRIP_TRAILING_WHITESPACE
)

find_library(TORCH_LIB torch HINTS ${_torch_lib_dirs} REQUIRED)
find_library(TORCH_CPU_LIB torch_cpu HINTS ${_torch_lib_dirs} REQUIRED)
find_library(C10_LIB c10 HINTS ${_torch_lib_dirs} REQUIRED)
set(TORCH_LIBRARIES ${TORCH_LIB} ${TORCH_CPU_LIB} ${C10_LIB})

# Query the `CXX11` ABI setting that `torch` was compiled with.
# A mismatch causes linker errors (e.g. undefined reference to
# `c10::Device::Device(std::string const&)`).
execute_process(
COMMAND ${Python_EXECUTABLE} -c "import torch; print(int(torch.compiled_with_cxx11_abi()))"
OUTPUT_VARIABLE TORCH_CXX11_ABI
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE _torch_abi_result
)

if(_torch_abi_result EQUAL 0)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI})
message(STATUS "PyTorch CXX11 ABI: ${TORCH_CXX11_ABI}")
endif()

message(STATUS "Found PyTorch: ${TORCH_INCLUDE_DIRS}")
endif()

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)

# Only one CUDA-like GPU backend can be enabled at a time.
Expand Down Expand Up @@ -110,11 +179,23 @@ if(WITH_ILUVATAR)
else()
set(CMAKE_CUDA_COMPILER "clang++" CACHE STRING "Iluvatar CUDA compiler (clang++)")
endif()
set(CMAKE_CUDA_FLAGS "-x ivcore -std=c++17 --cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags")
# `-x ivcore` must not be in `CMAKE_CUDA_FLAGS` — CMake passes those flags
# to both compile and link steps. During linking, `-x ivcore` causes
# `clang++` to re-parse `.o` files as source code.
set(CMAKE_CUDA_FLAGS "--cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags")
set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Iluvatar")
set(CMAKE_CUDA_ARCHITECTURES OFF CACHE STRING "Iluvatar CUDA architectures (passed via CMAKE_CUDA_FLAGS)")
# Iluvatar does not ship `libcudadevrt`, which CMake's compiler test
# tries to link. Disable automatic CUDA runtime linking and link
# manually via `find_package(CUDAToolkit)` instead.
set(CMAKE_CUDA_RUNTIME_LIBRARY NONE)
message(STATUS "Iluvatar: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${ILUVATAR_ARCH}")
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
# Add `-x ivcore` as a compile-only flag so it is not passed during
# linking, where it would cause `clang++` to re-parse `.o` files as
# source.
add_compile_options($<$<COMPILE_LANGUAGE:CUDA>:-x$<SEMICOLON>ivcore>)
endif()

if(WITH_METAX)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ install-dir = "infini"

[tool.scikit-build.cmake.define]
AUTO_DETECT_DEVICES = "ON"
AUTO_DETECT_BACKENDS = "ON"
GENERATE_PYTHON_BINDINGS = "ON"

[tool.pytest.ini_options]
Expand Down
19 changes: 15 additions & 4 deletions scripts/generate_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,12 @@ def _snake_to_pascal(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))


def _get_all_ops(devices):
def _get_all_ops(devices, with_torch=False):
scan_dirs = set(devices)

if with_torch:
scan_dirs.add("torch")

ops = {}

for file_path in _BASE_DIR.iterdir():
Expand All @@ -424,8 +429,8 @@ def _get_all_ops(devices):

ops[op_name] = []

for file_path in _SRC_DIR.rglob("*"):
if not file_path.is_file() or file_path.parent.parent.name not in devices:
for file_path in _SRC_DIR.rglob("*.h"):
if file_path.parent.parent.name not in scan_dirs:
continue

if f"class Operator<{_snake_to_pascal(op_name)}" in file_path.read_text():
Expand All @@ -445,6 +450,12 @@ def _get_all_ops(devices):
help="Devices to use. Please pick from `cpu`, `nvidia`, `cambricon`, `ascend`, `metax`, `moore`, `iluvatar`, `kunlun`, `hygon`, and `qy`. (default: `cpu`)",
)

parser.add_argument(
"--with-torch",
action="store_true",
help="Include PyTorch C++ backend implementations.",
)

args = parser.parse_args()

_BINDINGS_DIR.mkdir(parents=True, exist_ok=True)
Expand All @@ -456,7 +467,7 @@ def _get_all_ops(devices):
if ops_json.exists():
ops = json.loads(ops_json.read_text())
else:
ops = _get_all_ops(args.devices)
ops = _get_all_ops(args.devices, with_torch=args.with_torch)

header_paths = []
bind_func_names = []
Expand Down
87 changes: 85 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ if(WITH_CPU)

target_compile_definitions(infiniops PUBLIC WITH_CPU=1)

find_package(OpenMP REQUIRED)
find_package(OpenMP REQUIRED COMPONENTS CXX)
target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX)

list(APPEND DEVICE_LIST "cpu")
Expand Down Expand Up @@ -218,6 +218,69 @@ if(WITH_ASCEND)
list(APPEND DEVICE_LIST "ascend")
endif()

if(WITH_TORCH)
file(GLOB_RECURSE TORCH_SOURCES CONFIGURE_DEPENDS "torch/*.cc" "torch/*.cpp")

target_compile_definitions(infiniops PUBLIC WITH_TORCH=1)
target_link_libraries(infiniops PUBLIC ${TORCH_LIBRARIES})
target_include_directories(infiniops PUBLIC ${TORCH_INCLUDE_DIRS})

if(WITH_METAX OR WITH_MOORE)
# Vendor compilers (`mxcc`/`mcc`) cannot compile vendor-forked `torch`
# headers. Compile `torch` sources with the system C++ compiler instead.
find_program(SYSTEM_CXX NAMES g++ c++)

if(NOT SYSTEM_CXX)
message(FATAL_ERROR "Could not find system `g++` for `torch` compilation.")
endif()

set(_torch_include_flags "")
foreach(_dir ${TORCH_INCLUDE_DIRS})
list(APPEND _torch_include_flags "-isystem" "${_dir}")
endforeach()

# Vendor-specific defines required by forked `torch` headers.
set(_torch_extra_flags "")
if(WITH_METAX)
list(APPEND _torch_extra_flags "-DUSE_MACA=1")
endif()
if(DEFINED TORCH_CXX11_ABI)
list(APPEND _torch_extra_flags "-D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
endif()

set(TORCH_OBJECT_DIR "${CMAKE_CURRENT_BINARY_DIR}/torch_objs")
file(MAKE_DIRECTORY "${TORCH_OBJECT_DIR}")

set(TORCH_OBJECT_FILES)
foreach(_src ${TORCH_SOURCES})
file(RELATIVE_PATH _rel ${CMAKE_CURRENT_SOURCE_DIR} ${_src})
string(REPLACE "/" "_" _obj_name "${_rel}")
string(REPLACE ".cc" ".o" _obj_name "${_obj_name}")
string(REPLACE ".cpp" ".o" _obj_name "${_obj_name}")
set(_obj "${TORCH_OBJECT_DIR}/${_obj_name}")

add_custom_command(
OUTPUT "${_obj}"
COMMAND ${SYSTEM_CXX}
-std=c++17 -fPIC -O2
"-I${CMAKE_CURRENT_SOURCE_DIR}"
${_torch_include_flags}
${_torch_extra_flags}
-c "${_src}" -o "${_obj}"
DEPENDS "${_src}"
COMMENT "Compiling ${_rel} with system C++ compiler"
)
list(APPEND TORCH_OBJECT_FILES "${_obj}")
endforeach()

set_source_files_properties(${TORCH_OBJECT_FILES}
PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
target_sources(infiniops PRIVATE ${TORCH_OBJECT_FILES})
else()
target_sources(infiniops PRIVATE ${TORCH_SOURCES})
endif()
endif()

target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

if(GENERATE_PYTHON_BINDINGS)
Expand All @@ -226,8 +289,14 @@ if(GENERATE_PYTHON_BINDINGS)
# active device list. Stale generated files (e.g., committed for one
# platform) would omit specializations for other enabled backends,
# causing link-time or runtime failures.

set(GENERATOR_ARGS --devices ${DEVICE_LIST})
if(WITH_TORCH)
list(APPEND GENERATOR_ARGS --with-torch)
endif()

execute_process(
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST}
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS}
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
RESULT_VARIABLE script_result
)
Expand All @@ -246,6 +315,20 @@ if(GENERATE_PYTHON_BINDINGS)
endif()

find_package(Python COMPONENTS Interpreter Development)

if(NOT pybind11_DIR)
execute_process(
COMMAND ${Python_EXECUTABLE} -m pybind11 --cmakedir
OUTPUT_VARIABLE _pybind11_cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE _pybind11_result
)

if(_pybind11_result EQUAL 0)
set(pybind11_DIR "${_pybind11_cmake_dir}" CACHE PATH "pybind11 CMake directory")
endif()
endif()

find_package(pybind11 CONFIG)

if(PYBIND11_ENABLE_EXTRAS)
Expand Down
1 change: 0 additions & 1 deletion src/nvidia/gemm/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "cuda/gemm/blas.h"
#include "nvidia/blas.h"
#include "nvidia/gemm/registry.h"

namespace infini::ops {

Expand Down
1 change: 0 additions & 1 deletion src/nvidia/gemm/cublaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

#include "base/gemm.h"
#include "nvidia/blas_utils.h"
#include "nvidia/gemm/registry.h"
#include "nvidia/runtime_.h"

namespace infini::ops {
Expand Down
15 changes: 0 additions & 15 deletions src/nvidia/gemm/registry.h

This file was deleted.

53 changes: 44 additions & 9 deletions src/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,9 @@ struct std::equal_to<infini::ops::detail::CacheKey> {

namespace infini::ops {

// Forward declaration — defined after `Operator` using SFINAE auto-detection.
template <typename Key, Device::Type kDev>
struct ActiveImplementationsImpl {
using type = List<0>;
};

template <typename Key, Device::Type kDev>
using ActiveImplementations =
typename ActiveImplementationsImpl<Key, kDev>::type;
struct ActiveImplementations;

class OperatorBase {
public:
Expand Down Expand Up @@ -153,7 +148,7 @@ class Operator : public OperatorBase {
}
},
"Operator::make(implementation_index)",
ActiveImplementations<Key, kDev>{});
typename ActiveImplementations<Key, kDev>::type{});
},
"Operator::make");

Expand Down Expand Up @@ -200,7 +195,8 @@ class Operator : public OperatorBase {
dev_type,
[&](auto device_tag) {
constexpr Device::Type kDev = decltype(device_tag)::value;
result = detail::ListToVector(ActiveImplementations<Key, kDev>{});
result = detail::ListToVector(
typename ActiveImplementations<Key, kDev>::type{});
},
"Operator::active_implementation_indices");
return result;
Expand All @@ -227,6 +223,45 @@ class Operator : public OperatorBase {
static constexpr std::size_t implementation_index_{implementation_index};
};

// Maximum number of implementation slots per (operator, device) pair.
// Increase this value when adding operators with more implementations.
constexpr std::size_t kMaxImplementations = 16;

// SFINAE-based implementation detection. A partial specialization
// `Operator<Key, kDev, N>` inherits from `Key` (the operator base class),
// while the unspecialized primary template inherits only from `OperatorBase`.
// `std::is_base_of` distinguishes the two at compile time, eliminating the
// need for manual `registry.h` files.
template <typename Key, Device::Type kDev, std::size_t N,
bool = std::is_base_of_v<Key, Operator<Key, kDev, N>>>
struct ActiveImplementationsImpl {
using type = List<>;
};

template <typename Key, Device::Type kDev, std::size_t N>
struct ActiveImplementationsImpl<Key, kDev, N, true> {
using type = List<N>;
};

namespace detail {

template <typename Key, Device::Type kDev, typename Seq>
struct ActiveImplementationsHelper;

template <typename Key, Device::Type kDev, std::size_t... ns>
struct ActiveImplementationsHelper<Key, kDev, std::index_sequence<ns...>> {
using type = typename Flatten<
typename ActiveImplementationsImpl<Key, kDev, ns>::type...>::type;
};

} // namespace detail

template <typename Key, Device::Type kDev>
struct ActiveImplementations {
using type = typename detail::ActiveImplementationsHelper<
Key, kDev, std::make_index_sequence<kMaxImplementations>>::type;
};

} // namespace infini::ops

#endif
Loading
Loading