Skip to content

Commit 11347ff

Browse files
leofangemcastilloclaudepre-commit-ci[bot]
authored
Add torch.Tensor fast path for StridedMemoryView via AOTI tensor bridge (#1894)
* Add torch.Tensor fast path for StridedMemoryView via AOTI tensor bridge Provide a fast path for constructing a StridedMemoryView from a torch.Tensor by reading tensor metadata directly through PyTorch's AOT Inductor (AOTI) stable C ABI, avoiding DLPack/CAI protocol overhead (~10 ns per tensor via pointer arithmetic). Key design: - Vendored AOTI shim header (aoti_shim.h) with extern "C" wrapping - _tensor_bridge.pyx loaded lazily (only when a torch.Tensor is first passed) to avoid undefined AOTI symbols at import time - RTLD_GLOBAL bootstrap via sys.modules["torch._C"] before loading _tensor_bridge.so - torch detection via type(obj).__module__.startswith("torch") - PyTorch is NOT a build-time or run-time dependency of cuda.core Closes #749 Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Clean up tensor bridge: remove unused AOTI decls, lazy dtype, drop empty .pxd - Remove unused aoti_torch_get_numel and aoti_torch_get_storage_offset declarations from aoti_shim.h and _tensor_bridge.pyx - Fix license headers on new files to 2026 (not 2024-2026) - Delete empty _tensor_bridge.pxd (nothing cimports from it) - Defer numpy dtype resolution for torch tensors: store raw AOTI dtype code in metadata, compute itemsize from a cheap lookup table, and only resolve the full numpy dtype on first .dtype access via get_dtype() Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Move torch tensor fast path into each from_* classmethod Instead of short-circuiting in __init__ and from_any_interface, add the AOTI fast path check to from_dlpack, from_cuda_array_interface, and from_array_interface. This ensures torch tensors always take the fast path regardless of which constructor the user calls. Simplify from_any_interface and _StridedMemoryViewProxy to just delegate to the from_* methods (which now handle torch internally). Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add stream ordering for torch tensor bridge When stream_ptr is not -1, establish stream ordering between PyTorch's current CUDA stream (the producer) and the consumer stream, using the same event record + stream wait pattern as the CAI path. Uses aoti_torch_get_current_cuda_stream to get the producer stream, matching what PyTorch's own __dlpack__ does internally. Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Extract reusable sync_torch_stream and apply to CAI path Factor out stream ordering into a cpdef sync_torch_stream() helper in _tensor_bridge.pyx, callable from both C (view_as_torch_tensor) and Python (_memoryview.pyx). Apply the same stream ordering in view_as_cai for torch tensors: PyTorch's __cuda_array_interface__ reports version 2 and omits the "stream" field, so the standard CAI sync path is a no-op — leaving the consumer with no guarantee that the producer's work is visible. We now detect torch tensors in the CAI path and query PyTorch's current CUDA stream via AOTI to establish proper ordering. Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Nits: add check_aoti helper, size_t itemsize, 2D sliced test - Add check_aoti() inline helper to replace repetitive err/raise patterns for AOTI calls (one-liner per call) - Change itemsize type from int to size_t - Add test_torch_tensor_bridge_sliced_2d test case Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Revert itemsize to int, memoize int(stream_ptr) - Revert itemsize back to int (size_t was unnecessary for small values) - Memoize int(stream_ptr) to avoid redundant Python operator conversion Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Use except?-1 instead of except* for check_aoti Better Cython 3 performance: except?-1 avoids the overhead of except* which always checks for exceptions. Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Require PyTorch >= 2.3 for tensor bridge, move imports to module level The AOTI stable C ABI functions we use (get_dim, get_dtype, get_device_type, get_device_index, get_current_cuda_stream, complex dtype constants) were all introduced in PyTorch 2.3.0. Earlier versions are missing some or all of them. _is_torch_tensor now returns False when torch < 2.3, causing a graceful fallback to the standard DLPack/CAI paths. The version check result is memoized in a module-level variable. Also move `import ctypes, sys` from _get_tensor_bridge to module level. Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add tensor bridge entry to 1.0.0 release notes Document the AOTI-based fast path for torch.Tensor in StridedMemoryView with ~10-20x speedup and stream ordering support. Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Update speedup range in release notes to match benchmarks Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Document THPVariable layout change across PyTorch versions The cdata field changed from MaybeOwned<at::Tensor> (2.3-2.9) to at::Tensor (2.10+). Both layouts are compatible with our offset trick. Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Cache type check in _is_torch_tensor for ~20% speedup Cache the result of the torch tensor type check (module + hasattr + version) keyed by type(obj). Subsequent calls for the same type are a single dict lookup (~76 ns) instead of the full check (~186 ns). Non-torch objects also benefit as the cache returns False immediately after the first miss. Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add upper bound to torch version check (cap at 2.11) The pyobj_to_aten_handle trick and AtenTensorHandle == at::Tensor* identity are undocumented internals that could change. Cap at the latest tested version so unknown future versions fall back to the standard DLPack/CAI paths. Bump after verifying each new release. Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Update module docstring to document both THPVariable layouts Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Use except?-1 for sync_torch_stream instead of except* Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix linter errors Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix pyobj_to_aten_handle for PyTorch 2.3–2.9 MaybeOwned layout In PyTorch 2.3–2.9, THPVariable::cdata is c10::MaybeOwned<at::Tensor>, whose first member is bool isBorrowed_ (padded to 8 bytes) before the at::Tensor union member. The previous code always offset by sizeof(PyObject) which pointed to the bool tag (0x0), causing a segfault when AOTI functions dereferenced it as at::Tensor*. Add _get_cdata_extra_offset() that checks the torch version at runtime and adds 8 bytes for torch < 2.10 (MaybeOwned era). The result is memoized after the first call. Tested across PyTorch 2.3.1, 2.4.1, 2.5.1, 2.6.0, 2.7.1, 2.8.0, 2.9.1, 2.10.0, and 2.11.0 with CPU tensors (9 dtypes, sliced tensors, 0d/1d/4d shapes). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Consolidate torch tensor bridge tests into TestViewCPU/TestViewGPU Move the 9 standalone torch tensor bridge tests (1d, nd, scalar, empty, non-contiguous, sliced, sliced-2d, cpu, decorator) into the existing parametrized TestViewCPU and TestViewGPU classes. Each torch sample now runs through from_any_interface, the args_viewable_as_strided_memory decorator, and the deprecated __init__ path. Add helpers (_arr_ptr, _arr_strides_in_counts, _arr_is_c_contiguous, _arr_is_writeable) so _check_view works uniformly across numpy, cupy, numba, and torch arrays. Retain test_torch_tensor_bridge_dtypes and test_torch_tensor_bridge_bfloat16 as standalone tests since they verify dtype mapping specifically. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Extract _arr_size helper for torch/numpy size compatibility Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix ruff formatting in test_utils.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add readonly comment and fix vendored header license to BSD-3-Clause - Document why readonly=False is correct for torch tensors: PyTorch always reports tensors as writable via both DLPack (flags=0) and CAI (data=(ptr, False)), and the AOTI C ABI has no readonly query. - Change the vendored aoti_shim.h SPDX from Apache-2.0 to BSD-3-Clause to match PyTorch's actual license. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Merge bfloat16 test into test_torch_tensor_bridge_dtypes parametrization Add bfloat16 as a pytest.param with a skipif mark for ml_dtypes, removing the separate test_torch_tensor_bridge_bfloat16 function. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix SPDX linter: use PyTorch copyright in vendored header Replace the NVIDIA SPDX header with PyTorch's original BSD-3-Clause copyright text (from PyTorch LICENSE lines 3-11), following the same pattern as the vendored dlpack.h. Add aoti_shim.h to .spdx-ignore to bypass the NVIDIA-specific copyright check. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix Windows build: generate stub import library for AOTI symbols On Windows, MSVC requires a .lib to resolve __declspec(dllimport) symbols at link time. The AOTI symbols live in torch_cpu.dll (loaded by `import torch` at runtime) but torch is not a build-time dependency. Add: - aoti_shim.def: symbol list for generating the stub import library - AOTI_SHIM_API macro in aoti_shim.h: expands to __declspec(dllimport) on Windows, empty on Linux/macOS - build_hooks.py: on Windows, run `lib /DEF:... /OUT:...` to generate the stub .lib and link _tensor_bridge against it The stub .lib (~1KB) contains no code — it tells the linker that the symbols will come from torch_cpu.dll. At runtime, `import torch` loads the DLL before our extension is imported. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Exclude torch DLLs from delvewheel repair on Windows The _tensor_bridge extension links against torch_cpu.dll via a stub import library. delvewheel tries to bundle this DLL into the wheel and fails because torch is not installed in the build environment. Exclude torch_cpu.dll and torch_python.dll with --no-dll so delvewheel skips them — they are provided by the user's torch install. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix delvewheel flag: use --exclude instead of --no-dll delvewheel uses --exclude (not --no-dll) and semicolons as path separators on Windows. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix merge conflict resolution * [pre-commit.ci] auto code formatting * Add strided layout guard to tensor bridge, reject sparse tensors Check aoti_torch_get_layout() before extracting metadata — reject non-strided tensors (sparse, mkldnn, etc.) whose shape/strides are not meaningful for dense memory access. We intentionally skip the other Python-level __dlpack__ guards (requires_grad, is_conj, is_neg, wrong-device) for the same reason PyTorch's own __dlpack_c_exchange_api__ C path skips them: the C-level exchange path is designed for performance-critical consumers. PyTorch's DLTensorFromPyObjectNoSync → toDLPackNonOwning performs zero safety checks (see aten/src/ATen/DLConvertor.cpp). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Revert strided layout guard (symbols missing in torch 2.3–2.8) aoti_torch_get_layout was introduced in torch 2.9; referencing it in cdef extern causes an ImportError on torch 2.3–2.8 at .so load time. Remove the layout check entirely. Like PyTorch's own __dlpack_c_exchange_api__ C path (DLTensorFromPyObjectNoSync → toDLPackNonOwning), we skip all Python-level export guards (requires_grad, is_conj, is_neg, non-strided, wrong-device). Document this as a known limitation matching upstream precedent. Verified: all 9 torch versions (2.3.1–2.11.0) pass again. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Address review comments: dtypes, stale cache, stream_ptr, sync notes - Add uint16/uint32/uint64 to AOTI dtype and itemsize maps (fixes regression where these torch dtypes would raise TypeError instead of being handled by the bridge) - Clear buf._dtype when repopulating a reused StridedMemoryView to prevent returning a stale cached dtype - Reject stream_ptr=None for CUDA tensors with BufferError (matches DLPack semantics where None is ambiguous) - Add "keep in sync" comments to aoti_shim.h and aoti_shim.def per rwgk's review suggestion Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Emilio Castillo <ecastillo@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7f97d90 commit 11347ff

9 files changed

Lines changed: 848 additions & 20 deletions

File tree

.spdx-ignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@ cuda_bindings/examples/*
1010

1111
# Vendored
1212
cuda_core/cuda/core/_include/dlpack.h
13+
cuda_core/cuda/core/_include/aoti_shim.h
14+
cuda_core/cuda/core/_include/aoti_shim.def
1315

1416
qa/ctk-next.drawio.svg

cuda_core/build_hooks.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import glob
1212
import os
1313
import re
14+
import subprocess
1415
import sys
1516
import tempfile
1617
import zipfile
@@ -182,6 +183,25 @@ def get_sources(mod_name):
182183
# related to free-threading builds.
183184
extra_compile_args += ["-DCYTHON_TRACE_NOGIL=1", "-DCYTHON_USE_SYS_MONITORING=0"]
184185

186+
# On Windows, _tensor_bridge.pyx needs a stub import library so the MSVC
187+
# linker can resolve the AOTI symbols (they live in torch_cpu.dll at
188+
# runtime). We generate the .lib from a .def file at build time.
189+
_aoti_extra_link_args = []
190+
if sys.platform == "win32":
191+
_def_file = os.path.join("cuda", "core", "_include", "aoti_shim.def")
192+
_lib_file = os.path.join("build", "aoti_shim.lib")
193+
os.makedirs("build", exist_ok=True)
194+
subprocess.check_call( # noqa: S603
195+
["lib", f"/DEF:{_def_file}", f"/OUT:{_lib_file}", "/MACHINE:X64"], # noqa: S607
196+
stdout=subprocess.DEVNULL,
197+
)
198+
_aoti_extra_link_args = [_lib_file]
199+
200+
def get_extra_link_args(mod_name):
201+
if mod_name == "_tensor_bridge" and _aoti_extra_link_args:
202+
return extra_link_args + _aoti_extra_link_args
203+
return extra_link_args
204+
185205
ext_modules = tuple(
186206
Extension(
187207
f"cuda.core.{mod.replace(os.path.sep, '.')}",
@@ -193,7 +213,7 @@ def get_sources(mod_name):
193213
+ all_include_dirs,
194214
language="c++",
195215
extra_compile_args=extra_compile_args,
196-
extra_link_args=extra_link_args,
216+
extra_link_args=get_extra_link_args(mod),
197217
)
198218
for mod in module_names()
199219
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
; Stub import library definition for PyTorch's AOTI stable C ABI symbols.
2+
; Used on Windows only: 'lib /DEF:aoti_shim.def /OUT:aoti_shim.lib /MACHINE:X64'
3+
; generates a minimal import library that satisfies the MSVC linker.
4+
; At runtime the symbols resolve from torch_cpu.dll (loaded by 'import torch').
5+
;
6+
; IMPORTANT: Keep this export list in sync with the AOTI_SHIM_API declarations
7+
; in aoti_shim.h. build_hooks.py turns this file into the stub import library
8+
; that MSVC uses to link _tensor_bridge, so any added/removed/renamed AOTI
9+
; symbol must be updated in both files.
10+
LIBRARY torch_cpu.dll
11+
EXPORTS
12+
aoti_torch_get_data_ptr
13+
aoti_torch_get_dim
14+
aoti_torch_get_sizes
15+
aoti_torch_get_strides
16+
aoti_torch_get_dtype
17+
aoti_torch_dtype_float16
18+
aoti_torch_dtype_float32
19+
aoti_torch_dtype_float64
20+
aoti_torch_dtype_bfloat16
21+
aoti_torch_dtype_uint8
22+
aoti_torch_dtype_uint16
23+
aoti_torch_dtype_uint32
24+
aoti_torch_dtype_uint64
25+
aoti_torch_dtype_int8
26+
aoti_torch_dtype_int16
27+
aoti_torch_dtype_int32
28+
aoti_torch_dtype_int64
29+
aoti_torch_dtype_bool
30+
aoti_torch_dtype_complex32
31+
aoti_torch_dtype_complex64
32+
aoti_torch_dtype_complex128
33+
aoti_torch_get_device_type
34+
aoti_torch_get_device_index
35+
aoti_torch_device_type_cpu
36+
aoti_torch_device_type_cuda
37+
aoti_torch_get_current_cuda_stream
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Vendored subset of PyTorch's AOT Inductor (AOTI) stable C ABI.
3+
* Original: torch/csrc/inductor/aoti_torch/c/shim.h
4+
*
5+
* These are declarations only -- no definitions are provided. The actual
6+
* symbols are exported by libtorch (loaded via torch._C with RTLD_GLOBAL)
7+
* and resolved at runtime by the dynamic linker. This means PyTorch is
8+
* NOT required at compile time.
9+
*
10+
* From PyTorch:
11+
*
12+
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
13+
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
14+
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
15+
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
16+
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
17+
* Copyright (c) 2011-2013 NYU (Clement Farabet)
18+
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
19+
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
20+
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
21+
*
22+
* SPDX-License-Identifier: BSD-3-Clause
23+
* See https://github.com/pytorch/pytorch/blob/main/LICENSE
24+
*/
25+
26+
#ifndef CUDA_CORE_AOTI_SHIM_H
27+
#define CUDA_CORE_AOTI_SHIM_H
28+
29+
#include <stdint.h>
30+
31+
/*
32+
* On Windows the AOTI symbols live in torch_cpu.dll. We consume them
33+
* via __declspec(dllimport) and a stub import library generated from
34+
* aoti_shim.def at build time. On Linux/macOS the symbols are made
35+
* visible at runtime through ctypes.CDLL(torch._C, RTLD_GLOBAL).
36+
*/
37+
#ifdef _WIN32
38+
# define AOTI_SHIM_API __declspec(dllimport)
39+
#else
40+
# define AOTI_SHIM_API
41+
#endif
42+
43+
#ifdef __cplusplus
44+
extern "C" {
45+
#endif
46+
47+
typedef int32_t AOTITorchError;
48+
49+
/* Opaque tensor handle -- corresponds to at::Tensor on the C++ side. */
50+
struct AtenTensorOpaque;
51+
typedef struct AtenTensorOpaque* AtenTensorHandle;
52+
53+
/*
54+
* IMPORTANT: Keep the AOTI_SHIM_API declaration list below in sync with
55+
* aoti_shim.def. On Windows, build_hooks.py turns that .def file into the
56+
* stub import library that MSVC needs to link _tensor_bridge without making
57+
* PyTorch a build-time dependency. If you add, remove, or rename an imported
58+
* AOTI symbol here, update aoti_shim.def in the same change.
59+
*/
60+
61+
/* ---- tensor metadata --------------------------------------------------- */
62+
63+
AOTI_SHIM_API AOTITorchError aoti_torch_get_data_ptr(
64+
AtenTensorHandle tensor, void** ret_data_ptr);
65+
66+
AOTI_SHIM_API AOTITorchError aoti_torch_get_dim(
67+
AtenTensorHandle tensor, int64_t* ret_dim);
68+
69+
AOTI_SHIM_API AOTITorchError aoti_torch_get_sizes(
70+
AtenTensorHandle tensor, int64_t** ret_sizes);
71+
72+
AOTI_SHIM_API AOTITorchError aoti_torch_get_strides(
73+
AtenTensorHandle tensor, int64_t** ret_strides);
74+
75+
/* ---- dtype ------------------------------------------------------------- */
76+
77+
AOTI_SHIM_API AOTITorchError aoti_torch_get_dtype(
78+
AtenTensorHandle tensor, int32_t* ret_dtype);
79+
80+
AOTI_SHIM_API int32_t aoti_torch_dtype_float16(void);
81+
AOTI_SHIM_API int32_t aoti_torch_dtype_float32(void);
82+
AOTI_SHIM_API int32_t aoti_torch_dtype_float64(void);
83+
AOTI_SHIM_API int32_t aoti_torch_dtype_bfloat16(void);
84+
AOTI_SHIM_API int32_t aoti_torch_dtype_uint8(void);
85+
AOTI_SHIM_API int32_t aoti_torch_dtype_uint16(void);
86+
AOTI_SHIM_API int32_t aoti_torch_dtype_uint32(void);
87+
AOTI_SHIM_API int32_t aoti_torch_dtype_uint64(void);
88+
AOTI_SHIM_API int32_t aoti_torch_dtype_int8(void);
89+
AOTI_SHIM_API int32_t aoti_torch_dtype_int16(void);
90+
AOTI_SHIM_API int32_t aoti_torch_dtype_int32(void);
91+
AOTI_SHIM_API int32_t aoti_torch_dtype_int64(void);
92+
AOTI_SHIM_API int32_t aoti_torch_dtype_bool(void);
93+
AOTI_SHIM_API int32_t aoti_torch_dtype_complex32(void);
94+
AOTI_SHIM_API int32_t aoti_torch_dtype_complex64(void);
95+
AOTI_SHIM_API int32_t aoti_torch_dtype_complex128(void);
96+
97+
/* ---- device ------------------------------------------------------------ */
98+
99+
AOTI_SHIM_API AOTITorchError aoti_torch_get_device_type(
100+
AtenTensorHandle tensor, int32_t* ret_device_type);
101+
102+
AOTI_SHIM_API AOTITorchError aoti_torch_get_device_index(
103+
AtenTensorHandle tensor, int32_t* ret_device_index);
104+
105+
AOTI_SHIM_API int32_t aoti_torch_device_type_cpu(void);
106+
AOTI_SHIM_API int32_t aoti_torch_device_type_cuda(void);
107+
108+
/* ---- stream -------------------------------------------------------------- */
109+
110+
AOTI_SHIM_API AOTITorchError aoti_torch_get_current_cuda_stream(
111+
int32_t device_index, void** ret_stream);
112+
113+
#ifdef __cplusplus
114+
} /* extern "C" */
115+
#endif
116+
117+
#endif /* CUDA_CORE_AOTI_SHIM_H */

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ from libc.stdint cimport intptr_t
1010
from cuda.core._layout cimport _StridedLayout, get_strides_ptr
1111
from cuda.core._stream import Stream
1212

13+
import ctypes
1314
import functools
15+
import sys
1416
import warnings
1517

1618
import numpy
@@ -29,6 +31,73 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
2931
from cuda.core._memory import Buffer
3032

3133

34+
# ---------------------------------------------------------------------------
35+
# Lazy tensor bridge (avoids loading _tensor_bridge.so until torch is used)
36+
# ---------------------------------------------------------------------------
37+
38+
cdef object _tensor_bridge = None
39+
# Cache: type(obj) -> True/False for the torch tensor check.
40+
# Once a type is seen, we never re-check.
41+
cdef dict _torch_type_cache = {}
42+
# Tri-state: None = not checked, True/False = result of version check
43+
cdef object _torch_version_ok = None
44+
45+
cdef inline bint _torch_version_check():
46+
"""Return True if 2.3 <= torch <= 2.11 (known AOTI ABI range). Memoized.
47+
48+
Lower bound: AOTI functions we use were introduced in PyTorch 2.3.
49+
Upper bound: the ``pyobj_to_aten_handle`` trick relies on the
50+
THPVariable struct layout (PyObject_HEAD followed by at::Tensor cdata)
51+
and the identity ``AtenTensorHandle == at::Tensor*``. Both are
52+
undocumented internals that could change in a future PyTorch version.
53+
We cap at the latest version we have tested against; unknown versions
54+
fall back to the standard DLPack/CAI paths. Bump the upper bound
55+
after verifying a new PyTorch release.
56+
"""
57+
global _torch_version_ok
58+
if _torch_version_ok is not None:
59+
return <bint>_torch_version_ok
60+
torch = sys.modules.get("torch")
61+
if torch is None:
62+
_torch_version_ok = False
63+
return False
64+
try:
65+
major, minor = int(torch.__version__.split(".")[0]), \
66+
int(torch.__version__.split(".")[1])
67+
_torch_version_ok = (2, 3) <= (major, minor) <= (2, 11)
68+
except (ValueError, IndexError):
69+
_torch_version_ok = False
70+
return <bint>_torch_version_ok
71+
72+
73+
cdef inline bint _is_torch_tensor(object obj):
74+
cdef type tp = type(obj)
75+
cdef object cached = _torch_type_cache.get(tp)
76+
if cached is not None:
77+
return <bint>cached
78+
cdef str mod = tp.__module__ or ""
79+
cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \
80+
and _torch_version_check()
81+
_torch_type_cache[tp] = result
82+
return result
83+
84+
85+
cdef object _get_tensor_bridge():
86+
"""Bootstrap AOTI symbols, then import _tensor_bridge on first use."""
87+
global _tensor_bridge
88+
if _tensor_bridge is not None:
89+
return _tensor_bridge
90+
torch_C = sys.modules.get("torch._C")
91+
if torch_C is None:
92+
raise RuntimeError(
93+
"torch._C is not loaded; cannot initialise the tensor bridge. "
94+
"Make sure PyTorch is imported before passing a torch.Tensor.")
95+
ctypes.CDLL(torch_C.__file__, mode=ctypes.RTLD_GLOBAL)
96+
from cuda.core import _tensor_bridge as tb
97+
_tensor_bridge = tb
98+
return _tensor_bridge
99+
100+
32101
try:
33102
from ml_dtypes import bfloat16
34103
except ImportError:
@@ -150,6 +219,9 @@ cdef class StridedMemoryView:
150219
Stream pointer for synchronization. If ``None``, no synchronization is performed.
151220
"""
152221
cdef StridedMemoryView buf = StridedMemoryView.__new__(cls)
222+
if _is_torch_tensor(obj):
223+
_get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf)
224+
return buf
153225
view_as_dlpack(obj, stream_ptr, buf)
154226
return buf
155227

@@ -165,6 +237,9 @@ cdef class StridedMemoryView:
165237
Stream pointer for synchronization. If ``None``, no synchronization is performed.
166238
"""
167239
cdef StridedMemoryView buf = StridedMemoryView.__new__(cls)
240+
if _is_torch_tensor(obj):
241+
_get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf)
242+
return buf
168243
view_as_cai(obj, stream_ptr, buf)
169244
return buf
170245

