Skip to content

Commit fabfefc

Browse files
committed
get variational problem/solver from backend
1 parent 6c2f797 commit fabfefc

6 files changed

Lines changed: 130 additions & 138 deletions

File tree

irksome/backends/firedrake.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,35 @@ def Constant(self, val=0.0) -> ufl.Coefficient:
3636

3737
def get_mesh_constant(MC: MeshConstant | None):
3838
return MC.Constant if MC else firedrake.Constant
39+
40+
41+
def create_variational_problem(F, u, bcs=None, J=None, Jp=None, **kwargs):
42+
if len(F.arguments()) == 2:
43+
a = ufl.lhs(F)
44+
L = ufl.rhs(F)
45+
kwargs.pop("is_linear", None)
46+
problem = firedrake.LinearVariationalProblem(a, L, u, bcs=bcs, aP=Jp, **kwargs)
47+
else:
48+
constant_jacobian = kwargs.pop("constant_jacobian", False)
49+
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp, **kwargs)
50+
if constant_jacobian:
51+
problem._constant_jacobian = constant_jacobian
52+
return problem
53+
54+
55+
def create_variational_solver(problem, **kwargs):
56+
if isinstance(problem, firedrake.LinearVariationalProblem):
57+
return firedrake.LinearVariationalSolver(problem, **kwargs)
58+
else:
59+
return firedrake.NonlinearVariationalSolver(problem, **kwargs)
60+
61+
62+
def invalidate_jacobian(solver):
63+
return firedrake.LinearVariationalSolver.invalidate_jacobian(solver)
64+
65+
66+
derivative = firedrake.derivative
67+
norm = firedrake.norm
68+
Function = firedrake.Function
69+
TestFunction = firedrake.TestFunction
70+
TrialFunction = firedrake.TrialFunction

