1010from firedrake import (assemble_expressions , matrix , parameters , solving ,
1111 tsfc_interface , utils )
1212from firedrake .adjoint import annotate_assemble
13- from firedrake .bcs import DirichletBC , EquationBCSplit
13+ from firedrake .bcs import DirichletBC , EquationBC , EquationBCSplit
1414from firedrake .slate import slac , slate
1515from firedrake .utils import ScalarType
1616from pyop2 import op2
@@ -162,7 +162,7 @@ def create_assembly_callable(expr, tensor=None, bcs=None, form_compiler_paramete
162162 raise ValueError ("Have to provide tensor to write to" )
163163 if mat_type == "matfree" :
164164 return tensor .assemble
165- loops = _assemble (expr , tensor = tensor , bcs = bcs ,
165+ loops = _assemble (expr , tensor = tensor , bcs = solving . _extract_bcs ( bcs ) ,
166166 form_compiler_parameters = form_compiler_parameters ,
167167 mat_type = mat_type ,
168168 sub_mat_type = sub_mat_type ,
@@ -200,6 +200,10 @@ def get_matrix(expr, mat_type, sub_mat_type, *, bcs=None,
200200 arguments = expr .arguments ()
201201 if bcs is None :
202202 bcs = ()
203+ else :
204+ if any (isinstance (bc , EquationBC ) for bc in bcs ):
205+ raise TypeError ("EquationBC objects not expected here. "
206+ "Preprocess by extracting the appropriate form with bc.extract_form('Jp') or bc.extract_form('J')" )
203207 if tensor is not None and tensor .a .arguments () != arguments :
204208 raise ValueError ("Form's arguments do not match provided result tensor" )
205209 if matfree :
@@ -425,9 +429,6 @@ def apply_bcs(tensor, bcs, *, assembly_rank=None, form_compiler_parameters=None,
425429 """
426430 dirichletbcs = tuple (bc for bc in bcs if isinstance (bc , DirichletBC ))
427431 equationbcs = tuple (bc for bc in bcs if isinstance (bc , EquationBCSplit ))
428- if any (not isinstance (bc , (DirichletBC , EquationBCSplit )) for bc in bcs ):
429- raise NotImplementedError ("Unhandled type of bc object" )
430-
431432 if assembly_rank == AssemblyRank .MATRIX :
432433 op2tensor = tensor .M
433434 shape = tuple (len (a .function_space ()) for a in tensor .a .arguments ())
@@ -669,11 +670,6 @@ def _assemble(expr, tensor=None, bcs=None, form_compiler_parameters=None,
669670 # building a functionspace (e.g. if integrating a constant)).
670671 m .init ()
671672
672- if bcs is None :
673- bcs = ()
674- else :
675- bcs = tuple (bcs )
676-
677673 for o in chain (expr .arguments (), expr .coefficients ()):
678674 domain = o .ufl_domain ()
679675 if domain is not None and domain .topology != topology :
@@ -689,6 +685,16 @@ def _assemble(expr, tensor=None, bcs=None, form_compiler_parameters=None,
689685 else :
690686 assembly_rank = AssemblyRank .SCALAR
691687
688+ if not isinstance (bcs , (tuple , list )):
689+ raise RuntimeError ("Expecting bcs to be a tuple or a list by this stage." )
690+ if assembly_rank == AssemblyRank .MATRIX :
691+ # Checks will take place in get_matrix.
692+ pass
693+ elif assembly_rank == AssemblyRank .VECTOR :
694+ # Might have gotten here without `EquationBC` objects preprocessed.
695+ if any (isinstance (bc , EquationBC ) for bc in bcs ):
696+ bcs = tuple (bc .extract_form ('F' ) for bc in bcs )
697+
692698 if assembly_rank == AssemblyRank .MATRIX :
693699 test , trial = expr .arguments ()
694700 tensor , zeros , result = get_matrix (expr , mat_type , sub_mat_type ,
0 commit comments