@@ -178,6 +253,9 @@ cdef class StridedMemoryView:
178253
An object implementing the `__array_interface__ <https://numpy.org/doc/stable/reference/arrays.interface.html>`_ protocol (e.g., a numpy array).
179254
"""
180255
cdef StridedMemoryView buf = StridedMemoryView.__new__(cls)
256+
if _is_torch_tensor(obj):
257+
_get_tensor_bridge().view_as_torch_tensor(obj, None, buf)
258+
return buf
181259
view_as_array_interface(obj, buf)
182260
return buf
183261

@@ -187,6 +265,8 @@ cdef class StridedMemoryView:
187265

188266
Tries `DLPack <https://dmlc.github.io/dlpack/latest/>`_ first, then falls back to
189267
`__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_.
268+
``torch.Tensor`` objects are transparently handled via a fast AOTI path
269+
regardless of which protocol is selected.
190270

191271
Parameters
192272
----------
@@ -480,6 +560,10 @@ cdef class StridedMemoryView:
480560
if self._dtype is None:
481561
if self.dl_tensor != NULL:
482562
self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype)
563+
elif isinstance(self.metadata, int):
564+
# AOTI dtype code stored by the torch tensor bridge
565+
self._dtype = _get_tensor_bridge().resolve_aoti_dtype(
566+
self.metadata)
483567
elif self.metadata is not None:
484568
self._dtype = _typestr2dtype(self.metadata["typestr"])
485569
return self._dtype
@@ -1122,6 +1206,16 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
11221206
as_cu(h_event), <cydriver.CUstream>producer_s))
11231207
HANDLE_RETURN(cydriver.cuStreamWaitEvent(
11241208
<cydriver.CUstream>consumer_s, as_cu(h_event), 0))
1209+
elif _is_torch_tensor(obj):
1210+
# PyTorch's __cuda_array_interface__ reports version 2 and
1211+
# omits the "stream" field, so the standard CAI sync path
1212+
# above is a no-op for torch tensors. This is unsafe: the
1213+
# consumer has no guarantee that the producer's work is
1214+
# visible. We fix this by querying PyTorch's current CUDA
1215+
# stream via the AOTI stable C ABI and performing the same
1216+
# event-based stream ordering.
1217+
_get_tensor_bridge().sync_torch_stream(
1218+
buf.device_id, <intptr_t>(stream_ptr))
11251219

11261220
return buf
11271221

0 commit comments

Comments
 (0)