Skip to content

Commit af10d49

Browse files
committed
fix
1 parent ca36936 commit af10d49

File tree

5 files changed

+173
-31
lines changed

5 files changed

+173
-31
lines changed

cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from cpython.mem cimport PyMem_Malloc, PyMem_Free
6-
from libc.stdint cimport (intptr_t,
6+
from libc.stdint cimport (intptr_t, uintptr_t,
77
int8_t, int16_t, int32_t, int64_t,
88
uint8_t, uint16_t, uint32_t, uint64_t,)
99
from libcpp cimport bool as cpp_bool
1010
from libcpp.complex cimport complex as cpp_complex
1111
from libcpp cimport nullptr
1212
from libcpp cimport vector
1313

14+
from cuda.bindings cimport cydriver
15+
from cuda.core.experimental._memoryview cimport _MDSPAN
16+
1417
import ctypes
1518

1619
import numpy
1720

1821
from cuda.core.experimental._memory import Buffer
1922
from cuda.core.experimental._utils.cuda_utils import driver
20-
from cuda.bindings cimport cydriver
2123

2224

2325
ctypedef cpp_complex.complex[float] cpp_single_complex
@@ -265,6 +267,8 @@ cdef class ParamHolder:
265267
cdef size_t i
266268
cdef int not_prepared
267269
cdef object arg_type
270+
cdef _MDSPAN mdspan_obj
271+
cdef uintptr_t mdspan_ptr
268272
self.data = vector.vector[voidptr](n_args, nullptr)
269273
self.data_addresses = vector.vector[voidptr](n_args)
270274
for i, arg in enumerate(kernel_args):
@@ -296,6 +300,14 @@ cdef class ParamHolder:
296300
elif arg_type is complex:
297301
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
298302
continue
303+
elif arg_type is _MDSPAN:
304+
# The mdspan struct is allocated on the host and owned by the CuPy mdspan object.
305+
# We pass a pointer to the struct so the driver can copy it by value to the kernel.
306+
# Access _ptr at C level to avoid creating a temporary Python object.
307+
mdspan_obj = <_MDSPAN>arg
308+
mdspan_ptr = mdspan_obj._ptr
309+
self.data_addresses[i] = <void*>mdspan_ptr
310+
continue
299311

300312
not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
301313
if not_prepared:
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from libc.stdint cimport uintptr_t
2+
3+
4+
cdef class _MDSPAN:
5+
cdef:
6+
# this must be a pointer to a host mdspan object
7+
readonly uintptr_t _ptr
8+
# if the host mdspan is exported from any Python object,
9+
# we need to keep a reference to that object alive
10+
readonly object _exporting_obj

cuda_core/cuda/core/experimental/_memoryview.pyx

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from libc.stdint cimport uintptr_t
6+
57
from ._dlpack cimport *
8+
from cuda.core.experimental._utils cimport cuda_utils
69

710
import functools
811
import warnings
@@ -11,12 +14,26 @@ from typing import Optional
1114
import numpy
1215

1316
from cuda.core.experimental._utils.cuda_utils import handle_return, driver
14-
from cuda.core.experimental._utils cimport cuda_utils
1517

1618

1719
# TODO(leofang): support NumPy structured dtypes
1820

1921

22+
cdef class _MDSPAN:
23+
24+
def __cinit__(self):
25+
self._ptr = 0
26+
27+
def __init__(self, uintptr_t ptr, object obj=None):
28+
self._ptr = ptr
29+
self._exporting_obj = obj
30+
31+
def __dealloc__(self):
32+
self._ptr = 0
33+
self._exporting_obj = None
34+
35+
36+
2037
cdef class StridedMemoryView:
2138
"""A dataclass holding metadata of a strided dense array/tensor.
2239
@@ -98,6 +115,7 @@ cdef class StridedMemoryView:
98115
# this flag helps prevent unnecessary recompuation of _strides
99116
bint _strides_init
100117
object _dtype
118+
_MDSPAN _mdspan
101119

102120
def __init__(self, obj: object = None, stream_ptr: int | None = None) -> None:
103121
cdef str clsname = self.__class__.__name__
@@ -224,6 +242,27 @@ cdef class StridedMemoryView:
224242
self._dtype = numpy.dtype(self.metadata["typestr"])
225243
return self._dtype
226244

245+
@property
246+
def as_mdspan(self) -> _MDSPAN:
247+
"""A C++ mdspan view of the tensor.
248+
249+
Returns
250+
-------
251+
mdspan : _MDSPAN
252+
"""
253+
if self._mdspan is None:
254+
arr = self.exporting_obj
255+
module = self.exporting_obj.__class__.__module__.split(".")[0]
256+
if module == "cupy":
257+
mdspan = arr.mdspan
258+
#mdspan = arr.cstruct
259+
self._mdspan = _MDSPAN(<uintptr_t>(mdspan.ptr), mdspan)
260+
else:
261+
raise NotImplementedError(
262+
f"as_mdspan is not implemented for objects from module '{module}'"
263+
)
264+
return self._mdspan
265+
227266
def __repr__(self):
228267
return (f"StridedMemoryView(ptr={self.ptr},\n"
229268
+ f" shape={self.shape},\n"
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import cupy as cp
2+
from cuda.core.experimental import Program, Device, LaunchConfig, launch
3+
from cuda.core.experimental.utils import StridedMemoryView
4+
5+
6+
code = r"""
7+
extern "C"
8+
__global__ void debug_cupy_arr(const float* arr) {
9+
printf("ptr: %p\n", arr);
10+
}
11+
12+
"""
13+
14+
15+
ker = cp.RawKernel(code, 'debug_cupy_arr')
16+
arr = cp.array([1, 2, 3], dtype=cp.float32)
17+
print(f"arr device ptr: {arr.data.ptr:#x}")
18+
ker((1,), (1,), (arr,))
19+
cp.cuda.Device().synchronize()
20+
21+
22+
23+
mod = Program(code, code_type='c++')
24+
obj = mod.compile(target_type='cubin')
25+
ker = obj.get_kernel('debug_cupy_arr')
26+
dev = Device()
27+
dev.set_current()
28+
29+
#launch(dev.default_stream, LaunchConfig(grid=1, block=1), ker, arr.data.ptr)
30+
launch(dev.default_stream, LaunchConfig(grid=1, block=1), ker, StridedMemoryView(arr, stream_ptr=-1).as_mdspan)
31+
dev.default_stream.sync()
32+
33+
print("Done.")

cuda_core/examples/mdspan_verify_args.py

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os, sys
2222
import cupy as cp
2323
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch
24+
from cuda.core.experimental.utils import StridedMemoryView
2425

2526
# prepare include
2627
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME"))
@@ -42,37 +43,81 @@
4243
code_verify = """
4344
#include <cuda/std/mdspan>
4445
46+
typedef struct {
47+
void* ptr;
48+
size_t ext1;
49+
size_t ext2;
50+
} mdspan_view_t;
51+
52+
4553
// Kernel to verify layout_right (C-order) mdspan arguments
4654
template<typename T>
4755
__global__ void verify_mdspan_layout_right(
48-
cuda::std::mdspan<T, cuda::std::extents<size_t, cuda::std::dynamic_extent, cuda::std::dynamic_extent>, cuda::std::layout_right> arr
56+
mdspan_view_t arr
4957
) {
5058
// Only thread 0 prints to avoid cluttered output
5159
if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) {
5260
printf("=== layout_right (C-order) mdspan ===\\n");
53-
printf("Data pointer: %p\\n", arr.data_handle());
54-
printf("Extent 0 (rows): %zu\\n", arr.extent(0));
55-
printf("Extent 1 (cols): %zu\\n", arr.extent(1));
56-
printf("Size: %zu\\n", arr.size());
57-
58-
// For layout_right, strides are implicit but we can query them
59-
printf("Stride 0: %zu\\n", arr.stride(0));
60-
printf("Stride 1: %zu\\n", arr.stride(1));
61-
62-
// Verify memory layout: for layout_right (C-order)
63-
// stride(0) should equal extent(1), stride(1) should be 1
64-
printf("Expected stride(0) = extent(1): %s\\n",
65-
(arr.stride(0) == arr.extent(1)) ? "PASS" : "FAIL");
66-
printf("Expected stride(1) = 1: %s\\n",
67-
(arr.stride(1) == 1) ? "PASS" : "FAIL");
68-
69-
// Test element access
70-
if (arr.extent(0) > 0 && arr.extent(1) > 0) {
71-
printf("First element arr(0,0): %f\\n", static_cast<float>(arr(0, 0)));
72-
}
61+
printf("sizeof(mdspan_view_t): %llu\\n", sizeof(arr));
62+
printf("view - ptr: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ptr);
63+
printf("view2 : %p\\n", *(void**)((char*)(&arr) + 0));
64+
printf("view - ext1: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ext1);
65+
printf("view - ext2: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ext2);
7366
}
7467
}
7568
69+
// // Kernel to verify layout_right (C-order) mdspan arguments
70+
//
71+
// typedef struct {
72+
// void* ptr;
73+
// void* ext1;
74+
// void* ext2;
75+
// } mdspan_view_t;
76+
//
77+
// template<typename T>
78+
// __global__ void verify_mdspan_layout_right(
79+
// cuda::std::mdspan<T, cuda::std::extents<size_t, cuda::std::dynamic_extent, cuda::std::dynamic_extent>, cuda::std::layout_right> arr
80+
// ) {
81+
// // Only thread 0 prints to avoid cluttered output
82+
// if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) {
83+
// printf("=== layout_right (C-order) mdspan ===\\n");
84+
// printf("sizeof(mdspan): %llu\\n", sizeof(arr));
85+
// printf("view - ptr: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ptr);
86+
// printf("view2 : %p\\n", (void**)((char*)(&arr) + 0));
87+
// //printf("view - ext1: %llu\\n", *((size_t*)(reinterpret_cast<mdspan_view_t*>(&arr)->ext1)));
88+
// //printf("view - ext2: %llu\\n", *((size_t*)(reinterpret_cast<mdspan_view_t*>(&arr)->ext2)));
89+
// printf("view - ext1: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ext1);
90+
// printf("view - ext2: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ext2);
91+
//
92+
// printf("Data pointer: %p\\n", arr.data_handle());
93+
// printf("Data pointer (actual): %p\\n", (void*)((char*)(&arr) + 0));
94+
// printf("Data pointer (actual): %p\\n", addressof(arr));
95+
// printf("Extent 0 (rows): %llu\\n", arr.extent(0));
96+
// printf("Extent 1 (cols): %llu\\n", arr.extent(1));
97+
// printf("Extent 0 (rows) (actual): %llu\\n", (size_t)(*((char*)(&arr) + 8)));
98+
// printf("Extent 1 (cols) (actual): %llu\\n", (size_t)(*((char*)(&arr) + 16)));
99+
// printf("Size: %zu\\n", arr.size());
100+
//
101+
// // For layout_right, strides are implicit but we can query them
102+
// printf("Stride 0: %llu\\n", arr.stride(0));
103+
// printf("Stride 1: %llu\\n", arr.stride(1));
104+
// printf("Stride 0 (actual): %llu\\n", (size_t)((char*)(&arr) + 24));
105+
// printf("Stride 1 (actual): %llu\\n", (size_t)((char*)(&arr) + 32));
106+
//
107+
// // Verify memory layout: for layout_right (C-order)
108+
// // stride(0) should equal extent(1), stride(1) should be 1
109+
// printf("Expected stride(0) = extent(1): %s\\n",
110+
// (arr.stride(0) == arr.extent(1)) ? "PASS" : "FAIL");
111+
// printf("Expected stride(1) = 1: %s\\n",
112+
// (arr.stride(1) == 1) ? "PASS" : "FAIL");
113+
//
114+
// // Test element access
115+
// if (arr.extent(0) > 0 && arr.extent(1) > 0) {
116+
// printf("First element arr(0,0): %f\\n", static_cast<float>(arr(0, 0)));
117+
// }
118+
// }
119+
// }
120+
76121
// Kernel to verify layout_left (F-order) mdspan arguments
77122
template<typename T>
78123
__global__ void verify_mdspan_layout_left(
@@ -162,10 +207,13 @@ def prepare_mdspan_args_layout_right(arr, dtype, shape):
162207
tuple
163208
Arguments to pass to the kernel (needs investigation)
164209
"""
165-
data_ptr = arr.data.ptr
166-
rows, cols = shape
167-
# TODO: Determine exact argument structure
168-
return (data_ptr, rows, cols)
210+
#obj = arr.mdspan
211+
#print(f"{hex(obj.ptr)=}, {obj.ptr=}")
212+
#return (obj.ptr,)
213+
214+
obj = StridedMemoryView(arr, stream_ptr=-1).as_mdspan
215+
print(f"{hex(obj._ptr)=}, {obj._ptr=}, type={type(obj)}")
216+
return (obj,)
169217

170218

171219
def prepare_mdspan_args_layout_left(arr, dtype, shape):
@@ -266,7 +314,7 @@ def verify_layout_right():
266314

267315
# Verify array is in C-order
268316
assert arr.flags['C_CONTIGUOUS']
269-
317+
print(f"Array pointer: {hex(arr.data.ptr)}")
270318
print(f"Array shape: {arr.shape}")
271319
print(f"Array strides (bytes): {arr.strides}")
272320
print(f"Array strides (elements): ({arr.strides[0]//arr.itemsize}, {arr.strides[1]//arr.itemsize})")
@@ -282,8 +330,8 @@ def verify_layout_right():
282330
config = LaunchConfig(grid=1, block=1)
283331

284332
# TODO: Launch kernel with proper mdspan arguments
285-
# launch(s, config, ker, *args)
286-
# s.sync()
333+
launch(s, config, ker, *args)
334+
s.sync()
287335

288336
print("Verification kernel prepared (not executed)")
289337
print()

0 commit comments

Comments
 (0)