Skip to content

Commit c059759

Browse files
committed
finat.ufl.HDivTraceElement, full support for tensor product elements
1 parent 2d802c9 commit c059759

11 files changed

Lines changed: 151 additions & 114 deletions

FIAT/barycentric_interpolation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, ref_el, pts):
7575
x = pts[ibfs]
7676
if ref_el.is_trace():
7777
verts = ref_el.get_vertices_of_subcomplex(ref_el.topology[1][cell])
78-
A = numpy.diff(verts, axis=0)[0]/2
78+
A, = numpy.diff(verts, axis=0)
7979
A /= numpy.linalg.norm(A)
8080
b = -numpy.dot(numpy.sum(verts, axis=0)/2, A.T)
8181
self.affine_mappings[cell] = (A, b)

FIAT/expansions.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,6 @@ def distance(alpha, beta):
345345

346346
def _tabulate(self, n, pts, order=0):
347347
"""A version of tabulate() that also works for a single point."""
348-
from FIAT.hdiv_trace import TraceError
349-
from FIAT.polynomial_set import mis
350348
pts = numpy.asarray(pts)
351349
unique = self.continuity is not None and order == 0
352350
cell_point_map = compute_cell_point_map(self.ref_el, pts, unique=unique)
@@ -357,26 +355,7 @@ def _tabulate(self, n, pts, order=0):
357355
return phis[0]
358356

359357
if self.ref_el.is_trace():
360-
parent = self.ref_el.get_parent()
361-
tdim = self.ref_el.get_spatial_dimension()
362-
gdim = parent.get_spatial_dimension()
363-
for cell in phis:
364-
# Promote facet keys to cell keys
365-
phi = phis[cell][(0,) * tdim]
366-
# Raise TraceError on gradient tabulations
367-
msg = "Gradients on trace elements are not well-defined."
368-
phis[cell] = {alpha: phi if sum(alpha) == 0 else TraceError(msg)
369-
for i in range(order+1)
370-
for alpha in mis(gdim, i)}
371-
if tdim == 0 and len(phis) == 0:
372-
# Hack for TensorProduct HDivTrace: do not raise TraceError on the interval
373-
for cell in parent.topology[gdim]:
374-
phis[cell] = {(0,)*gdim: numpy.zeros(())}
375-
elif sum(len(cell_point_map[cell]) for cell in cell_point_map) < len(pts):
376-
# Raise TraceError when interior points fail to be binned on facets
377-
for cell in parent.topology[gdim]:
378-
msg = "The HDivTrace element can only be tabulated on facets."
379-
phis[cell] = {(0,)*gdim: TraceError(msg)}
358+
phis = trace_tabulation(self.ref_el, cell_point_map, order, pts, phis)
380359

381360
if pts.dtype == object:
382361
# If binning is undefined, scale by the characteristic function of each subcell
@@ -402,7 +381,7 @@ def _tabulate(self, n, pts, order=0):
402381
result = {}
403382
base_phi = tuple(phis.values())[0]
404383
for alpha in base_phi:
405-
if isinstance(base_phi[alpha], TraceError):
384+
if isinstance(base_phi[alpha], Exception):
406385
result[alpha] = base_phi[alpha]
407386
continue
408387
dtype = base_phi[alpha].dtype
@@ -723,7 +702,10 @@ def compute_cell_point_map(ref_el, pts, unique=True, tol=1E-12):
723702
return {cell: Ellipsis for cell in sorted(top[sd])}
724703

725704
# The distance to the nearest cell is equal to the distance to the parent cell
726-
best = ref_el.get_parent().distance_to_point_l1(pts, rescale=True)
705+
parent = ref_el
706+
while parent.get_parent() is not None:
707+
parent = parent.get_parent()
708+
best = parent.distance_to_point_l1(pts, rescale=True)
727709
tol = best + tol
728710

