Skip to content

Commit b5f116f

Browse files
committed
Add typing generics to CoordinateElement
1 parent efd17c2 commit b5f116f

1 file changed

Lines changed: 10 additions & 11 deletions

File tree

python/dolfinx/fem/element.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""Finite elements."""
77

88
from functools import singledispatch
9+
from typing import Generic, TypeVar
910

1011
import numpy as np
1112
import numpy.typing as npt
@@ -14,8 +15,10 @@
1415
import basix.ufl
1516
from dolfinx import cpp as _cpp
1617

18+
_T_CE = TypeVar("_T_CE", np.float32, np.float64)
1719

18-
class CoordinateElement:
20+
21+
class CoordinateElement(Generic[_T_CE]):
1922
"""Coordinate element describing the geometry map for mesh cells."""
2023

2124
_cpp_object: _cpp.fem.CoordinateElement_float32 | _cpp.fem.CoordinateElement_float64
@@ -61,10 +64,8 @@ def create_dof_layout(self) -> _cpp.fem.ElementDofLayout:
6164
return self._cpp_object.create_dof_layout()
6265

6366
def push_forward(
64-
self,
65-
X: npt.NDArray[np.float32] | npt.NDArray[np.float64],
66-
cell_geometry: npt.NDArray[np.float32] | npt.NDArray[np.float64],
67-
) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]:
67+
self, X: npt.NDArray[_T_CE], cell_geometry: npt.NDArray[_T_CE]
68+
) -> npt.NDArray[_T_CE]:
6869
"""Push points on the reference cell forward to the physical cell.
6970
7071
Args:
@@ -81,10 +82,8 @@ def push_forward(
8182
return self._cpp_object.push_forward(X, cell_geometry)
8283

8384
def pull_back(
84-
self,
85-
x: npt.NDArray[np.float32] | npt.NDArray[np.float64],
86-
cell_geometry: npt.NDArray[np.float32] | npt.NDArray[np.float64],
87-
) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]:
85+
self, x: npt.NDArray[_T_CE], cell_geometry: npt.NDArray[_T_CE]
86+
) -> npt.NDArray[_T_CE]:
8887
"""Pull points on the physical cell back to the reference cell.
8988
9089
For non-affine cells, the pull-back is a nonlinear operation.
@@ -124,7 +123,7 @@ def coordinate_element(
124123
degree: int,
125124
variant=int(basix.LagrangeVariant.unset),
126125
dtype: npt.DTypeLike = np.float64,
127-
):
126+
) -> CoordinateElement:
128127
"""Create a Lagrange CoordinateElement from element metadata.
129128
130129
Coordinate elements are typically used to create meshes.
@@ -147,7 +146,7 @@ def coordinate_element(
147146

148147

149148
@coordinate_element.register(basix.finite_element.FiniteElement)
150-
def _(e: basix.finite_element.FiniteElement):
149+
def _(e: basix.finite_element.FiniteElement) -> CoordinateElement:
151150
"""Create a Lagrange CoordinateElement from a Basix finite element.
152151
153152
Coordinate elements are typically used when creating meshes.

0 commit comments

Comments
 (0)