Skip to content

Commit db6118e

Browse files
authored
Add Buffer.fill() method for cuMemsetAsync support (NVIDIA#1314) (NVIDIA#1318)
* Add Buffer.fill() method for cuMemsetAsync support Implements Buffer.fill(value, width, *, stream) method that wraps cuMemsetD8Async, cuMemsetD16Async, and cuMemsetD32Async based on the width parameter (1, 2, or 4 bytes). - Add fill() method to Buffer class in _buffer.pyx - Support width=1 (byte), width=2 (16-bit), width=4 (32-bit) - Validate width, value range, and buffer size divisibility - Add comprehensive tests in test_memory.py - Tests cover all widths, error cases, and verification Part of issue NVIDIA#1314: CUDA Graph phase 3 - memcpy nodes * Add graph capture tests for Buffer.fill() Extend test_graph_alloc with 'fill' action parameter to test Buffer.fill() in graph capture mode. The test verifies graph capture for Buffer operations including copy_from, copy_to, fill, and kernel launch operations. Part of issue NVIDIA#1314 * Use cydriver directly in Buffer.fill() for efficiency - Replace Python driver module calls with direct cydriver calls - Use 'with nogil:' blocks around CUDA driver API calls - Use HANDLE_RETURN macro for error handling - Cast stream to Stream type to access _handle attribute - Improves performance by eliminating Python overhead * Use cydriver directly in Buffer.copy_to() and copy_from() for efficiency - Replace Python driver module calls with direct cydriver calls - Use 'with nogil:' blocks around CUDA driver API calls - Use HANDLE_RETURN macro for error handling - Cast stream to Stream type to access _handle attribute - Remove unused raise_if_driver_error import - Improves performance by eliminating Python overhead * Simplified argument validation logic in Buffer.fill. * Refactor Buffer.fill() to use helper function for value validation - Add _validate_value_against_bitwidth helper function - Move helper function to end of file as cdef function - Use 64-bit platform integers (int64_t/uint64_t) instead of Python ints - Add assertion that bitwidth < 64 - Remove magic numbers from fill() method - Update tests to match new error message format
1 parent 95d5844 commit db6118e

3 files changed

Lines changed: 228 additions & 22 deletions

File tree

cuda_core/cuda/core/experimental/_memory/_buffer.pyx

Lines changed: 117 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44

55
from __future__ import annotations
66

7-
from libc.stdint cimport uintptr_t
7+
from libc.stdint cimport uintptr_t, int64_t, uint64_t
88

9+
from cuda.bindings cimport cydriver
910
from cuda.core.experimental._memory._device_memory_resource cimport DeviceMemoryResource
1011
from cuda.core.experimental._memory._ipc cimport IPCBufferDescriptor, IPCDataForBuffer
1112
from cuda.core.experimental._memory cimport _ipc
1213
from cuda.core.experimental._stream cimport Stream_accept, Stream
13-
from cuda.core.experimental._utils.cuda_utils cimport (
14-
_check_driver_error as raise_if_driver_error,
15-
)
14+
from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN
1615

1716
import abc
1817
from typing import TypeVar, Union
@@ -137,6 +136,7 @@ cdef class Buffer:
137136

138137
"""
139138
stream = Stream_accept(stream)
139+
cdef Stream s_stream = <Stream>stream
140140
cdef size_t src_size = self._size
141141

142142
if dst is None:
@@ -150,8 +150,14 @@ cdef class Buffer:
150150
raise ValueError( "buffer sizes mismatch between src and dst (sizes "
151151
f"are: src={src_size}, dst={dst_size})"
152152
)
153-
err, = driver.cuMemcpyAsync(dst._ptr, self._ptr, src_size, stream.handle)
154-
raise_if_driver_error(err)
153+
cdef cydriver.CUstream s = s_stream._handle
154+
with nogil:
155+
HANDLE_RETURN(cydriver.cuMemcpyAsync(
156+
<cydriver.CUdeviceptr>dst._ptr,
157+
<cydriver.CUdeviceptr>self._ptr,
158+
src_size,
159+
s
160+
))
155161
return dst
156162

157163
def copy_from(self, src: Buffer, *, stream: Stream | GraphBuilder):
@@ -167,15 +173,78 @@ cdef class Buffer:
167173
168174
"""
169175
stream = Stream_accept(stream)
176+
cdef Stream s_stream = <Stream>stream
170177
cdef size_t dst_size = self._size
171178
cdef size_t src_size = src._size
172179

173180
if src_size != dst_size:
174181
raise ValueError( "buffer sizes mismatch between src and dst (sizes "
175182
f"are: src={src_size}, dst={dst_size})"
176183
)
177-
err, = driver.cuMemcpyAsync(self._ptr, src._ptr, dst_size, stream.handle)
178-
raise_if_driver_error(err)
184+
cdef cydriver.CUstream s = s_stream._handle
185+
with nogil:
186+
HANDLE_RETURN(cydriver.cuMemcpyAsync(
187+
<cydriver.CUdeviceptr>self._ptr,
188+
<cydriver.CUdeviceptr>src._ptr,
189+
dst_size,
190+
s
191+
))
192+
193+
def fill(self, value: int, width: int, *, stream: Stream | GraphBuilder):
194+
"""Fill this buffer with a value pattern asynchronously on the given stream.
195+
196+
Parameters
197+
----------
198+
value : int
199+
Integer value to fill the buffer with
200+
width : int
201+
Width in bytes for each element (must be 1, 2, or 4)
202+
stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
203+
Keyword argument specifying the stream for the asynchronous fill
204+
205+
Raises
206+
------
207+
ValueError
208+
If width is not 1, 2, or 4, if value is out of range for the width,
209+
or if buffer size is not divisible by width
210+
211+
"""
212+
cdef Stream s_stream = Stream_accept(stream)
213+
cdef unsigned char c_value8
214+
cdef unsigned short c_value16
215+
cdef unsigned int c_value32
216+
cdef size_t N
217+
218+
# Validate width
219+
if width not in (1, 2, 4):
220+
raise ValueError(f"width must be 1, 2, or 4, got {width}")
221+
222+
# Validate buffer size modulus.
223+
cdef size_t buffer_size = self._size
224+
if buffer_size % width != 0:
225+
raise ValueError(f"buffer size ({buffer_size}) must be divisible by width ({width})")
226+
227+
# Map width (bytes) to bitwidth and validate value
228+
cdef int bitwidth = width * 8
229+
_validate_value_against_bitwidth(bitwidth, value, is_signed=False)
230+
231+
# Validate value fits in width and perform fill
232+
cdef cydriver.CUstream s = s_stream._handle
233+
if width == 1:
234+
c_value8 = <unsigned char>value
235+
N = buffer_size
236+
with nogil:
237+
HANDLE_RETURN(cydriver.cuMemsetD8Async(<cydriver.CUdeviceptr>self._ptr, c_value8, N, s))
238+
elif width == 2:
239+
c_value16 = <unsigned short>value
240+
N = buffer_size // 2
241+
with nogil:
242+
HANDLE_RETURN(cydriver.cuMemsetD16Async(<cydriver.CUdeviceptr>self._ptr, c_value16, N, s))
243+
else: # width == 4
244+
c_value32 = <unsigned int>value
245+
N = buffer_size // 4
246+
with nogil:
247+
HANDLE_RETURN(cydriver.cuMemsetD32Async(<cydriver.CUdeviceptr>self._ptr, c_value32, N, s))
179248

180249
def __dlpack__(
181250
self,
@@ -340,3 +409,43 @@ cdef class MemoryResource:
340409
and document the behavior.
341410
"""
342411
...
412+
413+
414+
# Helper Functions
415+
# ----------------
416+
cdef void _validate_value_against_bitwidth(int bitwidth, int64_t value, bint is_signed=False) except *:
417+
"""Validate that a value fits within the representable range for a given bitwidth.
418+
419+
Parameters
420+
----------
421+
bitwidth : int
422+
Number of bits (e.g., 8, 16, 32)
423+
value : int64_t
424+
Value to validate
425+
is_signed : bool, optional
426+
Whether the value is signed (default: False)
427+
428+
Raises
429+
------
430+
ValueError
431+
If value is outside the representable range for the bitwidth
432+
"""
433+
cdef int max_bits = bitwidth
434+
assert max_bits < 64, f"bitwidth ({max_bits}) must be less than 64"
435+
436+
cdef int64_t min_value
437+
cdef uint64_t max_value_unsigned
438+
cdef int64_t max_value
439+
440+
if is_signed:
441+
min_value = -(<int64_t>1 << (max_bits - 1))
442+
max_value = (<int64_t>1 << (max_bits - 1)) - 1
443+
else:
444+
min_value = 0
445+
max_value_unsigned = (<uint64_t>1 << max_bits) - 1
446+
max_value = <int64_t>max_value_unsigned
447+
448+
if not min_value <= value <= max_value:
449+
raise ValueError(
450+
f"value must be in range [{min_value}, {max_value}]"
451+
)

cuda_core/tests/test_graph_mem.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,14 @@ def free(self, buffers):
7575

7676

7777
@pytest.mark.parametrize("mode", ["no_graph", "global", "thread_local", "relaxed"])
78-
def test_graph_alloc(mempool_device, mode):
79-
"""Test basic graph capture with memory allocated and deallocated by GraphMemoryResource."""
78+
@pytest.mark.parametrize("action", ["incr", "fill"])
79+
def test_graph_alloc(mempool_device, mode, action):
80+
"""Test basic graph capture with memory allocated and deallocated by
81+
GraphMemoryResource.
82+
83+
This test verifies graph capture for Buffer operations including copy_from,
84+
copy_to, fill, and kernel launch operations.
85+
"""
8086
NBYTES = 64
8187
device = mempool_device
8288
stream = device.create_stream()
@@ -93,14 +99,22 @@ def test_graph_alloc(mempool_device, mode):
9399
config = LaunchConfig(grid=1, block=1)
94100
launch(stream, config, set_zero, out, NBYTES)
95101

96-
# Increments out by 3
97-
def apply_kernels(mr, stream, out):
98-
buffer = mr.allocate(NBYTES, stream=stream)
99-
buffer.copy_from(out, stream=stream)
100-
for kernel in [add_one, add_one, add_one]:
101-
launch(stream, config, kernel, buffer, NBYTES)
102-
out.copy_from(buffer, stream=stream)
103-
buffer.close()
102+
if action == "incr":
103+
# Increments out by 3
104+
def apply_kernels(mr, stream, out):
105+
buffer = mr.allocate(NBYTES, stream=stream)
106+
buffer.copy_from(out, stream=stream)
107+
for kernel in [add_one, add_one, add_one]:
108+
launch(stream, config, kernel, buffer, NBYTES)
109+
out.copy_from(buffer, stream=stream)
110+
buffer.close()
111+
elif action == "fill":
112+
# Fills out with 3
113+
def apply_kernels(mr, stream, out):
114+
buffer = mr.allocate(NBYTES, stream=stream)
115+
buffer.fill(3, width=1, stream=stream)
116+
out.copy_from(buffer, stream=stream)
117+
buffer.close()
104118

105119
# Apply kernels, with or without graph capture.
106120
if mode == "no_graph":
@@ -121,10 +135,11 @@ def apply_kernels(mr, stream, out):
121135
assert compare_buffer_to_constant(out, 3)
122136

123137
# Second launch.
124-
graph.upload(stream)
125-
graph.launch(stream)
126-
stream.sync()
127-
assert compare_buffer_to_constant(out, 6)
138+
if action == "incr":
139+
graph.upload(stream)
140+
graph.launch(stream)
141+
stream.sync()
142+
assert compare_buffer_to_constant(out, 6)
128143

129144

130145
@pytest.mark.skipif(IS_WINDOWS or IS_WSL, reason="auto_free_on_launch not supported on Windows")

cuda_core/tests/test_memory.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,88 @@ def test_buffer_copy_from():
219219
buffer_copy_from(DummyPinnedMemoryResource(device), device, check=True)
220220

221221

222+
def buffer_fill(dummy_mr: MemoryResource, device: Device, check=False):
223+
stream = device.create_stream()
224+
225+
# Test width=1 (byte fill)
226+
buffer1 = dummy_mr.allocate(size=1024)
227+
buffer1.fill(0x42, width=1, stream=stream)
228+
device.sync()
229+
230+
if check:
231+
ptr = ctypes.cast(buffer1.handle, ctypes.POINTER(ctypes.c_byte))
232+
for i in range(10):
233+
assert ptr[i] == 0x42
234+
235+
# Test error: invalid width
236+
for bad_width in [w for w in range(-10, 10) if w not in (1, 2, 4)]:
237+
with pytest.raises(ValueError, match="width must be 1, 2, or 4"):
238+
buffer1.fill(0x42, width=bad_width, stream=stream)
239+
240+
# Test error: value out of range for width=1
241+
for bad_value in [-42, -1, 256]:
242+
with pytest.raises(ValueError, match="value must be in range \\[0, 255\\]"):
243+
buffer1.fill(bad_value, width=1, stream=stream)
244+
245+
# Test error: buffer size not divisible by width
246+
for bad_size in [1025, 1027, 1029, 1031]: # Not divisible by 2
247+
buffer_err = dummy_mr.allocate(size=1025)
248+
with pytest.raises(ValueError, match="must be divisible"):
249+
buffer_err.fill(0x1234, width=2, stream=stream)
250+
buffer_err.close()
251+
252+
buffer1.close()
253+
254+
# Test width=2 (16-bit fill)
255+
buffer2 = dummy_mr.allocate(size=1024) # Divisible by 2
256+
buffer2.fill(0x1234, width=2, stream=stream)
257+
device.sync()
258+
259+
if check:
260+
ptr = ctypes.cast(buffer2.handle, ctypes.POINTER(ctypes.c_uint16))
261+
for i in range(5):
262+
assert ptr[i] == 0x1234
263+
264+
# Test error: value out of range for width=2
265+
for bad_value in [-42, -1, 65536, 65537, 100000]:
266+
with pytest.raises(ValueError, match="value must be in range \\[0, 65535\\]"):
267+
buffer2.fill(bad_value, width=2, stream=stream)
268+
269+
buffer2.close()
270+
271+
# Test width=4 (32-bit fill)
272+
buffer4 = dummy_mr.allocate(size=1024) # Divisible by 4
273+
buffer4.fill(0xDEADBEEF, width=4, stream=stream)
274+
device.sync()
275+
276+
if check:
277+
ptr = ctypes.cast(buffer4.handle, ctypes.POINTER(ctypes.c_uint32))
278+
for i in range(5):
279+
assert ptr[i] == 0xDEADBEEF
280+
281+
# Test error: value out of range for width=4
282+
for bad_value in [-42, -1, 4294967296, 4294967297, 5000000000]:
283+
with pytest.raises(ValueError, match="value must be in range \\[0, 4294967295\\]"):
284+
buffer4.fill(bad_value, width=4, stream=stream)
285+
286+
# Test error: buffer size not divisible by width
287+
for bad_size in [1025, 1026, 1027, 1029, 1030, 1031]: # Not divisible by 4
288+
buffer_err2 = dummy_mr.allocate(size=bad_size)
289+
with pytest.raises(ValueError, match="must be divisible"):
290+
buffer_err2.fill(0xDEADBEEF, width=4, stream=stream)
291+
buffer_err2.close()
292+
293+
buffer4.close()
294+
295+
296+
def test_buffer_fill():
297+
device = Device()
298+
device.set_current()
299+
buffer_fill(DummyDeviceMemoryResource(device), device)
300+
buffer_fill(DummyUnifiedMemoryResource(device), device)
301+
buffer_fill(DummyPinnedMemoryResource(device), device, check=True)
302+
303+
222304
def buffer_close(dummy_mr: MemoryResource):
223305
buffer = dummy_mr.allocate(size=1024)
224306
buffer.close()

0 commit comments

Comments
 (0)