Skip to content

Commit b8376af

Browse files
committed
Fix Netgen permutation
1 parent 74ecb96 commit b8376af

2 files changed

Lines changed: 21 additions & 9 deletions

File tree

firedrake/mesh.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3017,10 +3017,13 @@ def curve_field(self, order, permutation_tol=1e-8, cg_field=None):
30173017
fiat_element = new_coordinates.function_space().finat_element.fiat_equivalent
30183018
nodes = fiat_element.dual_basis()
30193019
ref_pts = []
3020-
for node in nodes:
3021-
# Assert singleton point for each node.
3022-
pt, = node.get_point_dict().keys()
3023-
ref_pts.append(pt)
3020+
entity_ids = fiat_element.entity_dofs()
3021+
for dim in sorted(entity_ids):
3022+
for entity in sorted(entity_ids[dim]):
3023+
for i in entity_ids[dim][entity]:
3024+
# Assert singleton point for each node.
3025+
pt, = nodes[i].get_point_dict().keys()
3026+
ref_pts.append(pt)
30243027
reference_points = np.array(ref_pts)
30253028

30263029
# Construct numpy arrays for physical domain data
@@ -3030,8 +3033,8 @@ def curve_field(self, order, permutation_tol=1e-8, cg_field=None):
30303033
curved_points = np.zeros(
30313034
(ng_dimension, reference_points.shape[0], self.geometric_dimension)
30323035
)
3036+
self.netgen_mesh.Curve(1)
30333037
self.netgen_mesh.CalcElementMapping(reference_points, physical_points)
3034-
# NOTE: This will segfault for MeshHierarchy on a netgen CSG geometry
30353038
self.netgen_mesh.Curve(order)
30363039
self.netgen_mesh.CalcElementMapping(reference_points, curved_points)
30373040
curved = ng_element.NumPy()["curved"]
@@ -3056,7 +3059,6 @@ def curve_field(self, order, permutation_tol=1e-8, cg_field=None):
30563059
permutation = find_permutation(
30573060
own_physical_points,
30583061
new_coordinates.dat.data_ro_with_halos[broken_indices].real,
3059-
tol=permutation_tol,
30603062
)
30613063
self.comm.Barrier()
30623064
# Apply the permutation to each cell in turn

firedrake/netgen.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def netgen_distribute(V: firedrake.functionspaceimpl.WithGeometryBase,
8080

8181

8282
@PETSc.Log.EventDecorator()
83-
def find_permutation(points_a: np.ndarray, points_b: np.ndarray,
84-
tol: float = 1e-5):
83+
def find_permutation(points_a: np.ndarray, points_b: np.ndarray):
8584
""" Find all permutations between a list of two sets of points.
8685
8786
Given two numpy arrays of shape (ncells, npoints, dim) containing
@@ -95,7 +94,18 @@ def find_permutation(points_a: np.ndarray, points_b: np.ndarray,
9594
if points_a.shape != points_b.shape:
9695
raise ValueError("`points_a` and `points_b` must have the same shape.")
9796

98-
p = [np.where(cdist(a, b).T < tol)[1] for a, b in zip(points_a, points_b)]
97+
dim = points_a.shape[-1]
98+
vids = list(range(dim+1))
99+
100+
bs = points_a[:, vids[0], :]
101+
As = points_a[:, vids[1:], :]
102+
As -= bs[:, None, :]
103+
Ainvs = np.linalg.inv(As)
104+
105+
ref_points_a = np.matmul(points_a - bs[:, None, :], Ainvs)
106+
ref_points_b = np.matmul(points_b - bs[:, None, :], Ainvs)
107+
108+
p = [np.argmin(cdist(a, b), axis=0) for a, b in zip(ref_points_a, ref_points_b)]
99109

100110
if len(p) == 0:
101111
return p

0 commit comments

Comments
 (0)