729711
cell_point_map = {}
@@ -779,3 +761,30 @@ def compute_partition_of_unity(ref_el, pt, unique=True, tol=1E-12):
779761
mult = sum(masks)
780762
masks = [m / mult for m in masks]
781763
return masks
764+
765+
766+
def trace_tabulation(ref_el, cell_point_map, order, pts, phis):
767+
"""Lift trace tabulations into the cells and raise TraceError on invalid tabulations."""
768+
from FIAT.polynomial_set import mis
769+
from FIAT.hdiv_trace import TraceError
770+
parent = ref_el.get_parent()
771+
tdim = ref_el.get_spatial_dimension()
772+
gdim = parent.get_spatial_dimension()
773+
facet_key = (0, ) * tdim
774+
cell_key = (0, ) * gdim
775+
776+
for cell in phis:
777+
# Lift facet keys to cell keys
778+
phi = phis[cell][facet_key]
779+
# Raise TraceError on gradient tabulations
780+
msg = "Gradients on trace elements are not well-defined."
781+
phis[cell] = {alpha: phi if sum(alpha) == 0 else TraceError(msg)
782+
for i in range(order+1)
783+
for alpha in mis(gdim, i)}
784+
785+
if sum(len(cell_point_map[cell]) for cell in cell_point_map) < len(pts):
786+
# Raise TraceError when interior points fail to be binned on facets
787+
for cell in parent.topology[gdim]:
788+
msg = "The HDivTrace element can only be tabulated on facets."
789+
phis[cell] = {cell_key: TraceError(msg)}
790+
return phis

FIAT/hdiv_trace.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,8 @@ class HDivTrace(DiscontinuousLagrange):
100100
def __new__(cls, ref_el, degree, variant="equispaced_interior"):
101101
"""Constructor for the HDivTrace element.
102102
103-
:arg ref_el: A reference element, which may be a tensor product
104-
cell.
105-
:arg degree: The degree of approximation. If on a tensor product
106-
cell, then provide a tuple of degrees if you want
107-
varying degrees.
103+
:arg ref_el: A reference element.
104+
:arg degree: The degree of approximation.
108105
:arg variant: The point distribution variant passed on to recursivenodes.
109106
"""
110107
facets = TraceSimplicialComplex(ref_el)

FIAT/reference_element.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def compute_barycentric_coordinates(self, points, entity=None, rescale=False):
618618
if len(points) == 0:
619619
return points
620620
if entity is None:
621-
entity = (self.get_spatial_dimension(), 0)
621+
entity = (self.get_dimension(), 0)
622622
entity_dim, entity_id = entity
623623
sd = len(self.vertices[0])
624624

@@ -628,10 +628,9 @@ def compute_barycentric_coordinates(self, points, entity=None, rescale=False):
628628
top = self.get_topology()
629629
subcomplex = top[entity_dim][entity_id]
630630
if entity_dim != sd:
631-
parent = self.get_parent() if self.is_trace() else self
632-
cell_id = parent.connectivity[(entity_dim, sd)][0][0]
633-
top = parent.get_topology()
634-
subcell = top[sd][cell_id]
631+
parent = self.get_parent_complex() if self.is_trace() else self
632+
cell_id = min(parent.connectivity[(entity_dim, sd)][entity_id])
633+
subcell = parent.topology[sd][cell_id]
635634
while len(subcell) > sd + 1:
636635
# construct a simplex if we have a hypercube
637636
k = max(set(subcell) - set(subcomplex))
@@ -648,7 +647,7 @@ def compute_barycentric_coordinates(self, points, entity=None, rescale=False):
648647
A, b = make_affine_mapping(cell_verts, ref_verts)
649648
A, b = A[indices], b[indices]
650649
if rescale:
651-
# rescale barycentric coordinates by the height wrt. to the facet
650+
# rescale barycentric coordinates by the height w.r.t. the facet
652651
h = 1 / numpy.linalg.norm(A, axis=1)
653652
b *= h
654653
A *= h[:, None]

finat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from .cube import FlattenedDimensions # noqa: F401
3333
from .discontinuous import DiscontinuousElement # noqa: F401
3434
from .enriched import EnrichedElement # noqa: F401
35-
from .hdivcurl import HCurlElement, HDivElement # noqa: F401
35+
from .hdivcurl import HCurlElement, HDivElement, HDivTraceElement # noqa: F401
3636
from .mixed import MixedElement # noqa: F401
3737
from .nodal_enriched import NodalEnrichedElement # noqa: F401
3838
from .quadrature_element import QuadratureElement, make_quadrature_element # noqa: F401

