Commit 11347ff
Add torch.Tensor fast path for StridedMemoryView via AOTI tensor bridge (#1894)
* Add torch.Tensor fast path for StridedMemoryView via AOTI tensor bridge
Provide a fast path for constructing a StridedMemoryView from a
torch.Tensor by reading tensor metadata directly through PyTorch's
AOT Inductor (AOTI) stable C ABI, avoiding DLPack/CAI protocol
overhead (~10 ns per tensor via pointer arithmetic).
Key design:
- Vendored AOTI shim header (aoti_shim.h) with extern "C" wrapping
- _tensor_bridge.pyx loaded lazily (only when a torch.Tensor is first
passed) to avoid undefined AOTI symbols at import time
- RTLD_GLOBAL bootstrap via sys.modules["torch._C"] before loading
_tensor_bridge.so
- torch detection via type(obj).__module__.startswith("torch")
- PyTorch is NOT a build-time or run-time dependency of cuda.core
Closes #749
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Clean up tensor bridge: remove unused AOTI decls, lazy dtype, drop empty .pxd
- Remove unused aoti_torch_get_numel and aoti_torch_get_storage_offset
declarations from aoti_shim.h and _tensor_bridge.pyx
- Fix license headers on new files to 2026 (not 2024-2026)
- Delete empty _tensor_bridge.pxd (nothing cimports from it)
- Defer numpy dtype resolution for torch tensors: store raw AOTI dtype
code in metadata, compute itemsize from a cheap lookup table, and only
resolve the full numpy dtype on first .dtype access via get_dtype()
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Move torch tensor fast path into each from_* classmethod
Instead of short-circuiting in __init__ and from_any_interface,
add the AOTI fast path check to from_dlpack, from_cuda_array_interface,
and from_array_interface. This ensures torch tensors always take the
fast path regardless of which constructor the user calls.
Simplify from_any_interface and _StridedMemoryViewProxy to just
delegate to the from_* methods (which now handle torch internally).
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add stream ordering for torch tensor bridge
When stream_ptr is not -1, establish stream ordering between
PyTorch's current CUDA stream (the producer) and the consumer stream,
using the same event record + stream wait pattern as the CAI path.
Uses aoti_torch_get_current_cuda_stream to get the producer stream,
matching what PyTorch's own __dlpack__ does internally.
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Extract reusable sync_torch_stream and apply to CAI path
Factor out stream ordering into a cpdef sync_torch_stream() helper in
_tensor_bridge.pyx, callable from both C (view_as_torch_tensor) and
Python (_memoryview.pyx).
Apply the same stream ordering in view_as_cai for torch tensors:
PyTorch's __cuda_array_interface__ reports version 2 and omits the
"stream" field, so the standard CAI sync path is a no-op — leaving the
consumer with no guarantee that the producer's work is visible. We now
detect torch tensors in the CAI path and query PyTorch's current CUDA
stream via AOTI to establish proper ordering.
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Nits: add check_aoti helper, size_t itemsize, 2D sliced test
- Add check_aoti() inline helper to replace repetitive
err/raise patterns for AOTI calls (one-liner per call)
- Change itemsize type from int to size_t
- Add test_torch_tensor_bridge_sliced_2d test case
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Revert itemsize to int, memoize int(stream_ptr)
- Revert itemsize back to int (size_t was unnecessary for small values)
- Memoize int(stream_ptr) to avoid redundant Python operator conversion
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Use except?-1 instead of except* for check_aoti
Better Cython 3 performance: except?-1 avoids the overhead of
except* which always checks for exceptions.
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Require PyTorch >= 2.3 for tensor bridge, move imports to module level
The AOTI stable C ABI functions we use (get_dim, get_dtype,
get_device_type, get_device_index, get_current_cuda_stream, complex
dtype constants) were all introduced in PyTorch 2.3.0. Earlier versions
are missing some or all of them.
_is_torch_tensor now returns False when torch < 2.3, causing a
graceful fallback to the standard DLPack/CAI paths. The version
check result is memoized in a module-level variable.
Also move `import ctypes, sys` from _get_tensor_bridge to module level.
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add tensor bridge entry to 1.0.0 release notes
Document the AOTI-based fast path for torch.Tensor in
StridedMemoryView with ~10-20x speedup and stream ordering support.
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Update speedup range in release notes to match benchmarks
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Document THPVariable layout change across PyTorch versions
The cdata field changed from MaybeOwned<at::Tensor> (2.3-2.9) to
at::Tensor (2.10+). Both layouts are compatible with our offset trick.
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Cache type check in _is_torch_tensor for ~20% speedup
Cache the result of the torch tensor type check (module + hasattr +
version) keyed by type(obj). Subsequent calls for the same type are
a single dict lookup (~76 ns) instead of the full check (~186 ns).
Non-torch objects also benefit as the cache returns False immediately
after the first miss.
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add upper bound to torch version check (cap at 2.11)
The pyobj_to_aten_handle trick and AtenTensorHandle == at::Tensor*
identity are undocumented internals that could change. Cap at the
latest tested version so unknown future versions fall back to the
standard DLPack/CAI paths. Bump after verifying each new release.
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Update module docstring to document both THPVariable layouts
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Use except?-1 for sync_torch_stream instead of except*
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix linter errors
Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix pyobj_to_aten_handle for PyTorch 2.3–2.9 MaybeOwned layout
In PyTorch 2.3–2.9, THPVariable::cdata is c10::MaybeOwned<at::Tensor>,
whose first member is bool isBorrowed_ (padded to 8 bytes) before the
at::Tensor union member. The previous code always offset by
sizeof(PyObject) which pointed to the bool tag (0x0), causing a
segfault when AOTI functions dereferenced it as at::Tensor*.
Add _get_cdata_extra_offset() that checks the torch version at runtime
and adds 8 bytes for torch < 2.10 (MaybeOwned era). The result is
memoized after the first call.
Tested across PyTorch 2.3.1, 2.4.1, 2.5.1, 2.6.0, 2.7.1, 2.8.0,
2.9.1, 2.10.0, and 2.11.0 with CPU tensors (9 dtypes, sliced tensors,
0d/1d/4d shapes).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Consolidate torch tensor bridge tests into TestViewCPU/TestViewGPU
Move the 9 standalone torch tensor bridge tests (1d, nd, scalar, empty,
non-contiguous, sliced, sliced-2d, cpu, decorator) into the existing
parametrized TestViewCPU and TestViewGPU classes. Each torch sample now
runs through from_any_interface, the args_viewable_as_strided_memory
decorator, and the deprecated __init__ path.
Add helpers (_arr_ptr, _arr_strides_in_counts, _arr_is_c_contiguous,
_arr_is_writeable) so _check_view works uniformly across numpy, cupy,
numba, and torch arrays.
Retain test_torch_tensor_bridge_dtypes and test_torch_tensor_bridge_bfloat16
as standalone tests since they verify dtype mapping specifically.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Extract _arr_size helper for torch/numpy size compatibility
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix ruff formatting in test_utils.py
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add readonly comment and fix vendored header license to BSD-3-Clause
- Document why readonly=False is correct for torch tensors: PyTorch
always reports tensors as writable via both DLPack (flags=0) and CAI
(data=(ptr, False)), and the AOTI C ABI has no readonly query.
- Change the vendored aoti_shim.h SPDX from Apache-2.0 to BSD-3-Clause
to match PyTorch's actual license.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Merge bfloat16 test into test_torch_tensor_bridge_dtypes parametrization
Add bfloat16 as a pytest.param with a skipif mark for ml_dtypes,
removing the separate test_torch_tensor_bridge_bfloat16 function.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix SPDX linter: use PyTorch copyright in vendored header
Replace the NVIDIA SPDX header with PyTorch's original BSD-3-Clause
copyright text (from PyTorch LICENSE lines 3-11), following the same
pattern as the vendored dlpack.h. Add aoti_shim.h to .spdx-ignore
to bypass the NVIDIA-specific copyright check.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix Windows build: generate stub import library for AOTI symbols
On Windows, MSVC requires a .lib to resolve __declspec(dllimport)
symbols at link time. The AOTI symbols live in torch_cpu.dll
(loaded by `import torch` at runtime) but torch is not a build-time
dependency.
Add:
- aoti_shim.def: symbol list for generating the stub import library
- AOTI_SHIM_API macro in aoti_shim.h: expands to __declspec(dllimport)
on Windows, empty on Linux/macOS
- build_hooks.py: on Windows, run `lib /DEF:... /OUT:...` to generate
the stub .lib and link _tensor_bridge against it
The stub .lib (~1KB) contains no code — it tells the linker that
the symbols will come from torch_cpu.dll. At runtime, `import torch`
loads the DLL before our extension is imported.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Exclude torch DLLs from delvewheel repair on Windows
The _tensor_bridge extension links against torch_cpu.dll via a stub
import library. delvewheel tries to bundle this DLL into the wheel
and fails because torch is not installed in the build environment.
Exclude torch_cpu.dll and torch_python.dll with --no-dll so
delvewheel skips them — they are provided by the user's torch install.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix delvewheel flag: use --exclude instead of --no-dll
delvewheel uses --exclude (not --no-dll) and semicolons as path
separators on Windows.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix merge conflict resolution
* [pre-commit.ci] auto code formatting
* Add strided layout guard to tensor bridge, reject sparse tensors
Check aoti_torch_get_layout() before extracting metadata — reject
non-strided tensors (sparse, mkldnn, etc.) whose shape/strides are
not meaningful for dense memory access.
We intentionally skip the other Python-level __dlpack__ guards
(requires_grad, is_conj, is_neg, wrong-device) for the same reason
PyTorch's own __dlpack_c_exchange_api__ C path skips them: the
C-level exchange path is designed for performance-critical consumers.
PyTorch's DLTensorFromPyObjectNoSync → toDLPackNonOwning performs
zero safety checks (see aten/src/ATen/DLConvertor.cpp).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Revert strided layout guard (symbols missing in torch 2.3–2.8)
aoti_torch_get_layout was introduced in torch 2.9; referencing it in
cdef extern causes an ImportError on torch 2.3–2.8 at .so load time.
Remove the layout check entirely. Like PyTorch's own
__dlpack_c_exchange_api__ C path (DLTensorFromPyObjectNoSync →
toDLPackNonOwning), we skip all Python-level export guards
(requires_grad, is_conj, is_neg, non-strided, wrong-device).
Document this as a known limitation matching upstream precedent.
Verified: all 9 torch versions (2.3.1–2.11.0) pass again.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Address review comments: dtypes, stale cache, stream_ptr, sync notes
- Add uint16/uint32/uint64 to AOTI dtype and itemsize maps (fixes
regression where these torch dtypes would raise TypeError instead
of being handled by the bridge)
- Clear buf._dtype when repopulating a reused StridedMemoryView to
prevent returning a stale cached dtype
- Reject stream_ptr=None for CUDA tensors with BufferError (matches
DLPack semantics where None is ambiguous)
- Add "keep in sync" comments to aoti_shim.h and aoti_shim.def
per rwgk's review suggestion
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
---------
Co-authored-by: Emilio Castillo <ecastillo@nvidia.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent 7f97d90 commit 11347ff
9 files changed
Lines changed: 848 additions & 20 deletions
File tree
- cuda_core
- cuda/core
- _include
- docs/source/release
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
| 14 | + | |
13 | 15 | | |
14 | 16 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
11 | 11 | | |
12 | 12 | | |
13 | 13 | | |
| 14 | + | |
14 | 15 | | |
15 | 16 | | |
16 | 17 | | |
| |||
182 | 183 | | |
183 | 184 | | |
184 | 185 | | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
185 | 205 | | |
186 | 206 | | |
187 | 207 | | |
| |||
193 | 213 | | |
194 | 214 | | |
195 | 215 | | |
196 | | - | |
| 216 | + | |
197 | 217 | | |
198 | 218 | | |
199 | 219 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
13 | 14 | | |
| 15 | + | |
14 | 16 | | |
15 | 17 | | |
16 | 18 | | |
| |||
29 | 31 | | |
30 | 32 | | |
31 | 33 | | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
32 | 101 | | |
33 | 102 | | |
34 | 103 | | |
| |||
150 | 219 | | |
151 | 220 | | |
152 | 221 | | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
153 | 225 | | |
154 | 226 | | |
155 | 227 | | |
| |||
165 | 237 | | |
166 | 238 | | |
167 | 239 | | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
168 | 243 | | |
169 | 244 | | |
170 | 245 | | |
| |||
178 | 253 | | |
179 | 254 | | |
180 | 255 | | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
181 | 259 | | |
182 | 260 | | |
183 | 261 | | |
| |||
187 | 265 | | |
188 | 266 | | |
189 | 267 | | |
| 268 | + | |
| 269 | + | |
190 | 270 | | |
191 | 271 | | |
192 | 272 | | |
| |||
480 | 560 | | |
481 | 561 | | |
482 | 562 | | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
483 | 567 | | |
484 | 568 | | |
485 | 569 | | |
| |||
1122 | 1206 | | |
1123 | 1207 | | |
1124 | 1208 | | |
| 1209 | + | |
| 1210 | + | |
| 1211 | + | |
| 1212 | + | |
| 1213 | + | |
| 1214 | + | |
| 1215 | + | |
| 1216 | + | |
| 1217 | + | |
| 1218 | + | |
1125 | 1219 | | |
1126 | 1220 | | |
1127 | 1221 | | |
| |||
0 commit comments