irksome/base_time_stepper.py

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
from abc import abstractmethod
2-
from firedrake import (
3-
derivative, lhs, rhs, Function, TrialFunction,
4-
LinearVariationalProblem, LinearVariationalSolver,
5-
NonlinearVariationalProblem, NonlinearVariationalSolver,
6-
)
72
from firedrake.petsc import PETSc
83
from .tools import AI, getNullspace, flatten_dats, split_stages
94
from .labeling import as_form
@@ -122,44 +117,30 @@ def __init__(self, F, t, dt, u0, num_stages,
122117
Vbig = stages.function_space()
123118

124119
F_linear = len(as_form(F).arguments()) == 2
125-
stages_F = TrialFunction(Vbig) if F_linear else stages
120+
stages_F = self._backend.TrialFunction(Vbig) if F_linear else stages
126121
Fbig, bigBCs = self.get_form_and_bcs(stages_F)
127122

128123
Jpbig = None
129124
if Fp is not None:
130125
Fp_linear = len(as_form(Fp).arguments()) == 2
131-
stages_Fp = TrialFunction(Vbig) if Fp_linear else stages
126+
stages_Fp = self._backend.TrialFunction(Vbig) if Fp_linear else stages
132127
Fpbig, _ = self.get_form_and_bcs(stages_Fp, F=Fp, bcs=())
133-
Jpbig = lhs(Fpbig) if Fp_linear else derivative(Fpbig, stages_Fp)
128+
Jpbig = ufl.lhs(Fpbig) if Fp_linear else self._backed.derivative(Fpbig, stages_Fp)
134129

135130
nullspace = getNullspace(V, Vbig, num_stages, nullspace)
136131
transpose_nullspace = getNullspace(V, Vbig, num_stages, transpose_nullspace)
137132
near_nullspace = getNullspace(V, Vbig, num_stages, near_nullspace)
138133

139134
self.bigBCs = bigBCs
140135

141-
if F_linear:
142-
abig = lhs(Fbig)
143-
Lbig = rhs(Fbig)
144-
problem = LinearVariationalProblem(
145-
abig, Lbig, stages, bcs=bigBCs, aP=Jpbig,
146-
form_compiler_parameters=kwargs.pop("form_compiler_parameters", None),
147-
constant_jacobian=kwargs.pop("constant_jacobian", False),
148-
restrict=kwargs.pop("restrict", False),
149-
)
150-
solver_constructor = LinearVariationalSolver
151-
else:
152-
problem = NonlinearVariationalProblem(
153-
Fbig, stages, bcs=bigBCs, Jp=Jpbig,
154-
form_compiler_parameters=kwargs.pop("form_compiler_parameters", None),
155-
is_linear=kwargs.pop("is_linear", False),
156-
restrict=kwargs.pop("restrict", False),
157-
)
158-
problem._constant_jacobian = kwargs.pop("constant_jacobian", False)
159-
solver_constructor = NonlinearVariationalSolver
160-
161-
self.problem = problem
162-
self.solver = solver_constructor(
136+
self.problem = self._backend.create_variational_problem(
137+
Fbig, stages, bcs=bigBCs, Jp=Jpbig,
138+
form_compiler_parameters=kwargs.pop("form_compiler_parameters", None),
139+
is_linear=kwargs.pop("is_linear", False),
140+
restrict=kwargs.pop("restrict", False),
141+
constant_jacobian=kwargs.pop("constant_jacobian", False),
142+
)
143+
self.solver = self._backend.create_variational_solver(
163144
self.problem, appctx=self.appctx,
164145
nullspace=nullspace,
165146
transpose_nullspace=transpose_nullspace,
@@ -208,30 +189,30 @@ def get_stage_bounds(self, bounds=None):
208189
Vbig = self.stages.function_space()
209190
bounds_type, lower, upper = bounds
210191
if lower is None:
211-
slb = Function(Vbig).assign(PETSc.NINFINITY)
192+
slb = self._backend.Function(Vbig).assign(PETSc.NINFINITY)
212193
if upper is None:
213-
sub = Function(Vbig).assign(PETSc.INFINITY)
194+
sub = self._backend.Function(Vbig).assign(PETSc.INFINITY)
214195

215196
if bounds_type == "stage":
216197
if lower is not None:
217198
dats = [lower.dat] * (self.num_stages)
218-
slb = Function(Vbig, val=flatten_dats(dats))
199+
slb = self._backend.Function(Vbig, val=flatten_dats(dats))
219200
if upper is not None:
220201
dats = [upper.dat] * (self.num_stages)
221-
sub = Function(Vbig, val=flatten_dats(dats))
202+
sub = self._backend.Function(Vbig, val=flatten_dats(dats))
222203

223204
elif bounds_type == "last_stage":
224205
V = self.u0.function_space()
225206
if lower is not None:
226-
ninfty = Function(V).assign(PETSc.NINFINITY)
207+
ninfty = self._backend.Function(V).assign(PETSc.NINFINITY)
227208
dats = [ninfty.dat] * (self.num_stages-1)
228209
dats.append(lower.dat)
229-
slb = Function(Vbig, val=flatten_dats(dats))
210+
slb = self._backend.Function(Vbig, val=flatten_dats(dats))
230211
if upper is not None:
231-
infty = Function(V).assign(PETSc.INFINITY)
212+
infty = self._backend.Function(V).assign(PETSc.INFINITY)
232213
dats = [infty.dat] * (self.num_stages-1)
233214
dats.append(upper.dat)
234-
sub = Function(Vbig, val=flatten_dats(dats))
215+
sub = self._backend.Function(Vbig, val=flatten_dats(dats))
235216

236217
else:
237218
raise ValueError("Unknown bounds type")
@@ -253,7 +234,7 @@ def build_poly(self):
253234
pts = numpy.reshape(self.sample_points, (-1, 1))
254235
vander = self.tabulate_poly(pts)
255236

256-
self.u_old = Function(self.u0)
237+
self.u_old = self._backend.Function(self.u0)
257238
ks = [self.u_old]
258239
ks.extend(split_stages(self.u0.function_space(), self.stages))
259240
num_samples = vander.shape[1]
@@ -264,4 +245,4 @@ def invalidate_jacobian(self):
264245
"""
265246
Forces the matrix to be reassembled next time it is required.
266247
"""
267-
LinearVariationalSolver.invalidate_jacobian(self.solver)
248+
self._backend.invalidate_jacobian(self.solver)

irksome/dirk_stepper.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import numpy
2-
from firedrake import (derivative, Function,
3-
LinearVariationalSolver,
4-
NonlinearVariationalProblem,
5-
NonlinearVariationalSolver)
6-
from ufl.constantvalue import as_ufl
2+
from ufl import as_ufl, lhs
73

8-
from .ufl.deriv import TimeDerivative, expand_time_derivatives
94
from .constant import vecconst
10-
from .tools import replace
11-
from .constant import MeshConstant
5+
from .backend import get_backend
126
from .bcs import bc2space
7+
from .ufl.deriv import TimeDerivative, expand_time_derivatives
8+
from .tools import replace
139

1410

15-
def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None, kgac=None):
11+
def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None, kgac=None, backend="firedrake"):
12+
backend_cls = get_backend(backend)
1613
if bcs is None:
1714
bcs = []
1815

@@ -33,10 +30,10 @@ def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None, kgac=None):
3330
# variational form and BC's, and we update it for each stage in
3431
# the loop over stages in the advance method. The Constant a is
3532
# used similarly in the variational form
36-
MC = MeshConstant(V.mesh())
33+
MC = backend_cls.MeshConstant(V.mesh())
3734
if kgac is None:
38-
k = Function(V)
39-
g = Function(V)
35+
k = backend_cls.Function(V)
36+
g = backend_cls.Function(V)
4037
a = MC.Constant(1.0)
4138
c = MC.Constant(1.0)
4239
else:
@@ -79,7 +76,9 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None,
7976
solver_parameters=None,
8077
appctx=None, nullspace=None,
8178
transpose_nullspace=None, near_nullspace=None,
79+
backend="firedrake",
8280
**kwargs):
81+
self._backend = backend_cls = get_backend(backend)
8382
assert butcher_tableau.is_diagonally_implicit
8483

