Skip to content

Commit 9d39035

Browse files
committed
review comments
1 parent 75f5c95 commit 9d39035

4 files changed

Lines changed: 57 additions & 12 deletions

File tree

firedrake/cython/dmcommon.pyx

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,11 +2203,11 @@ def _get_expanded_dm_dg_coords(dm: PETSc.DM, ndofs: np.ndarray):
22032203

22042204
def _get_periodicity(dm: PETSc.DM) -> tuple[tuple[bool, bool], ...]:
22052205
"""Return mesh periodicity information.
2206-
2206+
22072207
This function returns a 2-tuple of bools per dimension where the first entry indicates
22082208
whether the mesh is periodic in that dimension, and the second indicates whether the
22092209
mesh is single-cell periodic in that dimension.
2210-
2210+
22112211
"""
22122212
cdef:
22132213
const PetscReal *maxCell, *L
@@ -4325,3 +4325,41 @@ def get_dm_cell_types(PETSc.DM dm):
43254325
return tuple(
43264326
polytope_type_enum for polytope_type_enum, found in enumerate(found_all) if found
43274327
)
4328+
4329+
4330+
def create_label_intersection(PETSc.DM dm, label_name, label_values):
4331+
"""Return the intersection of the closure of a subdomains of a DMPlex.
4332+
4333+
Parameters
4334+
----------
4335+
dm : PETSc.DM
4336+
The DMPlex.
4337+
label_name : str
4338+
The name of the label
4339+
label_values : Sequence[int]
4340+
The values of the subdomain label to intersect
4341+
4342+
Returns
4343+
-------
4344+
tuple
4345+
A PETSc.IS with the points in the intersection.
4346+
4347+
"""
4348+
cdef:
4349+
PETSc.DMLabel label
4350+
PETSc.PetscIS is1, is2
4351+
PetscInt val = label_values[0]
4352+
4353+
label = dm.getLabel(label_name)
4354+
CHKERR(DMPlexLabelComplete(dm.dm, label.dmlabel))
4355+
CHKERR(DMLabelGetStratumIS(<DMLabel>label.dmlabel, val, &is1))
4356+
4357+
for i in range(1, len(label_values)):
4358+
iout = PETSc.IS()
4359+
val = label_values[i]
4360+
CHKERR(DMLabelGetStratumIS(<DMLabel>label.dmlabel, val, &is2))
4361+
CHKERR(ISIntersect(is1, is2, &(<PETSc.IS?>iout).iset))
4362+
CHKERR(ISDestroy(&is1))
4363+
CHKERR(ISDestroy(&is2))
4364+
is1 = (<PETSc.IS?>iout).iset
4365+
return iout

firedrake/cython/petschdr.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ cdef extern from "petscis.h" nogil:
142142
PetscErrorCode ISLocalToGlobalMappingGetBlockIndices(PETSc.PetscLGMap, const PetscInt**)
143143
PetscErrorCode ISLocalToGlobalMappingRestoreBlockIndices(PETSc.PetscLGMap, const PetscInt**)
144144
PetscErrorCode ISDestroy(PETSc.PetscIS*)
145+
PetscErrorCode ISIntersect(PETSc.PetscIS, PETSc.PetscIS, PETSc.PetscIS*)
145146

146147
cdef extern from "petscsf.h" nogil:
147148
struct PetscSFNode_:

firedrake/mesh.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4863,6 +4863,8 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
48634863
raise NotImplementedError("Can not create a submesh of a ``VertexOnlyMesh``")
48644864
if subdim is None:
48654865
subdim = mesh.topological_dimension
4866+
if subdomain_id == "on_boundary":
4867+
subdim = subdim - 1
48664868
plex = mesh.topology_dm
48674869
dim = plex.getDimension()
48684870
if subdim not in {dim, dim - 1}:
@@ -4880,20 +4882,24 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
48804882
label_name = dmcommon.FACE_SETS_LABEL
48814883

48824884
# Parse non-integer subdomain_id
4883-
if subdomain_id == "on_boundary":
4884-
subdomain_id = tuple(mesh.exterior_facets.unique_markers)
4885+
if isinstance(subdomain_id, str):
4886+
if subdomain_id == "on_boundary":
4887+
subdomain_id = tuple(mesh.exterior_facets.unique_markers)
4888+
else:
4889+
raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.")
48854890

48864891
if isinstance(subdomain_id, Sequence):
48874892
# Create a temporary DMLabel with the union of the labels in the list
48884893
icomm = comm or mesh.comm
48894894
iset = PETSc.IS().createGeneral([], comm=icomm)
48904895
for sub in subdomain_id:
4896+
try:
4897+
sub, = sub
4898+
except ValueError:
4899+
pass
48914900
if isinstance(sub, Sequence):
48924901
# Take the intersection of the (closure of the) labels from nested lists
4893-
ises = [plex.getStratumIS(label_name, subi) for subi in sub]
4894-
closure = [[plex.getTransitiveClosure(p)[0] for p in i.indices] for i in ises]
4895-
indices = reduce(np.intersect1d, closure)
4896-
cur = PETSc.IS().createGeneral(indices, comm=icomm)
4902+
cur = dmcommon.create_label_intersection(plex, label_name, sub)
48974903
else:
48984904
cur = plex.getStratumIS(label_name, sub)
48994905
iset = iset.union(cur)

tests/firedrake/submesh/test_submesh_interface.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from firedrake import *
44

55

6-
def test_submesh_subdomain_id_tuple():
6+
def test_submesh_subdomain_id_union():
77
mesh = UnitSquareMesh(4, 4)
88
x, y = SpatialCoordinate(mesh)
99
M = FunctionSpace(mesh, "DG", 0)
@@ -25,7 +25,7 @@ def test_submesh_subdomain_id_tuple():
2525
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)
2626

2727

28-
def test_submesh_subdomain_id_nested_tuple():
28+
def test_submesh_subdomain_id_intersection():
2929
mesh = UnitSquareMesh(4, 4)
3030
x, y = SpatialCoordinate(mesh)
3131
M = FunctionSpace(mesh, "DG", 0)
@@ -48,7 +48,7 @@ def test_submesh_subdomain_id_nested_tuple():
4848

4949

5050
@pytest.mark.parametrize("subdomain_id", ["on_boundary", (1, 3, 6)])
51-
def test_submesh_facet_subdomain_id_tuple(subdomain_id):
51+
def test_submesh_facet_subdomain_id_union(subdomain_id):
5252
mesh = UnitCubeMesh(2, 2, 2)
5353
submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id)
5454
if subdomain_id == "on_boundary":
@@ -67,7 +67,7 @@ def test_submesh_facet_subdomain_id_tuple(subdomain_id):
6767
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)
6868

6969

70-
def test_submesh_facet_subdomain_id_nested_tuple():
70+
def test_submesh_facet_subdomain_id_intersection():
7171
mesh = UnitSquareMesh(4, 4)
7272
x, y = SpatialCoordinate(mesh)
7373
M = FunctionSpace(mesh, "DG", 0)

0 commit comments

Comments
 (0)