Skip to content

Commit 2a8aea3

Browse files
authored
Merge pull request #1726 from firedrakeproject/fix_equation_bc
fix direct assembly with equation_bcs
2 parents 5fe3cc2 + 004e90b commit 2a8aea3

5 files changed

Lines changed: 92 additions & 18 deletions

File tree

firedrake/assemble.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from firedrake import (assemble_expressions, matrix, parameters, solving,
1111
tsfc_interface, utils)
1212
from firedrake.adjoint import annotate_assemble
13-
from firedrake.bcs import DirichletBC, EquationBCSplit
13+
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
1414
from firedrake.slate import slac, slate
1515
from firedrake.utils import ScalarType
1616
from pyop2 import op2
@@ -162,7 +162,7 @@ def create_assembly_callable(expr, tensor=None, bcs=None, form_compiler_paramete
162162
raise ValueError("Have to provide tensor to write to")
163163
if mat_type == "matfree":
164164
return tensor.assemble
165-
loops = _assemble(expr, tensor=tensor, bcs=bcs,
165+
loops = _assemble(expr, tensor=tensor, bcs=solving._extract_bcs(bcs),
166166
form_compiler_parameters=form_compiler_parameters,
167167
mat_type=mat_type,
168168
sub_mat_type=sub_mat_type,
@@ -200,6 +200,10 @@ def get_matrix(expr, mat_type, sub_mat_type, *, bcs=None,
200200
arguments = expr.arguments()
201201
if bcs is None:
202202
bcs = ()
203+
else:
204+
if any(isinstance(bc, EquationBC) for bc in bcs):
205+
raise TypeError("EquationBC objects not expected here. "
206+
"Preprocess by extracting the appropriate form with bc.extract_form('Jp') or bc.extract_form('J')")
203207
if tensor is not None and tensor.a.arguments() != arguments:
204208
raise ValueError("Form's arguments do not match provided result tensor")
205209
if matfree:
@@ -425,9 +429,6 @@ def apply_bcs(tensor, bcs, *, assembly_rank=None, form_compiler_parameters=None,
425429
"""
426430
dirichletbcs = tuple(bc for bc in bcs if isinstance(bc, DirichletBC))
427431
equationbcs = tuple(bc for bc in bcs if isinstance(bc, EquationBCSplit))
428-
if any(not isinstance(bc, (DirichletBC, EquationBCSplit)) for bc in bcs):
429-
raise NotImplementedError("Unhandled type of bc object")
430-
431432
if assembly_rank == AssemblyRank.MATRIX:
432433
op2tensor = tensor.M
433434
shape = tuple(len(a.function_space()) for a in tensor.a.arguments())
@@ -669,11 +670,6 @@ def _assemble(expr, tensor=None, bcs=None, form_compiler_parameters=None,
669670
# building a functionspace (e.g. if integrating a constant)).
670671
m.init()
671672

672-
if bcs is None:
673-
bcs = ()
674-
else:
675-
bcs = tuple(bcs)
676-
677673
for o in chain(expr.arguments(), expr.coefficients()):
678674
domain = o.ufl_domain()
679675
if domain is not None and domain.topology != topology:
@@ -689,6 +685,16 @@ def _assemble(expr, tensor=None, bcs=None, form_compiler_parameters=None,
689685
else:
690686
assembly_rank = AssemblyRank.SCALAR
691687

688+
if not isinstance(bcs, (tuple, list)):
689+
raise RuntimeError("Expecting bcs to be a tuple or a list by this stage.")
690+
if assembly_rank == AssemblyRank.MATRIX:
691+
# Checks will take place in get_matrix.
692+
pass
693+
elif assembly_rank == AssemblyRank.VECTOR:
694+
# Might have gotten here without `EquationBC` objects preprocessed.
695+
if any(isinstance(bc, EquationBC) for bc in bcs):
696+
bcs = tuple(bc.extract_form('F') for bc in bcs)
697+
692698
if assembly_rank == AssemblyRank.MATRIX:
693699
test, trial = expr.arguments()
694700
tensor, zeros, result = get_matrix(expr, mat_type, sub_mat_type,

firedrake/bcs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@ def increment_bc_depth(self):
240240
for bc in itertools.chain(*self.bcs):
241241
bc._bc_depth += 1
242242

243+
def extract_forms(self, form_type):
244+
# Return boundary condition objects actually used in assembly.
245+
raise NotImplementedError("Method to extract form objects not implemented.")
246+
243247

244248
class DirichletBC(BCBase, DirichletBCMixin):
245249
r'''Implementation of a strong Dirichlet boundary condition.
@@ -437,6 +441,10 @@ def apply(self, r, u=None):
437441
def integrals(self):
438442
return []
439443

444+
def extract_form(self, form_type):
445+
# DirichletBC is directly used in assembly.
446+
return self
447+
440448

441449
class EquationBC(object):
442450
r'''Construct and store EquationBCSplit objects (for `F`, `J`, and `Jp`).
@@ -516,6 +524,16 @@ def dirichlet_bcs(self):
516524
# _F, _J, and _Jp all have the same DirichletBCs
517525
yield from self._F.dirichlet_bcs()
518526

527+
def extract_form(self, form_type):
528+
r"""Return :class:`.EquationBCSplit` associated with the given 'form_type'.
529+
530+
:arg form_type: Form to extract; 'F', 'J', or 'Jp'.
531+
"""
532+
if form_type not in {"F", "J", "Jp"}:
533+
raise ValueError("Unknown form_type: 'form_type' must be 'F', 'J', or 'Jp'.")
534+
else:
535+
return getattr(self, f"_{form_type}")
536+
519537
def reconstruct(self, V, subu, u, field):
520538
_F = self._F.reconstruct(field=field, V=V, subu=subu, u=u)
521539
_J = self._J.reconstruct(field=field, V=V, subu=subu, u=u)

firedrake/solving.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,10 +318,11 @@ def _extract_bcs(bcs):
318318
from firedrake.bcs import BCBase, EquationBC
319319
if bcs is None:
320320
return ()
321-
try:
322-
bcs = tuple(bcs)
323-
except TypeError:
324-
bcs = (bcs,)
321+
if isinstance(bcs, (BCBase, EquationBC)):
322+
return (bcs, )
323+
else:
324+
if not isinstance(bcs, (tuple, list)):
325+
raise TypeError("bcs must be BCBase, EquationBC, tuple, or list, not '%s'." % type(bcs).__name__)
325326
for bc in bcs:
326327
if not isinstance(bc, (BCBase, EquationBC)):
327328
raise TypeError("Provided boundary condition is a '%s', not a BCBase" % type(bc).__name__)

firedrake/solving_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,
8282
options_prefix=None,
8383
transfer_manager=None):
8484
from firedrake.assemble import create_assembly_callable
85-
from firedrake.bcs import DirichletBC
8685
if pmat_type is None:
8786
pmat_type = mat_type
8887
self.mat_type = mat_type
@@ -132,9 +131,9 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,
132131
# pmat_type == mat_type and Jp_eq_J
133132
self.Jp = None
134133

135-
self.bcs_F = [bc if isinstance(bc, DirichletBC) else bc._F for bc in problem.bcs]
136-
self.bcs_J = [bc if isinstance(bc, DirichletBC) else bc._J for bc in problem.bcs]
137-
self.bcs_Jp = [bc if isinstance(bc, DirichletBC) else bc._Jp for bc in problem.bcs]
134+
self.bcs_F = tuple(bc.extract_form('F') for bc in problem.bcs)
135+
self.bcs_J = tuple(bc.extract_form('J') for bc in problem.bcs)
136+
self.bcs_Jp = tuple(bc.extract_form('Jp') for bc in problem.bcs)
138137
self._assemble_residual = create_assembly_callable(self.F,
139138
tensor=self._F,
140139
bcs=self.bcs_F,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
3+
from firedrake import *
4+
import numpy as np
5+
6+
7+
def test_equation_bcs_direct_assemble_one_form():
8+
mesh = UnitSquareMesh(1, 1, quadrilateral=True)
9+
V = FunctionSpace(mesh, "CG", 1)
10+
u = Function(V).assign(1.)
11+
v = TestFunction(V)
12+
F = inner(grad(u), grad(v)) * dx
13+
F1 = inner(u, v) * ds(1)
14+
bc = EquationBC(F1 == 0, u, 1)
15+
16+
g = assemble(F, bcs=bc.extract_form('F'))
17+
assert(np.allclose(g.dat.data, [0.5, 0.5, 0, 0]))
18+
g = assemble(F, bcs=bc)
19+
assert(np.allclose(g.dat.data, [0.5, 0.5, 0, 0]))
20+
21+
22+
def test_equation_bcs_direct_assemble_two_form():
23+
mesh = UnitSquareMesh(1, 1, quadrilateral=True)
24+
V = FunctionSpace(mesh, "CG", 1)
25+
u = TrialFunction(V)
26+
v = TestFunction(V)
27+
a = inner(grad(u), grad(v)) * dx
28+
a1 = inner(u, v) * ds(1)
29+
L1 = inner(Constant(0), v) * ds(1)
30+
sol = Function(V)
31+
bc = EquationBC(a1 == L1, sol, 1, Jp=2 * inner(u, v) * ds(1))
32+
33+
# Must preprocess bc to extract appropriate
34+
# `EquationBCSplit` object.
35+
A = assemble(a, bcs=bc.extract_form('J'))
36+
assert(np.allclose(A.M.values, [[1 / 3, 1 / 6, 0, 0],
37+
[1 / 6, 1 / 3, 0, 0],
38+
[-1 / 3, -1 / 6, 2 / 3, -1 / 6],
39+
[-1 / 6, -1 / 3, -1 / 6, 2 / 3]]))
40+
A = assemble(a, bcs=bc.extract_form('Jp'))
41+
assert(np.allclose(A.M.values, [[2 / 3, 2 / 6, 0, 0],
42+
[2 / 6, 2 / 3, 0, 0],
43+
[-1 / 3, -1 / 6, 2 / 3, -1 / 6],
44+
[-1 / 6, -1 / 3, -1 / 6, 2 / 3]]))
45+
with pytest.raises(TypeError) as excinfo:
46+
# Unable to use raw `EquationBC` object, as
47+
# assembler can not infer merely from the rank
48+
# which form ('J' or 'Jp') should be assembled.
49+
assemble(a, bcs=bc)
50+
assert "EquationBC objects not expected here" in str(excinfo.value)

0 commit comments

Comments
 (0)