1- # Test that the Numba voidptr -> typed pointer caster works in ffcx utils
1+ # Test that the Numba voidptr -> typed pointer caster factory works in ffcx utils
22import ctypes
33
44import numpy as np
55import pytest
66
7- from ffcx .codegeneration .utils import (
8- numba_ufcx_kernel_signature ,
9- voidptr_to_float64_ptr ,
10- voidptr_to_int32_ptr ,
11- )
7+ import ffcx .codegeneration .utils as codegen_utils
128
139# Skip the tests if Numba is not available in the environment.
1410numba = pytest .importorskip ("numba" )
1511
16-
17- def test_numba_voidptr_caster_basic ():
18- """Simple test: Numba cfunc reads a double from custom_data via the caster."""
19- sig = numba_ufcx_kernel_signature (np .float64 , np .float64 )
20-
21- @numba .cfunc (sig , nopython = True )
22- def tabulate (b_ , w_ , c_ , coords_ , local_index , orientation , custom_data ):
23- b = numba .carray (b_ , (1 ,), dtype = np .float64 )
24- # Cast void* to float64*
25- typed = voidptr_to_float64_ptr (custom_data )
26- b [0 ] = typed [0 ]
27-
28- # Prepare arguments
29- b = np .zeros (1 , dtype = np .float64 )
30- w = np .zeros (1 , dtype = np .float64 )
31- c = np .zeros (1 , dtype = np .float64 )
32- coords = np .zeros (9 , dtype = np .float64 )
33- local_index = np .array ([0 ], dtype = np .int32 )
34- orientation = np .array ([0 ], dtype = np .uint8 )
35-
36- # custom_data: single double value
37- val = np .array ([2.5 ], dtype = np .float64 )
38- val_ptr = val .ctypes .data
39-
40- # Call the compiled cfunc via ctypes
41- tabulate .ctypes (
42- b .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
43- w .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
44- c .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
45- coords .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
46- local_index .ctypes .data_as (ctypes .POINTER (ctypes .c_int )),
47- orientation .ctypes .data_as (ctypes .POINTER (ctypes .c_uint8 )),
48- ctypes .c_void_p (val_ptr ),
49- )
50-
51- assert b [0 ] == pytest .approx (2.5 )
52-
53-
54- def test_numba_voidptr_caster_int32 ():
55- """Test casting void* to int32* and reading an integer value."""
56- sig = numba_ufcx_kernel_signature (np .float64 , np .float64 )
57-
58- @numba .cfunc (sig , nopython = True )
59- def tabulate (b_ , w_ , c_ , coords_ , local_index , orientation , custom_data ):
60- b = numba .carray (b_ , (1 ,), dtype = np .float64 )
61- typed = voidptr_to_int32_ptr (custom_data )
62- # Promote int32 to float64 for the output
63- b [0 ] = typed [0 ]
64-
65- b = np .zeros (1 , dtype = np .float64 )
66- w = np .zeros (1 , dtype = np .float64 )
67- c = np .zeros (1 , dtype = np .float64 )
68- coords = np .zeros (9 , dtype = np .float64 )
69- local_index = np .array ([0 ], dtype = np .int32 )
70- orientation = np .array ([0 ], dtype = np .uint8 )
71-
72- val = np .array ([7 ], dtype = np .int32 )
73- val_ptr = val .ctypes .data
74-
75- tabulate .ctypes (
76- b .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
77- w .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
78- c .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
79- coords .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
80- local_index .ctypes .data_as (ctypes .POINTER (ctypes .c_int )),
81- orientation .ctypes .data_as (ctypes .POINTER (ctypes .c_uint8 )),
82- ctypes .c_void_p (val_ptr ),
83- )
84-
85- assert b [0 ] == pytest .approx (7.0 )
86-
87-
88- def test_numba_voidptr_caster_multiple_params ():
89- """Test reading multiple float64 parameters from custom_data."""
90- sig = numba_ufcx_kernel_signature (np .float64 , np .float64 )
91-
92- @numba .cfunc (sig , nopython = True )
93- def tabulate (b_ , w_ , c_ , coords_ , local_index , orientation , custom_data ):
94- b = numba .carray (b_ , (1 ,), dtype = np .float64 )
95- typed = voidptr_to_float64_ptr (custom_data )
96- b [0 ] = typed [0 ] + typed [1 ] + typed [2 ]
97-
98- b = np .zeros (1 , dtype = np .float64 )
99- w = np .zeros (1 , dtype = np .float64 )
100- c = np .zeros (1 , dtype = np .float64 )
101- coords = np .zeros (9 , dtype = np .float64 )
102- local_index = np .array ([0 ], dtype = np .int32 )
103- orientation = np .array ([0 ], dtype = np .uint8 )
104-
105- vals = np .array ([1.5 , 2.0 , 3.0 ], dtype = np .float64 )
106- vals_ptr = vals .ctypes .data
107-
108- tabulate .ctypes (
109- b .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
110- w .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
111- c .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
112- coords .ctypes .data_as (ctypes .POINTER (ctypes .c_double )),
113- local_index .ctypes .data_as (ctypes .POINTER (ctypes .c_int )),
114- orientation .ctypes .data_as (ctypes .POINTER (ctypes .c_uint8 )),
115- ctypes .c_void_p (vals_ptr ),
116- )
117-
118- assert b [0 ] == pytest .approx (6.5 )
12+ float64_ptr_caster = codegen_utils ._create_voidptr_to_dtype_ptr_caster (
13+ numba .types .float64
14+ )
15+ int32_ptr_caster = codegen_utils ._create_voidptr_to_dtype_ptr_caster (numba .types .int32 )
11916
12017
12118def test_numba_voidptr_struct_like_mixed_types ():
@@ -125,13 +22,13 @@ def test_numba_voidptr_struct_like_mixed_types():
12522 ('id', int32) with padding to align to 16 bytes. The kernel casts the
12623 void* to float64* and int32* and reads the corresponding offsets.
12724 """
128- sig = numba_ufcx_kernel_signature (np .float64 , np .float64 )
25+ sig = codegen_utils . numba_ufcx_kernel_signature (np .float64 , np .float64 )
12926
13027 @numba .cfunc (sig , nopython = True )
13128 def tabulate (b_ , w_ , c_ , coords_ , local_index , orientation , custom_data ):
13229 b = numba .carray (b_ , (1 ,), dtype = np .float64 )
133- fptr = voidptr_to_float64_ptr (custom_data )
134- iptr = voidptr_to_int32_ptr (custom_data )
30+ fptr = float64_ptr_caster (custom_data )
31+ iptr = int32_ptr_caster (custom_data )
13532 scale = fptr [0 ]
13633 # int32 index for offset 8 bytes == 8/4 == 2
13734 id_val = iptr [2 ]
@@ -144,7 +41,7 @@ def tabulate(b_, w_, c_, coords_, local_index, orientation, custom_data):
14441 local_index = np .array ([0 ], dtype = np .int32 )
14542 orientation = np .array ([0 ], dtype = np .uint8 )
14643
147- # structured dtype: float64 at offset 0, int32 at offset 8, with C-compatible alignment
44+ # structured dtype with C-compatible alignment
14845 dtype = np .dtype ([("scale" , np .float64 ), ("id" , np .int32 )], align = True )
14946 arr = np .zeros (1 , dtype = dtype )
15047 arr ["scale" ][0 ] = 1.25
0 commit comments