2121import os , sys
2222import cupy as cp
2323from cuda .core .experimental import Device , LaunchConfig , Program , ProgramOptions , launch
24+ from cuda .core .experimental .utils import StridedMemoryView
2425
2526# prepare include
2627cuda_path = os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" ))
4243code_verify = """
4344#include <cuda/std/mdspan>
4445
46+ // typedef struct {
47+ // void* ptr;
48+ // size_t ext1;
49+ // size_t ext2;
50+ // } mdspan_view_t;
51+ //
52+ //
53+ // // Kernel to verify layout_right (C-order) mdspan arguments
54+ // template<typename T>
55+ // __global__ void verify_mdspan_layout_right(
56+ // mdspan_view_t arr
57+ // ) {
58+ // // Only thread 0 prints to avoid cluttered output
59+ // if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) {
60+ // printf("=== layout_right (C-order) mdspan ===\\ n");
61+ // printf("sizeof(mdspan_view_t): %llu\\ n", sizeof(arr));
62+ // printf("view - ptr: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ptr);
63+ // printf("view2 : %p\\ n", *(void**)((char*)(&arr) + 0));
64+ // printf("view - ext1: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ext1);
65+ // printf("view - ext2: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ext2);
66+ // }
67+ // }
68+
4569// Kernel to verify layout_right (C-order) mdspan arguments
70+
71+ typedef struct {
72+ void* ptr;
73+ void* ext1;
74+ void* ext2;
75+ } mdspan_view_t;
76+
4677template<typename T>
4778__global__ void verify_mdspan_layout_right(
4879 cuda::std::mdspan<T, cuda::std::extents<size_t, cuda::std::dynamic_extent, cuda::std::dynamic_extent>, cuda::std::layout_right> arr
4980) {
5081 // Only thread 0 prints to avoid cluttered output
5182 if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) {
5283 printf("=== layout_right (C-order) mdspan ===\\ n");
84+ printf("sizeof(mdspan): %llu\\ n", sizeof(arr));
85+ printf("view - ptr: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ptr);
86+ printf("view2 : %p\\ n", (void**)((char*)(&arr) + 0));
87+ //printf("view - ext1: %llu\\ n", *((size_t*)(reinterpret_cast<mdspan_view_t*>(&arr)->ext1)));
88+ //printf("view - ext2: %llu\\ n", *((size_t*)(reinterpret_cast<mdspan_view_t*>(&arr)->ext2)));
89+ printf("view - ext1: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ext1);
90+ printf("view - ext2: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ext2);
91+
5392 printf("Data pointer: %p\\ n", arr.data_handle());
54- printf("Extent 0 (rows): %zu\\ n", arr.extent(0));
55- printf("Extent 1 (cols): %zu\\ n", arr.extent(1));
93+ printf("Data pointer (actual): %p\\ n", (void*)((char*)(&arr) + 0));
94+ printf("Data pointer (actual): %p\\ n", addressof(arr));
95+ printf("Extent 0 (rows): %llu\\ n", arr.extent(0));
96+ printf("Extent 1 (cols): %llu\\ n", arr.extent(1));
97+ printf("Extent 0 (rows) (actual): %llu\\ n", (size_t)(*((char*)(&arr) + 8)));
98+ printf("Extent 1 (cols) (actual): %llu\\ n", (size_t)(*((char*)(&arr) + 16)));
5699 printf("Size: %zu\\ n", arr.size());
57100
58101 // For layout_right, strides are implicit but we can query them
59- printf("Stride 0: %zu\\ n", arr.stride(0));
60- printf("Stride 1: %zu\\ n", arr.stride(1));
102+ printf("Stride 0: %llu\\ n", arr.stride(0));
103+ printf("Stride 1: %llu\\ n", arr.stride(1));
104+ printf("Stride 0 (actual): %llu\\ n", (size_t)((char*)(&arr) + 24));
105+ printf("Stride 1 (actual): %llu\\ n", (size_t)((char*)(&arr) + 32));
61106
62107 // Verify memory layout: for layout_right (C-order)
63108 // stride(0) should equal extent(1), stride(1) should be 1
@@ -162,10 +207,13 @@ def prepare_mdspan_args_layout_right(arr, dtype, shape):
162207 tuple
163208 Arguments to pass to the kernel (needs investigation)
164209 """
165- data_ptr = arr .data .ptr
166- rows , cols = shape
167- # TODO: Determine exact argument structure
168- return (data_ptr , rows , cols )
210+ #obj = arr.mdspan
211+ #print(f"{hex(obj.ptr)=}, {obj.ptr=}")
212+ #return (obj.ptr,)
213+
214+ obj = StridedMemoryView (arr , stream_ptr = - 1 ).as_mdspan
215+ print (f"{ hex (obj ._ptr )= } , { obj ._ptr = } , type={ type (obj )} " )
216+ return (obj ,)
169217
170218
171219def prepare_mdspan_args_layout_left (arr , dtype , shape ):
@@ -266,7 +314,7 @@ def verify_layout_right():
266314
267315 # Verify array is in C-order
268316 assert arr .flags ['C_CONTIGUOUS' ]
269-
317+ print ( f"Array pointer: { hex ( arr . data . ptr ) } " )
270318 print (f"Array shape: { arr .shape } " )
271319 print (f"Array strides (bytes): { arr .strides } " )
272320 print (f"Array strides (elements): ({ arr .strides [0 ]// arr .itemsize } , { arr .strides [1 ]// arr .itemsize } )" )
@@ -282,8 +330,8 @@ def verify_layout_right():
282330 config = LaunchConfig (grid = 1 , block = 1 )
283331
284332 # TODO: Launch kernel with proper mdspan arguments
285- # launch(s, config, ker, *args)
286- # s.sync()
333+ launch (s , config , ker , * args )
334+ s .sync ()
287335
288336 print ("Verification kernel prepared (not executed)" )
289337 print ()
0 commit comments