Skip to content

Commit e1daada

Browse files
committed
Merge branch 'main' into pbrubeck/goal-adaptive-solver
2 parents a08b6a1 + 8d51705 commit e1daada

34 files changed

Lines changed: 970 additions & 814 deletions

docs/source/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@
157157
# Cofunction.ufl_domains references FormArgument but it isn't picked
158158
# up by Sphinx (see https://github.com/sphinx-doc/sphinx/issues/11225)
159159
('py:class', 'FormArgument'),
160+
# Some complex type hints confuse Sphinx (https://github.com/sphinx-doc/sphinx/issues/14159)
161+
("py:obj", r"typing\.Literal\[.*"),
160162
]
161163

162164
# Dodgy links

firedrake/assemble.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from pyop2.exceptions import MapValueError, SparsityFormatError
3333
from functools import cached_property
3434

35-
from pyop2.types.glob import Global
3635
from pyop2.types.mat import _GlobalMatPayload, _DatMatPayload
3736

3837

@@ -1863,11 +1862,9 @@ def _as_global_kernel_arg_coefficient(_, self):
18631862
if V.ufl_element().family() == "Real":
18641863
# Interior facet integrals double Real coefficients for the
18651864
# two sides of the facet, matching the TSFC-generated kernel.
1866-
if self._integral_type.startswith("interior_facet"):
1867-
shape = (2, V.value_size)
1868-
else:
1869-
shape = (V.value_size,)
1870-
return op2.GlobalKernelArg(shape)
1865+
return op2.GlobalKernelArg(
1866+
(V.value_size,), double=self._integral_type.startswith("interior_facet")
1867+
)
18711868
else:
18721869
return self._make_dat_global_kernel_arg(V, index=index)
18731870

@@ -2217,14 +2214,6 @@ def _as_parloop_arg_cell_sizes(_, self):
22172214
def _as_parloop_arg_coefficient(arg, self):
22182215
coeff = next(self._active_coefficients)
22192216
if coeff.ufl_element().family() == "Real":
2220-
if self._integral_type.startswith("interior_facet"):
2221-
# The TSFC kernel expects the Real value on both facet
2222-
# sides so we tile the underlying data into a new Global.
2223-
data = numpy.tile(coeff.dat.data_ro, 2)
2224-
return op2.GlobalParloopArg(
2225-
Global(data.shape, data, coeff.dat.dtype,
2226-
name=coeff.dat.name, comm=coeff.dat.comm)
2227-
)
22282217
return op2.GlobalParloopArg(coeff.dat)
22292218
else:
22302219
m = self._get_map(coeff.function_space())

firedrake/bcs.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# A module implementing strong (Dirichlet) boundary conditions.
2-
import numpy as np
32

43
from functools import partial, reduce, cached_property
54
import itertools
65

6+
import numpy as np
7+
from mpi4py import MPI
8+
79
import ufl
810
from ufl import as_ufl, as_tensor
911
from finat.ufl import VectorElement
1012
import finat
1113

1214
import pyop2 as op2
1315
from pyop2 import exceptions
16+
from pyop2.mpi import temp_internal_comm
1417
from pyop2.utils import as_tuple
1518

1619
import firedrake
@@ -19,6 +22,7 @@
1922
from firedrake import slate
2023
from firedrake import solving
2124
from firedrake.formmanipulation import ExtractSubBlock
25+
from firedrake.logging import logger
2226
from firedrake.adjoint_utils.dirichletbc import DirichletBCMixin
2327
from firedrake.petsc import PETSc
2428

@@ -147,7 +151,7 @@ def hermite_stride(bcnodes):
147151
bcnodes = np.setdiff1d(bcnodes, deriv_ids)
148152
return bcnodes
149153

150-
sub_d = (self.sub_domain, ) if isinstance(self.sub_domain, str) else as_tuple(self.sub_domain)
154+
sub_d = (self.sub_domain,) if isinstance(self.sub_domain, str) else as_tuple(self.sub_domain)
151155
sub_d = [s if isinstance(s, str) else as_tuple(s) for s in sub_d]
152156
bcnodes = []
153157
for s in sub_d:
@@ -168,7 +172,15 @@ def hermite_stride(bcnodes):
168172
bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss)))
169173
bcnodes1 = reduce(np.intersect1d, bcnodes1)
170174
bcnodes.append(bcnodes1)
171-
return np.concatenate(bcnodes)
175+
bcnodes = np.concatenate(bcnodes)
176+
177+
with temp_internal_comm(self._function_space.mesh().comm) as icomm:
178+
num_global_nodes = icomm.reduce(len(bcnodes), MPI.SUM, root=0)
179+
if num_global_nodes == 0 and icomm.rank == 0:
180+
logger.warn(f"Subdomain {self.sub_domain} is empty. This is likely an error. "
181+
"Did you choose the right label?")
182+
183+
return bcnodes
172184

