1616import numpy as np
1717import pytest
1818from cuda .core .experimental import Device
19- from cuda .core .experimental .utils import StridedMemoryView , _StridedLayout , args_viewable_as_strided_memory
19+ from cuda .core .experimental ._layout import _StridedLayout
20+ from cuda .core .experimental .utils import StridedMemoryView , args_viewable_as_strided_memory
2021
2122
2223def test_cast_to_3_tuple_success ():
@@ -234,44 +235,60 @@ def _dense_strides(shape, stride_order):
234235 return tuple (strides )
235236
236237
237- @pytest .mark .parametrize ("shape" , [tuple (), (2 , 3 ), (10 , 10 ), (10 , 13 , 11 )])
238- @pytest .mark .parametrize ("itemsize " , [1 , 4 ] )
238+ @pytest .mark .parametrize ("shape" , [tuple (), (2 , 3 ), (10 , 10 ), (10 , 13 , 11 )], ids = str )
239+ @pytest .mark .parametrize ("dtype " , [np . dtype ( np . int8 ), np . dtype ( np . uint32 )], ids = str )
239240@pytest .mark .parametrize ("stride_order" , ["C" , "F" ])
240241@pytest .mark .parametrize ("readonly" , [True , False ])
241- def test_from_buffer (shape , itemsize , stride_order , readonly ):
242+ def test_from_buffer (shape , dtype , stride_order , readonly ):
242243 dev = Device ()
243244 dev .set_current ()
244- layout = _StridedLayout .dense (shape = shape , itemsize = itemsize , stride_order = stride_order )
245+ layout = _StridedLayout .dense (shape = shape , itemsize = dtype . itemsize , stride_order = stride_order )
245246 required_size = layout .required_size_in_bytes ()
246- assert required_size == math .prod (shape ) * itemsize
247+ assert required_size == math .prod (shape ) * dtype . itemsize
247248 buffer = dev .memory_resource .allocate (required_size )
248- view = StridedMemoryView .from_buffer (buffer , layout , is_readonly = readonly )
249+ view = StridedMemoryView .from_buffer (buffer , shape = shape , strides = layout . strides , dtype = dtype , is_readonly = readonly )
249250 assert view .exporting_obj is buffer
250- assert view .layout is layout
251+ assert view ._layout == layout
251252 assert view .ptr == int (buffer .handle )
252253 assert view .shape == shape
253254 assert view .strides == _dense_strides (shape , stride_order )
254- assert view .dtype is None
255+ assert view .dtype == dtype
255256 assert view .device_id == dev .device_id
256257 assert view .is_device_accessible
257258 assert view .readonly == readonly
258259
259260
261+ @pytest .mark .parametrize (
262+ ("dtype" , "itemsize" , "msg" ),
263+ [
264+ (np .dtype ("int16" ), 1 , "itemsize .+ does not match dtype.itemsize .+" ),
265+ (None , None , "itemsize or dtype must be specified" ),
266+ ],
267+ )
268+ def test_from_buffer_incompatible_dtype_and_itemsize (dtype , itemsize , msg ):
269+ layout = _StridedLayout .dense ((5 ,), 2 )
270+ device = Device ()
271+ device .set_current ()
272+ buffer = device .memory_resource .allocate (layout .required_size_in_bytes ())
273+ with pytest .raises (ValueError , match = msg ):
274+ StridedMemoryView .from_buffer (buffer , (5 ,), dtype = dtype , itemsize = itemsize )
275+
276+
260277@pytest .mark .parametrize ("stride_order" , ["C" , "F" ])
261278def test_from_buffer_sliced (stride_order ):
262279 layout = _StridedLayout .dense ((5 , 7 ), 2 , stride_order = stride_order )
263280 device = Device ()
264281 device .set_current ()
265282 buffer = device .memory_resource .allocate (layout .required_size_in_bytes ())
266- view = StridedMemoryView .from_buffer (buffer , layout )
283+ view = StridedMemoryView .from_buffer (buffer , ( 5 , 7 ), dtype = np . dtype ( np . int16 ) )
267284 assert view .shape == (5 , 7 )
268285 assert int (buffer .handle ) == view .ptr
269286
270287 sliced_view = view .view (layout [:- 2 , 3 :])
271288 assert sliced_view .shape == (3 , 4 )
272289 expected_offset = 3 if stride_order == "C" else 3 * 5
273- assert sliced_view .layout .slice_offset == expected_offset
274- assert sliced_view .layout .slice_offset_in_bytes == expected_offset * 2
290+ assert sliced_view ._layout .slice_offset == expected_offset
291+ assert sliced_view ._layout .slice_offset_in_bytes == expected_offset * 2
275292 assert sliced_view .ptr == view .ptr + expected_offset * 2
276293 assert int (buffer .handle ) + expected_offset * 2 == sliced_view .ptr
277294
@@ -282,16 +299,26 @@ def test_from_buffer_too_small():
282299 d .set_current ()
283300 buffer = d .memory_resource .allocate (20 )
284301 with pytest .raises (ValueError , match = "Expected at least 40 bytes, got 20 bytes." ):
285- StridedMemoryView .from_buffer (buffer , layout )
302+ StridedMemoryView .from_buffer (
303+ buffer ,
304+ shape = layout .shape ,
305+ strides = layout .strides ,
306+ dtype = np .dtype ("int16" ),
307+ )
286308
287309
288310def test_from_buffer_disallowed_negative_offset ():
289311 layout = _StridedLayout ((5 , 4 ), (- 4 , 1 ), 1 )
290312 d = Device ()
291313 d .set_current ()
292314 buffer = d .memory_resource .allocate (20 )
293- with pytest .raises (ValueError , match = "please use _StridedLayout.to_dense()." ):
294- StridedMemoryView .from_buffer (buffer , layout )
315+ with pytest .raises (ValueError ):
316+ StridedMemoryView .from_buffer (
317+ buffer ,
318+ shape = layout .shape ,
319+ strides = layout .strides ,
320+ dtype = np .dtype ("uint8" ),
321+ )
295322
296323
297324class _EnforceCAIView :
@@ -331,7 +358,7 @@ def test_view_sliced_external(shape, slices, stride_order, view_as):
331358 pytest .skip ("CuPy is not installed" )
332359 a = cp .arange (math .prod (shape ), dtype = cp .int32 ).reshape (shape , order = stride_order )
333360 view = StridedMemoryView .from_cuda_array_interface (_EnforceCAIView (a ), - 1 )
334- layout = view .layout
361+ layout = view ._layout
335362 assert layout .is_dense
336363 assert layout .required_size_in_bytes () == a .nbytes
337364 assert view .ptr == _get_ptr (a )
@@ -344,11 +371,11 @@ def test_view_sliced_external(shape, slices, stride_order, view_as):
344371
345372 assert 0 <= sliced_layout .required_size_in_bytes () <= a .nbytes
346373 assert not sliced_layout .is_dense
347- assert sliced_view .layout is sliced_layout
374+ assert sliced_view ._layout is sliced_layout
348375 assert view .dtype == sliced_view .dtype
349- assert sliced_view .layout .itemsize == a_sliced .itemsize == layout .itemsize
376+ assert sliced_view ._layout .itemsize == a_sliced .itemsize == layout .itemsize
350377 assert sliced_view .shape == a_sliced .shape
351- assert sliced_view .layout .strides_in_bytes == a_sliced .strides
378+ assert sliced_view ._layout .strides_in_bytes == a_sliced .strides
352379
353380
354381@pytest .mark .parametrize (
@@ -369,7 +396,7 @@ def test_view_sliced_external_negative_offset(stride_order, view_as):
369396 a = cp .arange (math .prod (shape ), dtype = cp .int32 ).reshape (shape , order = stride_order )
370397 a = a [::- 1 ]
371398 view = StridedMemoryView .from_cuda_array_interface (_EnforceCAIView (a ), - 1 )
372- layout = view .layout
399+ layout = view ._layout
373400 assert not layout .is_dense
374401 assert layout .strides == (- 1 ,)
375402 assert view .ptr == _get_ptr (a )
@@ -381,8 +408,8 @@ def test_view_sliced_external_negative_offset(stride_order, view_as):
381408 assert sliced_view .ptr == view .ptr - 3 * a .itemsize
382409
383410 assert not sliced_layout .is_dense
384- assert sliced_view .layout is sliced_layout
411+ assert sliced_view ._layout is sliced_layout
385412 assert view .dtype == sliced_view .dtype
386- assert sliced_view .layout .itemsize == a_sliced .itemsize == layout .itemsize
413+ assert sliced_view ._layout .itemsize == a_sliced .itemsize == layout .itemsize
387414 assert sliced_view .shape == a_sliced .shape
388- assert sliced_view .layout .strides_in_bytes == a_sliced .strides
415+ assert sliced_view ._layout .strides_in_bytes == a_sliced .strides
0 commit comments