Skip to content

Commit 75f5c95

Browse files
committed
Take the intersection for nested lists of subdomain_ids
1 parent 926925d commit 75f5c95

3 files changed

Lines changed: 115 additions & 23 deletions

File tree

firedrake/mesh.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pyop2.mpi import (
2626
MPI, COMM_WORLD, temp_internal_comm
2727
)
28-
from functools import cached_property
28+
from functools import cached_property, reduce
2929
from pyop2.utils import as_tuple
3030
import petsctools
3131
from petsctools import OptionsManager, get_external_packages
@@ -4809,6 +4809,7 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
48094809
subdomain_id : int | Sequence | None
48104810
Subdomain ID representing the submesh.
48114811
If multiple subdomain IDs are provided, their union is taken.
4812+
If nested lists of subdomain IDs are provided, their intersection is taken.
48124813
If `None` the submesh will cover the entire domain,
48134814
this is useful to obtain a codim-1 submesh over all facets or
48144815
a submesh over a different communicator.
@@ -4878,20 +4879,33 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
48784879
elif subdim == dim - 1:
48794880
label_name = dmcommon.FACE_SETS_LABEL
48804881

4881-
if isinstance(subdomain_id, (tuple, list)):
4882-
# A list of subdomain ids requires us to build an internal DM label with the union
4883-
iset = PETSc.IS().createGeneral([], comm=comm or mesh.comm)
4882+
# Parse non-integer subdomain_id
4883+
if subdomain_id == "on_boundary":
4884+
subdomain_id = tuple(mesh.exterior_facets.unique_markers)
4885+
4886+
if isinstance(subdomain_id, Sequence):
4887+
# Create a temporary DMLabel with the union of the labels in the list
4888+
icomm = comm or mesh.comm
4889+
iset = PETSc.IS().createGeneral([], comm=icomm)
48844890
for sub in subdomain_id:
4885-
iset = iset.union(plex.getStratumIS(label_name, sub))
4886-
label_name = "temp_union"
4891+
if isinstance(sub, Sequence):
4892+
# Take the intersection of the (closure of the) labels from nested lists
4893+
ises = [plex.getStratumIS(label_name, subi) for subi in sub]
4894+
closure = [[plex.getTransitiveClosure(p)[0] for p in i.indices] for i in ises]
4895+
indices = reduce(np.intersect1d, closure)
4896+
cur = PETSc.IS().createGeneral(indices, comm=icomm)
4897+
else:
4898+
cur = plex.getStratumIS(label_name, sub)
4899+
iset = iset.union(cur)
4900+
label_name = "temp_label"
48874901
subdomain_id = 1
48884902
plex.createLabel(label_name)
48894903
label = plex.getLabel(label_name)
48904904
label.setStratumIS(subdomain_id, iset)
48914905

48924906
subplex = dmcommon.submesh_create(plex, subdim, label_name, subdomain_id, ignore_halo, comm=comm)
48934907

4894-
if label_name == "temp_union":
4908+
if label_name == "temp_label":
48954909
plex.removeLabel(label_name)
48964910

48974911
comm = comm or mesh.comm
@@ -4900,7 +4914,7 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
49004914
if subplex.getDimension() != subdim:
49014915
raise RuntimeError(f"Found subplex dim ({subplex.getDimension()}) != expected ({subdim})")
49024916
if reorder is None:
4903-
# Ideally we should set perm_is = mesh.dm_reordering[label_indices]
4917+
# Ideally we should set perm_is = mesh._dm_renumbering[label_indices]
49044918
reorder = mesh._did_reordering
49054919