173185
@cached_property
174186
def node_set(self):

firedrake/cython/dmcommon.pyx

Lines changed: 197 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,7 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
788788
PetscInt nclosure, p, vi, v, fi, i
789789
PetscInt start_v, off
790790
PetscInt *closure = NULL
791+
PetscInt closure_tmp[2*9]
791792
PetscInt c_vertices[4]
792793
PetscInt c_facets[4]
793794
PetscInt g_vertices[4]
@@ -804,13 +805,13 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
804805
ncells = cEnd - cStart
805806
entity_per_cell = 4 + 4 + 1
806807

808+
CHKERR(PetscMalloc1(2*9, &closure))
809+
807810
cell_closure = np.empty((ncells, entity_per_cell), dtype=IntType)
808811
for c in range(cStart, cEnd):
809812
CHKERR(PetscSectionGetOffset(cell_numbering.sec, c, &cell))
810813
get_transitive_closure(plex.dm, c, PETSC_TRUE, &nclosure, &closure)
811814

812-
# First extract the facets (edges) and the vertices
813-
# from the transitive closure into c_facets and c_vertices.
814815
# Here we assume that DMPlex gives entities in the order:
815816
#
816817
# 8--3--7
@@ -821,7 +822,65 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
821822
#
822823
# where the starting vertex and order of traversal is arbitrary.
823824
# (We fix that later.)
825+
826+
# If we have a periodic mesh with only a single cell in the periodic
827+
# direction then the closure will look like
828+
#
829+
# 4--1--5
830+
# | |
831+
# 3 0 2 (vertical periodicity)
832+
# | |
833+
# 4--1--5
824834
#
835+
# or
836+
#
837+
# 5--3--5
838+
# | |
839+
# 2 0 2 (horizontal periodicity)
840+
# | |
841+
# 4--1--4
842+
#
843+
# and only have 6 entries instead of 9. For the following to work we have
844+
# to blow this out to a 9 entry array including the repeats.
845+
if nclosure == 4:
846+
raise NotImplementedError("Single-cell periodic quad meshes are "
847+
"not supported")
848+
elif nclosure == 6:
849+
horiz_periodicity, vert_periodicity = _get_periodicity(plex)
850+
(_, horiz_unit_periodic) = horiz_periodicity
851+
(_, vert_unit_periodic) = vert_periodicity
852+
if vert_unit_periodic:
853+
assert not horiz_unit_periodic
854+
closure_tmp[2*0] = closure[2*0]
855+
closure_tmp[2*1] = closure[2*1]
856+
closure_tmp[2*2] = closure[2*2]
857+
closure_tmp[2*3] = closure[2*1]
858+
closure_tmp[2*4] = closure[2*3]
859+
closure_tmp[2*5] = closure[2*4]
860+
closure_tmp[2*6] = closure[2*5]
861+
closure_tmp[2*7] = closure[2*5]
862+
closure_tmp[2*8] = closure[2*4]
863+
else:
864+
assert horiz_unit_periodic
865+
assert not vert_unit_periodic
866+
closure_tmp[2*0] = closure[2*0]
867+
closure_tmp[2*1] = closure[2*1]
868+
closure_tmp[2*2] = closure[2*2]
869+
closure_tmp[2*3] = closure[2*3]
870+
closure_tmp[2*4] = closure[2*2]
871+
closure_tmp[2*5] = closure[2*4]
872+
closure_tmp[2*6] = closure[2*4]
873+
closure_tmp[2*7] = closure[2*5]
874+
closure_tmp[2*8] = closure[2*5]
875+
876+
nclosure = 9
877+
for i in range(9):
878+
closure[2*i] = closure_tmp[2*i]
879+
else:
880+
assert nclosure == 9
881+
882+
# Extract the facets (edges) and the vertices
883+
# from the transitive closure into c_facets and c_vertices.
825884
# For the vertices, we also retrieve the global numbers into g_vertices.
826885
vi = 0
827886
fi = 0
@@ -923,8 +982,7 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
923982
cell_closure[cell, 4 + 3] = facets[1]
924983
cell_closure[cell, 8] = c
925984

926-
if closure != NULL:
927-
restore_transitive_closure(plex.dm, 0, PETSC_TRUE, &nclosure, &closure)
985+
CHKERR(PetscFree(closure))
928986

929987
return cell_closure
930988

