@@ -30,7 +30,7 @@ def barycentric_interpolation(nodes, wts, dmat, pts, order=0):
3030 sp_simplify = numpy .vectorize (simplify )
3131 else :
3232 sp_simplify = lambda x : x
33- phi = numpy .add .outer (- nodes , pts .flatten ())
33+ phi = numpy .add .outer (- nodes . flatten () , pts .flatten ())
3434 with numpy .errstate (divide = 'ignore' , invalid = 'ignore' ):
3535 numpy .reciprocal (phi , out = phi )
3636 numpy .multiply (phi , wts [:, None ], out = phi )
@@ -49,7 +49,7 @@ def barycentric_interpolation(nodes, wts, dmat, pts, order=0):
4949def make_dmat (x ):
5050 """Returns Lagrange differentiation matrix and barycentric weights
5151 associated with x[j]."""
52- dmat = numpy .add .outer (- x , x )
52+ dmat = numpy .add .outer (- x . flatten () , x . flatten () )
5353 numpy .fill_diagonal (dmat , 1.0 )
5454 wts = numpy .prod (dmat , axis = 0 )
5555 numpy .reciprocal (wts , out = wts )
@@ -59,22 +59,38 @@ def make_dmat(x):
5959
6060
6161class LagrangeLineExpansionSet (expansions .LineExpansionSet ):
62- """Lagrange polynomial expansion set for given points the line."""
62+ """Lagrange polynomial expansion set for given points on the line."""
6363 def __init__ (self , ref_el , pts ):
64+ if ref_el .get_spatial_dimension () != 1 :
65+ raise Exception ("Must have a line" )
66+ pts = numpy .asarray (pts )
6467 self .points = pts
65- self . x = numpy . array ( pts , dtype = "d" ). flatten ()
68+
6669 self .cell_node_map = expansions .compute_cell_point_map (ref_el , pts , unique = False )
6770 self .dmats = [None for _ in self .cell_node_map ]
6871 self .weights = [None for _ in self .cell_node_map ]
6972 self .nodes = [None for _ in self .cell_node_map ]
73+ self .affine_mappings = {}
7074 for cell , ibfs in self .cell_node_map .items ():
71- self .nodes [cell ] = self .x [ibfs ]
72- self .dmats [cell ], self .weights [cell ] = make_dmat (self .nodes [cell ])
75+ x = pts [ibfs ]
76+ if ref_el .is_trace ():
77+ verts = ref_el .get_vertices_of_subcomplex (ref_el .topology [1 ][cell ])
78+ A = numpy .diff (verts , axis = 0 )[0 ]/ 2
79+ A /= numpy .linalg .norm (A )
80+ b = - numpy .dot (numpy .sum (verts , axis = 0 )/ 2 , A .T )
81+ self .affine_mappings [cell ] = (A , b )
82+ x = numpy .add (numpy .dot (x , A .T ), b )
83+ self .nodes [cell ] = x
84+ self .dmats [cell ], self .weights [cell ] = make_dmat (x )
7385
7486 self .degree = max (len (wts ) for wts in self .weights )- 1
7587 self .recurrence_order = self .degree + 1
76- super ().__init__ (ref_el )
77- self .continuity = None if len (self .x ) == sum (len (xk ) for xk in self .nodes ) else "C0"
88+ self .ref_el = ref_el
89+ self .variant = None
90+ self .scale = 1
91+ self .continuity = None if len (pts ) == sum (len (xk ) for xk in self .nodes ) else "C0"
92+ self ._dmats_cache = {}
93+ self ._cell_node_map_cache = {}
7894
7995 def get_num_members (self , n ):
8096 return len (self .points )
@@ -89,6 +105,12 @@ def get_dmats(self, degree, cell=0):
89105 return [self .dmats [cell ].T ]
90106
91107 def _tabulate_on_cell (self , n , pts , order = 0 , cell = 0 , direction = None ):
108+ try :
109+ A , b = self .affine_mappings [cell ]
110+ ref_pts = numpy .add (numpy .dot (pts .reshape (- 1 , A .shape [- 1 ]), A .T ), b )
111+ pts = ref_pts .reshape (* pts .shape [:- 1 ], - 1 )
112+ except KeyError :
113+ pass
92114 return barycentric_interpolation (self .nodes [cell ], self .weights [cell ], self .dmats [cell ], pts , order = order )
93115
94116
0 commit comments