Skip to content

Commit d667d3e

Browse files
committed
address comments
1 parent c2ac738 commit d667d3e

2 files changed

Lines changed: 11 additions & 120 deletions

File tree

ffcx/codegeneration/utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def codegen(context, builder, signature, args):
160160
sig = numba.types.voidptr(arr)
161161
return sig, codegen
162162

163-
def create_voidptr_to_dtype_ptr_caster(
163+
def _create_voidptr_to_dtype_ptr_caster(
164164
target_dtype: numba.types.Type,
165165
) -> numba.extending.intrinsic:
166166
"""Factory that creates a Numba intrinsic casting void* to CPointer(target_dtype).
@@ -212,9 +212,3 @@ def codegen(context, builder, signature, args):
212212
return sig, codegen
213213

214214
return voidptr_to_dtype_ptr
215-
216-
# Pre-built casters for common types (convenience wrappers)
217-
voidptr_to_float64_ptr = create_voidptr_to_dtype_ptr_caster(numba.types.float64)
218-
voidptr_to_float32_ptr = create_voidptr_to_dtype_ptr_caster(numba.types.float32)
219-
voidptr_to_int32_ptr = create_voidptr_to_dtype_ptr_caster(numba.types.int32)
220-
voidptr_to_int64_ptr = create_voidptr_to_dtype_ptr_caster(numba.types.int64)

test/test_numba_custom_data.py

Lines changed: 10 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,18 @@
1-
# Test that the Numba voidptr -> typed pointer caster works in ffcx utils
1+
# Test that the Numba voidptr -> typed pointer caster factory works in ffcx utils
22
import ctypes
33

44
import numpy as np
55
import pytest
66

7-
from ffcx.codegeneration.utils import (
8-
numba_ufcx_kernel_signature,
9-
voidptr_to_float64_ptr,
10-
voidptr_to_int32_ptr,
11-
)
7+
import ffcx.codegeneration.utils as codegen_utils
128

139
# Skip the tests if Numba is not available in the environment.
1410
numba = pytest.importorskip("numba")
1511

16-
17-
def test_numba_voidptr_caster_basic():
18-
"""Simple test: Numba cfunc reads a double from custom_data via the caster."""
19-
sig = numba_ufcx_kernel_signature(np.float64, np.float64)
20-
21-
@numba.cfunc(sig, nopython=True)
22-
def tabulate(b_, w_, c_, coords_, local_index, orientation, custom_data):
23-
b = numba.carray(b_, (1,), dtype=np.float64)
24-
# Cast void* to float64*
25-
typed = voidptr_to_float64_ptr(custom_data)
26-
b[0] = typed[0]
27-
28-
# Prepare arguments
29-
b = np.zeros(1, dtype=np.float64)
30-
w = np.zeros(1, dtype=np.float64)
31-
c = np.zeros(1, dtype=np.float64)
32-
coords = np.zeros(9, dtype=np.float64)
33-
local_index = np.array([0], dtype=np.int32)
34-
orientation = np.array([0], dtype=np.uint8)
35-
36-
# custom_data: single double value
37-
val = np.array([2.5], dtype=np.float64)
38-
val_ptr = val.ctypes.data
39-
40-
# Call the compiled cfunc via ctypes
41-
tabulate.ctypes(
42-
b.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
43-
w.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
44-
c.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
45-
coords.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
46-
local_index.ctypes.data_as(ctypes.POINTER(ctypes.c_int)),
47-
orientation.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)),
48-
ctypes.c_void_p(val_ptr),
49-
)
50-
51-
assert b[0] == pytest.approx(2.5)
52-
53-
54-
def test_numba_voidptr_caster_int32():
55-
"""Test casting void* to int32* and reading an integer value."""
56-
sig = numba_ufcx_kernel_signature(np.float64, np.float64)
57-
58-
@numba.cfunc(sig, nopython=True)
59-
def tabulate(b_, w_, c_, coords_, local_index, orientation, custom_data):
60-
b = numba.carray(b_, (1,), dtype=np.float64)
61-
typed = voidptr_to_int32_ptr(custom_data)
62-
# Promote int32 to float64 for the output
63-
b[0] = typed[0]
64-
65-
b = np.zeros(1, dtype=np.float64)
66-
w = np.zeros(1, dtype=np.float64)
67-
c = np.zeros(1, dtype=np.float64)
68-
coords = np.zeros(9, dtype=np.float64)
69-
local_index = np.array([0], dtype=np.int32)
70-
orientation = np.array([0], dtype=np.uint8)
71-
72-
val = np.array([7], dtype=np.int32)
73-
val_ptr = val.ctypes.data
74-
75-
tabulate.ctypes(
76-
b.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
77-
w.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
78-
c.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
79-
coords.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
80-
local_index.ctypes.data_as(ctypes.POINTER(ctypes.c_int)),
81-
orientation.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)),
82-
ctypes.c_void_p(val_ptr),
83-
)
84-
85-
assert b[0] == pytest.approx(7.0)
86-
87-
88-
def test_numba_voidptr_caster_multiple_params():
89-
"""Test reading multiple float64 parameters from custom_data."""
90-
sig = numba_ufcx_kernel_signature(np.float64, np.float64)
91-
92-
@numba.cfunc(sig, nopython=True)
93-
def tabulate(b_, w_, c_, coords_, local_index, orientation, custom_data):
94-
b = numba.carray(b_, (1,), dtype=np.float64)
95-
typed = voidptr_to_float64_ptr(custom_data)
96-
b[0] = typed[0] + typed[1] + typed[2]
97-
98-
b = np.zeros(1, dtype=np.float64)
99-
w = np.zeros(1, dtype=np.float64)
100-
c = np.zeros(1, dtype=np.float64)
101-
coords = np.zeros(9, dtype=np.float64)
102-
local_index = np.array([0], dtype=np.int32)
103-
orientation = np.array([0], dtype=np.uint8)
104-
105-
vals = np.array([1.5, 2.0, 3.0], dtype=np.float64)
106-
vals_ptr = vals.ctypes.data
107-
108-
tabulate.ctypes(
109-
b.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
110-
w.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
111-
c.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
112-
coords.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
113-
local_index.ctypes.data_as(ctypes.POINTER(ctypes.c_int)),
114-
orientation.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)),
115-
ctypes.c_void_p(vals_ptr),
116-
)
117-
118-
assert b[0] == pytest.approx(6.5)
12+
float64_ptr_caster = codegen_utils._create_voidptr_to_dtype_ptr_caster(
13+
numba.types.float64
14+
)
15+
int32_ptr_caster = codegen_utils._create_voidptr_to_dtype_ptr_caster(numba.types.int32)
11916

12017

12118
def test_numba_voidptr_struct_like_mixed_types():
@@ -125,13 +22,13 @@ def test_numba_voidptr_struct_like_mixed_types():
12522
('id', int32) with padding to align to 16 bytes. The kernel casts the
12623
void* to float64* and int32* and reads the corresponding offsets.
12724
"""
128-
sig = numba_ufcx_kernel_signature(np.float64, np.float64)
25+
sig = codegen_utils.numba_ufcx_kernel_signature(np.float64, np.float64)
12926

13027
@numba.cfunc(sig, nopython=True)
13128
def tabulate(b_, w_, c_, coords_, local_index, orientation, custom_data):
13229
b = numba.carray(b_, (1,), dtype=np.float64)
133-
fptr = voidptr_to_float64_ptr(custom_data)
134-
iptr = voidptr_to_int32_ptr(custom_data)
30+
fptr = float64_ptr_caster(custom_data)
31+
iptr = int32_ptr_caster(custom_data)
13532
scale = fptr[0]
13633
# int32 index for offset 8 bytes == 8/4 == 2
13734
id_val = iptr[2]
@@ -144,7 +41,7 @@ def tabulate(b_, w_, c_, coords_, local_index, orientation, custom_data):
14441
local_index = np.array([0], dtype=np.int32)
14542
orientation = np.array([0], dtype=np.uint8)
14643

147-
# structured dtype: float64 at offset 0, int32 at offset 8, with C-compatible alignment
44+
# structured dtype with C-compatible alignment
14845
dtype = np.dtype([("scale", np.float64), ("id", np.int32)], align=True)
14946
arr = np.zeros(1, dtype=dtype)
15047
arr["scale"][0] = 1.25

0 commit comments

Comments
 (0)