|
25 | 25 | from pyop2.mpi import ( |
26 | 26 | MPI, COMM_WORLD, temp_internal_comm |
27 | 27 | ) |
28 | | -from functools import cached_property |
| 28 | +from functools import cached_property, reduce |
29 | 29 | from pyop2.utils import as_tuple |
30 | 30 | import petsctools |
31 | 31 | from petsctools import OptionsManager, get_external_packages |
@@ -4954,14 +4954,19 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig |
4954 | 4954 | if isinstance(subdomain_id, str): |
4955 | 4955 | raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.") |
4956 | 4956 | elif isinstance(subdomain_id, Sequence): |
| 4957 | + label = plex.getLabel(label_name) |
| 4958 | + if subdim != dim: |
| 4959 | + plex.labelComplete(label) |
4957 | 4960 | # Take the union of the labels in the list |
4958 | 4961 | iset = PETSc.IS().createGeneral([], comm=mesh.comm) |
4959 | 4962 | for sub in subdomain_id: |
4960 | 4963 | if isinstance(sub, Sequence): |
4961 | 4964 | # Take the intersection of the labels from nested lists |
4962 | | - cur = dmcommon.create_label_intersection(plex, label_name, sub) |
| 4965 | + if len(sub) == 0: |
| 4966 | + continue |
| 4967 | + cur = reduce(dmcommon.intersectIS, map(label.getStratumIS, sub)) |
4963 | 4968 | else: |
4964 | | - cur = plex.getStratumIS(label_name, sub) |
| 4969 | + cur = label.getStratumIS(sub) |
4965 | 4970 | iset = iset.union(cur) |
4966 | 4971 | # Create a temporary label |
4967 | 4972 | label_name = "temp_label" |
|
0 commit comments