Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3017,10 +3017,13 @@ def curve_field(self, order, permutation_tol=1e-8, cg_field=None):
fiat_element = new_coordinates.function_space().finat_element.fiat_equivalent
nodes = fiat_element.dual_basis()
ref_pts = []
for node in nodes:
# Assert singleton point for each node.
pt, = node.get_point_dict().keys()
ref_pts.append(pt)
entity_ids = fiat_element.entity_dofs()
for dim in sorted(entity_ids):
for entity in sorted(entity_ids[dim]):
for i in entity_ids[dim][entity]:
# Assert singleton point for each node.
pt, = nodes[i].get_point_dict().keys()
ref_pts.append(pt)
reference_points = np.array(ref_pts)

# Construct numpy arrays for physical domain data
Expand All @@ -3030,8 +3033,8 @@ def curve_field(self, order, permutation_tol=1e-8, cg_field=None):
curved_points = np.zeros(
(ng_dimension, reference_points.shape[0], self.geometric_dimension)
)
self.netgen_mesh.Curve(1)
self.netgen_mesh.CalcElementMapping(reference_points, physical_points)
# NOTE: This will segfault for MeshHierarchy on a netgen CSG geometry
self.netgen_mesh.Curve(order)
self.netgen_mesh.CalcElementMapping(reference_points, curved_points)
curved = ng_element.NumPy()["curved"]
Expand All @@ -3056,7 +3059,6 @@ def curve_field(self, order, permutation_tol=1e-8, cg_field=None):
permutation = find_permutation(
own_physical_points,
new_coordinates.dat.data_ro_with_halos[broken_indices].real,
tol=permutation_tol,
)
self.comm.Barrier()
# Apply the permutation to each cell in turn
Expand Down
16 changes: 13 additions & 3 deletions firedrake/netgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def netgen_distribute(V: firedrake.functionspaceimpl.WithGeometryBase,


@PETSc.Log.EventDecorator()
def find_permutation(points_a: np.ndarray, points_b: np.ndarray,
tol: float = 1e-5):
def find_permutation(points_a: np.ndarray, points_b: np.ndarray):
""" Find all permutations between a list of two sets of points.

Given two numpy arrays of shape (ncells, npoints, dim) containing
Expand All @@ -95,7 +94,18 @@ def find_permutation(points_a: np.ndarray, points_b: np.ndarray,
if points_a.shape != points_b.shape:
raise ValueError("`points_a` and `points_b` must have the same shape.")

p = [np.where(cdist(a, b).T < tol)[1] for a, b in zip(points_a, points_b)]
dim = points_a.shape[-1]
vids = list(range(dim+1))

bs = points_a[:, vids[0], :]
As = points_a[:, vids[1:], :]
As -= bs[:, None, :]
Ainvs = np.linalg.inv(As)

ref_points_a = np.matmul(points_a - bs[:, None, :], Ainvs)
ref_points_b = np.matmul(points_b - bs[:, None, :], Ainvs)

p = [np.argmin(cdist(a, b), axis=0) for a, b in zip(ref_points_a, ref_points_b)]

if len(p) == 0:
return p
Expand Down
Loading