Skip to content

Commit e54aa74

Browse files
authored
Extend interpolate_geometry tests to GLL (#4223)
* Also lift to gll geometry * gll_isaac - what does it do on quads? * ruff
1 parent 3bb0e73 commit e54aa74

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

python/test/unit/fem/test_interpolate_geometry.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_interpolate_geometry_p1_roundtrip(dtype):
8080
np.testing.assert_allclose(x_new[dm_new[c]], x_old[dm_old[c]], atol=atol, rtol=0.0)
8181

8282

83-
def _curve_mesh_errors(N, degree, dtype, R, cell_type):
83+
def _curve_mesh_errors(N, degree, dtype, R, cell_type, lagrange_variant):
8484
"""Return (area_error, circumference_error) for a degree-p curved disk mesh with N cells."""
8585
mesh = create_rectangle(
8686
MPI.COMM_WORLD,
@@ -98,7 +98,7 @@ def transform(x):
9898
x_c[:, 1] = R * x[:, 1] * np.sqrt(1.0 - (x[:, 0] ** 2 / (2.0)))
9999
return x_c
100100

101-
cmap = coordinate_element(cell_type, degree, variant=LagrangeVariant.equispaced, dtype=dtype)
101+
cmap = coordinate_element(cell_type, degree, variant=lagrange_variant, dtype=dtype)
102102
curved_mesh = interpolate_geometry(mesh, cmap)
103103
curved_mesh.geometry.x[:] = transform(curved_mesh.geometry.x)
104104

@@ -131,11 +131,14 @@ def transform(x):
131131

132132
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
133133
@pytest.mark.parametrize("degree", [1, 2, 3])
134-
@pytest.mark.parametrize("R", [0.1, 1])
134+
@pytest.mark.parametrize("R", [0.1])
135135
@pytest.mark.parametrize("cell_type", [CellType.triangle, CellType.quadrilateral])
136-
def test_curve_mesh(degree, dtype, R, cell_type):
136+
@pytest.mark.parametrize(
137+
"lagrange_variant", [LagrangeVariant.equispaced, LagrangeVariant.gll_isaac]
138+
)
139+
def test_curve_mesh(degree, dtype, R, cell_type, lagrange_variant):
137140
Ns = [4, 8, 16, 32]
138-
errors = [_curve_mesh_errors(N, degree, dtype, R, cell_type) for N in Ns]
141+
errors = [_curve_mesh_errors(N, degree, dtype, R, cell_type, lagrange_variant) for N in Ns]
139142

140143
area_errors = np.array([e[0] for e in errors])
141144
circ_errors = np.array([e[1] for e in errors])

0 commit comments

Comments
 (0)