2121import os , sys
2222import cupy as cp
2323from cuda .core .experimental import Device , LaunchConfig , Program , ProgramOptions , launch
24+ from cuda .core .experimental .utils import StridedMemoryView
2425
2526# prepare include
2627cuda_path = os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" ))
4243code_verify = """
4344#include <cuda/std/mdspan>
4445
46+ typedef struct {
47+ void* ptr;
48+ size_t ext1;
49+ size_t ext2;
50+ } mdspan_view_t;
51+
52+
4553// Kernel to verify layout_right (C-order) mdspan arguments
4654template<typename T>
4755__global__ void verify_mdspan_layout_right(
48- cuda::std::mdspan<T, cuda::std::extents<size_t, cuda::std::dynamic_extent, cuda::std::dynamic_extent>, cuda::std::layout_right> arr
56+ mdspan_view_t arr
4957) {
5058 // Only thread 0 prints to avoid cluttered output
5159 if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) {
5260 printf("=== layout_right (C-order) mdspan ===\\ n");
53- printf("Data pointer: %p\\ n", arr.data_handle());
54- printf("Extent 0 (rows): %zu\\ n", arr.extent(0));
55- printf("Extent 1 (cols): %zu\\ n", arr.extent(1));
56- printf("Size: %zu\\ n", arr.size());
57-
58- // For layout_right, strides are implicit but we can query them
59- printf("Stride 0: %zu\\ n", arr.stride(0));
60- printf("Stride 1: %zu\\ n", arr.stride(1));
61-
62- // Verify memory layout: for layout_right (C-order)
63- // stride(0) should equal extent(1), stride(1) should be 1
64- printf("Expected stride(0) = extent(1): %s\\ n",
65- (arr.stride(0) == arr.extent(1)) ? "PASS" : "FAIL");
66- printf("Expected stride(1) = 1: %s\\ n",
67- (arr.stride(1) == 1) ? "PASS" : "FAIL");
68-
69- // Test element access
70- if (arr.extent(0) > 0 && arr.extent(1) > 0) {
71- printf("First element arr(0,0): %f\\ n", static_cast<float>(arr(0, 0)));
72- }
61+ printf("sizeof(mdspan_view_t): %llu\\ n", sizeof(arr));
62+ printf("view - ptr: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ptr);
63+ printf("view2 : %p\\ n", *(void**)((char*)(&arr) + 0));
64+ printf("view - ext1: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ext1);
65+ printf("view - ext2: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ext2);
7366 }
7467}
7568
69+ // // Kernel to verify layout_right (C-order) mdspan arguments
70+ //
71+ // typedef struct {
72+ // void* ptr;
73+ // void* ext1;
74+ // void* ext2;
75+ // } mdspan_view_t;
76+ //
77+ // template<typename T>
78+ // __global__ void verify_mdspan_layout_right(
79+ // cuda::std::mdspan<T, cuda::std::extents<size_t, cuda::std::dynamic_extent, cuda::std::dynamic_extent>, cuda::std::layout_right> arr
80+ // ) {
81+ // // Only thread 0 prints to avoid cluttered output
82+ // if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) {
83+ // printf("=== layout_right (C-order) mdspan ===\\ n");
84+ // printf("sizeof(mdspan): %llu\\ n", sizeof(arr));
85+ // printf("view - ptr: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ptr);
86+ // printf("view2 : %p\\ n", (void**)((char*)(&arr) + 0));
87+ // //printf("view - ext1: %llu\\ n", *((size_t*)(reinterpret_cast<mdspan_view_t*>(&arr)->ext1)));
88+ // //printf("view - ext2: %llu\\ n", *((size_t*)(reinterpret_cast<mdspan_view_t*>(&arr)->ext2)));
89+ // printf("view - ext1: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ext1);
90+ // printf("view - ext2: %p\\ n", reinterpret_cast<mdspan_view_t*>(&arr)->ext2);
91+ //
92+ // printf("Data pointer: %p\\ n", arr.data_handle());
93+ // printf("Data pointer (actual): %p\\ n", (void*)((char*)(&arr) + 0));
94+ // printf("Data pointer (actual): %p\\ n", addressof(arr));
95+ // printf("Extent 0 (rows): %llu\\ n", arr.extent(0));
96+ // printf("Extent 1 (cols): %llu\\ n", arr.extent(1));
97+ // printf("Extent 0 (rows) (actual): %llu\\ n", (size_t)(*((char*)(&arr) + 8)));
98+ // printf("Extent 1 (cols) (actual): %llu\\ n", (size_t)(*((char*)(&arr) + 16)));
99+ // printf("Size: %zu\\ n", arr.size());
100+ //
101+ // // For layout_right, strides are implicit but we can query them
102+ // printf("Stride 0: %llu\\ n", arr.stride(0));
103+ // printf("Stride 1: %llu\\ n", arr.stride(1));
104+ // printf("Stride 0 (actual): %llu\\ n", (size_t)((char*)(&arr) + 24));
105+ // printf("Stride 1 (actual): %llu\\ n", (size_t)((char*)(&arr) + 32));
106+ //
107+ // // Verify memory layout: for layout_right (C-order)
108+ // // stride(0) should equal extent(1), stride(1) should be 1
109+ // printf("Expected stride(0) = extent(1): %s\\ n",
110+ // (arr.stride(0) == arr.extent(1)) ? "PASS" : "FAIL");
111+ // printf("Expected stride(1) = 1: %s\\ n",
112+ // (arr.stride(1) == 1) ? "PASS" : "FAIL");
113+ //
114+ // // Test element access
115+ // if (arr.extent(0) > 0 && arr.extent(1) > 0) {
116+ // printf("First element arr(0,0): %f\\ n", static_cast<float>(arr(0, 0)));
117+ // }
118+ // }
119+ // }
120+
76121// Kernel to verify layout_left (F-order) mdspan arguments
77122template<typename T>
78123__global__ void verify_mdspan_layout_left(
@@ -162,10 +207,13 @@ def prepare_mdspan_args_layout_right(arr, dtype, shape):
162207 tuple
163208 Arguments to pass to the kernel (needs investigation)
164209 """
165- data_ptr = arr .data .ptr
166- rows , cols = shape
167- # TODO: Determine exact argument structure
168- return (data_ptr , rows , cols )
210+ #obj = arr.mdspan
211+ #print(f"{hex(obj.ptr)=}, {obj.ptr=}")
212+ #return (obj.ptr,)
213+
214+ obj = StridedMemoryView (arr , stream_ptr = - 1 ).as_mdspan
215+ print (f"{ hex (obj ._ptr )= } , { obj ._ptr = } , type={ type (obj )} " )
216+ return (obj ,)
169217
170218
171219def prepare_mdspan_args_layout_left (arr , dtype , shape ):
@@ -266,7 +314,7 @@ def verify_layout_right():
266314
267315 # Verify array is in C-order
268316 assert arr .flags ['C_CONTIGUOUS' ]
269-
317+ print ( f"Array pointer: { hex ( arr . data . ptr ) } " )
270318 print (f"Array shape: { arr .shape } " )
271319 print (f"Array strides (bytes): { arr .strides } " )
272320 print (f"Array strides (elements): ({ arr .strides [0 ]// arr .itemsize } , { arr .strides [1 ]// arr .itemsize } )" )
@@ -282,8 +330,8 @@ def verify_layout_right():
282330 config = LaunchConfig (grid = 1 , block = 1 )
283331
284332 # TODO: Launch kernel with proper mdspan arguments
285- # launch(s, config, ker, *args)
286- # s.sync()
333+ launch (s , config , ker , * args )
334+ s .sync ()
287335
288336 print ("Verification kernel prepared (not executed)" )
289337 print ()
0 commit comments