11import FIAT
22import numpy as np
33from firedrake import Function , TestFunction
4- from firedrake import LinearVariationalSolver as LVS
5- from firedrake import LinearVariationalProblem as LVP
6- from firedrake import NonlinearVariationalProblem as NLVP
7- from firedrake import NonlinearVariationalSolver as NLVS
84from ufl import Form , as_ufl , dx , inner , lhs , rhs
95
10- from .tableaux .ButcherTableaux import RadauIIA
11- from .ufl .deriv import TimeDerivative , expand_time_derivatives
12- from .stage_value import getFormStage
13- from .tools import AI , IA , reshape , replace , getNullspace , get_stage_space
6+ from .backend import get_backend
147from .bcs import bc2space
158from .constant import MeshConstant , vecconst
9+ from .stage_value import getFormStage
10+ from .tools import AI , IA , reshape , replace , getNullspace , get_stage_space
11+ from .tableaux .ButcherTableaux import RadauIIA
12+ from .ufl .deriv import TimeDerivative , expand_time_derivatives
1613
1714
1815def riia_explicit_coeffs (k ):
@@ -172,8 +169,10 @@ def __init__(self, F, Fexp, butcher_tableau,
172169 nullspace = None ,
173170 num_its_initial = 0 ,
174171 num_its_per_step = 0 ,
172+ backend = "firedrake" ,
175173 ** kwargs ):
176174 assert isinstance (butcher_tableau , RadauIIA )
175+ self ._backend = backend_cls = get_backend (backend )
177176
178177 self .u0 = u0
179178 self .t = t
@@ -197,7 +196,7 @@ def __init__(self, F, Fexp, butcher_tableau,
197196 # the update information on the floor.
198197 V = u0 .function_space ()
199198 Vbig = get_stage_space (V , self .num_stages )
200- UU = Function (Vbig )
199+ UU = backend_cls . Function (Vbig )
201200
202201 restrict = kwargs .pop ("restrict" , False )
203202 is_linear = kwargs .pop ("is_linear" , False )
@@ -219,28 +218,16 @@ def __init__(self, F, Fexp, butcher_tableau,
219218 Fit , Fprop = getFormExplicit (
220219 Fexp , butcher_tableau , u0 , UU_old , t , dt , splitting )
221220
222- F_linear = len (Fbig .arguments ()) == 2
223- if F_linear :
224- F1 = Fbig + Fit
225- F2 = Fbig + Fprop
226- itprob = LVP (lhs (F1 ), rhs (F1 ), UU , bcs = bigBCs ,
227- constant_jacobian = constant_jacobian ,
228- restrict = restrict )
229- propprob = LVP (lhs (F2 ), rhs (F2 ), UU , bcs = bigBCs ,
230- constant_jacobian = constant_jacobian ,
231- restrict = restrict )
232- create_solver = LVS
233- else :
234- itprob = NLVP (Fbig + Fit , UU , bcs = bigBCs ,
235- is_linear = is_linear , restrict = restrict )
236- propprob = NLVP (Fbig + Fprop , UU , bcs = bigBCs ,
237- is_linear = is_linear , restrict = restrict )
238- itprob ._constant_jacobian = constant_jacobian
239- propprob ._constant_jacobian = constant_jacobian
240- create_solver = NLVS
241-
242- self .itprob = itprob
243- self .propprob = propprob
221+ self .itprob = backend_cls .create_variational_problem (
222+ Fbig + Fit , UU , bcs = bigBCs ,
223+ is_linear = is_linear , restrict = restrict ,
224+ constant_jacobian = constant_jacobian ,
225+ )
226+ self .propprob = backend_cls .create_variational_problem (
227+ Fbig + Fprop , UU , bcs = bigBCs ,
228+ is_linear = is_linear , restrict = restrict ,
229+ constant_jacobian = constant_jacobian ,
230+ )
244231
245232 self .F = F
246233 self .orig_bcs = bcs
@@ -253,11 +240,11 @@ def __init__(self, F, Fexp, butcher_tableau,
253240 else :
254241 appctx = {** appctx , ** appctx_irksome }
255242
256- self .it_solver = create_solver (
243+ self .it_solver = backend_cls . create_variational_solver (
257244 self .itprob , appctx = appctx ,
258245 solver_parameters = it_solver_parameters ,
259246 nullspace = nsp , ** kwargs )
260- self .prop_solver = create_solver (
247+ self .prop_solver = backend_cls . create_variational_solver (
261248 self .propprob , appctx = appctx ,
262249 solver_parameters = prop_solver_parameters ,
263250 nullspace = nsp , ** kwargs )
@@ -464,31 +451,24 @@ def __init__(self, F, F_explicit, butcher_tableau, t, dt, u0, bcs=None,
464451 else :
465452 appctx = {** appctx , ** appctx_irksome }
466453
467- F_linear = len (stage_F .arguments ()) == 2
468- if F_linear :
469- problem = LVP (lhs (stage_F ), rhs (stage_F ), k , bcnew ,
470- constant_jacobian = constant_jacobian ,
471- restrict = restrict )
472- mass_problem = LVP (lhs (Fhat ), rhs (Fhat ), khat ,
473- constant_jacobian = constant_jacobian )
474- create_solver = LVS
475- else :
476- problem = NLVP (stage_F , k , bcnew ,
477- is_linear = is_linear ,
478- restrict = restrict )
479- problem ._constant_jacobian = constant_jacobian
480- mass_problem = NLVP (Fhat , khat , is_linear = True )
481- mass_problem ._constant_jacobian = True
482- create_solver = NLVS
483-
484- self .problem = problem
485- self .solver = create_solver (problem , appctx = appctx ,
486- solver_parameters = solver_parameters ,
487- nullspace = nullspace , ** kwargs )
488-
489- self .mass_problem = mass_problem
490- self .mass_solver = create_solver (mass_problem ,
491- solver_parameters = mass_parameters )
454+ self .problem = backend_cls .create_variational_problem (
455+ stage_F , k , bcnew ,
456+ is_linear = is_linear ,
457+ restrict = restrict
458+ )
459+ self .solver = backend_cls .create_variational_solver (
460+ self .problem , appctx = appctx ,
461+ solver_parameters = solver_parameters ,
462+ nullspace = nullspace , ** kwargs ,
463+ )
464+ self .mass_problem = backend_cls .create_variational_problem (
465+ Fhat , khat , is_linear = True ,
466+ constant_jacobian = constant_jacobian ,
467+ )
468+ self .mass_solver = backend_cls .create_variational_solver (
469+ self .mass_problem ,
470+ solver_parameters = mass_parameters ,
471+ )
492472
493473 self .kgac = k , g , a , c
494474 self .kgchat = khat , ghat , chat
0 commit comments