66"""Finite elements."""
77
88from functools import singledispatch
9+ from typing import Generic , TypeVar
910
1011import numpy as np
1112import numpy .typing as npt
1415import basix .ufl
1516from dolfinx import cpp as _cpp
1617
18+ _T_CE = TypeVar ("_T_CE" , np .float32 , np .float64 )
1719
18- class CoordinateElement :
20+
21+ class CoordinateElement (Generic [_T_CE ]):
1922 """Coordinate element describing the geometry map for mesh cells."""
2023
2124 _cpp_object : _cpp .fem .CoordinateElement_float32 | _cpp .fem .CoordinateElement_float64
@@ -61,10 +64,8 @@ def create_dof_layout(self) -> _cpp.fem.ElementDofLayout:
6164 return self ._cpp_object .create_dof_layout ()
6265
6366 def push_forward (
64- self ,
65- X : npt .NDArray [np .float32 ] | npt .NDArray [np .float64 ],
66- cell_geometry : npt .NDArray [np .float32 ] | npt .NDArray [np .float64 ],
67- ) -> npt .NDArray [np .float32 ] | npt .NDArray [np .float64 ]:
67+ self , X : npt .NDArray [_T_CE ], cell_geometry : npt .NDArray [_T_CE ]
68+ ) -> npt .NDArray [_T_CE ]:
6869 """Push points on the reference cell forward to the physical cell.
6970
7071 Args:
@@ -81,10 +82,8 @@ def push_forward(
8182 return self ._cpp_object .push_forward (X , cell_geometry )
8283
8384 def pull_back (
84- self ,
85- x : npt .NDArray [np .float32 ] | npt .NDArray [np .float64 ],
86- cell_geometry : npt .NDArray [np .float32 ] | npt .NDArray [np .float64 ],
87- ) -> npt .NDArray [np .float32 ] | npt .NDArray [np .float64 ]:
85+ self , x : npt .NDArray [_T_CE ], cell_geometry : npt .NDArray [_T_CE ]
86+ ) -> npt .NDArray [_T_CE ]:
8887 """Pull points on the physical cell back to the reference cell.
8988
9089 For non-affine cells, the pull-back is a nonlinear operation.
@@ -124,7 +123,7 @@ def coordinate_element(
124123 degree : int ,
125124 variant = int (basix .LagrangeVariant .unset ),
126125 dtype : npt .DTypeLike = np .float64 ,
127- ):
126+ ) -> CoordinateElement :
128127 """Create a Lagrange CoordinateElement from element metadata.
129128
130129 Coordinate elements are typically used to create meshes.
@@ -147,7 +146,7 @@ def coordinate_element(
147146
148147
149148@coordinate_element .register (basix .finite_element .FiniteElement )
150- def _ (e : basix .finite_element .FiniteElement ):
149+ def _ (e : basix .finite_element .FiniteElement ) -> CoordinateElement :
151150 """Create a Lagrange CoordinateElement from a Basix finite element.
152151
153152 Coordinate elements are typically used when creating meshes.
0 commit comments