Skip to content

Commit 33424fe

Browse files
More specific Callable type annotations in mobject_update_utils (#4728)
* More specific Callable type annotations in mobject_update_utils * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1b33900 commit 33424fe

1 file changed

Lines changed: 9 additions & 6 deletions

File tree

manim/animation/updaters/mobject_update_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import inspect
1818
from collections.abc import Callable
19-
from typing import TYPE_CHECKING
19+
from typing import TYPE_CHECKING, TypeVar
2020

2121
import numpy as np
2222

@@ -29,6 +29,9 @@
2929
from manim.animation.animation import Animation
3030

3131

32+
M = TypeVar("M", bound=Mobject)
33+
34+
3235
def assert_is_mobject_method(method: Callable) -> None:
3336
assert inspect.ismethod(method)
3437
mobject = method.__self__
@@ -43,7 +46,7 @@ def always(method: Callable, *args, **kwargs) -> Mobject:
4346
return mobject
4447

4548

46-
def f_always(method: Callable[[Mobject], None], *arg_generators, **kwargs) -> Mobject:
49+
def f_always(method: Callable[[M], None], *arg_generators, **kwargs) -> M:
4750
"""
4851
More functional version of always, where instead
4952
of taking in args, it takes in functions which output
@@ -61,7 +64,7 @@ def updater(mob):
6164
return mobject
6265

6366

64-
def always_redraw(func: Callable[[], Mobject]) -> Mobject:
67+
def always_redraw(func: Callable[[], M]) -> M:
6568
"""Redraw the mobject constructed by a function every frame.
6669
6770
This function returns a mobject with an attached updater that
@@ -107,8 +110,8 @@ def construct(self):
107110

108111

109112
def always_shift(
110-
mobject: Mobject, direction: np.ndarray[np.float64] = RIGHT, rate: float = 0.1
111-
) -> Mobject:
113+
mobject: M, direction: np.ndarray[np.float64] = RIGHT, rate: float = 0.1
114+
) -> M:
112115
"""A mobject which is continuously shifted along some direction
113116
at a certain rate.
114117
@@ -145,7 +148,7 @@ def construct(self):
145148
return mobject
146149

147150

148-
def always_rotate(mobject: Mobject, rate: float = 20 * DEGREES, **kwargs) -> Mobject:
151+
def always_rotate(mobject: M, rate: float = 20 * DEGREES, **kwargs) -> M:
149152
"""A mobject which is continuously rotated at a certain rate.
150153
151154
Parameters

0 commit comments

Comments
 (0)