Skip to content

Commit cb24fa8

Browse files
committed
cleanup
1 parent 95b1219 commit cb24fa8

2 files changed

Lines changed: 17 additions & 27 deletions

File tree

firedrake/cython/dmcommon.pyx

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4327,37 +4327,22 @@ def get_dm_cell_types(PETSc.DM dm):
43274327
)
43284328

43294329

4330-
def create_label_intersection(PETSc.DM dm, label_name, label_values):
4331-
"""Return the intersection of the closure of subdomains of a DMPlex.
4330+
def intersectIS(PETSc.IS i1, PETSc.IS i2):
4331+
"""Return the intersection of two IS objects.
43324332
43334333
Parameters
43344334
----------
4335-
dm : PETSc.DM
4336-
The DMPlex.
4337-
label_name : str
4338-
The name of the label
4339-
label_values : Sequence[int]
4340-
The values of the subdomain label to intersect
4335+
i1 : PETSc.IS
4336+
The first IS.
4337+
i2 : PETSc.IS
4338+
The second IS.
43414339
43424340
Returns
43434341
-------
43444342
PETSc.IS
4345-
A PETSc.IS with the points in the intersection.
4343+
A PETSc.IS with the intersection.
43464344
43474345
"""
4348-
cdef:
4349-
PETSc.IS iout, i1, i2
4350-
PETSc.DMLabel label
4351-
4352-
if len(label_values) == 0:
4353-
return PETSc.IS().createGeneral([], comm=dm.comm)
4354-
4355-
label = dm.getLabel(label_name)
4356-
CHKERR(DMPlexLabelComplete(dm.dm, label.dmlabel))
4357-
iout = label.getStratumIS(label_values[0])
4358-
for val in label_values[1:]:
4359-
i1 = iout
4360-
i2 = label.getStratumIS(val)
4361-
iout = PETSc.IS()
4362-
CHKERR(ISIntersect(i1.iset, i2.iset, &iout.iset))
4346+
cdef PETSc.IS iout = PETSc.IS()
4347+
CHKERR(ISIntersect(i1.iset, i2.iset, &iout.iset))
43634348
return iout

firedrake/mesh.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pyop2.mpi import (
2626
MPI, COMM_WORLD, temp_internal_comm
2727
)
28-
from functools import cached_property
28+
from functools import cached_property, reduce
2929
from pyop2.utils import as_tuple
3030
import petsctools
3131
from petsctools import OptionsManager, get_external_packages
@@ -4954,14 +4954,19 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
49544954
if isinstance(subdomain_id, str):
49554955
raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.")
49564956
elif isinstance(subdomain_id, Sequence):
4957+
label = plex.getLabel(label_name)
4958+
if subdim != dim:
4959+
plex.labelComplete(label)
49574960
# Take the union of the labels in the list
49584961
iset = PETSc.IS().createGeneral([], comm=mesh.comm)
49594962
for sub in subdomain_id:
49604963
if isinstance(sub, Sequence):
49614964
# Take the intersection of the labels from nested lists
4962-
cur = dmcommon.create_label_intersection(plex, label_name, sub)
4965+
if len(sub) == 0:
4966+
continue
4967+
cur = reduce(dmcommon.intersectIS, map(label.getStratumIS, sub))
49634968
else:
4964-
cur = plex.getStratumIS(label_name, sub)
4969+
cur = label.getStratumIS(sub)
49654970
iset = iset.union(cur)
49664971
# Create a temporary label
49674972
label_name = "temp_label"

0 commit comments

Comments
 (0)