Skip to content

Commit 951bcd2

Browse files
authored
Refactor periodic utility meshes (#4836)
* Use DMPlex.createBoxMesh to create our periodic meshes This is necessary because the previous approach to generating periodic meshes resulted in a DMPlex with a slightly invalid state that prevents subsequent transformations (in particular extrusion). We also avoid manually labelling the boundaries in favour of just renumbering the existing 'Face Sets' label produced by the DMPlex. * linting * vom fix * submesh fix * More fixes, all of them? * Apply suggestions from code review Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk> * fixup * Add warning for empty subdomains * Apply suggestions from code review Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk> * Try final PETSc fix * Apply suggestion from @connorjward * Remove breaking API change for this PR
1 parent 6c900cf commit 951bcd2

16 files changed

Lines changed: 820 additions & 652 deletions

docs/source/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@
157157
# Cofunction.ufl_domains references FormArgument but it isn't picked
158158
# up by Sphinx (see https://github.com/sphinx-doc/sphinx/issues/11225)
159159
('py:class', 'FormArgument'),
160+
# Some complex type hints confuse Sphinx (https://github.com/sphinx-doc/sphinx/issues/14159)
161+
("py:obj", r"typing\.Literal\[.*"),
160162
]
161163

162164
# Dodgy links

firedrake/bcs.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# A module implementing strong (Dirichlet) boundary conditions.
2-
import numpy as np
32

43
from functools import partial, reduce, cached_property
54
import itertools
65

6+
import numpy as np
7+
from mpi4py import MPI
8+
79
import ufl
810
from ufl import as_ufl, as_tensor
911
from finat.ufl import VectorElement
1012
import finat
1113

1214
import pyop2 as op2
1315
from pyop2 import exceptions
16+
from pyop2.mpi import temp_internal_comm
1417
from pyop2.utils import as_tuple
1518

1619
import firedrake
@@ -19,6 +22,7 @@
1922
from firedrake import slate
2023
from firedrake import solving
2124
from firedrake.formmanipulation import ExtractSubBlock
25+
from firedrake.logging import logger
2226
from firedrake.adjoint_utils.dirichletbc import DirichletBCMixin
2327
from firedrake.petsc import PETSc
2428

@@ -147,7 +151,7 @@ def hermite_stride(bcnodes):
147151
bcnodes = np.setdiff1d(bcnodes, deriv_ids)
148152
return bcnodes
149153

150-
sub_d = (self.sub_domain, ) if isinstance(self.sub_domain, str) else as_tuple(self.sub_domain)
154+
sub_d = (self.sub_domain,) if isinstance(self.sub_domain, str) else as_tuple(self.sub_domain)
151155
sub_d = [s if isinstance(s, str) else as_tuple(s) for s in sub_d]
152156
bcnodes = []
153157
for s in sub_d:
@@ -168,7 +172,15 @@ def hermite_stride(bcnodes):
168172
bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss)))
169173
bcnodes1 = reduce(np.intersect1d, bcnodes1)
170174
bcnodes.append(bcnodes1)
171-
return np.concatenate(bcnodes)
175+
bcnodes = np.concatenate(bcnodes)
176+
177+
with temp_internal_comm(self._function_space.mesh().comm) as icomm:
178+
num_global_nodes = icomm.reduce(len(bcnodes), MPI.SUM, root=0)
179+
if num_global_nodes == 0 and icomm.rank == 0:
180+
logger.warn(f"Subdomain {self.sub_domain} is empty. This is likely an error. "
181+
"Did you choose the right label?")
182+
183+
return bcnodes
172184

173185
@cached_property
174186
def node_set(self):

firedrake/cython/dmcommon.pyx

Lines changed: 197 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,7 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
788788
PetscInt nclosure, p, vi, v, fi, i
789789
PetscInt start_v, off
790790
PetscInt *closure = NULL
791+
PetscInt closure_tmp[2*9]
791792
PetscInt c_vertices[4]
792793
PetscInt c_facets[4]
793794
PetscInt g_vertices[4]
@@ -804,13 +805,13 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
804805
ncells = cEnd - cStart
805806
entity_per_cell = 4 + 4 + 1
806807

