Skip to content

Commit 7e3574b

Browse files
feat: relax the power of two check in StridedLayout (NVIDIA#1427) (NVIDIA#1471)
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
1 parent 083315c commit 7e3574b

3 files changed

Lines changed: 31 additions & 25 deletions

File tree

cuda_core/cuda/core/_layout.pxd

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ cdef class _StridedLayout:
111111
# ==============================
112112

113113
cdef inline int _init(_StridedLayout self, BaseLayout& base, int itemsize, bint divide_strides=False) except -1 nogil:
114-
_validate_itemsize(itemsize)
114+
if itemsize <= 0:
115+
raise ValueError("itemsize must be positive")
115116

116117
if base.strides != NULL and divide_strides:
117118
_divide_strides(base, itemsize)
@@ -123,7 +124,8 @@ cdef class _StridedLayout:
123124
return 0
124125

125126
cdef inline stride_t _init_dense(_StridedLayout self, BaseLayout& base, int itemsize, OrderFlag order_flag, axis_vec_t* stride_order=NULL) except -1 nogil:
126-
_validate_itemsize(itemsize)
127+
if itemsize <= 0:
128+
raise ValueError("itemsize must be positive")
127129

128130
cdef stride_t volume
129131
if order_flag == ORDER_C:
@@ -643,14 +645,6 @@ cdef inline bint _normalize_axis(integer_t& axis, integer_t extent) except -1 no
643645
return True
644646

645647

646-
cdef inline int _validate_itemsize(int itemsize) except -1 nogil:
647-
if itemsize <= 0:
648-
raise ValueError("itemsize must be positive")
649-
if itemsize & (itemsize - 1):
650-
raise ValueError("itemsize must be a power of two")
651-
return 0
652-
653-
654648
cdef inline bint _is_unique(BaseLayout& base, axis_vec_t& stride_order) except -1 nogil:
655649
if base.strides == NULL:
656650
return True

cuda_core/cuda/core/_layout.pyx

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ cdef class _StridedLayout:
2929
Otherwise, the strides are assumed to be implicitly C-contiguous and the resulting
3030
layout's :attr:`strides` will be None.
3131
itemsize : int
32-
The number of bytes per single element (dtype size). Must be a power of two.
32+
The number of bytes per single element (dtype size).
3333
divide_strides : bool, optional
3434
If True, the provided :attr:`strides` will be divided by the :attr:`itemsize`.
3535
@@ -40,7 +40,7 @@ cdef class _StridedLayout:
4040
Attributes
4141
----------
4242
itemsize : int
43-
The number of bytes per single element (dtype size). Must be a power of two.
43+
The number of bytes per single element (dtype size).
4444
slice_offset : int
4545
The offset (as a number of elements, not bytes) of the element at
4646
index ``(0,) * ndim``. See also :attr:`slice_offset_in_bytes`.
@@ -636,7 +636,6 @@ cdef class _StridedLayout:
636636
In either case, the ``volume * itemsize`` of the layout remains the same.
637637

638638
The conversion is subject to the following constraints:
639-
* The old and new itemsizes must be powers of two.
640639
* The extent at ``axis`` must be a positive integer.
641640
* The stride at ``axis`` must be 1.
642641

@@ -1214,10 +1213,10 @@ cdef inline int64_t gcd(int64_t a, int64_t b) except? -1 nogil:
12141213

12151214
cdef inline int pack_extents(BaseLayout& out_layout, stride_t& out_slice_offset, BaseLayout& in_layout, stride_t slice_offset, int itemsize, int new_itemsize, intptr_t data_ptr, bint keep_dim, int axis) except -1 nogil:
12161215
cdef int ndim = in_layout.ndim
1217-
if new_itemsize <= 0 or new_itemsize & (new_itemsize - 1):
1218-
raise ValueError(f"new itemsize must be a power of two, got {new_itemsize}.")
1219-
if itemsize <= 0 or itemsize & (itemsize - 1):
1220-
raise ValueError(f"itemsize must be a power of two, got {itemsize}.")
1216+
if new_itemsize <= 0:
1217+
raise ValueError(f"new itemsize must be greater than zero, got {new_itemsize}.")
1218+
if itemsize <= 0:
1219+
raise ValueError(f"itemsize must be greater than zero, got {itemsize}.")
12211220
if new_itemsize <= itemsize:
12221221
if new_itemsize == itemsize:
12231222
return 1
@@ -1270,10 +1269,10 @@ cdef inline int unpack_extents(BaseLayout &out_layout, BaseLayout &in_layout, in
12701269
cdef int ndim = in_layout.ndim
12711270
if not _normalize_axis(axis, ndim):
12721271
raise ValueError(f"Invalid axis: {axis} out of range for {ndim}D tensor")
1273-
if new_itemsize <= 0 or new_itemsize & (new_itemsize - 1):
1274-
raise ValueError(f"new itemsize must be a power of two, got {new_itemsize}.")
1275-
if itemsize <= 0 or itemsize & (itemsize - 1):
1276-
raise ValueError(f"itemsize must be a power of two, got {itemsize}.")
1272+
if new_itemsize <= 0:
1273+
raise ValueError(f"new itemsize must be greater than zero, got {new_itemsize}.")
1274+
if itemsize <= 0:
1275+
raise ValueError(f"itemsize must be greater than zero, got {itemsize}.")
12771276
if new_itemsize >= itemsize:
12781277
if new_itemsize == itemsize:
12791278
return 1
@@ -1301,10 +1300,10 @@ cdef inline int unpack_extents(BaseLayout &out_layout, BaseLayout &in_layout, in
13011300

13021301
cdef inline int max_compatible_itemsize(BaseLayout& layout, stride_t slice_offset, int itemsize, int max_itemsize, intptr_t data_ptr, int axis) except? -1 nogil:
13031302
cdef int ndim = layout.ndim
1304-
if max_itemsize <= 0 or max_itemsize & (max_itemsize - 1):
1305-
raise ValueError(f"max_itemsize must be a power of two, got {max_itemsize}.")
1306-
if itemsize <= 0 or itemsize & (itemsize - 1):
1307-
raise ValueError(f"itemsize must be a power of two, got {itemsize}.")
1303+
if max_itemsize <= 0:
1304+
raise ValueError(f"max_itemsize must be greater than zero, got {max_itemsize}.")
1305+
if itemsize <= 0:
1306+
raise ValueError(f"itemsize must be greater than zero, got {itemsize}.")
13081307
if not _normalize_axis(axis, ndim):
13091308
raise ValueError(f"Invalid axis: {axis} out of range for {ndim}D tensor")
13101309
if max_itemsize < itemsize:

cuda_core/tests/test_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,16 @@ def test_view_zero_size_array(api, shape, dtype):
433433
assert smv.size == 0
434434
assert smv.shape == shape
435435
assert smv.dtype == np.dtype(dtype)
436+
437+
438+
def test_from_buffer_with_non_power_of_two_itemsize():
439+
dev = Device()
440+
dev.set_current()
441+
dtype = np.dtype([("a", "int32"), ("b", "int8")])
442+
shape = (1,)
443+
layout = _StridedLayout(shape=shape, strides=None, itemsize=dtype.itemsize)
444+
required_size = layout.required_size_in_bytes()
445+
assert required_size == math.prod(shape) * dtype.itemsize
446+
buffer = dev.memory_resource.allocate(required_size)
447+
view = StridedMemoryView.from_buffer(buffer, shape=shape, strides=layout.strides, dtype=dtype, is_readonly=True)
448+
assert view.dtype == dtype

0 commit comments

Comments
 (0)