Skip to content

Commit 9152bbb

Browse files
authored
Simplify Buffer.fill() API (NVIDIA#1366)
1 parent 2d39bec commit 9152bbb

File tree

3 files changed

+150
-138
lines changed

3 files changed

+150
-138
lines changed

cuda_core/cuda/core/experimental/_memory/_buffer.pyx

Lines changed: 56 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from __future__ import annotations
66

7-
from libc.stdint cimport uintptr_t, int64_t, uint64_t
7+
from libc.stdint cimport uintptr_t
88

99
from cuda.bindings cimport cydriver
1010
from cuda.core.experimental._memory._device_memory_resource cimport DeviceMemoryResource
@@ -14,8 +14,14 @@ from cuda.core.experimental._stream cimport Stream_accept, Stream
1414
from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN
1515

1616
import abc
17+
import sys
1718
from typing import TypeVar, Union
1819

20+
if sys.version_info >= (3, 12):
21+
from collections.abc import Buffer as BufferProtocol
22+
else:
23+
BufferProtocol = object
24+
1925
from cuda.core.experimental._dlpack import DLDeviceType, make_py_capsule
2026
from cuda.core.experimental._utils.cuda_utils import driver
2127

@@ -190,58 +196,85 @@ cdef class Buffer:
190196
s
191197
))
192198

193-
def fill(self, value: int, width: int, *, stream: Stream | GraphBuilder):
194-
"""Fill this buffer with a value pattern asynchronously on the given stream.
199+
def fill(self, value: int | BufferProtocol, *, stream: Stream | GraphBuilder):
200+
"""Fill this buffer with a repeating byte pattern.
195201
196202
Parameters
197203
----------
198-
value : int
199-
Integer value to fill the buffer with
200-
width : int
201-
Width in bytes for each element (must be 1, 2, or 4)
204+
value : int | :obj:`collections.abc.Buffer`
205+
- int: Must be in range [0, 256). Converted to 1 byte.
206+
- :obj:`collections.abc.Buffer`: Must be 1, 2, or 4 bytes.
202207
stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
203-
Keyword argument specifying the stream for the asynchronous fill
208+
Stream for the asynchronous fill operation.
204209
205210
Raises
206211
------
212+
TypeError
213+
If value is not an int and does not support the buffer protocol.
207214
ValueError
208-
If width is not 1, 2, or 4, if value is out of range for the width,
209-
or if buffer size is not divisible by width
215+
If value byte length is not 1, 2, or 4.
216+
If buffer size is not divisible by value byte length.
217+
OverflowError
218+
If int value is outside [0, 256).
210219
211220
"""
212221
cdef Stream s_stream = Stream_accept(stream)
213222
cdef unsigned char c_value8
214223
cdef unsigned short c_value16
215224
cdef unsigned int c_value32
216225
cdef size_t N
217-
218-
# Validate width
219-
if width not in (1, 2, 4):
220-
raise ValueError(f"width must be 1, 2, or 4, got {width}")
226+
cdef size_t width
227+
cdef unsigned int int_value
228+
229+
# Get fill pattern from value
230+
if isinstance(value, int):
231+
# We define the int input to mean a 1-byte pattern.
232+
# Match int.to_bytes(1, "little") behavior: raise OverflowError if not in [0, 256).
233+
if value < 0 or value >= 256:
234+
raise OverflowError("int value must be in range [0, 256)")
235+
width = 1
236+
int_value = <unsigned int>value
237+
else:
238+
try:
239+
mv = memoryview(value)
240+
except TypeError:
241+
raise TypeError(
242+
f"value must be an int or support the buffer protocol, got {type(value).__name__}"
243+
) from None
244+
width = mv.nbytes
245+
246+
# Validate width early to avoid copying/processing large invalid inputs.
247+
if width not in (1, 2, 4):
248+
raise ValueError(f"value must be 1, 2, or 4 bytes, got {width}")
249+
250+
# Convert to a 1-D view of bytes.
251+
#
252+
# Note: NumPy scalar memoryviews are 0-D, and int.from_bytes(mv, ...) errors with
253+
# "0-dim memory has no length". Casting to 'B' gives us a byte-addressable view.
254+
try:
255+
int_value = int.from_bytes(mv.cast("B"), "little")
256+
except TypeError:
257+
int_value = int.from_bytes(mv.tobytes(), "little")
221258

222259
# Validate buffer size modulus.
223260
cdef size_t buffer_size = self._size
224261
if buffer_size % width != 0:
225-
raise ValueError(f"buffer size ({buffer_size}) must be divisible by width ({width})")
226-
227-
# Map width (bytes) to bitwidth and validate value
228-
cdef int bitwidth = width * 8
229-
_validate_value_against_bitwidth(bitwidth, value, is_signed=False)
262+
raise ValueError(f"buffer size ({buffer_size}) must be divisible by {width}")
230263

231-
# Validate value fits in width and perform fill
264+
# Perform fill based on width
232265
cdef cydriver.CUstream s = s_stream._handle
233266
if width == 1:
234-
c_value8 = <unsigned char>value
267+
c_value8 = <unsigned char>int_value
235268
N = buffer_size
236269
with nogil:
237270
HANDLE_RETURN(cydriver.cuMemsetD8Async(<cydriver.CUdeviceptr>self._ptr, c_value8, N, s))
238271
elif width == 2:
239-
c_value16 = <unsigned short>value
272+
c_value16 = <unsigned short>int_value
240273
N = buffer_size // 2
241274
with nogil:
242275
HANDLE_RETURN(cydriver.cuMemsetD16Async(<cydriver.CUdeviceptr>self._ptr, c_value16, N, s))
243276
else: # width == 4
244-
c_value32 = <unsigned int>value
277+
c_value32 = <unsigned int>int_value
245278
N = buffer_size // 4
246279
with nogil:
247280
HANDLE_RETURN(cydriver.cuMemsetD32Async(<cydriver.CUdeviceptr>self._ptr, c_value32, N, s))
@@ -409,43 +442,3 @@ cdef class MemoryResource:
409442
and document the behavior.
410443
"""
411444
...
412-
413-
414-
# Helper Functions
415-
# ----------------
416-
cdef void _validate_value_against_bitwidth(int bitwidth, int64_t value, bint is_signed=False) except *:
417-
"""Validate that a value fits within the representable range for a given bitwidth.
418-
419-
Parameters
420-
----------
421-
bitwidth : int
422-
Number of bits (e.g., 8, 16, 32)
423-
value : int64_t
424-
Value to validate
425-
is_signed : bool, optional
426-
Whether the value is signed (default: False)
427-
428-
Raises
429-
------
430-
ValueError
431-
If value is outside the representable range for the bitwidth
432-
"""
433-
cdef int max_bits = bitwidth
434-
assert max_bits < 64, f"bitwidth ({max_bits}) must be less than 64"
435-
436-
cdef int64_t min_value
437-
cdef uint64_t max_value_unsigned
438-
cdef int64_t max_value
439-
440-
if is_signed:
441-
min_value = -(<int64_t>1 << (max_bits - 1))
442-
max_value = (<int64_t>1 << (max_bits - 1)) - 1
443-
else:
444-
min_value = 0
445-
max_value_unsigned = (<uint64_t>1 << max_bits) - 1
446-
max_value = <int64_t>max_value_unsigned
447-
448-
if not min_value <= value <= max_value:
449-
raise ValueError(
450-
f"value must be in range [{min_value}, {max_value}]"
451-
)