808+
CHKERR(PetscMalloc1(2*9, &closure))
809+
807810
cell_closure = np.empty((ncells, entity_per_cell), dtype=IntType)
808811
for c in range(cStart, cEnd):
809812
CHKERR(PetscSectionGetOffset(cell_numbering.sec, c, &cell))
810813
get_transitive_closure(plex.dm, c, PETSC_TRUE, &nclosure, &closure)
811814

812-
# First extract the facets (edges) and the vertices
813-
# from the transitive closure into c_facets and c_vertices.
814815
# Here we assume that DMPlex gives entities in the order:
815816
#
816817
# 8--3--7
@@ -821,7 +822,65 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
821822
#
822823
# where the starting vertex and order of traversal is arbitrary.
823824
# (We fix that later.)
825+
826+
# If we have a periodic mesh with only a single cell in the periodic
827+
# direction then the closure will look like
828+
#
829+
# 4--1--5
830+
# | |
831+
# 3 0 2 (vertical periodicity)
832+
# | |
833+
# 4--1--5
824834
#
835+
# or
836+
#
837+
# 5--3--5
838+
# | |
839+
# 2 0 2 (horizontal periodicity)
840+
# | |
841+
# 4--1--4
842+
#
843+
# and only have 6 entries instead of 9. For the following to work we have
844+
# to blow this out to a 9 entry array including the repeats.
845+
if nclosure == 4:
846+
raise NotImplementedError("Single-cell periodic quad meshes are "
847+
"not supported")
848+
elif nclosure == 6:
849+
horiz_periodicity, vert_periodicity = _get_periodicity(plex)
850+
(_, horiz_unit_periodic) = horiz_periodicity
851+
(_, vert_unit_periodic) = vert_periodicity
852+
if vert_unit_periodic:
853+
assert not horiz_unit_periodic
854+
closure_tmp[2*0] = closure[2*0]
855+
closure_tmp[2*1] = closure[2*1]
856+
closure_tmp[2*2] = closure[2*2]
857+
closure_tmp[2*3] = closure[2*1]
858+
closure_tmp[2*4] = closure[2*3]
859+
closure_tmp[2*5] = closure[2*4]
860+
closure_tmp[2*6] = closure[2*5]
861+
closure_tmp[2*7] = closure[2*5]
862+
closure_tmp[2*8] = closure[2*4]
863+
else:
864+
assert horiz_unit_periodic
865+
assert not vert_unit_periodic
866+
closure_tmp[2*0] = closure[2*0]
867+
closure_tmp[2*1] = closure[2*1]
868+
closure_tmp[2*2] = closure[2*2]
869+
closure_tmp[2*3] = closure[2*3]
870+
closure_tmp[2*4] = closure[2*2]
871+
closure_tmp[2*5] = closure[2*4]
872+
closure_tmp[2*6] = closure[2*4]
873+
closure_tmp[2*7] = closure[2*5]
874+
closure_tmp[2*8] = closure[2*5]
875+
876+
nclosure = 9
877+
for i in range(9):
878+
closure[2*i] = closure_tmp[2*i]
879+
else:
880+
assert nclosure == 9
881+
882+
# Extract the facets (edges) and the vertices
883+
# from the transitive closure into c_facets and c_vertices.
825884
# For the vertices, we also retrieve the global numbers into g_vertices.
826885
vi = 0
827886
fi = 0
@@ -923,8 +982,7 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
923982
cell_closure[cell, 4 + 3] = facets[1]
924983
cell_closure[cell, 8] = c
925984

926-
if closure != NULL:
927-
restore_transitive_closure(plex.dm, 0, PETSC_TRUE, &nclosure, &closure)
985+
CHKERR(PetscFree(closure))
928986

929987
return cell_closure
930988