49064920
submesh = Mesh(

tests/firedrake/submesh/test_submesh_facet.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,3 @@ def test_submesh_facet_all_facets():
134134
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
135135
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
136136
assert submesh2.cell_set.size == submesh1.cell_set.size
137-
138-
139-
def test_submesh_facet_subdomain_id_tuple():
140-
mesh = UnitCubeMesh(2, 2, 2)
141-
subdomain_id = (1, 3, 6)
142-
submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id)
143-
assert abs(assemble(1*dx(domain=submesh1)) - len(subdomain_id)) < 1E-12
144-
145-
V = FunctionSpace(mesh, "HDiv Trace", 0)
146-
facet_function = Function(V)
147-
DirichletBC(V, 1, subdomain_id).apply(facet_function)
148-
facet_value = 999
149-
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
150-
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
151-
assert submesh2.cell_set.size == submesh1.cell_set.size
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import pytest
2+
import numpy as np
3+
from firedrake import *
4+
5+
6+
def test_submesh_subdomain_id_tuple():
7+
mesh = UnitSquareMesh(4, 4)
8+
x, y = SpatialCoordinate(mesh)
9+
M = FunctionSpace(mesh, "DG", 0)
10+
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
11+
m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0))
12+
mesh.mark_entities(m1, 111)
13+
mesh.mark_entities(m2, 222)
14+
15+
subdomain_id = [111, 222]
16+
submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id)
17+
18+
m3 = Function(M).interpolate(m1 + m2 - m1 * m2)
19+
expected = assemble(m3*dx)
20+
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12
21+
22+
mesh.mark_entities(m3, 333)
23+
submesh2 = Submesh(mesh, mesh.topological_dimension, 333)
24+
assert submesh2.cell_set.size == submesh1.cell_set.size
25+
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)
26+
27+
28+
def test_submesh_subdomain_id_nested_tuple():
29+
mesh = UnitSquareMesh(4, 4)
30+
x, y = SpatialCoordinate(mesh)
31+
M = FunctionSpace(mesh, "DG", 0)
32+
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
33+
m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0))
34+
mesh.mark_entities(m1, 111)
35+
mesh.mark_entities(m2, 222)
36+
37+
subdomain_id = [(111, 222)]
38+
submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id)
39+
40+
m3 = Function(M).interpolate(m1 * m2)
41+
expected = assemble(m3*dx)
42+
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12
43+
44+
mesh.mark_entities(m3, 333)
45+
submesh2 = Submesh(mesh, mesh.topological_dimension, 333)
46+
assert submesh2.cell_set.size == submesh1.cell_set.size
47+
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)
48+
49+
50+
@pytest.mark.parametrize("subdomain_id", ["on_boundary", (1, 3, 6)])
51+
def test_submesh_facet_subdomain_id_tuple(subdomain_id):
52+
mesh = UnitCubeMesh(2, 2, 2)
53+
submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id)
54+
if subdomain_id == "on_boundary":
55+
area = assemble(1*ds(domain=mesh))
56+
else:
57+
area = assemble(1*ds(subdomain_id, domain=mesh))
58+
assert abs(assemble(1*dx(domain=submesh1)) - area) < 1E-12
59+
60+
V = FunctionSpace(mesh, "HDiv Trace", 0)
61+
facet_function = Function(V)
62+
DirichletBC(V, 1, subdomain_id).apply(facet_function)
63+
facet_value = 999
64+
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
65+
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
66+
assert submesh2.cell_set.size == submesh1.cell_set.size
67+
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)
68+
69+
70+
def test_submesh_facet_subdomain_id_nested_tuple():
71+
mesh = UnitSquareMesh(4, 4)
72+
x, y = SpatialCoordinate(mesh)
73+
M = FunctionSpace(mesh, "DG", 0)
74+
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
75+
m2 = Function(M).interpolate(conditional(lt(x, 0.5), 0, 1))
76+
mesh.mark_entities(m1, 111)
77+
mesh.mark_entities(m2, 222)
78+
79+
subdomain_id = [(111, 222)]
80+
submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id, label_name="Cell Sets")
81+
82+
expected = 1
83+
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12
84+
85+
x, y = SpatialCoordinate(mesh)
86+
V = FunctionSpace(mesh, "HDiv Trace", 0)
87+
facet_function = Function(V)
88+
facet_function.interpolate(conditional(lt(abs(x-0.5), 1E-8), 1, 0))
89+
facet_value = 999
90+
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
91+
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
92+
assert submesh2.cell_set.size == submesh1.cell_set.size
93+
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)

0 commit comments

Comments
 (0)