-
Notifications
You must be signed in to change notification settings - Fork 189
Expand file tree
/
Copy pathcommon.py
More file actions
589 lines (492 loc) · 22.5 KB
/
common.py
File metadata and controls
589 lines (492 loc) · 22.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
import collections
import operator
import string
from functools import cached_property, reduce
from itertools import chain, product
import copy
from ufl.utils.sequences import max_degree
from ufl.domain import extract_unique_domain
import gem
import gem.impero_utils as impero_utils
import petsctools
import numpy
from FIAT.reference_element import TensorProductCell
from finat.cell_tools import max_complex
from finat.quadrature import AbstractQuadratureRule
from gem.node import traversal
from gem.optimise import constant_fold_zero
from gem.optimise import remove_componenttensors as prune
from numpy import asarray
from tsfc import fem
from finat.element_factory import as_fiat_cell, create_element
from finat.ufl import MixedElement
from tsfc.kernel_interface import KernelInterface
from tsfc.logging import logger
class KernelBuilderBase(KernelInterface):
"""Helper class for building local assembly kernels."""
def __init__(self, scalar_type):
"""Initialise a kernel builder."""
self.scalar_type = scalar_type
self.prepare = []
self.finalise = []
# Coordinates
self.domain_coordinate = {}
# Coefficients
self.coefficient_map = collections.OrderedDict()
# Constants
self.constant_map = collections.OrderedDict()
@cached_property
def unsummed_coefficient_indices(self):
return frozenset()
def coordinate(self, domain):
return self.domain_coordinate[domain]
def coefficient(self, ufl_coefficient, restriction):
"""A function that maps :class:`ufl.Coefficient`s to GEM
expressions."""
kernel_arg = self.coefficient_map[ufl_coefficient]
domain = extract_unique_domain(ufl_coefficient)
assert self._domain_integral_type_map[domain] is not None
if ufl_coefficient.ufl_element().family() == 'Real':
return kernel_arg
elif not self._domain_integral_type_map[domain].startswith("interior_facet"):
return kernel_arg
else:
return kernel_arg[{'+': 0, '-': 1}[restriction]]
def constant(self, const):
return self.constant_map[const]
def cell_orientation(self, domain, restriction):
"""Cell orientation as a GEM expression."""
if not hasattr(self, "_cell_orientations"):
raise RuntimeError("Haven't called set_cell_orientations")
f = {None: 0, '+': 0, '-': 1}[restriction]
co_int = self._cell_orientations[domain][f]
return gem.Conditional(gem.Comparison("==", co_int, gem.Literal(1)),
gem.Literal(-1),
gem.Conditional(gem.Comparison("==", co_int, gem.Zero()),
gem.Literal(1),
gem.Literal(numpy.nan)))
def cell_size(self, domain, restriction):
if not hasattr(self, "_cell_sizes"):
raise RuntimeError("Haven't called set_cell_sizes")
if self._domain_integral_type_map[domain].startswith("interior_facet"):
return self._cell_sizes[domain][{'+': 0, '-': 1}[restriction]]
else:
return self._cell_sizes[domain]
def entity_ids(self, domain):
"""Target indices of entity_number."""
if not hasattr(self, "_entity_ids"):
raise RuntimeError("Haven't called set_entity_numbers")
return self._entity_ids[domain]
def entity_number(self, domain, restriction):
"""Facet or vertex number as a GEM index."""
if not hasattr(self, "_entity_numbers"):
raise RuntimeError("Haven't called set_entity_numbers")
return self._entity_numbers[domain][restriction]
def entity_orientation(self, domain, restriction):
"""Facet orientation as a GEM index."""
if not hasattr(self, "_entity_orientations"):
raise RuntimeError("Haven't called set_entity_orientations")
return self._entity_orientations[domain][restriction]
def apply_glue(self, prepare=None, finalise=None):
"""Append glue code for operations that are not handled in the
GEM abstraction.
Current uses: mixed interior facet mess
:arg prepare: code snippets to be prepended to the kernel
:arg finalise: code snippets to be appended to the kernel
"""
if prepare is not None:
self.prepare.extend(prepare)
if finalise is not None:
self.finalise.extend(finalise)
def register_requirements(self, ir):
"""Inspect what is referenced by the IR that needs to be
provided by the kernel interface.
:arg ir: multi-root GEM expression DAG
"""
# Nothing is required by default
pass
@property
def domain_integral_type_map(self):
"""domain integral_type map."""
return self._domain_integral_type_map
class KernelBuilderMixin(object):
"""Mixin for KernelBuilder classes."""
def compile_integrand(self, integrand, params, ctx):
"""Compile UFL integrand.
:arg integrand: UFL integrand.
:arg params: a dict containing "quadrature_rule".
:arg ctx: context created with :meth:`create_context` method.
See :meth:`create_context` for typical calling sequence.
"""
# Compile: ufl -> gem
info = self.integral_data_info
functions = [*info.arguments, self.coordinate(info.domain), *info.coefficients]
set_quad_rule(params, info.domain.ufl_cell(), info.integral_type, functions)
quad_rule = params["quadrature_rule"]
config = self.fem_config()
config['argument_multiindices'] = self.argument_multiindices
config['quadrature_rule'] = quad_rule
config['index_cache'] = ctx['index_cache']
expressions = fem.compile_ufl(integrand,
fem.PointSetContext(**config))
ctx['quadrature_indices'].extend(quad_rule.point_set.indices)
return expressions
def construct_integrals(self, integrand_expressions, params):
"""Construct integrals from integrand expressions.
:arg integrand_expressions: gem expressions for integrands.
:arg params: a dict containing "mode" and "quadrature_rule".
integrand_expressions must be indexed with :attr:`argument_multiindices`;
these gem expressions are obtained by calling :meth:`compile_integrand`
method or by modifying the gem expressions returned by
:meth:`compile_integrand`.
See :meth:`create_context` for typical calling sequence.
"""
mode = pick_mode(params["mode"])
return mode.Integrals(integrand_expressions,
params["quadrature_rule"].point_set.indices,
self.argument_multiindices,
params)
def stash_integrals(self, reps, params, ctx):
"""Stash integral representations in ctx.
:arg reps: integral representations.
:arg params: a dict containing "mode".
:arg ctx: context in which reps are stored.
See :meth:`create_context` for typical calling sequence.
"""
mode = pick_mode(params["mode"])
mode_irs = ctx['mode_irs']
mode_irs.setdefault(mode, collections.OrderedDict())
for var, rep in zip(self.return_variables, reps):
mode_irs[mode].setdefault(var, []).append(rep)
def compile_gem(self, ctx):
"""Compile gem representation of integrals to impero_c.
:arg ctx: the context containing the gem representation of integrals.
:returns: a tuple of impero_c, oriented, needs_cell_sizes, tabulations
required to finally construct a kernel in :meth:`construct_kernel`.
See :meth:`create_context` for typical calling sequence.
"""
# Finalise mode representations into a set of assignments
mode_irs = ctx['mode_irs']
assignments = []
for mode, var_reps in mode_irs.items():
assignments.extend(mode.flatten(var_reps.items(), ctx['index_cache']))
if assignments:
return_variables, expressions = zip(*assignments)
else:
return_variables = []
expressions = []
expressions = constant_fold_zero(expressions)
# Need optimised roots
options = dict(reduce(operator.and_,
[mode.finalise_options.items()
for mode in mode_irs.keys()]))
expressions = impero_utils.preprocess_gem(expressions, **options)
# Let the kernel interface inspect the optimised IR to register
# what kind of external data is required (e.g., cell orientations,
# cell sizes, etc.).
oriented, needs_cell_sizes, tabulations = self.register_requirements(expressions)
# Extract Variables that are actually used
active_variables = gem.extract_type(expressions, gem.Variable)
# Construct ImperoC
assignments = list(zip(return_variables, expressions))
index_ordering = get_index_ordering(ctx['quadrature_indices'], return_variables)
try:
impero_c = impero_utils.compile_gem(assignments, index_ordering, remove_zeros=True)
except impero_utils.NoopError:
impero_c = None
return impero_c, oriented, needs_cell_sizes, tabulations, active_variables
def fem_config(self):
"""Return a dictionary used with fem.compile_ufl.
One needs to update this dictionary with "argument_multiindices",
"quadrature_rule", and "index_cache" before using this with
fem.compile_ufl.
"""
info = self.integral_data_info
integral_type = info.integral_type
cell = info.domain.ufl_cell()
fiat_cell = as_fiat_cell(cell)
integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
return dict(interface=self,
ufl_cell=cell,
integration_dim=integration_dim,
scalar_type=self.fem_scalar_type)
def create_context(self):
"""Create builder context.
*index_cache*
Map from UFL FiniteElement objects to multiindices.
This is so we reuse Index instances when evaluating the same
coefficient multiple times with the same table.
We also use the same dict for the unconcatenate index cache,
which maps index objects to tuples of multiindices. These two
caches shall never conflict as their keys have different types
(UFL finite elements vs. GEM index objects).
*quadrature_indices*
List of quadrature indices used.
*mode_irs*
Dict for mode representations.
For each set of integrals to make a kernel for (i,e.,
`integral_data.integrals`), one must first create a ctx object by
calling :meth:`create_context` method.
This ctx object collects objects associated with the integrals that
are eventually used to construct the kernel.
The following is a typical calling sequence:
.. code-block:: python3
builder = KernelBuilder(...)
params = {"mode": "spectral"}
ctx = builder.create_context()
for integral in integral_data.integrals:
integrand = integral.integrand()
integrand_exprs = builder.compile_integrand(integrand, params, ctx)
integral_exprs = builder.construct_integrals(integrand_exprs, params)
builder.stash_integrals(integral_exprs, params, ctx)
kernel = builder.construct_kernel(kernel_name, ctx)
"""
return {'index_cache': {},
'quadrature_indices': [],
'mode_irs': collections.OrderedDict()}
def set_quad_rule(params, cell, integral_type, functions):
# Check if the integral has a quad degree or quad element attached,
# otherwise use the estimated polynomial degree attached by compute_form_data
quad_rule = params.get("quadrature_rule", "default")
elements = []
for f in functions:
e = f.ufl_element()
if type(e) is MixedElement:
elements.extend(e.sub_elements)
else:
elements.append(e)
try:
quadrature_degree = params["quadrature_degree"]
except KeyError:
quad_data = set((e.degree(), e.quadrature_scheme() or "default") for e in elements
if e.family() in {"Quadrature", "Boundary Quadrature"})
if len(quad_data) == 0:
quadrature_degree = params["estimated_polynomial_degree"]
if all((asarray(quadrature_degree) > 10 * asarray(e.degree())).all() for e in elements):
logger.warning("Estimated quadrature degree %s more "
"than tenfold greater than any "
"argument/coefficient degree (max %s)",
quadrature_degree, max_degree([e.degree() for e in elements]))
else:
try:
(quadrature_degree, quad_rule), = quad_data
except ValueError:
raise ValueError("The quadrature rule cannot be inferred from multiple Quadrature elements")
if isinstance(quad_rule, str):
scheme = quad_rule
fiat_cell = as_fiat_cell(cell)
finat_elements = set(create_element(e) for e in elements if e.family() != "Real")
fiat_cells = [fiat_cell] + [finat_el.complex for finat_el in finat_elements]
if any(c.is_macrocell() for c in fiat_cells):
fiat_cell = max_complex(fiat_cells)
integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
quad_rule = fem.get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme)
params["quadrature_rule"] = quad_rule
if not isinstance(quad_rule, AbstractQuadratureRule):
raise ValueError("Expected to find a QuadratureRule object, not a %s" %
type(quad_rule))
def get_index_ordering(quadrature_indices, return_variables):
split_argument_indices = tuple(chain(*(var.index_ordering()
for var in return_variables)))
return tuple(quadrature_indices) + split_argument_indices
def get_index_names(quadrature_indices, argument_multiindices, index_cache):
index_names = []
def name_index(index, name):
index_names.append((index, name))
if index in index_cache:
for multiindex, suffix in zip(index_cache[index],
string.ascii_lowercase):
name_multiindex(multiindex, name + suffix)
def name_multiindex(multiindex, name):
if len(multiindex) == 1:
name_index(multiindex[0], name)
else:
for i, index in enumerate(multiindex):
name_index(index, name + str(i))
name_multiindex(quadrature_indices, 'ip')
for multiindex, name in zip(argument_multiindices, ['j', 'k']):
name_multiindex(multiindex, name)
return index_names
def lower_integral_type(fiat_cell, integral_type):
"""Lower integral type into the dimension of the integration
subentity and a list of entity numbers for that dimension.
:arg fiat_cell: FIAT reference cell
:arg integral_type: integral type (string)
"""
vert_facet_types = ['exterior_facet_vert', 'interior_facet_vert']
horiz_facet_types = ['exterior_facet_bottom', 'exterior_facet_top', 'interior_facet_horiz']
dim = fiat_cell.get_dimension()
if integral_type == 'cell':
integration_dim = dim
elif integral_type in ['exterior_facet', 'interior_facet']:
if isinstance(fiat_cell, TensorProductCell):
raise ValueError("{} integral cannot be used with a TensorProductCell; need to distinguish between vertical and horizontal contributions.".format(integral_type))
integration_dim = dim - 1
elif integral_type == 'vertex':
integration_dim = 0
elif integral_type in vert_facet_types + horiz_facet_types:
# Extrusion case
if not isinstance(fiat_cell, TensorProductCell):
raise ValueError("{} integral requires a TensorProductCell.".format(integral_type))
basedim, extrdim = dim
assert extrdim == 1
if integral_type in vert_facet_types:
integration_dim = (basedim - 1, 1)
elif integral_type in horiz_facet_types:
integration_dim = (basedim, 0)
else:
raise NotImplementedError("integral type %s not supported" % integral_type)
if integral_type == 'exterior_facet_bottom':
entity_ids = [0]
elif integral_type == 'exterior_facet_top':
entity_ids = [1]
else:
entity_ids = list(fiat_cell.get_topology()[integration_dim])
return integration_dim, entity_ids
def pick_mode(mode):
"Return one of the specialized optimisation modules from a mode string."
try:
cites = {"vanilla": ("Homolya2017", ),
"coffee": ("Luporini2016", "Homolya2017", ),
"spectral": ("Luporini2016", "Homolya2017", "Homolya2017a"),
"tensor": ("Kirby2006", "Homolya2017", )}
for c in cites[mode]:
petsctools.cite(c)
except KeyError:
pass
if mode == "vanilla":
import tsfc.vanilla as m
elif mode == "coffee":
import tsfc.coffee_mode as m
elif mode == "spectral":
import tsfc.spectral as m
elif mode == "tensor":
import tsfc.tensor as m
else:
raise ValueError("Unknown mode: {}".format(mode))
return m
def check_requirements(ir):
"""Look for cell orientations, cell sizes, and collect tabulations
in one pass."""
cell_orientations = False
cell_sizes = False
rt_tabs = {}
for node in traversal(ir):
if isinstance(node, gem.Variable):
if node.name == "cell_orientations_0":
cell_orientations = True
elif node.name == "cell_sizes_0":
cell_sizes = True
elif node.name.startswith("rt_"):
rt_tabs[node.name] = node.shape
return cell_orientations, cell_sizes, tuple(sorted(rt_tabs.items()))
def prepare_constant(constant, number):
"""Bridges the kernel interface and the GEM abstraction for
Constants.
:arg constant: Firedrake Constant
:arg number: Value to uniquely identify the constant
:returns: (funarg, expression)
expression - GEM expression referring to the Constant value(s)
"""
value_size = numpy.prod(constant.ufl_shape, dtype=int)
return gem.reshape(gem.Variable(f"c_{number}", (value_size,)),
constant.ufl_shape)
def prepare_coefficient(coefficient, name, domain_integral_type_map):
"""Bridges the kernel interface and the GEM abstraction for
Coefficients.
Parameters
----------
coefficient : ufl.Coefficient
UFL Coefficient.
name : str
Unique name to refer to the Coefficient in the kernel.
domain_integral_type_map : dict
Map from domain to integral_type.
Returns
-------
gem.Node
GEM expression referring to the Coefficient values.
"""
if coefficient.ufl_element().family() == 'Real':
# Constant
value_size = coefficient.ufl_function_space().value_size
expression = gem.reshape(gem.Variable(name, (value_size,)),
coefficient.ufl_shape)
return expression
finat_element = create_element(coefficient.ufl_element())
shape = finat_element.index_shape
size = numpy.prod(shape, dtype=int)
domain = extract_unique_domain(coefficient)
try:
integral_type = domain_integral_type_map[domain]
except KeyError:
# This means that this coefficient does not exist in the DAG,
# so corresponding gem expression will never be needed.
return None
if integral_type.startswith("interior_facet"):
varexp = gem.Variable(name, (2 * size,))
plus = gem.view(varexp, slice(size))
minus = gem.view(varexp, slice(size, 2 * size))
expression = (gem.reshape(plus, shape), gem.reshape(minus, shape))
else:
expression = gem.reshape(gem.Variable(name, (size,)), shape)
return expression
def prepare_arguments(arguments, multiindices, domain_integral_type_map, diagonal=False):
"""Bridges the kernel interface and the GEM abstraction for
Arguments. Vector Arguments are rearranged here for interior
facet integrals.
Parameters
----------
arguments : tuple
UFL Arguments.
multiindices : tuple
Argument multiindices.
domain_integral_type_map : dict
Map from domain to integral_type.
diagonal : bool
Are we assembling the diagonal of a rank-2 element tensor?
Returns
-------
tuple
Tuple of function arg and GEM expressions referring to the argument tensor.
"""
if len(multiindices) != len(arguments):
raise ValueError(f"Got inconsistent lengths of arguments ({len(arguments)}) and multiindices ({len(multiindices)})")
if len(arguments) == 0:
# No arguments
expression = gem.Indexed(gem.Variable("A", (1,)), (0,))
return (expression, )
elements = tuple(create_element(arg.ufl_element()) for arg in arguments)
shapes = tuple(element.index_shape for element in elements)
if diagonal:
if len(arguments) != 2:
raise ValueError("Diagonal only for 2-forms")
try:
element, = set(elements)
except ValueError:
raise ValueError("Diagonal only for diagonal blocks (test and trial spaces the same)")
elements = (element, )
shapes = tuple(element.index_shape for element in elements)
multiindices = multiindices[:1]
arguments = arguments[:1]
def expression(restricted):
return gem.Indexed(gem.reshape(restricted, *shapes),
tuple(chain(*multiindices)))
u_shape = numpy.array([numpy.prod(shape, dtype=int) for shape in shapes])
c_shape = copy.deepcopy(u_shape)
rs_tuples = []
for arg_num, arg in enumerate(arguments):
integral_type = domain_integral_type_map[extract_unique_domain(arg)]
if integral_type is None:
raise RuntimeError(f"Can not determine integral_type on {arg}")
if integral_type.startswith("interior_facet"):
rs_tuples.append((0, 1))
c_shape[arg_num] *= 2
else:
rs_tuples.append((0, ))
slicez = [[slice(r * s, (r + 1) * s)
for r, s in zip(restrictions, u_shape)]
for restrictions in product(*rs_tuples)]
varexp = gem.Variable("A", tuple(c_shape))
expressions = [expression(gem.view(varexp, *slices)) for slices in slicez]
return tuple(prune(expressions))