diff --git a/firedrake/mesh.py b/firedrake/mesh.py index f0ec0debdf..268b38b348 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -3017,10 +3017,13 @@ def curve_field(self, order, permutation_tol=1e-8, cg_field=None): fiat_element = new_coordinates.function_space().finat_element.fiat_equivalent nodes = fiat_element.dual_basis() ref_pts = [] - for node in nodes: - # Assert singleton point for each node. - pt, = node.get_point_dict().keys() - ref_pts.append(pt) + entity_ids = fiat_element.entity_dofs() + for dim in sorted(entity_ids): + for entity in sorted(entity_ids[dim]): + for i in entity_ids[dim][entity]: + # Assert singleton point for each node. + pt, = nodes[i].get_point_dict().keys() + ref_pts.append(pt) reference_points = np.array(ref_pts) # Construct numpy arrays for physical domain data @@ -3030,8 +3033,8 @@ def curve_field(self, order, permutation_tol=1e-8, cg_field=None): curved_points = np.zeros( (ng_dimension, reference_points.shape[0], self.geometric_dimension) ) + self.netgen_mesh.Curve(1) self.netgen_mesh.CalcElementMapping(reference_points, physical_points) - # NOTE: This will segfault for MeshHierarchy on a netgen CSG geometry self.netgen_mesh.Curve(order) self.netgen_mesh.CalcElementMapping(reference_points, curved_points) curved = ng_element.NumPy()["curved"] @@ -3056,7 +3059,6 @@ def curve_field(self, order, permutation_tol=1e-8, cg_field=None): permutation = find_permutation( own_physical_points, new_coordinates.dat.data_ro_with_halos[broken_indices].real, - tol=permutation_tol, ) self.comm.Barrier() # Apply the permutation to each cell in turn diff --git a/firedrake/netgen.py b/firedrake/netgen.py index 57b442b83a..178dbb6a6d 100644 --- a/firedrake/netgen.py +++ b/firedrake/netgen.py @@ -80,8 +80,7 @@ def netgen_distribute(V: firedrake.functionspaceimpl.WithGeometryBase, @PETSc.Log.EventDecorator() -def find_permutation(points_a: np.ndarray, points_b: np.ndarray, - tol: float = 1e-5): +def find_permutation(points_a: np.ndarray, points_b: np.ndarray): """ Find all permutations between a list of two sets of points. Given two numpy arrays of shape (ncells, npoints, dim) containing @@ -95,7 +94,18 @@ def find_permutation(points_a: np.ndarray, points_b: np.ndarray, if points_a.shape != points_b.shape: raise ValueError("`points_a` and `points_b` must have the same shape.") - p = [np.where(cdist(a, b).T < tol)[1] for a, b in zip(points_a, points_b)] + dim = points_a.shape[-1] + vids = list(range(dim+1)) + + bs = points_a[:, vids[0], :] + As = points_a[:, vids[1:], :] + As -= bs[:, None, :] + Ainvs = np.linalg.inv(As) + + ref_points_a = np.matmul(points_a - bs[:, None, :], Ainvs) + ref_points_b = np.matmul(points_b - bs[:, None, :], Ainvs) + + p = [np.argmin(cdist(a, b), axis=0) for a, b in zip(ref_points_a, ref_points_b)] if len(p) == 0: return p