@@ -1987,7 +2045,7 @@ def reordered_coords(PETSc.DM dm, PETSc.Section global_numbering, shape, referen
19872045
get_depth_stratum(dm.dm, 0, &vStart, &vEnd)
19882046
if isinstance(dm, PETSc.DMPlex):
19892047
if not dm.getCoordinatesLocalized():
1990-
# Use CG coordiantes.
2048+
# Use CG coordinates.
19912049
dm_sec = dm.getCoordinateSection()
19922050
dm_coords = dm.getCoordinatesLocal().array.reshape(shape)
19932051
coords = np.empty_like(dm_coords)
@@ -1998,12 +2056,11 @@ def reordered_coords(PETSc.DM dm, PETSc.Section global_numbering, shape, referen
19982056
for i in range(dim):
19992057
coords[offset, i] = dm_coords[dm_offset, i]
20002058
else:
2001-
# Use DG coordiantes.
2059+
# Use DG coordinates.
20022060
get_height_stratum(dm.dm, 0, &cStart, &cEnd)
20032061
dim = dm.getCoordinateDim()
20042062
ndofs, perm, perm_offsets = _get_firedrake_plex_permutation_dg_transitive_closure(dm)
2005-
dm_sec = dm.getCellCoordinateSection()
2006-
dm_coords = dm.getCellCoordinatesLocal().array.reshape(((cEnd - cStart) * ndofs[0], dim))
2063+
dm_coords, dm_sec = _get_expanded_dm_dg_coords(dm, ndofs)
20072064
coords = np.empty_like(dm_coords)
20082065
for c in range(cStart, cEnd):
20092066
CHKERR(PetscSectionGetOffset(global_numbering.sec, c, &offset)) # scalar offset
@@ -2031,6 +2088,138 @@ def reordered_coords(PETSc.DM dm, PETSc.Section global_numbering, shape, referen
20312088
raise ValueError("Only DMPlex and DMSwarm are supported.")
20322089
return coords
20332090

2091+
2092+
def _get_expanded_dm_dg_coords(dm: PETSc.DM, ndofs: np.ndarray):
2093+
"""Return the DM DG coordinates expanded to the full closure size.
2094+
2095+
This transformation accounts for the fact that single-cell periodic
2096+
domains have closures that are smaller than expected (due to repeated
2097+
points).
2098+
2099+
"""
2100+
cdef:
2101+
const PetscReal *L
2102+
2103+
PETSc.Section dm_sec_expanded
2104+
2105+
cStart, cEnd = dm.getHeightStratum(0)
2106+
dim = dm.getCoordinateDim()
2107+
coords_shape = ((cEnd-cStart) * ndofs[0], dim)
2108+
2109+
if dm.getCellCoordinateSection().getDof(cStart) < ndofs[0] * dim:
2110+
# Fewer cell coordinates available, we must be single-cell periodic
2111+
if dm.getCellType(cStart) == PETSc.DM.PolytopeType.QUADRILATERAL:
2112+
# If we have a periodic mesh with only a single cell in the periodic
2113+
# direction then the cell coordinates will be
2114+
#
2115+
# 1-----2
2116+
# | |
2117+
# | | (vertical periodicity)
2118+
# | |
2119+
# 1-----2
2120+
#
2121+
# or
2122+
#
2123+
# 2-----2
2124+
# | |
2125+
# | | (horizontal periodicity)
2126+
# | |
2127+
# 1-----1
2128+
#
2129+
# when the standard layout is
2130+
#
2131+
# 4-----3
2132+
# | |
2133+
# | |
2134+
# | |
2135+
# 1-----2
2136+
assert ndofs[0] == 4, "Not expecting high order coords here"
2137+
dm_coords_orig = dm.getCellCoordinatesLocal().array_r.reshape(((cEnd-cStart) * 2, dim))
2138+
dm_coords_expanded = np.empty(coords_shape, dtype=dm_coords_orig.dtype)
2139+
2140+
# Create a new cell coordinate section
2141+
dm_sec_orig = dm.getCellCoordinateSection()
2142+
dm_sec_expanded = PETSc.Section().create(comm=dm_sec_orig.comm)
2143+
dm_sec_expanded.setChart(*dm_sec_orig.getChart())
2144+
dm_sec_expanded.setPermutation(dm_sec_orig.getPermutation())
2145+
2146+
horiz_periodicity, vert_periodicity = _get_periodicity(dm)
2147+
(_, horiz_unit_periodic) = horiz_periodicity
2148+
(_, vert_unit_periodic) = vert_periodicity
2149+
2150+
# Find the domain sizes
2151+
CHKERR(DMGetPeriodicity(dm.dm, NULL, NULL, &L))
2152+
2153+
if horiz_unit_periodic:
2154+
if vert_unit_periodic:
2155+
raise NotImplementedError("Single-cell periodic quad meshes are "
2156+
"not supported")
2157+
else:
2158+
cell_width = L[0]
2159+
2160+
for c in range(cStart, cEnd):
2161+
CHKERR(PetscSectionSetDof(dm_sec_expanded.sec, c, 8))
2162+
2163+
dm_coords_expanded[4*c+0, 0] = dm_coords_orig[2*c+0, 0]
2164+
dm_coords_expanded[4*c+1, 0] = dm_coords_orig[2*c+0, 0] + cell_width
2165+
dm_coords_expanded[4*c+2, 0] = dm_coords_orig[2*c+1, 0] + cell_width
2166+
dm_coords_expanded[4*c+3, 0] = dm_coords_orig[2*c+1, 0]
2167+
dm_coords_expanded[4*c+0, 1] = dm_coords_orig[2*c+0, 1]
2168+
dm_coords_expanded[4*c+1, 1] = dm_coords_orig[2*c+0, 1]
2169+
dm_coords_expanded[4*c+2, 1] = dm_coords_orig[2*c+1, 1]
2170+
dm_coords_expanded[4*c+3, 1] = dm_coords_orig[2*c+1, 1]
2171+
2172+
else:
2173+
assert vert_unit_periodic
2174+
cell_height = L[1]
2175+
2176+
for c in range(cStart, cEnd):
2177+
CHKERR(PetscSectionSetDof(dm_sec_expanded.sec, c, 8))
2178+
2179+
dm_coords_expanded[4*c+0, 0] = dm_coords_orig[2*c+0, 0]
2180+
dm_coords_expanded[4*c+1, 0] = dm_coords_orig[2*c+1, 0]
2181+
dm_coords_expanded[4*c+2, 0] = dm_coords_orig[2*c+1, 0]
2182+
dm_coords_expanded[4*c+3, 0] = dm_coords_orig[2*c+0, 0]
2183+
dm_coords_expanded[4*c+0, 1] = dm_coords_orig[2*c+0, 1]
2184+
dm_coords_expanded[4*c+1, 1] = dm_coords_orig[2*c+1, 1]
2185+
dm_coords_expanded[4*c+2, 1] = dm_coords_orig[2*c+1, 1] + cell_height
2186+
dm_coords_expanded[4*c+3, 1] = dm_coords_orig[2*c+0, 1] + cell_height
2187+
2188+
dm_sec_expanded.setUp()
2189+
2190+
dm_coords = dm_coords_expanded
2191+
dm_sec = dm_sec_expanded
2192+
2193+
else:
2194+
raise NotImplementedError("Single cell periodicity for cell type "
2195+
f"{dm.getCellType(cStart)} is not supported")
2196+
2197+
else:
2198+
dm_coords = dm.getCellCoordinatesLocal().array_r.reshape(coords_shape)
2199+
dm_sec = dm.getCellCoordinateSection()
2200+
2201+
return dm_coords, dm_sec
2202+
2203+
2204+
def _get_periodicity(dm: PETSc.DM) -> tuple[tuple[bool, bool], ...]:
2205+
"""Return mesh periodicity information.
2206+
2207+
This function returns a 2-tuple of bools per dimension where the first entry indicates
2208+
whether the mesh is periodic in that dimension, and the second indicates whether the
2209+
mesh is single-cell periodic in that dimension.
2210+
2211+
"""
2212+
cdef:
2213+
const PetscReal *maxCell, *L
2214+
2215+
dim = dm.getCoordinateDim()
2216+
CHKERR(DMGetPeriodicity(dm.dm, &maxCell, NULL, &L))
2217+
return tuple(
2218+
(L[d] >= 0, maxCell[d] >= L[d])
2219+
for d in range(dim)
2220+
)
2221+
2222+
20342223
@cython.boundscheck(False)
20352224
@cython.wraparound(False)
20362225
def mark_entity_classes(PETSc.DM dm):

firedrake/cython/petschdr.pxi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ cdef extern from "petscdm.h" nogil:
103103
PetscErrorCode DMSetLabelValue(PETSc.PetscDM,char[],PetscInt,PetscInt)
104104
PetscErrorCode DMGetLabelValue(PETSc.PetscDM,char[],PetscInt,PetscInt*)
105105

106+
PetscErrorCode DMGetPeriodicity(PETSc.PetscDM,PetscReal *[], PetscReal *[], PetscReal *[])
107+
PetscErrorCode DMGetSparseLocalize(PETSc.PetscDM,PetscBool *)
108+
PetscErrorCode DMSetSparseLocalize(PETSc.PetscDM,PetscBool)
109+
106110
cdef extern from "petscdmswarm.h" nogil:
107111
PetscErrorCode DMSwarmGetLocalSize(PETSc.PetscDM,PetscInt*)
108112
PetscErrorCode DMSwarmGetCellDM(PETSc.PetscDM, PETSc.PetscDM*)

firedrake/mesh.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import firedrake.cython.spatialindex as spatialindex
3838
import firedrake.utils as utils
3939
from firedrake.utils import as_cstr, IntType, RealType
40-
from firedrake.logging import info_red
40+
from firedrake.logging import info_red, logger
4141
from firedrake.parameters import parameters
4242
from firedrake.petsc import PETSc, DEFAULT_PARTITIONER
4343
from firedrake.adjoint_utils import MeshGeometryMixin
@@ -201,14 +201,6 @@ def __init__(self, mesh, facets, classes, set_, kind, facet_cell, local_facet_nu
201201
self.unique_markers = [] if unique_markers is None else unique_markers
202202
self._subsets = {}
203203

204-
@cached_property
205-
def _null_subset(self):
206-
'''Empty subset for the case in which there are no facets with
207-
a given marker value. This is required because not all
208-
markers need be represented on all processors.'''
209-
210-
return op2.Subset(self.set, [])
211-
212204
@PETSc.Log.EventDecorator()
213205
def measure_set(self, integral_type, subdomain_id,
214206
all_integer_subdomain_ids=None):
@@ -283,9 +275,16 @@ def subset(self, markers):
283275
marked_points_list.append(self.mesh.topology_dm.getStratumIS(dmcommon.FACE_SETS_LABEL, i).indices)
284276
if marked_points_list:
285277
_, indices, _ = np.intersect1d(self.facets, np.concatenate(marked_points_list), return_indices=True)
286-
return self._subsets.setdefault(markers, op2.Subset(self.set, indices))
287278
else:
288-
return self._subsets.setdefault(markers, self._null_subset)
279+
indices = np.empty(0, dtype=IntType)
280+
281+
with temp_internal_comm(self.mesh.comm) as icomm:
282+
num_global_indices = icomm.reduce(len(indices), MPI.SUM, root=0)
283+
if num_global_indices == 0 and icomm.rank == 0:
284+
logger.warn(f"Subdomain {markers} is empty. This is likely an error. "
285+
"Did you choose the right label?")
286+
287+
return self._subsets.setdefault(markers, op2.Subset(self.set, indices))
289288

290289
def _collect_unmarked_points(self, markers):
291290
"""Collect points that are not marked by markers."""

0 commit comments

Comments
 (0)