@@ -1987,7 +2045,7 @@ def reordered_coords(PETSc.DM dm, PETSc.Section global_numbering, shape, referen
19872045
get_depth_stratum(dm.dm, 0, &vStart, &vEnd)
19882046
if isinstance(dm, PETSc.DMPlex):
19892047
if not dm.getCoordinatesLocalized():
1990-
# Use CG coordiantes.
2048+
# Use CG coordinates.
19912049
dm_sec = dm.getCoordinateSection()
19922050
dm_coords = dm.getCoordinatesLocal().array.reshape(shape)
19932051
coords = np.empty_like(dm_coords)
@@ -1998,12 +2056,11 @@ def reordered_coords(PETSc.DM dm, PETSc.Section global_numbering, shape, referen
19982056
for i in range(dim):
19992057
coords[offset, i] = dm_coords[dm_offset, i]
20002058
else:
2001-
# Use DG coordiantes.
2059+
# Use DG coordinates.
20022060
get_height_stratum(dm.dm, 0, &cStart, &cEnd)
20032061
dim = dm.getCoordinateDim()
20042062
ndofs, perm, perm_offsets = _get_firedrake_plex_permutation_dg_transitive_closure(dm)
2005-
dm_sec = dm.getCellCoordinateSection()
2006-
dm_coords = dm.getCellCoordinatesLocal().array.reshape(((cEnd - cStart) * ndofs[0], dim))
2063+
dm_coords, dm_sec = _get_expanded_dm_dg_coords(dm, ndofs)
20072064
coords = np.empty_like(dm_coords)
20082065
for c in range(cStart, cEnd):
20092066
CHKERR(PetscSectionGetOffset(global_numbering.sec, c, &offset)) # scalar offset
@@ -2031,6 +2088,138 @@ def reordered_coords(PETSc.DM dm, PETSc.Section global_numbering, shape, referen
20312088
raise ValueError("Only DMPlex and DMSwarm are supported.")
20322089
return coords
20332090

2091+
2092+
def _get_expanded_dm_dg_coords(dm: PETSc.DM, ndofs: np.ndarray):
2093+
"""Return the DM DG coordinates expanded to the full closure size.
2094+
2095+
This transformation accounts for the fact that single-cell periodic
2096+
domains have closures that are smaller than expected (due to repeated
2097+
points).
2098+
2099+
"""
2100+
cdef:
2101+
const PetscReal *L
2102+
2103+
PETSc.Section dm_sec_expanded
2104+
2105+
cStart, cEnd = dm.getHeightStratum(0)
2106+
dim = dm.getCoordinateDim()
2107+
coords_shape = ((cEnd-cStart) * ndofs[0], dim)
2108+
2109+
if dm.getCellCoordinateSection().getDof(cStart) < ndofs[0] * dim:
2110+
# Fewer cell coordinates available, we must be single-cell periodic
2111+
if dm.getCellType(cStart) == PETSc.DM.PolytopeType.QUADRILATERAL:
2112+
# If we have a periodic mesh with only a single cell in the periodic
2113+
# direction then the cell coordinates will be
2114+
#
2115+
# 1-----2
2116+
# | |
2117+
# | | (vertical periodicity)
2118+
# | |
2119+
# 1-----2
2120+
#
2121+
# or
2122+
#
2123+
# 2-----2
2124+
# | |
2125+
# | | (horizontal periodicity)
2126+
# | |
2127+
# 1-----1
2128+
#
2129+
# when the standard layout is
2130+
#
2131+
# 4-----3
2132+
# | |
2133+
# | |
2134+
# | |
2135+
# 1-----2
2136+
assert ndofs[0] == 4, "Not expecting high order coords here"
2137+
dm_coords_orig = dm.getCellCoordinatesLocal().array_r.reshape(((cEnd-cStart) * 2, dim))
2138+
dm_coords_expanded = np.empty(coords_shape, dtype=dm_coords_orig.dtype)
2139+
2140+
# Create a new cell coordinate section
2141+
dm_sec_orig = dm.getCellCoordinateSection()
2142+
dm_sec_expanded = PETSc.Section().create(comm=dm_sec_orig.comm)
2143+
dm_sec_expanded.setChart(*dm_sec_orig.getChart())
2144+
dm_sec_expanded.setPermutation(dm_sec_orig.getPermutation())
2145+
2146+
horiz_periodicity, vert_periodicity = _get_periodicity(dm)
2147+
(_, horiz_unit_periodic) = horiz_periodicity
2148+
(_, vert_unit_periodic) = vert_periodicity
2149+
2150+
# Find the domain sizes
2151+
CHKERR(DMGetPeriodicity(dm.dm, NULL, NULL, &L))
2152+
2153+
if horiz_unit_periodic:
2154+
if vert_unit_periodic:
2155+
raise NotImplementedError("Single-cell periodic quad meshes are "
2156+
"not supported")
2157+
else:
2158+
cell_width = L[0]
2159+
2160+
for c in range(cStart, cEnd):
2161+
CHKERR(PetscSectionSetDof(dm_sec_expanded.sec, c, 8))
2162+
2163+
dm_coords_expanded[4*c+0, 0] = dm_coords_orig[2*c+0, 0]
2164+
dm_coords_expanded[4*c+1, 0] = dm_coords_orig[2*c+0, 0] + cell_width
2165+
dm_coords_expanded[4*c+2, 0] = dm_coords_orig[2*c+1, 0] + cell_width
2166+
dm_coords_expanded[4*c+3, 0] = dm_coords_orig[2*c+1, 0]
2167+
dm_coords_expanded[4*c+0, 1] = dm_coords_orig[2*c+0, 1]
2168+
dm_coords_expanded[4*c+1, 1] = dm_coords_orig[2*c+0, 1]
2169+
dm_coords_expanded[4*c+2, 1] = dm_coords_orig[2*c+1, 1]
2170+
dm_coords_expanded[4*c+3, 1] = dm_coords_orig[2*c+1, 1]
2171+
2172+
else:
2173+
assert vert_unit_periodic
2174+
cell_height = L[1]
2175+
2176+
for c in range(cStart, cEnd):
2177+
CHKERR(PetscSectionSetDof(dm_sec_expanded.sec, c, 8))
2178+
2179+
dm_coords_expanded[4*c+0, 0] = dm_coords_orig[2*c+0, 0]
2180+
dm_coords_expanded[4*c+1, 0] = dm_coords_orig[2*c+1, 0]
2181+
dm_coords_expanded[4*c+2, 0] = dm_coords_orig[2*c+1, 0]
2182+
dm_coords_expanded[4*c+3, 0] = dm_coords_orig[2*c+0, 0]
2183+
dm_coords_expanded[4*c+0, 1] = dm_coords_orig[2*c+0, 1]
2184+
dm_coords_expanded[4*c+1, 1] = dm_coords_orig[2*c+1, 1]
2185+
dm_coords_expanded[4*c+2, 1] = dm_coords_orig[2*c+1, 1] + cell_height
2186+
dm_coords_expanded[4*c+3, 1] = dm_coords_orig[2*c+0, 1] + cell_height
2187+
2188+
dm_sec_expanded.setUp()
2189+
2190+
dm_coords = dm_coords_expanded
2191+
dm_sec = dm_sec_expanded
2192+
2193+
else:
2194+
raise NotImplementedError("Single cell periodicity for cell type "
2195+
f"{dm.getCellType(cStart)} is not supported")
2196+
2197+
else:
2198+
dm_coords = dm.getCellCoordinatesLocal().array_r.reshape(coords_shape)
2199+
dm_sec = dm.getCellCoordinateSection()
2200+
2201+
return dm_coords, dm_sec
2202+
2203+
2204+
def _get_periodicity(dm: PETSc.DM) -> tuple[tuple[bool, bool], ...]:
2205+
"""Return mesh periodicity information.
2206+
2207+
This function returns a 2-tuple of bools per dimension where the first entry indicates
2208+
whether the mesh is periodic in that dimension, and the second indicates whether the
2209+
mesh is single-cell periodic in that dimension.
2210+
2211+
"""
2212+
cdef:
2213+
const PetscReal *maxCell, *L
2214+
2215+
dim = dm.getCoordinateDim()
2216+
CHKERR(DMGetPeriodicity(dm.dm, &maxCell, NULL, &L))
2217+
return tuple(
2218+
(L[d] >= 0, maxCell[d] >= L[d])
2219+
for d in range(dim)
2220+
)
2221+
2222+
20342223
@cython.boundscheck(False)
20352224
@cython.wraparound(False)
20362225
def mark_entity_classes(PETSc.DM dm):

firedrake/cython/petschdr.pxi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ cdef extern from "petscdm.h" nogil:
103103
PetscErrorCode DMSetLabelValue(PETSc.PetscDM,char[],PetscInt,PetscInt)
104104
PetscErrorCode DMGetLabelValue(PETSc.PetscDM,char[],PetscInt,PetscInt*)
105105

106+
PetscErrorCode DMGetPeriodicity(PETSc.PetscDM,PetscReal *[], PetscReal *[], PetscReal *[])
107+
PetscErrorCode DMGetSparseLocalize(PETSc.PetscDM,PetscBool *)
108+
PetscErrorCode DMSetSparseLocalize(PETSc.PetscDM,PetscBool)
109+
106110
cdef extern from "petscdmswarm.h" nogil:
107111
PetscErrorCode DMSwarmGetLocalSize(PETSc.PetscDM,PetscInt*)
108112
PetscErrorCode DMSwarmGetCellDM(PETSc.PetscDM, PETSc.PetscDM*)

0 commit comments

Comments
 (0)