-
Notifications
You must be signed in to change notification settings - Fork 274
Add torch.Tensor fast path for StridedMemoryView via AOTI tensor bridge #1894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
82ad598
f8f8d8c
af06e9b
44be580
6e6b8a6
85caaaf
9fad471
cc4558a
5f49e7a
b98fe71
30ba7d5
0f57646
74798e7
00b8ec9
0c31df1
8c20237
8c019b9
6682646
0b7245b
626736a
d1d3841
b9d80e7
7d46123
c7331a9
0e75229
37fce1a
7f5dda6
f6a3032
2748a52
d543be1
84ff2ec
a66c0d0
833bcf8
615f984
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| ; Stub import library definition for PyTorch's AOTI stable C ABI symbols. | ||
| ; Used on Windows only: 'lib /DEF:aoti_shim.def /OUT:aoti_shim.lib /MACHINE:X64' | ||
| ; generates a minimal import library that satisfies the MSVC linker. | ||
| ; At runtime the symbols resolve from torch_cpu.dll (loaded by 'import torch'). | ||
| LIBRARY torch_cpu.dll | ||
| EXPORTS | ||
| aoti_torch_get_data_ptr | ||
| aoti_torch_get_dim | ||
| aoti_torch_get_sizes | ||
| aoti_torch_get_strides | ||
| aoti_torch_get_dtype | ||
| aoti_torch_dtype_float16 | ||
| aoti_torch_dtype_float32 | ||
| aoti_torch_dtype_float64 | ||
| aoti_torch_dtype_bfloat16 | ||
| aoti_torch_dtype_uint8 | ||
| aoti_torch_dtype_int8 | ||
| aoti_torch_dtype_int16 | ||
| aoti_torch_dtype_int32 | ||
| aoti_torch_dtype_int64 | ||
| aoti_torch_dtype_bool | ||
| aoti_torch_dtype_complex32 | ||
| aoti_torch_dtype_complex64 | ||
| aoti_torch_dtype_complex128 | ||
| aoti_torch_get_device_type | ||
| aoti_torch_get_device_index | ||
| aoti_torch_device_type_cpu | ||
| aoti_torch_device_type_cuda | ||
| aoti_torch_get_current_cuda_stream | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| /* | ||
| * Vendored subset of PyTorch's AOT Inductor (AOTI) stable C ABI. | ||
| * Original: torch/csrc/inductor/aoti_torch/c/shim.h | ||
| * | ||
| * These are declarations only -- no definitions are provided. The actual | ||
| * symbols are exported by libtorch (loaded via torch._C with RTLD_GLOBAL) | ||
| * and resolved at runtime by the dynamic linker. This means PyTorch is | ||
| * NOT required at compile time. | ||
| * | ||
| * From PyTorch: | ||
| * | ||
| * Copyright (c) 2016- Facebook, Inc (Adam Paszke) | ||
| * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | ||
| * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | ||
| * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | ||
| * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | ||
| * Copyright (c) 2011-2013 NYU (Clement Farabet) | ||
| * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | ||
| * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | ||
| * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | ||
| * | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| * See https://github.com/pytorch/pytorch/blob/main/LICENSE | ||
| */ | ||
|
|
||
| #ifndef CUDA_CORE_AOTI_SHIM_H | ||
| #define CUDA_CORE_AOTI_SHIM_H | ||
|
|
||
| #include <stdint.h> | ||
|
|
||
| /* | ||
| * On Windows the AOTI symbols live in torch_cpu.dll. We consume them | ||
| * via __declspec(dllimport) and a stub import library generated from | ||
| * aoti_shim.def at build time. On Linux/macOS the symbols are made | ||
| * visible at runtime through ctypes.CDLL(torch._C, RTLD_GLOBAL). | ||
| */ | ||
| #ifdef _WIN32 | ||
| # define AOTI_SHIM_API __declspec(dllimport) | ||
| #else | ||
| # define AOTI_SHIM_API | ||
| #endif | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| typedef int32_t AOTITorchError; | ||
|
|
||
| /* Opaque tensor handle -- corresponds to at::Tensor on the C++ side. */ | ||
| struct AtenTensorOpaque; | ||
| typedef struct AtenTensorOpaque* AtenTensorHandle; | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To help future maintainers and agents: |
||
| /* ---- tensor metadata --------------------------------------------------- */ | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_data_ptr( | ||
| AtenTensorHandle tensor, void** ret_data_ptr); | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_dim( | ||
| AtenTensorHandle tensor, int64_t* ret_dim); | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_sizes( | ||
| AtenTensorHandle tensor, int64_t** ret_sizes); | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_strides( | ||
| AtenTensorHandle tensor, int64_t** ret_strides); | ||
|
|
||
| /* ---- dtype ------------------------------------------------------------- */ | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_dtype( | ||
| AtenTensorHandle tensor, int32_t* ret_dtype); | ||
|
|
||
| AOTI_SHIM_API int32_t aoti_torch_dtype_float16(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_float32(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_float64(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_bfloat16(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_uint8(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_int8(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_int16(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_int32(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_int64(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_bool(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_complex32(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_complex64(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_dtype_complex128(void); | ||
|
|
||
| /* ---- device ------------------------------------------------------------ */ | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_device_type( | ||
| AtenTensorHandle tensor, int32_t* ret_device_type); | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_device_index( | ||
| AtenTensorHandle tensor, int32_t* ret_device_index); | ||
|
|
||
| AOTI_SHIM_API int32_t aoti_torch_device_type_cpu(void); | ||
| AOTI_SHIM_API int32_t aoti_torch_device_type_cuda(void); | ||
|
|
||
| /* ---- stream -------------------------------------------------------------- */ | ||
|
|
||
| AOTI_SHIM_API AOTITorchError aoti_torch_get_current_cuda_stream( | ||
| int32_t device_index, void** ret_stream); | ||
|
|
||
| #ifdef __cplusplus | ||
| } /* extern "C" */ | ||
| #endif | ||
|
|
||
| #endif /* CUDA_CORE_AOTI_SHIM_H */ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,9 @@ from libc.stdint cimport intptr_t | |
| from cuda.core._layout cimport _StridedLayout, get_strides_ptr | ||
| from cuda.core._stream import Stream | ||
|
|
||
| import ctypes | ||
| import functools | ||
| import sys | ||
| import warnings | ||
|
|
||
| import numpy | ||
|
|
@@ -29,6 +31,73 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN | |
| from cuda.core._memory import Buffer | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Lazy tensor bridge (avoids loading _tensor_bridge.so until torch is used) | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| cdef object _tensor_bridge = None | ||
| # Cache: type(obj) -> True/False for the torch tensor check. | ||
| # Once a type is seen, we never re-check. | ||
| cdef dict _torch_type_cache = {} | ||
| # Tri-state: None = not checked, True/False = result of version check | ||
| cdef object _torch_version_ok = None | ||
|
|
||
| cdef inline bint _torch_version_check(): | ||
| """Return True if 2.3 <= torch <= 2.11 (known AOTI ABI range). Memoized. | ||
|
|
||
| Lower bound: AOTI functions we use were introduced in PyTorch 2.3. | ||
| Upper bound: the ``pyobj_to_aten_handle`` trick relies on the | ||
| THPVariable struct layout (PyObject_HEAD followed by at::Tensor cdata) | ||
| and the identity ``AtenTensorHandle == at::Tensor*``. Both are | ||
| undocumented internals that could change in a future PyTorch version. | ||
| We cap at the latest version we have tested against; unknown versions | ||
| fall back to the standard DLPack/CAI paths. Bump the upper bound | ||
| after verifying a new PyTorch release. | ||
| """ | ||
| global _torch_version_ok | ||
| if _torch_version_ok is not None: | ||
| return <bint>_torch_version_ok | ||
| torch = sys.modules.get("torch") | ||
| if torch is None: | ||
| _torch_version_ok = False | ||
| return False | ||
| try: | ||
| major, minor = int(torch.__version__.split(".")[0]), \ | ||
| int(torch.__version__.split(".")[1]) | ||
| _torch_version_ok = (2, 3) <= (major, minor) <= (2, 11) | ||
| except (ValueError, IndexError): | ||
| _torch_version_ok = False | ||
| return <bint>_torch_version_ok | ||
|
|
||
|
|
||
| cdef inline bint _is_torch_tensor(object obj): | ||
| cdef type tp = type(obj) | ||
| cdef object cached = _torch_type_cache.get(tp) | ||
| if cached is not None: | ||
| return <bint>cached | ||
| cdef str mod = tp.__module__ or "" | ||
| cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \ | ||
| and _torch_version_check() | ||
| _torch_type_cache[tp] = result | ||
| return result | ||
|
|
||
|
|
||
| cdef object _get_tensor_bridge(): | ||
| """Bootstrap AOTI symbols, then import _tensor_bridge on first use.""" | ||
| global _tensor_bridge | ||
| if _tensor_bridge is not None: | ||
| return _tensor_bridge | ||
| torch_C = sys.modules.get("torch._C") | ||
| if torch_C is None: | ||
| raise RuntimeError( | ||
| "torch._C is not loaded; cannot initialise the tensor bridge. " | ||
| "Make sure PyTorch is imported before passing a torch.Tensor.") | ||
| ctypes.CDLL(torch_C.__file__, mode=ctypes.RTLD_GLOBAL) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: if Windows is supported here, should this path handle Windows explicitly instead of always using mode=ctypes.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As per the ctypes docs, on Windows |
||
| from cuda.core import _tensor_bridge as tb | ||
| _tensor_bridge = tb | ||
| return _tensor_bridge | ||
|
|
||
|
|
||
| try: | ||
| from ml_dtypes import bfloat16 | ||
| except ImportError: | ||
|
|
@@ -150,6 +219,9 @@ cdef class StridedMemoryView: | |
| Stream pointer for synchronization. If ``None``, no synchronization is performed. | ||
| """ | ||
| cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) | ||
| if _is_torch_tensor(obj): | ||
| _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf) | ||
| return buf | ||
| view_as_dlpack(obj, stream_ptr, buf) | ||
| return buf | ||
|
|
||
|
|
@@ -165,6 +237,9 @@ cdef class StridedMemoryView: | |
| Stream pointer for synchronization. If ``None``, no synchronization is performed. | ||
| """ | ||
| cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) | ||
| if _is_torch_tensor(obj): | ||
| _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf) | ||
| return buf | ||
| view_as_cai(obj, stream_ptr, buf) | ||
| return buf | ||
|
|
||
|
|
@@ -178,6 +253,9 @@ cdef class StridedMemoryView: | |
| An object implementing the `__array_interface__ <https://numpy.org/doc/stable/reference/arrays.interface.html>`_ protocol (e.g., a numpy array). | ||
| """ | ||
| cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) | ||
| if _is_torch_tensor(obj): | ||
| _get_tensor_bridge().view_as_torch_tensor(obj, None, buf) | ||
| return buf | ||
| view_as_array_interface(obj, buf) | ||
| return buf | ||
|
|
||
|
|
@@ -187,6 +265,8 @@ cdef class StridedMemoryView: | |
|
|
||
| Tries `DLPack <https://dmlc.github.io/dlpack/latest/>`_ first, then falls back to | ||
| `__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_. | ||
| ``torch.Tensor`` objects are transparently handled via a fast AOTI path | ||
| regardless of which protocol is selected. | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -480,6 +560,10 @@ cdef class StridedMemoryView: | |
| if self._dtype is None: | ||
| if self.dl_tensor != NULL: | ||
| self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype) | ||
| elif isinstance(self.metadata, int): | ||
| # AOTI dtype code stored by the torch tensor bridge | ||
| self._dtype = _get_tensor_bridge().resolve_aoti_dtype( | ||
| self.metadata) | ||
| elif self.metadata is not None: | ||
| self._dtype = _typestr2dtype(self.metadata["typestr"]) | ||
| return self._dtype | ||
|
|
@@ -1122,6 +1206,16 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None): | |
| as_cu(h_event), <cydriver.CUstream>producer_s)) | ||
| HANDLE_RETURN(cydriver.cuStreamWaitEvent( | ||
| <cydriver.CUstream>consumer_s, as_cu(h_event), 0)) | ||
| elif _is_torch_tensor(obj): | ||
| # PyTorch's __cuda_array_interface__ reports version 2 and | ||
| # omits the "stream" field, so the standard CAI sync path | ||
| # above is a no-op for torch tensors. This is unsafe: the | ||
| # consumer has no guarantee that the producer's work is | ||
| # visible. We fix this by querying PyTorch's current CUDA | ||
| # stream via the AOTI stable C ABI and performing the same | ||
| # event-based stream ordering. | ||
| _get_tensor_bridge().sync_torch_stream( | ||
| buf.device_id, <intptr_t>(stream_ptr)) | ||
|
|
||
| return buf | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the suggested comment in aoti_shim.h: