Skip to content

Commit 781d042

Browse files
committed
Refactor to allow derivative dof conversion
1 parent c7e2852 commit 781d042

4 files changed

Lines changed: 86 additions & 22 deletions

File tree

fuse/dof.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from FIAT.quadrature_schemes import create_quadrature
22
from FIAT.quadrature import FacetQuadratureRule
3-
from FIAT.functional import PointEvaluation, FrobeniusIntegralMoment
3+
from FIAT.functional import PointEvaluation, FrobeniusIntegralMoment, Functional
44
from fuse.utils import sympy_to_numpy
55
import numpy as np
66
import sympy as sp
@@ -40,12 +40,7 @@ def convert_to_fiat(self, ref_el, dof, interpolant_deg):
4040

4141
def get_pts(self, ref_el, total_degree):
4242
entity = ref_el.construct_subelement(self.entity.dim())
43-
Q_ref = create_quadrature(entity, total_deg)
44-
45-
ent_id = self.entity.id - ref_el.fe_cell.get_starter_ids()[self.entity.dim()]
46-
Q = FacetQuadratureRule(ref_el, self.entity.dim(), ent_id, Q_ref)
47-
qpts, qwts = Q.get_points(), Q.get_weights()
48-
return [(1,)], [(1,)]
43+
return [(0,) * entity.get_spatial_dimension()], [1], 1
4944

5045
def add_entity(self, entity):
5146
res = DeltaPairing()
@@ -101,12 +96,15 @@ def add_entity(self, entity):
10196

10297
def get_pts(self, ref_el, total_degree):
10398
entity = ref_el.construct_subelement(self.entity.dim())
104-
Q_ref = create_quadrature(entity, total_deg)
99+
Q_ref = create_quadrature(entity, total_degree)
105100

106101
ent_id = self.entity.id - ref_el.fe_cell.get_starter_ids()[self.entity.dim()]
107102
Q = FacetQuadratureRule(ref_el, self.entity.dim(), ent_id, Q_ref)
108-
qpts, qwts = Q.get_points(), Q.get_weights()
109-
return qpts, qwts
103+
Jdet = Q.jacobian_determinant()
104+
# TODO work out how to get J from attachment
105+
106+
qpts, qwts = Q_ref.get_points(), Q_ref.get_weights()
107+
return qpts, qwts, Jdet
110108

111109
def convert_to_fiat(self, ref_el, dof, interpolant_degree):
112110
total_deg = interpolant_degree + dof.kernel.degree()
@@ -168,8 +166,11 @@ def permute(self, g):
168166
def __call__(self, *args):
169167
return self.pt
170168

171-
def tabulate(self, Qpts):
172-
return np.array([self.pt for _ in Qpts]).astype(np.float64)
169+
def tabulate(self, Qpts, attachment=None):
170+
if attachment is None:
171+
return np.array([self.pt for _ in Qpts]).astype(np.float64)
172+
else:
173+
return np.array([attachment(*self.pt) for _ in Qpts]).astype(np.float64)
173174

174175
def _to_dict(self):
175176
o_dict = {"pt": self.pt}
@@ -209,8 +210,10 @@ def __call__(self, *args):
209210
return [res]
210211
return res
211212

212-
def tabulate(self, Qpts):
213-
return np.array([self(*pt) for pt in Qpts]).astype(np.float64)
213+
def tabulate(self, Qpts, attachment=None):
214+
if attachment is None:
215+
return np.array([self(*pt) for pt in Qpts]).astype(np.float64)
216+
return np.array([self(*attachment(*pt)) for pt in Qpts]).astype(np.float64)
214217

215218
def _to_dict(self):
216219
o_dict = {"fn": self.fn}
@@ -270,7 +273,23 @@ def add_context(self, dof_gen, cell, space, g, overall_id=None, generator_id=Non
270273

271274
def convert_to_fiat(self, ref_el, interpolant_degree):
272275
return self.pairing.convert_to_fiat(ref_el, self, interpolant_degree)
273-
raise NotImplementedError("Fiat conversion only implemented for Point eval")
276+
277+
def convert_to_fiat_new(self, ref_el, interpolant_degree):
278+
total_degree = self.kernel.degree() + interpolant_degree
279+
pts, wts, jdet = self.pairing.get_pts(ref_el, total_degree)
280+
f_pts = self.kernel.tabulate(pts).T / jdet
281+
# TODO need to work out how i can discover the shape in a better way
282+
if isinstance(self.pairing, DeltaPairing):
283+
shp = tuple()
284+
pt_dict = {tuple(p) : [(w, tuple())] for (p, w) in zip(f_pts.T, wts)}
285+
else:
286+
shp = tuple(f_pts.shape[:-1])
287+
weights = np.transpose(np.multiply(f_pts, wts), (-1,) + tuple(range(len(shp))))
288+
alphas = list(np.ndindex(shp))
289+
pt_dict = {tuple(pt): [(wt[alpha], alpha) for alpha in alphas] for pt, wt in zip(pts, weights)}
290+
291+
return Functional(ref_el, shp, pt_dict, {}, self.__repr__())
292+
274293

275294
def __repr__(self, fn="v"):
276295
return str(self.pairing).format(fn=fn, kernel=self.kernel)
@@ -309,13 +328,25 @@ def eval(self, fn, pullback=True):
309328
def tabulate(self, Qpts):
310329
immersion = self.target_space.tabulate(Qpts, self.trace_entity, self.g)
311330
res = self.kernel.tabulate(Qpts)
312-
# [self.attachment(*tuple(r)) for r in res]
313331
return immersion*res
314332

315-
def convert_to_fiat(self, ref_el):
316-
pts, wts = self.pairing.get_pts()
317-
self.target_space.convert_to_fiat(tabulated, wts)
333+
def convert_to_fiat_new(self, ref_el, interpolant_degree):
334+
total_degree = self.kernel.degree() + interpolant_degree
335+
pts, wts, jdet = self.pairing.get_pts(ref_el, total_degree)
336+
f_pts = self.kernel.tabulate(pts, self.attachment)
337+
attached_pts = [self.attachment(*p) for p in pts]
338+
immersion = self.target_space.tabulate(f_pts, self.trace_entity, self.g)
318339

340+
f_pts = (f_pts * immersion).T / jdet
341+
pt_dict, deriv_dict = self.target_space.convert_to_fiat(attached_pts, f_pts, wts)
342+
343+
# breakpoint()
344+
# TODO need to work out how i can discover the shape in a better way
345+
if isinstance(self.pairing, DeltaPairing):
346+
shp = tuple()
347+
else:
348+
shp = tuple(f_pts.shape[:-1])
349+
return Functional(ref_el, shp, pt_dict, deriv_dict, self.__repr__())
319350

320351
def __call__(self, g):
321352
permuted = self.cell.permute_entities(g, self.trace_entity.dim())

fuse/traces.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def plot(self, ax, coord, trace_entity, g, **kwargs):
5454
def tabulate(self, Qpts, trace_entity, g):
5555
return np.ones_like(Qpts)
5656

57+
def convert_to_fiat(self, qpts, pts, wts):
58+
pt_dict = {tuple(p) : [(w, tuple())] for (p, w) in zip(pts.T, wts)}
59+
return pt_dict, {}
60+
5761
def __repr__(self):
5862
return "H1"
5963

@@ -86,6 +90,15 @@ def plot(self, ax, coord, trace_entity, g, **kwargs):
8690
vec = np.cross(basis[0], basis[1])
8791
ax.quiver(*coord, *vec, **kwargs)
8892

93+
def convert_to_fiat(self, qpts, pts, wts):
94+
f_at_qpts = pts
95+
shp = tuple(f_at_qpts.shape[:-1])
96+
weights = np.transpose(np.multiply(f_at_qpts, wts), (-1,) + tuple(range(len(shp))))
97+
alphas = list(np.ndindex(shp))
98+
pt_dict = {tuple(pt): [(wt[alpha], alpha) for alpha in alphas] for pt, wt in zip(qpts, weights)}
99+
return pt_dict, {}
100+
101+
89102
def tabulate(self, Qpts, trace_entity, g):
90103
entityBasis = np.array(trace_entity.basis_vectors())
91104
cellEntityBasis = np.array(self.domain.basis_vectors(entity=trace_entity))
@@ -122,6 +135,14 @@ def tabulate(self, Qpts, trace_entity, g):
122135
subEntityBasis = np.array(self.domain.basis_vectors(entity=trace_entity))
123136
result = np.matmul(tangent, subEntityBasis)
124137
return result
138+
139+
def convert_to_fiat(self, qpts, pts, wts):
140+
f_at_qpts = pts
141+
shp = tuple(f_at_qpts.shape[:-1])
142+
weights = np.transpose(np.multiply(f_at_qpts, wts), (-1,) + tuple(range(len(shp))))
143+
alphas = list(np.ndindex(shp))
144+
pt_dict = {tuple(pt): [(wt[alpha], alpha) for alpha in alphas] for pt, wt in zip(qpts, weights)}
145+
return pt_dict, {}
125146

126147
def plot(self, ax, coord, trace_entity, g, **kwargs):
127148
permuted = self.domain.permute_entities(g, trace_entity.dimension)

fuse/triples.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,19 @@ def to_fiat(self):
113113
for i in range(len(dofs)):
114114
if entity[1] == dofs[i].trace_entity.id - min_ids[dim]:
115115
entity_ids[dim][dofs[i].trace_entity.id - min_ids[dim]].append(counter)
116-
nodes.append(dofs[i].convert_to_fiat(ref_el, degree))
116+
if hasattr(dofs[i], "convert_to_fiat_new"):
117+
nodes.append(dofs[i].convert_to_fiat_new(ref_el, degree))
118+
print("old")
119+
120+
print(dofs[i].convert_to_fiat(ref_el, degree).pt_dict)
121+
print(dofs[i].convert_to_fiat(ref_el, degree).target_shape)
122+
print("new")
123+
124+
print(dofs[i].convert_to_fiat_new(ref_el, degree).pt_dict)
125+
print(dofs[i].convert_to_fiat_new(ref_el, degree).target_shape)
126+
else:
127+
raise ValueError("using old")
128+
nodes.append(dofs[i].convert_to_fiat(ref_el, degree))
117129
counter += 1
118130
entity_perms, pure_perm = self.make_dof_perms(ref_el, entity_ids, nodes, poly_set)
119131

@@ -189,7 +201,7 @@ def compute_dense_matrix(self, ref_el, entity_ids, nodes, poly_set):
189201
A = dualmat.reshape((shp[0], -1))
190202
B = old_coeffs.reshape((shp[0], -1))
191203
V = np.dot(A, np.transpose(B))
192-
204+
print(V)
193205
with warnings.catch_warnings():
194206
warnings.filterwarnings("error")
195207
try:

test/test_convert_to_fiat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def test_project_3d(elem_gen, elem_code, deg):
564564

565565
assert np.allclose(out.dat.data, f.dat.data, rtol=1e-5)
566566

567-
567+
@pytest.mark.xfail(reason='Derivative nodes to fiat')
568568
def test_create_hermite():
569569
deg = 3
570570
cell = polygon(3)

0 commit comments

Comments
 (0)