finat/element_factory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,12 @@ def convert_tensorproductelement(element, **kwargs):
287287
return finat.TensorProductElement(elements), deps
288288

289289

290+
@convert.register(finat.ufl.HDivTraceElement)
291+
def convert_hdivtraceelement(element, **kwargs):
292+
finat_elem, deps = _create_element(element._element, **kwargs)
293+
return finat.HDivTraceElement(finat_elem), deps
294+
295+
290296
@convert.register(finat.ufl.HDivElement)
291297
def convert_hdivelement(element, **kwargs):
292298
finat_elem, deps = _create_element(element._element, **kwargs)

finat/fiat_elements.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
112112
index_shape = (self._element.space_dimension(),)
113113
for alpha, fiat_table in fiat_result.items():
114114
if isinstance(fiat_table, Exception):
115-
result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table)
115+
shape = ps.points.shape[:-1] + self.index_shape + self.value_shape
116+
result[alpha] = gem.partial_indexed(gem.Failure(shape, fiat_table), ps.indices)
116117
continue
117118

118119
derivative = sum(alpha)

finat/hdivcurl.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from FIAT.reference_element import LINE
33

44
import gem
5+
import numpy
56
from gem.utils import cached_property
67
from finat.finiteelementbase import FiniteElementBase
78
from finat.tensor_product import TensorProductElement
@@ -58,7 +59,7 @@ def index_shape(self):
5859
def value_shape(self):
5960
return (self.cell.get_spatial_dimension(),)
6061

61-
def _transform_evaluation(self, core_eval):
62+
def _transform_evaluation(self, core_eval, entity=None):
6263
beta = self.get_indices()
6364
zeta = self.get_value_indices()
6465

@@ -72,11 +73,11 @@ def promote(table):
7273

7374
def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
7475
core_eval = self.wrappee.basis_evaluation(order, ps, entity)
75-
return self._transform_evaluation(core_eval)
76+
return self._transform_evaluation(core_eval, entity)
7677

7778
def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None):
7879
core_eval = self.wrappee.point_evaluation(order, refcoords, entity)
79-
return self._transform_evaluation(core_eval)
80+
return self._transform_evaluation(core_eval, entity)
8081

8182
@property
8283
def dual_basis(self):
@@ -92,6 +93,59 @@ def dual_basis(self):
9293
return gem.ComponentTensor(Q[zeta], beta + zeta), x
9394

9495

96+
class HDivTraceElement(WrapperElementBase):
97+
"""HDiv Trace wrapper element for tensor product elements."""
98+
99+
def __init__(self, wrappee):
100+
assert isinstance(wrappee, TensorProductElement)
101+
if any(fe.formdegree is None for fe in wrappee.factors):
102+
raise ValueError("Form degree of subelement is None, cannot HDiv Trace!")
103+
104+
formdegree = sum(fe.formdegree for fe in wrappee.factors)
105+
if formdegree != wrappee.cell.get_spatial_dimension() - 1:
106+
raise ValueError("HDiv Trace requires (n-1)-form element!")
107+
108+
self.support_dim = tuple(fe.formdegree for fe in wrappee.factors)
109+
110+
super().__init__(wrappee, None)
111+
112+
def _transform_evaluation(self, core_eval, entity):
113+
if entity is None:
114+
entity = (self.cell.get_dimension(), 0)
115+
entity_dim, entity_id = entity
116+
117+
if entity_dim == self.support_dim or sum(entity_dim) != sum(self.support_dim):
118+
return core_eval
119+
120+
def zero_failure(expr):
121+
if isinstance(expr, gem.Failure):
122+
return gem.Literal(numpy.zeros(expr.shape, expr.dtype))
123+
return expr.reconstruct(*map(zero_failure, expr.children))
124+
125+
# Create a zero tabulation with the same tensor-product structure
126+
return {alpha: zero_failure(table) for alpha, table in core_eval.items()}
127+
128+
@property
129+
def formdegree(self):
130+
return self.wrappee.formdegree
131+
132+
@property
133+
def value_shape(self):
134+
return self.wrappee.value_shape
135+
136+
@cached_property
137+
def fiat_equivalent(self):
138+
return self.wrappee.fiat_equivalent
139+
140+
@property
141+
def mapping(self):
142+
return self.wrappee.mapping
143+
144+
@property
145+
def dual_basis(self):
146+
return self.wrappee.dual_basis
147+
148+
95149
class HDivElement(WrapperElementBase):
96150
"""H(div) wrapper element for tensor product elements."""
97151