8584
self.num_steps = 0
@@ -117,7 +116,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None,
117116
self.dt = dt
118117
self.orig_bcs = bcs
119118
self.num_fields = len(u0.function_space())
120-
self.ks = [Function(V) for _ in range(num_stages)]
119+
self.ks = [backend_cls.Function(V) for _ in range(num_stages)]
121120

122121
# "k" is a generic function for which we will solve the
123122
# NVLP for the next stage value
@@ -136,7 +135,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None,
136135
Fp_linear = len(Fp.arguments()) == 2
137136
ks_Fp = Fp.arguments()[1] if Fp_linear else self.ks
138137
stage_Fp, *_ = self.get_form_and_bcs(ks_Fp, F=Fp, bcs=())
139-
stage_Jp = stage_Fp if Fp_linear else derivative(stage_Fp, k)
138+
stage_Jp = lhs(stage_Fp) if Fp_linear else backend_cls.derivative(stage_Fp, k)
140139

141140
appctx_irksome = {"stepper": self}
142141
if appctx is None:
@@ -145,14 +144,14 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None,
145144
appctx = {**appctx, **appctx_irksome}
146145
self.appctx = appctx
147146

148-
self.problem = NonlinearVariationalProblem(
147+
self.problem = backend_cls.create_variational_problem(
149148
stage_F, k, bcs=bcnew, Jp=stage_Jp,
150149
form_compiler_parameters=kwargs.pop("form_compiler_parameters", None),
151150
is_linear=kwargs.pop("is_linear", False),
152151
restrict=kwargs.pop("restrict", False),
152+
constant_jacobian=kwargs.pop("constant_jacobian", False),
153153
)
154-
self.problem._constant_jacobian = kwargs.pop("constant_jacobian", False)
155-
self.solver = NonlinearVariationalSolver(
154+
self.solver = backend_cls.create_variational_solver(
156155
self.problem, appctx=appctx,
157156
nullspace=nullspace,
158157
transpose_nullspace=transpose_nullspace,
@@ -222,4 +221,4 @@ def invalidate_jacobian(self):
222221
"""
223222
Forces the matrix to be reassembled next time it is required.
224223
"""
225-
LinearVariationalSolver.invalidate_jacobian(self.solver)
224+
self._backend.invalidate_jacobian(self.solver)

irksome/imex.py

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import FIAT
22
import numpy as np
33
from firedrake import Function, TestFunction
4-
from firedrake import LinearVariationalSolver as LVS
5-
from firedrake import LinearVariationalProblem as LVP
6-
from firedrake import NonlinearVariationalProblem as NLVP
7-
from firedrake import NonlinearVariationalSolver as NLVS
84
from ufl import Form, as_ufl, dx, inner, lhs, rhs
95

10-
from .tableaux.ButcherTableaux import RadauIIA
11-
from .ufl.deriv import TimeDerivative, expand_time_derivatives
12-
from .stage_value import getFormStage
13-
from .tools import AI, IA, reshape, replace, getNullspace, get_stage_space
6+
from .backend import get_backend
147
from .bcs import bc2space
158
from .constant import MeshConstant, vecconst
9+
from .stage_value import getFormStage
10+
from .tools import AI, IA, reshape, replace, getNullspace, get_stage_space
11+
from .tableaux.ButcherTableaux import RadauIIA
12+
from .ufl.deriv import TimeDerivative, expand_time_derivatives
1613

1714

1815
def riia_explicit_coeffs(k):
@@ -172,8 +169,10 @@ def __init__(self, F, Fexp, butcher_tableau,
172169
nullspace=None,
173170
num_its_initial=0,
174171
num_its_per_step=0,
172+
backend="firedrake",
175173
**kwargs):
176174
assert isinstance(butcher_tableau, RadauIIA)
175+
self._backend = backend_cls = get_backend(backend)
177176

178177
self.u0 = u0
179178
self.t = t
@@ -197,7 +196,7 @@ def __init__(self, F, Fexp, butcher_tableau,
197196
# the update information on the floor.
198197
V = u0.function_space()
199198
Vbig = get_stage_space(V, self.num_stages)
200-
UU = Function(Vbig)
199+
UU = backend_cls.Function(Vbig)
201200

202201
restrict = kwargs.pop("restrict", False)
203202
is_linear = kwargs.pop("is_linear", False)
@@ -219,28 +218,16 @@ def __init__(self, F, Fexp, butcher_tableau,
219218
Fit, Fprop = getFormExplicit(
220219
Fexp, butcher_tableau, u0, UU_old, t, dt, splitting)
221220

222-
F_linear = len(Fbig.arguments()) == 2
223-
if F_linear:
224-
F1 = Fbig + Fit
225-
F2 = Fbig + Fprop
226-
itprob = LVP(lhs(F1), rhs(F1), UU, bcs=bigBCs,
227-
constant_jacobian=constant_jacobian,
228-
restrict=restrict)
229-
propprob = LVP(lhs(F2), rhs(F2), UU, bcs=bigBCs,
230-
constant_jacobian=constant_jacobian,
231-
restrict=restrict)
232-
create_solver = LVS
233-
else:
234-
itprob = NLVP(Fbig + Fit, UU, bcs=bigBCs,
235-
is_linear=is_linear, restrict=restrict)
236-
propprob = NLVP(Fbig + Fprop, UU, bcs=bigBCs,
237-
is_linear=is_linear, restrict=restrict)
238-
itprob._constant_jacobian = constant_jacobian
239-
propprob._constant_jacobian = constant_jacobian
240-
create_solver = NLVS
241-
242-
self.itprob = itprob
243-
self.propprob = propprob
221+
self.itprob = backend_cls.create_variational_problem(
222+
Fbig + Fit, UU, bcs=bigBCs,
223+
is_linear=is_linear, restrict=restrict,
224+
constant_jacobian=constant_jacobian,
225+
)
226+
self.propprob = backend_cls.create_variational_problem(
227+
Fbig + Fprop, UU, bcs=bigBCs,
228+
is_linear=is_linear, restrict=restrict,
229+
constant_jacobian=constant_jacobian,
230+
)
244231

245232
self.F = F
246233
self.orig_bcs = bcs
@@ -253,11 +240,11 @@ def __init__(self, F, Fexp, butcher_tableau,
253240
else:
254241
appctx = {**appctx, **appctx_irksome}
255242

256-
self.it_solver = create_solver(
243+
self.it_solver = backend_cls.create_variational_solver(
257244
self.itprob, appctx=appctx,
258245
solver_parameters=it_solver_parameters,
259246
nullspace=nsp, **kwargs)
260-
self.prop_solver = create_solver(
247+
self.prop_solver = backend_cls.create_variational_solver(
261248
self.propprob, appctx=appctx,
262249
solver_parameters=prop_solver_parameters,
263250
nullspace=nsp, **kwargs)
@@ -464,31 +451,24 @@ def __init__(self, F, F_explicit, butcher_tableau, t, dt, u0, bcs=None,
464451
else:
465452
appctx = {**appctx, **appctx_irksome}
466453

467-
F_linear = len(stage_F.arguments()) == 2
468-
if F_linear:
469-
problem = LVP(lhs(stage_F), rhs(stage_F), k, bcnew,
470-
constant_jacobian=constant_jacobian,
471-
restrict=restrict)
472-
mass_problem = LVP(lhs(Fhat), rhs(Fhat), khat,
473-
constant_jacobian=constant_jacobian)
474-
create_solver = LVS
475-
else:
476-
problem = NLVP(stage_F, k, bcnew,
477-
is_linear=is_linear,
478-
restrict=restrict)
479-
problem._constant_jacobian = constant_jacobian
480-
mass_problem = NLVP(Fhat, khat, is_linear=True)
481-
mass_problem._constant_jacobian = True
482-
create_solver = NLVS
483-
484-
self.problem = problem
485-
self.solver = create_solver(problem, appctx=appctx,
486-
solver_parameters=solver_parameters,
487-
nullspace=nullspace, **kwargs)
488-
489-
self.mass_problem = mass_problem
490-
self.mass_solver = create_solver(mass_problem,
491-
solver_parameters=mass_parameters)
454+
self.problem = backend_cls.create_variational_problem(
455+
stage_F, k, bcnew,
456+
is_linear=is_linear,
457+
restrict=restrict
458+
)
459+
self.solver = backend_cls.create_variational_solver(
460+
self.problem, appctx=appctx,
461+
solver_parameters=solver_parameters,
462+
nullspace=nullspace, **kwargs,
463+
)
464+
self.mass_problem = backend_cls.create_variational_problem(
465+
Fhat, khat, is_linear=True,
466+
constant_jacobian=constant_jacobian,
467+
)
468+
self.mass_solver = backend_cls.create_variational_solver(
469+
self.mass_problem,
470+
solver_parameters=mass_parameters,
471+
)
492472

493473
self.kgac = k, g, a, c
494474
self.kgchat = khat, ghat, chat

0 commit comments

Comments
 (0)