Skip to content

Commit fa8f699

Browse files
committed
fix
1 parent ca36936 commit fa8f699

File tree

4 files changed

+119
-14
lines changed

4 files changed

+119
-14
lines changed

cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from cpython.mem cimport PyMem_Malloc, PyMem_Free
6-
from libc.stdint cimport (intptr_t,
6+
from libc.stdint cimport (intptr_t, uintptr_t,
77
int8_t, int16_t, int32_t, int64_t,
88
uint8_t, uint16_t, uint32_t, uint64_t,)
99
from libcpp cimport bool as cpp_bool
1010
from libcpp.complex cimport complex as cpp_complex
1111
from libcpp cimport nullptr
1212
from libcpp cimport vector
1313

14+
from cuda.bindings cimport cydriver
15+
from cuda.core.experimental._memoryview cimport _MDSPAN
16+
1417
import ctypes
1518

1619
import numpy
1720

1821
from cuda.core.experimental._memory import Buffer
1922
from cuda.core.experimental._utils.cuda_utils import driver
20-
from cuda.bindings cimport cydriver
2123

2224

2325
ctypedef cpp_complex.complex[float] cpp_single_complex
@@ -296,6 +298,12 @@ cdef class ParamHolder:
296298
elif arg_type is complex:
297299
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
298300
continue
301+
elif arg_type is _MDSPAN:
302+
# The mdspan struct is allocated on the host and owned by the CuPy mdspan object.
303+
# We pass a pointer to the struct so the driver can copy it by value to the kernel.
304+
# Access _ptr at C level to avoid creating a temporary Python object.
305+
self.data_addresses[i] = <void*>((<_MDSPAN>arg)._ptr)
306+
continue
299307

300308
not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
301309
if not_prepared:
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from libc.stdint cimport uintptr_t
2+
3+
4+
cdef class _MDSPAN:
5+
cdef:
6+
# this must be a pointer to a host mdspan object
7+
readonly uintptr_t _ptr
8+
# if the host mdspan is exported from any Python object,
9+
# we need to keep a reference to that object alive
10+
readonly object _exporting_obj

cuda_core/cuda/core/experimental/_memoryview.pyx

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from libc.stdint cimport uintptr_t
6+
57
from ._dlpack cimport *
8+
from cuda.core.experimental._utils cimport cuda_utils
69

710
import functools
811
import warnings
@@ -11,12 +14,26 @@ from typing import Optional
1114
import numpy
1215

1316
from cuda.core.experimental._utils.cuda_utils import handle_return, driver
14-
from cuda.core.experimental._utils cimport cuda_utils
1517

1618

1719
# TODO(leofang): support NumPy structured dtypes
1820

1921

22+
cdef class _MDSPAN:
23+
24+
def __cinit__(self):
25+
self._ptr = 0
26+
27+
def __init__(self, uintptr_t ptr, object obj=None):
28+
self._ptr = ptr
29+
self._exporting_obj = obj
30+
31+
def __dealloc__(self):
32+
self._ptr = 0
33+
self._exporting_obj = None
34+
35+
36+
2037
cdef class StridedMemoryView:
2138
"""A dataclass holding metadata of a strided dense array/tensor.
2239
@@ -98,6 +115,7 @@ cdef class StridedMemoryView:
98115
# this flag helps prevent unnecessary recompuation of _strides
99116
bint _strides_init
100117
object _dtype
118+
_MDSPAN _mdspan
101119

102120
def __init__(self, obj: object = None, stream_ptr: int | None = None) -> None:
103121
cdef str clsname = self.__class__.__name__
@@ -224,6 +242,27 @@ cdef class StridedMemoryView:
224242
self._dtype = numpy.dtype(self.metadata["typestr"])
225243
return self._dtype
226244

245+
@property
246+
def as_mdspan(self) -> _MDSPAN:
247+
"""A C++ mdspan view of the tensor.
248+
249+
Returns
250+
-------
251+
mdspan : _MDSPAN
252+
"""
253+
if self._mdspan is None:
254+
arr = self.exporting_obj
255+
module = self.exporting_obj.__class__.__module__.split(".")[0]
256+
if module == "cupy":
257+
mdspan = arr.mdspan
258+
#mdspan = arr.cstruct
259+
self._mdspan = _MDSPAN(<uintptr_t>(mdspan.ptr), mdspan)
260+
else:
261+
raise NotImplementedError(
262+
f"as_mdspan is not implemented for objects from module '{module}'"
263+
)
264+
return self._mdspan
265+
227266
def __repr__(self):
228267
return (f"StridedMemoryView(ptr={self.ptr},\n"
229268
+ f" shape={self.shape},\n"

cuda_core/examples/mdspan_verify_args.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os, sys
2222
import cupy as cp
2323
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch
24+
from cuda.core.experimental.utils import StridedMemoryView
2425

2526
# prepare include
2627
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME"))
@@ -42,22 +43,66 @@
4243
code_verify = """
4344
#include <cuda/std/mdspan>
4445
46+
// typedef struct {
47+
// void* ptr;
48+
// size_t ext1;
49+
// size_t ext2;
50+
// } mdspan_view_t;
51+
//
52+
//
53+
// // Kernel to verify layout_right (C-order) mdspan arguments
54+
// template<typename T>
55+
// __global__ void verify_mdspan_layout_right(
56+
// mdspan_view_t arr
57+
// ) {
58+
// // Only thread 0 prints to avoid cluttered output
59+
// if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) {
60+
// printf("=== layout_right (C-order) mdspan ===\\n");
61+
// printf("sizeof(mdspan_view_t): %llu\\n", sizeof(arr));
62+
// printf("view - ptr: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ptr);
63+
// printf("view2 : %p\\n", *(void**)((char*)(&arr) + 0));
64+
// printf("view - ext1: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ext1);
65+
// printf("view - ext2: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ext2);
66+
// }
67+
// }
68+
4569
// Kernel to verify layout_right (C-order) mdspan arguments
70+
71+
typedef struct {
72+
void* ptr;
73+
void* ext1;
74+
void* ext2;
75+
} mdspan_view_t;
76+
4677
template<typename T>
4778
__global__ void verify_mdspan_layout_right(
4879
cuda::std::mdspan<T, cuda::std::extents<size_t, cuda::std::dynamic_extent, cuda::std::dynamic_extent>, cuda::std::layout_right> arr
4980
) {
5081
// Only thread 0 prints to avoid cluttered output
5182
if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) {
5283
printf("=== layout_right (C-order) mdspan ===\\n");
84+
printf("sizeof(mdspan): %llu\\n", sizeof(arr));
85+
printf("view - ptr: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ptr);
86+
printf("view2 : %p\\n", (void**)((char*)(&arr) + 0));
87+
//printf("view - ext1: %llu\\n", *((size_t*)(reinterpret_cast<mdspan_view_t*>(&arr)->ext1)));
88+
//printf("view - ext2: %llu\\n", *((size_t*)(reinterpret_cast<mdspan_view_t*>(&arr)->ext2)));
89+
printf("view - ext1: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ext1);
90+
printf("view - ext2: %p\\n", reinterpret_cast<mdspan_view_t*>(&arr)->ext2);
91+
5392
printf("Data pointer: %p\\n", arr.data_handle());
54-
printf("Extent 0 (rows): %zu\\n", arr.extent(0));
55-
printf("Extent 1 (cols): %zu\\n", arr.extent(1));
93+
printf("Data pointer (actual): %p\\n", (void*)((char*)(&arr) + 0));
94+
printf("Data pointer (actual): %p\\n", addressof(arr));
95+
printf("Extent 0 (rows): %llu\\n", arr.extent(0));
96+
printf("Extent 1 (cols): %llu\\n", arr.extent(1));
97+
printf("Extent 0 (rows) (actual): %llu\\n", (size_t)(*((char*)(&arr) + 8)));
98+
printf("Extent 1 (cols) (actual): %llu\\n", (size_t)(*((char*)(&arr) + 16)));
5699
printf("Size: %zu\\n", arr.size());
57100
58101
// For layout_right, strides are implicit but we can query them
59-
printf("Stride 0: %zu\\n", arr.stride(0));
60-
printf("Stride 1: %zu\\n", arr.stride(1));
102+
printf("Stride 0: %llu\\n", arr.stride(0));
103+
printf("Stride 1: %llu\\n", arr.stride(1));
104+
printf("Stride 0 (actual): %llu\\n", (size_t)((char*)(&arr) + 24));
105+
printf("Stride 1 (actual): %llu\\n", (size_t)((char*)(&arr) + 32));
61106
62107
// Verify memory layout: for layout_right (C-order)
63108
// stride(0) should equal extent(1), stride(1) should be 1
@@ -162,10 +207,13 @@ def prepare_mdspan_args_layout_right(arr, dtype, shape):
162207
tuple
163208
Arguments to pass to the kernel (needs investigation)
164209
"""
165-
data_ptr = arr.data.ptr
166-
rows, cols = shape
167-
# TODO: Determine exact argument structure
168-
return (data_ptr, rows, cols)
210+
#obj = arr.mdspan
211+
#print(f"{hex(obj.ptr)=}, {obj.ptr=}")
212+
#return (obj.ptr,)
213+
214+
obj = StridedMemoryView(arr, stream_ptr=-1).as_mdspan
215+
print(f"{hex(obj._ptr)=}, {obj._ptr=}, type={type(obj)}")
216+
return (obj,)
169217

170218

171219
def prepare_mdspan_args_layout_left(arr, dtype, shape):
@@ -266,7 +314,7 @@ def verify_layout_right():
266314

267315
# Verify array is in C-order
268316
assert arr.flags['C_CONTIGUOUS']
269-
317+
print(f"Array pointer: {hex(arr.data.ptr)}")
270318
print(f"Array shape: {arr.shape}")
271319
print(f"Array strides (bytes): {arr.strides}")
272320
print(f"Array strides (elements): ({arr.strides[0]//arr.itemsize}, {arr.strides[1]//arr.itemsize})")
@@ -282,8 +330,8 @@ def verify_layout_right():
282330
config = LaunchConfig(grid=1, block=1)
283331

284332
# TODO: Launch kernel with proper mdspan arguments
285-
# launch(s, config, ker, *args)
286-
# s.sync()
333+
launch(s, config, ker, *args)
334+
s.sync()
287335

288336
print("Verification kernel prepared (not executed)")
289337
print()

0 commit comments

Comments
 (0)