33import ufl
44import numpy .typing as npt
55import numpy as np
6+ from packaging .version import Version
7+ from ffcx .ir .elementtables import (
8+ permute_quadrature_interval ,
9+ permute_quadrature_triangle ,
10+ permute_quadrature_quadrilateral ,
11+ )
612
713__all__ = ["interpolate_function_onto_facet_dofs" ]
814
915
16+ def build_quadrature_permutations (facet_type , points ):
17+ if Version (dolfinx .__version__ ) < Version ("0.11.0.dev0" ):
18+ # In older versions of dolfinx, the permutation is handled internally
19+ # in the C++ code, so we can just return the original points
20+ if facet_type == basix .CellType .interval :
21+ num_permutations = 2
22+ elif facet_type == basix .CellType .triangle :
23+ num_permutations = 6
24+ elif facet_type == basix .CellType .quadrilateral :
25+ num_permutations = 8
26+ else :
27+ raise ValueError (f"Unsupported { facet_type = } " )
28+ return [points for _ in range (num_permutations )]
29+ else :
30+ if facet_type == basix .CellType .interval :
31+ return [permute_quadrature_interval (points , ref ) for ref in range (2 )]
32+ elif facet_type == basix .CellType .triangle :
33+ perms = []
34+ # FFCx order: rot is outer loop, ref is inner loop
35+ for rot in range (3 ):
36+ for ref in range (2 ):
37+ # Counteract the mapping with the inverse permutation
38+ rot_inv = (3 - rot ) % 3 if ref == 0 else rot
39+ perms .append (permute_quadrature_triangle (points , ref , rot_inv ))
40+ return perms
41+
42+ elif facet_type == basix .CellType .quadrilateral :
43+ perms = []
44+ # FFCx order: rot is outer loop, ref is inner loop
45+ for rot in range (4 ):
46+ for ref in range (2 ):
47+ # Counteract the mapping with the inverse permutation
48+ rot_inv = (4 - rot ) % 4 if ref == 0 else rot
49+ perms .append (permute_quadrature_quadrilateral (points , ref , rot_inv ))
50+ return perms
51+
52+ else :
53+ raise ValueError (f"Unsupported { facet_type = } " )
54+
55+
1056def interpolate_function_onto_facet_dofs (
1157 Q : dolfinx .fem .FunctionSpace ,
1258 expr : ufl .core .expr .Expr ,
@@ -49,17 +95,17 @@ def interpolate_function_onto_facet_dofs(
4995 )
5096 ref_top = ref_cmap .reference_topology
5197 ref_geom = ref_cmap .reference_geometry
98+ facet_type = facet_types .pop ()
5299 facet_cmap = basix .ufl .element (
53100 "Lagrange" ,
54- facet_types . pop () ,
101+ facet_type ,
55102 1 ,
56103 shape = (domain .topology .dim ,),
57104 dtype = np .float64 ,
58105 )
59106 facet_cel = dolfinx .fem .CoordinateElement (
60107 dolfinx .cpp .fem .CoordinateElement_float64 (facet_cmap .basix_element ._e )
61108 )
62-
63109 reference_facet_points = None
64110 for i , points in enumerate (interpolation_points [fdim ]):
65111 geom = ref_geom [ref_top [fdim ][i ]]
@@ -71,9 +117,10 @@ def interpolate_function_onto_facet_dofs(
71117 else :
72118 assert np .allclose (reference_facet_points , ref_points )
73119 assert reference_facet_points is not None
74- # Create expression for BC
75- normal_expr = dolfinx .fem .Expression (expr , reference_facet_points )
76-
120+ facet_points = build_quadrature_permutations (facet_type , reference_facet_points )
121+ expressions = []
122+ for i , perm in enumerate (facet_points ):
123+ expressions .append (dolfinx .fem .Expression (expr , perm ))
77124 points_per_entity = [sum (ip .shape [0 ] for ip in ips ) for ips in interpolation_points ]
78125 offsets = np .zeros (domain .topology .dim + 2 , dtype = np .int32 )
79126 offsets [1 :] = np .cumsum (points_per_entity [: domain .topology .dim + 1 ])
@@ -85,14 +132,22 @@ def interpolate_function_onto_facet_dofs(
85132 all_connected_cells = dolfinx .mesh .compute_incident_entities (
86133 domain .topology , facets , domain .topology .dim - 1 , domain .topology .dim
87134 )
88- values = np .zeros (len (all_connected_cells ) * offsets [- 1 ] * domain .geometry .dim )
135+ expr_value_size = expressions [0 ].value_size
136+
137+ # Update array allocations to use the exact expression value size
138+ values_per_entity = np .zeros ((offsets [- 1 ], expr_value_size ), dtype = dolfinx .default_scalar_type )
139+ values = np .zeros (len (all_connected_cells ) * offsets [- 1 ] * expr_value_size )
140+
89141 domain .topology .create_connectivity (domain .topology .dim , fdim )
90142 c_to_f = domain .topology .connectivity (domain .topology .dim , fdim )
91143 num_facets_on_process = (
92144 domain .topology .index_map (fdim ).size_local + domain .topology .index_map (fdim ).num_ghosts
93145 )
94146 is_marked = np .zeros (num_facets_on_process , dtype = np .int8 )
95147 is_marked [facets ] = 1
148+ domain .topology .create_entity_permutations ()
149+ num_facets_per_cell = dolfinx .cpp .mesh .cell_num_entities (domain .topology .cell_type , fdim )
150+ facet_permutations = domain .topology .get_facet_permutations ().reshape (- 1 , num_facets_per_cell )
96151 for i , cell in enumerate (all_connected_cells ):
97152 values_per_entity [:] = 0.0
98153 local_facets = c_to_f .links (cell )
@@ -102,24 +157,24 @@ def interpolate_function_onto_facet_dofs(
102157 insert_pos = offsets [fdim ] + reference_facet_points .shape [0 ] * j
103158 # Backwards compatibility
104159 entity = np .array ([[cell , j ]], dtype = np .int32 )
160+ perm = facet_permutations [cell , j ]
105161 try :
106- normal_on_facet = normal_expr .eval (domain , entity )
162+ normal_on_facet = expressions [ perm ] .eval (domain , entity )
107163 except (AttributeError , AssertionError ):
108- normal_on_facet = normal_expr .eval (domain , entity .flatten ())
164+ normal_on_facet = expressions [ perm ] .eval (domain , entity .flatten ())
109165 # NOTE: evaluate within loop to avoid large memory requirements
110166 values_per_entity [insert_pos : insert_pos + reference_facet_points .shape [0 ]] = (
111- normal_on_facet .reshape (- 1 , domain . geometry . dim )
167+ normal_on_facet .reshape (- 1 , expr_value_size )
112168 )
113- values [
114- i * offsets [ - 1 ] * domain . geometry . dim : ( i + 1 ) * offsets [ - 1 ] * domain . geometry . dim
115- ] = values_per_entity . reshape ( - 1 )
169+ values [i * offsets [ - 1 ] * expr_value_size : ( i + 1 ) * offsets [ - 1 ] * expr_value_size ] = (
170+ values_per_entity . reshape ( - 1 )
171+ )
116172
117173 qh = dolfinx .fem .Function (Q )
118174 if hasattr (qh ._cpp_object , "interpolate_f" ):
119175 interpolate_func = qh ._cpp_object .interpolate_f
120176 else :
121177 interpolate_func = qh ._cpp_object .interpolate
122- interpolate_func (values .reshape (- 1 , domain . geometry . dim ).T .copy (), all_connected_cells )
178+ interpolate_func (values .reshape (- 1 , expr_value_size ).T .copy (), all_connected_cells )
123179 qh .x .scatter_forward ()
124-
125180 return qh
0 commit comments