Skip to content

Commit b5ec10d

Browse files
leofangclaude
andcommitted
Add strided layout guard to tensor bridge, reject sparse tensors
Check aoti_torch_get_layout() before extracting metadata — reject non-strided tensors (sparse, mkldnn, etc.) whose shape/strides are not meaningful for dense memory access. We intentionally skip the other Python-level __dlpack__ guards (requires_grad, is_conj, is_neg, wrong-device) for the same reason PyTorch's own __dlpack_c_exchange_api__ C path skips them: the C-level exchange path is designed for performance-critical consumers. PyTorch's DLTensorFromPyObjectNoSync → toDLPackNonOwning performs zero safety checks (see aten/src/ATen/DLConvertor.cpp). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 615f984 commit b5ec10d

3 files changed

Lines changed: 29 additions & 0 deletions

File tree

cuda_core/cuda/core/_include/aoti_shim.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ EXPORTS
2626
aoti_torch_get_device_index
2727
aoti_torch_device_type_cpu
2828
aoti_torch_device_type_cuda
29+
aoti_torch_get_layout
30+
aoti_torch_layout_strided
2931
aoti_torch_get_current_cuda_stream

cuda_core/cuda/core/_include/aoti_shim.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ AOTI_SHIM_API AOTITorchError aoti_torch_get_device_index(
9494
AOTI_SHIM_API int32_t aoti_torch_device_type_cpu(void);
9595
AOTI_SHIM_API int32_t aoti_torch_device_type_cuda(void);
9696

97+
/* ---- layout -------------------------------------------------------------- */
98+
99+
AOTI_SHIM_API AOTITorchError aoti_torch_get_layout(
100+
AtenTensorHandle tensor, int32_t* ret_layout);
101+
102+
AOTI_SHIM_API int32_t aoti_torch_layout_strided(void);
103+
97104
/* ---- stream -------------------------------------------------------------- */
98105

99106
AOTI_SHIM_API AOTITorchError aoti_torch_get_current_cuda_stream(

cuda_core/cuda/core/_tensor_bridge.pyx

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ cdef extern from "_include/aoti_shim.h":
100100
int32_t aoti_torch_device_type_cpu()
101101
int32_t aoti_torch_device_type_cuda()
102102

103+
# layout
104+
AOTITorchError aoti_torch_get_layout(AtenTensorHandle, int32_t*)
105+
int32_t aoti_torch_layout_strided()
106+
103107
# stream
104108
AOTITorchError aoti_torch_get_current_cuda_stream(int32_t, void**)
105109

@@ -115,6 +119,7 @@ import sys
115119

116120
cdef int32_t _DEVICE_TYPE_CPU = aoti_torch_device_type_cpu()
117121
cdef int32_t _DEVICE_TYPE_CUDA = aoti_torch_device_type_cuda()
122+
cdef int32_t _LAYOUT_STRIDED = aoti_torch_layout_strided()
118123
cdef dict _aoti_dtype_map = None
119124
cdef dict _aoti_itemsize_map = None
120125

@@ -310,11 +315,26 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None):
310315
cdef int64_t* strides_ptr
311316
cdef int32_t dtype_code
312317
cdef int32_t device_type, device_index
318+
cdef int32_t tensor_layout
313319
cdef StridedMemoryView buf
314320
cdef int itemsize
315321
cdef intptr_t _stream_ptr_int
316322
cdef _StridedLayout layout
317323

324+
# Reject non-strided (sparse, mkldnn, etc.) tensors whose shape/strides
325+
# are not meaningful for dense memory access. This mirrors the guard in
326+
# PyTorch's Python-level __dlpack__ ("layout other than torch.strided").
327+
# Note: we intentionally skip the other Python-level guards
328+
# (requires_grad, is_conj, is_neg, wrong-device) for the same reason
329+
# PyTorch's own __dlpack_c_exchange_api__ C path skips them — the C-level
330+
# exchange path is designed for performance-critical consumers.
331+
check_aoti(aoti_torch_get_layout(handle, &tensor_layout),
332+
b"aoti_torch_get_layout")
333+
if tensor_layout != _LAYOUT_STRIDED:
334+
raise BufferError(
335+
"Only strided tensors can be viewed via the tensor bridge "
336+
"(use tensor.to_dense() to convert sparse tensors first)")
337+
318338
check_aoti(aoti_torch_get_data_ptr(handle, &data_ptr),
319339
b"aoti_torch_get_data_ptr")
320340
check_aoti(aoti_torch_get_dim(handle, &ndim),

0 commit comments

Comments
 (0)