Skip to content

Commit ad9bc92

Browse files
leofangclaude
andauthored
Fix torch-incompatible assertions in TestViewCudaArrayInterfaceGPU (#1999)
* Fix torch-incompatible assertions in TestViewCudaArrayInterfaceGPU The _check_view method in TestViewCudaArrayInterfaceGPU was missed during the tensor bridge refactor (#1894) and still used raw numpy attributes (in_arr.size, in_arr.strides, in_arr.flags, etc.) that don't work with torch tensors. Use the _arr_* helpers that #1894 added for torch/numpy compatibility. Caught by the nightly optional-dependency CI (#1987). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix strides assertion for torch CAI: allow explicit C-contiguous strides torch's __cuda_array_interface__ always reports strides, even for C-contiguous tensors. Use the same assertion pattern as the other _check_view methods: allow strides to equal the C-contiguous values instead of requiring None. Verified locally: 7/7 torch CAI tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Unify strides assertion pattern across all _check_view methods Use the same if/else pattern with `in (None, strides_in_counts)` in all three _check_view methods for consistency. Previously TestViewCPU and TestViewCudaArrayInterfaceGPU used a one-liner that was harder to read and behaved slightly differently. Verified locally: 66/66 tests pass across TestViewCPU, TestViewGPU, and TestViewCudaArrayInterfaceGPU (including all torch variants). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Address review: flip strides assertion, add _arr_dtype, merge main Per @rwgk's review: - Flip strides check to branch on view.strides (all 3 _check_view) - Add _arr_dtype helper using __cuda_array_interface__["typestr"] for torch tensors, restore dtype assertion in CAI _check_view - Merge main to pick up #1998 (numba flags fix) Verified locally: 76/76 tests pass across all three test classes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 371fa42 commit ad9bc92

1 file changed

Lines changed: 20 additions & 10 deletions

File tree

cuda_core/tests/test_utils.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def _arr_is_writeable(arr):
111111
return arr.flags.writeable if hasattr(arr.flags, "writeable") else True
112112

113113

114+
def _arr_dtype(arr):
115+
if torch is not None and isinstance(arr, torch.Tensor):
116+
return np.dtype(arr.__cuda_array_interface__["typestr"])
117+
return arr.dtype
118+
119+
114120
def _cpu_array_samples():
115121
samples = [
116122
np.empty(3, dtype=np.int32),
@@ -171,7 +177,10 @@ def _check_view(self, view, in_arr):
171177
assert view.shape == expected_shape
172178
assert view.size == _arr_size(in_arr)
173179
strides_in_counts = _arr_strides_in_counts(in_arr)
174-
assert (_arr_is_c_contiguous(in_arr) and view.strides is None) or view.strides == strides_in_counts
180+
if view.strides is None:
181+
assert _arr_is_c_contiguous(in_arr)
182+
else:
183+
assert view.strides == strides_in_counts
175184
assert view.device_id == -1
176185
assert view.is_device_accessible is False
177186
assert view.exporting_obj is in_arr
@@ -277,8 +286,8 @@ def _check_view(self, view, in_arr, dev):
277286
assert view.shape == expected_shape
278287
assert view.size == _arr_size(in_arr)
279288
strides_in_counts = _arr_strides_in_counts(in_arr)
280-
if _arr_is_c_contiguous(in_arr):
281-
assert view.strides in (None, strides_in_counts)
289+
if view.strides is None:
290+
assert _arr_is_c_contiguous(in_arr)
282291
else:
283292
assert view.strides == strides_in_counts
284293
assert view.device_id == dev.device_id
@@ -343,15 +352,16 @@ def test_cuda_array_interface_gpu(self, in_arr, use_stream):
343352

344353
def _check_view(self, view, in_arr, dev):
345354
assert isinstance(view, StridedMemoryView)
346-
assert view.ptr == gpu_array_ptr(in_arr)
347-
assert view.shape == in_arr.shape
348-
assert view.size == in_arr.size
349-
strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize)
350-
if in_arr.flags["C_CONTIGUOUS"]:
351-
assert view.strides is None
355+
assert view.ptr == _arr_ptr(in_arr)
356+
expected_shape = tuple(in_arr.shape)
357+
assert view.shape == expected_shape
358+
assert view.size == _arr_size(in_arr)
359+
strides_in_counts = _arr_strides_in_counts(in_arr)
360+
if view.strides is None:
361+
assert _arr_is_c_contiguous(in_arr)
352362
else:
353363
assert view.strides == strides_in_counts
354-
assert view.dtype == in_arr.dtype
364+
assert view.dtype == _arr_dtype(in_arr)
355365
assert view.device_id == dev.device_id
356366
assert view.is_device_accessible is True
357367
assert view.exporting_obj is in_arr

0 commit comments

Comments
 (0)