cuda_core/tests/test_graph_mem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def apply_kernels(mr, stream, out):
112112
# Fills out with 3
113113
def apply_kernels(mr, stream, out):
114114
buffer = mr.allocate(NBYTES, stream=stream)
115-
buffer.fill(3, width=1, stream=stream)
115+
buffer.fill(3, stream=stream)
116116
out.copy_from(buffer, stream=stream)
117117
buffer.close()
118118

cuda_core/tests/test_memory.py

Lines changed: 93 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -219,86 +219,105 @@ def test_buffer_copy_from():
219219
buffer_copy_from(DummyPinnedMemoryResource(device), device, check=True)
220220

221221

222-
def buffer_fill(dummy_mr: MemoryResource, device: Device, check=False):
223-
stream = device.create_stream()
224-
225-
# Test width=1 (byte fill)
226-
buffer1 = dummy_mr.allocate(size=1024)
227-
buffer1.fill(0x42, width=1, stream=stream)
228-
device.sync()
229-
230-
if check:
231-
ptr = ctypes.cast(buffer1.handle, ctypes.POINTER(ctypes.c_byte))
232-
for i in range(10):
233-
assert ptr[i] == 0x42
222+
def _bytes_repeat(pattern: bytes, size: int) -> bytes:
223+
assert len(pattern) > 0
224+
assert size % len(pattern) == 0
225+
return pattern * (size // len(pattern))
234226

235-
# Test error: invalid width
236-
for bad_width in [w for w in range(-10, 10) if w not in (1, 2, 4)]:
237-
with pytest.raises(ValueError, match="width must be 1, 2, or 4"):
238-
buffer1.fill(0x42, width=bad_width, stream=stream)
239227

240-
# Test error: value out of range for width=1
241-
for bad_value in [-42, -1, 256]:
242-
with pytest.raises(ValueError, match="value must be in range \\[0, 255\\]"):
243-
buffer1.fill(bad_value, width=1, stream=stream)
228+
def _pattern_bytes(value) -> bytes:
229+
if isinstance(value, int):
230+
return bytes([value])
231+
return bytes(memoryview(value).cast("B"))
244232

245-
# Test error: buffer size not divisible by width
246-
for bad_size in [1025, 1027, 1029, 1031]: # Not divisible by 2
247-
buffer_err = dummy_mr.allocate(size=1025)
248-
with pytest.raises(ValueError, match="must be divisible"):
249-
buffer_err.fill(0x1234, width=2, stream=stream)
250-
buffer_err.close()
251-
252-
buffer1.close()
253-
254-
# Test width=2 (16-bit fill)
255-
buffer2 = dummy_mr.allocate(size=1024) # Divisible by 2
256-
buffer2.fill(0x1234, width=2, stream=stream)
257-
device.sync()
258-
259-
if check:
260-
ptr = ctypes.cast(buffer2.handle, ctypes.POINTER(ctypes.c_uint16))
261-
for i in range(5):
262-
assert ptr[i] == 0x1234
263-
264-
# Test error: value out of range for width=2
265-
for bad_value in [-42, -1, 65536, 65537, 100000]:
266-
with pytest.raises(ValueError, match="value must be in range \\[0, 65535\\]"):
267-
buffer2.fill(bad_value, width=2, stream=stream)
268-
269-
buffer2.close()
270233

271-
# Test width=4 (32-bit fill)
272-
buffer4 = dummy_mr.allocate(size=1024) # Divisible by 4
273-
buffer4.fill(0xDEADBEEF, width=4, stream=stream)
274-
device.sync()
275-
276-
if check:
277-
ptr = ctypes.cast(buffer4.handle, ctypes.POINTER(ctypes.c_uint32))
278-
for i in range(5):
279-
assert ptr[i] == 0xDEADBEEF
280-
281-
# Test error: value out of range for width=4
282-
for bad_value in [-42, -1, 4294967296, 4294967297, 5000000000]:
283-
with pytest.raises(ValueError, match="value must be in range \\[0, 4294967295\\]"):
284-
buffer4.fill(bad_value, width=4, stream=stream)
285-
286-
# Test error: buffer size not divisible by width
287-
for bad_size in [1025, 1026, 1027, 1029, 1030, 1031]: # Not divisible by 4
288-
buffer_err2 = dummy_mr.allocate(size=bad_size)
289-
with pytest.raises(ValueError, match="must be divisible"):
290-
buffer_err2.fill(0xDEADBEEF, width=4, stream=stream)
291-
buffer_err2.close()
292-
293-
buffer4.close()
294-
295-
296-
def test_buffer_fill():
234+
@pytest.fixture(params=["device", "unified", "pinned"])
235+
def fill_env(request):
297236
device = Device()
298237
device.set_current()
299-
buffer_fill(DummyDeviceMemoryResource(device), device)
300-
buffer_fill(DummyUnifiedMemoryResource(device), device)
301-
buffer_fill(DummyPinnedMemoryResource(device), device, check=True)
238+
if request.param == "device":
239+
mr = DummyDeviceMemoryResource(device)
240+
elif request.param == "unified":
241+
mr = DummyUnifiedMemoryResource(device)
242+
else:
243+
mr = DummyPinnedMemoryResource(device)
244+
return device, mr
245+
246+
247+
_FILL_SIZE = 64 # Keep small; divisible by 1/2/4.
248+
249+
_FILL_CASES = [
250+
# int -> 1-byte pattern
251+
pytest.param(0x42, _FILL_SIZE, None, id="int-0x42"),
252+
pytest.param(-1, _FILL_SIZE, OverflowError, id="int-neg"),
253+
pytest.param(256, _FILL_SIZE, OverflowError, id="int-256"),
254+
pytest.param(1000, _FILL_SIZE, OverflowError, id="int-1000"),
255+
# bad type
256+
pytest.param("invalid", _FILL_SIZE, TypeError, id="bad-type-str"),
257+
# bytes-like patterns
258+
pytest.param(b"\x7f", _FILL_SIZE, None, id="bytes-1"),
259+
pytest.param(b"\x34\x12", _FILL_SIZE, None, id="bytes-2"),
260+
pytest.param(b"\xef\xbe\xad\xde", _FILL_SIZE, None, id="bytes-4"),
261+
pytest.param(b"\x34\x12", _FILL_SIZE + 1, ValueError, id="bytes-2-bad-size"),
262+
pytest.param(b"\xef\xbe\xad\xde", _FILL_SIZE + 2, ValueError, id="bytes-4-bad-size"),
263+
pytest.param(b"", _FILL_SIZE, ValueError, id="bytes-0"),
264+
pytest.param(b"\x01\x02\x03", _FILL_SIZE, ValueError, id="bytes-3"),
265+
]
266+
267+
if np is not None:
268+
_FILL_CASES.extend(
269+
[
270+
# 8-bit patterns
271+
pytest.param(np.uint8(0), _FILL_SIZE, None, id="np-uint8-0"),
272+
pytest.param(np.uint8(255), _FILL_SIZE, None, id="np-uint8-255"),
273+
pytest.param(np.int8(-1), _FILL_SIZE, None, id="np-int8--1"),
274+
pytest.param(np.int8(127), _FILL_SIZE, None, id="np-int8-127"),
275+
pytest.param(np.int8(-128), _FILL_SIZE, None, id="np-int8--128"),
276+
# 16-bit patterns
277+
pytest.param(np.uint16(0x1234), _FILL_SIZE, None, id="np-uint16-0x1234"),
278+
pytest.param(np.uint16(0xFFFF), _FILL_SIZE, None, id="np-uint16-0xFFFF"),
279+
pytest.param(np.int16(-1), _FILL_SIZE, None, id="np-int16--1"),
280+
pytest.param(np.int16(32767), _FILL_SIZE, None, id="np-int16-max"),
281+
pytest.param(np.int16(-32768), _FILL_SIZE, None, id="np-int16-min"),
282+
pytest.param(np.uint16(0x1234), _FILL_SIZE + 1, ValueError, id="np-uint16-bad-size"),
283+
# 32-bit patterns
284+
pytest.param(np.uint32(0xDEADBEEF), _FILL_SIZE, None, id="np-uint32-0xDEADBEEF"),
285+
pytest.param(np.uint32(0xFFFFFFFF), _FILL_SIZE, None, id="np-uint32-0xFFFFFFFF"),
286+
pytest.param(np.int32(-1), _FILL_SIZE, None, id="np-int32--1"),
287+
pytest.param(np.int32(2147483647), _FILL_SIZE, None, id="np-int32-max"),
288+
pytest.param(np.int32(-2147483648), _FILL_SIZE, None, id="np-int32-min"),
289+
pytest.param(np.uint32(0xDEADBEEF), _FILL_SIZE + 2, ValueError, id="np-uint32-bad-size"),
290+
# float32 (bit-pattern fill)
291+
pytest.param(np.float32(1.0), _FILL_SIZE, None, id="np-float32-1.0"),
292+
# 64-bit patterns should error (8-byte pattern)
293+
pytest.param(np.uint64(0), _FILL_SIZE, ValueError, id="np-uint64-err"),
294+
pytest.param(np.int64(0), _FILL_SIZE, ValueError, id="np-int64-err"),
295+
pytest.param(np.float64(0), _FILL_SIZE, ValueError, id="np-float64-err"),
296+
]
297+
)
298+
299+
300+
@pytest.mark.parametrize("value,size,exc", _FILL_CASES)
301+
def test_buffer_fill(fill_env, value, size, exc):
302+
device, mr = fill_env
303+
stream = device.create_stream()
304+
buffer = mr.allocate(size=size)
305+
try:
306+
if exc is not None:
307+
with pytest.raises(exc):
308+
buffer.fill(value, stream=stream)
309+
return
310+
311+
buffer.fill(value, stream=stream)
312+
device.sync()
313+
314+
# Verify contents only for host-accessible buffers.
315+
if buffer.is_host_accessible:
316+
pat = _pattern_bytes(value)
317+
got = ctypes.string_at(int(buffer.handle), size)
318+
assert got == _bytes_repeat(pat, size)
319+
finally:
320+
buffer.close()
302321

303322

304323
def buffer_close(dummy_mr: MemoryResource):

0 commit comments

Comments
 (0)