Skip to content

feat: add PyTorch C++ backend for Add and Gemm#51

Open
voltjia wants to merge 14 commits intomasterfrom
feat/torch-backend
Open

feat: add PyTorch C++ backend for Add and Gemm#51
voltjia wants to merge 14 commits intomasterfrom
feat/torch-backend

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented Apr 13, 2026

Summary

  • Add slotted ActiveImplementationsImpl<Key, kDev, N> so add-on backends can register extra implementation indices without modifying existing device registries
  • Add WITH_TORCH CMake option that links against libtorch and compiles sources under src/torch/
  • Implement torch Add (index 1) and Gemm (index 2) using ATen's device-generic dispatch (at::add_out, at::addmm_out / at::baddbmm_out)
  • Extract TorchDeviceName<kDev> into src/torch/device_.h for reuse and add zero-copy ToAtenTensor<kDev>() conversion via at::from_blob()

voltjia added 13 commits April 15, 2026 02:41
Move the `TorchDeviceName<kDev>` template specializations from
`pybind11_utils.h` into a standalone header so they can be reused
by torch operator implementations without pulling in pybind11.
Add a third template parameter `N` (slot index, default 0) to
`ActiveImplementationsImpl`. Slot 0 holds the base/device-native
indices, and higher slots let add-on backends register extra indices
without conflicting with existing specializations. `ActiveImplementations`
flattens slots 0-3 via `Flatten`.
Add `WITH_TORCH` option that finds PyTorch via pip, links against
libtorch, and compiles sources under `src/torch/`. Pass `--with-torch`
to `generate_wrappers.py` so it scans `src/torch/` for operator
specializations.
Add `ToAtenDtype()` and `ToAtenTensor<kDev>()` in `src/torch/tensor_.h`
for zero-copy conversion from `infini::ops::Tensor` to `at::Tensor`
via `at::from_blob()`.
Register torch `Add` via `ActiveImplementationsImpl` slot 1 for all
devices. The implementation uses `at::add_out()` through ATen's
device-generic dispatch.
Register torch `Gemm` via `ActiveImplementationsImpl` slot 1 for all
devices. The implementation uses `at::addmm_out()` / `at::baddbmm_out()`
through ATen's device-generic dispatch.
Auto-detect PyTorch by attempting `import torch`. When found,
`WITH_TORCH` is enabled automatically.
`find_package(Torch)` pulls in Caffe2's cmake config, which calls
`enable_language(CUDA)` and breaks on platforms with non-standard
CUDA toolchains (e.g. Iluvatar). Query include and library paths
directly via `torch.utils.cpp_extension` instead.
…(CUDA)`

CMake 4.3+ requires `CMAKE_CUDA_ARCHITECTURES` to be set before
`enable_language(CUDA)` when using non-standard CUDA compilers like
Iluvatar's `clang++`. Without it, CMake fails to detect a default
architecture.
When `pybind11` is installed via pip but not in a standard CMake search
path, `find_package(pybind11 CONFIG)` fails. Query `python -m pybind11
--cmakedir` as a fallback to locate the package.
- Iluvatar: set `CMAKE_CUDA_ARCHITECTURES` to `OFF` instead of
  `ivcore20` (CMake 4.3 rejects non-integer architecture names; the
  architecture is already passed via `CMAKE_CUDA_FLAGS`).

- MetaX/Moore: split torch operator headers into declaration-only `.h`
  files and `.cc` implementation files with explicit instantiations.
  Compile the `.cc` files with the system `g++` instead of the vendor
  compiler (`mxcc`/`mcc`), which cannot parse vendor-forked torch
  headers in C++ extension mode.

- Cambricon: guard `UInt16`/`UInt32`/`UInt64` scalar types in
  `ToAtenDtype()` with a `TORCH_VERSION` check (these types require
  PyTorch 2.4+; Cambricon ships torch 2.1).

- Wrapper generator: scan only `.h` files to avoid including `.cc`
  explicit-instantiation files in the generated `ops.cc`.
- Iluvatar: move `-x ivcore` from CMAKE_CUDA_FLAGS to compile-only
  options so it doesn't get passed during linking (which caused
  clang++ to re-parse .o files as source code)
- MetaX: add `-DUSE_MACA=1` to g++ flags for torch source
  compilation (MetaX torch fork headers require this define)
- Cambricon: query `torch.compiled_with_cxx11_abi()` and set
  `_GLIBCXX_USE_CXX11_ABI` globally to match torch's ABI setting
  (fixes undefined reference to `c10::Device::Device(std::string)`)
Use backtick-fenced Markdown syntax for identifiers in comments
and error messages, and ensure comments are complete sentences.
@voltjia voltjia force-pushed the feat/torch-backend branch from 465ac5d to 9c04549 Compare April 15, 2026 03:02
@voltjia
Copy link
Copy Markdown
Collaborator Author

voltjia commented Apr 15, 2026

results.log

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant