Skip to content

Commit 6d5be09

Browse files
committed
Add typing generics Function
1 parent dff1b7b commit 6d5be09

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

python/dolfinx/fem/function.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def dtype(self) -> npt.DTypeLike:
311311
return np.dtype(self._cpp_object.dtype)
312312

313313

314-
class Function(ufl.Coefficient):
314+
class Function(ufl.Coefficient, Generic[_S]):
315315
"""A finite element function.
316316
317317
A finite element function is represented by a function space
@@ -327,10 +327,12 @@ class Function(ufl.Coefficient):
327327
| _cpp.fem.Function_float64
328328
)
329329

330+
_x: la.Vector[_S]
331+
330332
def __init__(
331333
self,
332334
V: FunctionSpace,
333-
x: la.Vector | None = None,
335+
x: la.Vector[_S] | None = None,
334336
name: str | None = None,
335337
dtype: npt.DTypeLike | None = None,
336338
):
@@ -395,7 +397,7 @@ def function_space(self) -> FunctionSpace:
395397
"""FunctionSpace that the Function is defined on."""
396398
return self._V
397399

398-
def eval(self, x: npt.ArrayLike, cells: npt.ArrayLike, u=None) -> np.ndarray:
400+
def eval(self, x: npt.ArrayLike, cells: npt.NDArray[np.int32], u=None) -> npt.NDArray[_S]:
399401
"""Evaluate Function at points x.
400402
401403
Points where x has shape (num_points, 3), and cells has shape
@@ -431,7 +433,7 @@ def eval(self, x: npt.ArrayLike, cells: npt.ArrayLike, u=None) -> np.ndarray:
431433
return u
432434

433435
def interpolate_nonmatching(
434-
self, u0: Function, cells: npt.NDArray[np.int32], interpolation_data: PointOwnershipData
436+
self, u0: Function[_S], cells: npt.NDArray[np.int32], interpolation_data: PointOwnershipData
435437
) -> None:
436438
"""Interpolate a Function on a non-matching mesh.
437439
@@ -447,9 +449,9 @@ def interpolate_nonmatching(
447449

448450
def interpolate(
449451
self,
450-
u0: Callable | Expression | Function,
451-
cells0: np.ndarray | None = None,
452-
cells1: np.ndarray | None = None,
452+
u0: Callable | Expression[_S] | Function[_S],
453+
cells0: npt.NDArray[np.int32] | None = None,
454+
cells1: npt.NDArray[np.int32] | None = None,
453455
) -> None:
454456
"""Interpolate an expression.
455457
@@ -495,7 +497,7 @@ def _(e0: Expression):
495497
)
496498
self._cpp_object.interpolate_f(np.asarray(u0(x), dtype=self.dtype), cells0)
497499

498-
def copy(self) -> Function:
500+
def copy(self) -> Function[_S]:
499501
"""Create a copy of the Function.
500502
501503
The function space is shared and the degree-of-freedom vector is
@@ -511,12 +513,12 @@ def copy(self) -> Function:
511513
)
512514

513515
@property
514-
def x(self) -> la.Vector:
516+
def x(self) -> la.Vector[_S]:
515517
"""Vector holding the degrees-of-freedom."""
516518
return self._x
517519

518520
@property
519-
def dtype(self) -> np.dtype:
521+
def dtype(self) -> npt.DTypeLike:
520522
"""Function value dtype."""
521523
return np.dtype(self._cpp_object.x.array.dtype)
522524

@@ -529,11 +531,11 @@ def name(self) -> str:
529531
def name(self, name):
530532
self._cpp_object.name = name
531533

532-
def __str__(self):
534+
def __str__(self) -> str:
533535
"""Pretty print representation."""
534536
return self.name
535537

536-
def sub(self, i: int) -> Function:
538+
def sub(self, i: int) -> Function[_S]:
537539
"""Return a sub-function (a view into the ``Function``).
538540
539541
Sub-functions are indexed ``i = 0, ..., N-1``, where ``N`` is
@@ -552,7 +554,7 @@ def sub(self, i: int) -> Function:
552554
"""
553555
return Function(self._V.sub(i), self.x, name=f"{self!s}_{i}")
554556

555-
def split(self) -> tuple[Function, ...]:
557+
def split(self) -> tuple[Function[_S], ...]:
556558
"""Extract (any) sub-functions.
557559
558560
A sub-function can be extracted from a discrete function that is
@@ -567,7 +569,7 @@ def split(self) -> tuple[Function, ...]:
567569
raise RuntimeError("No subfunctions to extract")
568570
return tuple(self.sub(i) for i in range(num_sub_spaces))
569571

570-
def collapse(self) -> Function:
572+
def collapse(self) -> Function[_S]:
571573
"""Create a collapsed version of this Function."""
572574
u_collapsed = self._cpp_object.collapse() # type: ignore
573575
V_collapsed = FunctionSpace(

0 commit comments

Comments
 (0)