finat/ufl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from finat.ufl.enrichedelement import EnrichedElement, NodalEnrichedElement # noqa: F401
1616
from finat.ufl.finiteelement import FiniteElement # noqa: F401
1717
from finat.ufl.finiteelementbase import FiniteElementBase # noqa: F401
18-
from finat.ufl.hdivcurl import HCurlElement, HDivElement, WithMapping, HDiv, HCurl # noqa: F401
18+
from finat.ufl.hdivcurl import HCurlElement, HDivElement, HDivTraceElement, WithMapping, HDiv, HCurl # noqa: F401
1919
from finat.ufl.mixedelement import MixedElement, TensorElement, VectorElement # noqa: F401
2020
from finat.ufl.restrictedelement import RestrictedElement # noqa: F401
2121
from finat.ufl.tensorproductelement import TensorProductElement # noqa: F401

finat/ufl/finiteelement.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __new__(cls,
3939
from finat.ufl.enrichedelement import EnrichedElement
4040
from finat.ufl.hdivcurl import HCurlElement as HCurl
4141
from finat.ufl.hdivcurl import HDivElement as HDiv
42+
from finat.ufl.hdivcurl import HDivTraceElement as HDivTrace
4243
from finat.ufl.tensorproductelement import TensorProductElement
4344

4445
family, short_name, degree, reference_value_shape, sobolev_space, mapping, embedded_degree = \
@@ -96,18 +97,21 @@ def __new__(cls,
9697

9798
elif family == "HDiv Trace":
9899
cell_h, cell_v = cell.sub_cells()
99-
cell_h, cell_v = cell.sub_cells()
100+
if not isinstance(degree, tuple):
101+
degree = (degree, degree)
102+
hdegree, vdegree = degree
103+
104+
dg_family = lambda cell: "DG" if cell.is_simplex() else "DQ"
105+
is_interval = lambda cell: cell.cellname() == "interval"
100106

101-
hdegree = 0 if cell_h.cellname() == "interval" else degree
102-
vdegree = 0 if cell_v.cellname() == "interval" else degree
103-
tr_h = FiniteElement("HDiv Trace", cell_h, hdegree, variant=variant)
104-
tr_v = FiniteElement("HDiv Trace", cell_v, vdegree, variant=variant)
107+
dg_h = FiniteElement(dg_family(cell_h), cell_h, hdegree, variant=variant)
108+
tr_h = FiniteElement("HDiv Trace", cell_h, 0 if is_interval(cell_h) else hdegree, variant=variant)
105109

106-
dg_h = FiniteElement("DG", cell_h, degree, variant=variant)
107-
dg_v = FiniteElement("DG", cell_v, degree, variant=variant)
110+
dg_v = FiniteElement(dg_family(cell_v), cell_v, vdegree, variant=variant)
111+
tr_v = FiniteElement("HDiv Trace", cell_v, 0 if is_interval(cell_v) else vdegree, variant=variant)
108112

109-
return EnrichedElement(TensorProductElement(tr_h, dg_v, cell=cell),
110-
TensorProductElement(dg_h, tr_v, cell=cell))
113+
return EnrichedElement(HDivTrace(TensorProductElement(tr_h, dg_v, cell=cell)),
114+
HDivTrace(TensorProductElement(dg_h, tr_v, cell=cell)))
111115

112116
elif family == "Q":
113117
return TensorProductElement(*[FiniteElement("CG", c, degree, variant=variant)

0 commit comments

Comments
 (0)