Commit 299e858
authored
* refactor: extract `TorchDeviceName` into `src/torch/device_.h`
Move the `TorchDeviceName<kDev>` template specializations from
`pybind11_utils.h` into a standalone header so they can be reused
by torch operator implementations without pulling in pybind11.
* feat: add slotted `ActiveImplementationsImpl` for composable registries
Add a third template parameter `N` (slot index, default 0) to
`ActiveImplementationsImpl`. Slot 0 holds the base/device-native
indices, and higher slots let add-on backends register extra indices
without conflicting with existing specializations. `ActiveImplementations`
flattens slots 0-3 via `Flatten`.
* feat: add `WITH_TORCH` CMake option and wrapper generator support
Add `WITH_TORCH` option that finds PyTorch via pip, links against
libtorch, and compiles sources under `src/torch/`. Pass `--with-torch`
to `generate_wrappers.py` so it scans `src/torch/` for operator
specializations.
* feat: add torch tensor conversion utilities
Add `ToAtenDtype()` and `ToAtenTensor<kDev>()` in `src/torch/tensor_.h`
for zero-copy conversion from `infini::ops::Tensor` to `at::Tensor`
via `at::from_blob()`.
* feat: add torch `Add` operator (implementation index 1)
Register torch `Add` via `ActiveImplementationsImpl` slot 1 for all
devices. The implementation uses `at::add_out()` through ATen's
device-generic dispatch.
* feat: add torch `Gemm` operator (implementation index 2)
Register torch `Gemm` via `ActiveImplementationsImpl` slot 1 for all
devices. The implementation uses `at::addmm_out()` / `at::baddbmm_out()`
through ATen's device-generic dispatch.
* feat: add `AUTO_DETECT_BACKENDS` CMake option
Auto-detect PyTorch by attempting `import torch`. When found,
`WITH_TORCH` is enabled automatically.
* fix: replace `find_package(Torch)` with direct path queries
`find_package(Torch)` pulls in Caffe2's cmake config, which calls
`enable_language(CUDA)` and breaks on platforms with non-standard
CUDA toolchains (e.g. Iluvatar). Query include and library paths
directly via `torch.utils.cpp_extension` instead.
* fix(iluvatar): set `CMAKE_CUDA_ARCHITECTURES` before `enable_language(CUDA)`
CMake 4.3+ requires `CMAKE_CUDA_ARCHITECTURES` to be set before
`enable_language(CUDA)` when using non-standard CUDA compilers like
Iluvatar's `clang++`. Without it, CMake fails to detect a default
architecture.
* fix: auto-detect `pybind11_DIR` from Python when not set
When `pybind11` is installed via pip but not in a standard CMake search
path, `find_package(pybind11 CONFIG)` fails. Query `python -m pybind11
--cmakedir` as a fallback to locate the package.
* fix: cross-platform torch compilation and version compatibility
- Iluvatar: set `CMAKE_CUDA_ARCHITECTURES` to `OFF` instead of
`ivcore20` (CMake 4.3 rejects non-integer architecture names; the
architecture is already passed via `CMAKE_CUDA_FLAGS`).
- MetaX/Moore: split torch operator headers into declaration-only `.h`
files and `.cc` implementation files with explicit instantiations.
Compile the `.cc` files with the system `g++` instead of the vendor
compiler (`mxcc`/`mcc`), which cannot parse vendor-forked torch
headers in C++ extension mode.
- Cambricon: guard `UInt16`/`UInt32`/`UInt64` scalar types in
`ToAtenDtype()` with a `TORCH_VERSION` check (these types require
PyTorch 2.4+; Cambricon ships torch 2.1).
- Wrapper generator: scan only `.h` files to avoid including `.cc`
explicit-instantiation files in the generated `ops.cc`.
* fix: resolve build failures on iluvatar, metax, and cambricon
- Iluvatar: move `-x ivcore` from CMAKE_CUDA_FLAGS to compile-only
options so it doesn't get passed during linking (which caused
clang++ to re-parse .o files as source code)
- MetaX: add `-DUSE_MACA=1` to g++ flags for torch source
compilation (MetaX torch fork headers require this define)
- Cambricon: query `torch.compiled_with_cxx11_abi()` and set
`_GLIBCXX_USE_CXX11_ABI` globally to match torch's ABI setting
(fixes undefined reference to `c10::Device::Device(std::string)`)
* style: fix comments to comply with `CONTRIBUTING.md`
Use backtick-fenced Markdown syntax for identifiers in comments
and error messages, and ensure comments are complete sentences.
* style: apply `clang-format`
* refactor: auto-detect operator implementations via SFINAE
Replace the manual `ActiveImplementationsImpl` slot system with
`std::is_base_of`-based compile-time detection. A real `Operator`
specialization inherits from `Key` (e.g., `Gemm`), while the primary
template inherits only from `OperatorBase` — SFINAE distinguishes the
two automatically, eliminating the need for `registry.h` files.
* refactor: use `std::index_sequence` for implementation auto-detection
Replace the hand-unrolled `Flatten<..., 0>::type, ..., 3>::type>` with
`std::index_sequence<0..kMaxImplementations>` expansion. Increase
`kMaxImplementations` from 4 to 16.
* refactor: replace preprocessor conditionals with `constexpr` in `tensor_.h`
Use `constexpr int kTorchVersion` and `if constexpr` instead of `#if`
macros for PyTorch version checks. Extract unsigned dtype handling into
`detail::ToAtenUnsignedDataType`.
* docs: add TODO comments for dynamic implementation index parametrization
* fix: use dependent type alias to support `if constexpr` on PyTorch < 2.4
`c10::ScalarType::UInt16` is a non-dependent name resolved at template
definition time. Introduce `DependentScalarType<kVersion>::type` so
the enum member access becomes dependent and is properly discarded by
`if constexpr` on older PyTorch versions.
* fix(tests): skip torch Gemm on CPU half-precision
ATen `addmm`/`baddbmm` does not support `float16`/`bfloat16` on CPU.
* refactor: rename `ToAtenDtype` to `ToAtenDataType`
1 parent e5571b4 commit 299e858
File tree
17 files changed
+593
-91
lines changed- scripts
- src
- nvidia/gemm
- torch
- add
- gemm
- tests
17 files changed
+593
-91
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
| 20 | + | |
19 | 21 | | |
| 22 | + | |
20 | 23 | | |
21 | 24 | | |
22 | 25 | | |
| |||
79 | 82 | | |
80 | 83 | | |
81 | 84 | | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
82 | 151 | | |
83 | 152 | | |
84 | 153 | | |
| |||
110 | 179 | | |
111 | 180 | | |
112 | 181 | | |
113 | | - | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
114 | 186 | | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
115 | 192 | | |
116 | 193 | | |
117 | 194 | | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
118 | 199 | | |
119 | 200 | | |
120 | 201 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| 17 | + | |
17 | 18 | | |
18 | 19 | | |
19 | 20 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
413 | 413 | | |
414 | 414 | | |
415 | 415 | | |
416 | | - | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
417 | 422 | | |
418 | 423 | | |
419 | 424 | | |
| |||
424 | 429 | | |
425 | 430 | | |
426 | 431 | | |
427 | | - | |
428 | | - | |
| 432 | + | |
| 433 | + | |
429 | 434 | | |
430 | 435 | | |
431 | 436 | | |
| |||
445 | 450 | | |
446 | 451 | | |
447 | 452 | | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
448 | 459 | | |
449 | 460 | | |
450 | 461 | | |
| |||
456 | 467 | | |
457 | 468 | | |
458 | 469 | | |
459 | | - | |
| 470 | + | |
460 | 471 | | |
461 | 472 | | |
462 | 473 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
19 | | - | |
| 19 | + | |
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
| |||
218 | 218 | | |
219 | 219 | | |
220 | 220 | | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
221 | 284 | | |
222 | 285 | | |
223 | 286 | | |
| |||
226 | 289 | | |
227 | 290 | | |
228 | 291 | | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
229 | 298 | | |
230 | | - | |
| 299 | + | |
231 | 300 | | |
232 | 301 | | |
233 | 302 | | |
| |||
246 | 315 | | |
247 | 316 | | |
248 | 317 | | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
249 | 332 | | |
250 | 333 | | |
251 | 334 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
6 | | - | |
7 | 6 | | |
8 | 7 | | |
9 | 8 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
13 | | - | |
14 | 13 | | |
15 | 14 | | |
16 | 15 | | |
| |||
This file was deleted.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
84 | 84 | | |
85 | 85 | | |
86 | 86 | | |
| 87 | + | |
87 | 88 | | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
| 89 | + | |
95 | 90 | | |
96 | 91 | | |
97 | 92 | | |
| |||
153 | 148 | | |
154 | 149 | | |
155 | 150 | | |
156 | | - | |
| 151 | + | |
157 | 152 | | |
158 | 153 | | |
159 | 154 | | |
| |||
200 | 195 | | |
201 | 196 | | |
202 | 197 | | |
203 | | - | |
| 198 | + | |
| 199 | + | |
204 | 200 | | |
205 | 201 | | |
206 | 202 | | |
| |||
227 | 223 | | |
228 | 224 | | |
229 | 225 | | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
230 | 265 | | |
231 | 266 | | |
232 | 267 | | |
0 commit comments