Skip to content

Commit d1c7302

Browse files
alexfiklinducer
authored andcommitted
tools: add some type annotations
1 parent e174032 commit d1c7302

1 file changed

Lines changed: 23 additions & 10 deletions

File tree

sumpy/tools.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,31 +107,35 @@
107107

108108
# {{{ multi_index helpers
109109

110-
def add_mi(mi1, mi2):
110+
def add_mi(mi1: Sequence[int], mi2: Sequence[int]) -> Tuple[int, ...]:
111111
return tuple([mi1i + mi2i for mi1i, mi2i in zip(mi1, mi2)])
112112

113113

114-
def mi_factorial(mi):
114+
def mi_factorial(mi: Sequence[int]) -> int:
115115
import math
116116
result = 1
117117
for mi_i in mi:
118118
result *= math.factorial(mi_i)
119119
return result
120120

121121

122-
def mi_increment_axis(mi, axis, increment):
122+
def mi_increment_axis(
123+
mi: Sequence[int], axis: int, increment: int
124+
) -> Tuple[int, ...]:
123125
new_mi = list(mi)
124126
new_mi[axis] += increment
125127
return tuple(new_mi)
126128

127129

128-
def mi_set_axis(mi, axis, value):
130+
def mi_set_axis(mi: Sequence[int], axis: int, value: int) -> Tuple[int, ...]:
129131
new_mi = list(mi)
130132
new_mi[axis] = value
131133
return tuple(new_mi)
132134

133135

134-
def mi_power(vector, mi, evaluate=True):
136+
def mi_power(
137+
vector: Sequence[T], mi: Sequence[int],
138+
evaluate: bool = True) -> T:
135139
result = 1
136140
for mi_i, vec_i in zip(mi, vector):
137141
if mi_i == 1:
@@ -147,8 +151,8 @@ def add_to_sac(sac, expr):
147151
if sac is None:
148152
return expr
149153

150-
if isinstance(expr, (numbers.Number, sym.Number, int,
151-
float, complex, sym.Symbol)):
154+
from numbers import Number
155+
if isinstance(expr, (Number, sym.Number, sym.Symbol)):
152156
return expr
153157

154158
name = sac.assign_temp("temp", expr)
@@ -280,7 +284,7 @@ def __init__(self, ctx: Any,
280284
target_kernels: List["Kernel"],
281285
source_kernels: List["Kernel"],
282286
strength_usage: Optional[List[int]] = None,
283-
value_dtypes: Optional[List["np.dtype"]] = None,
287+
value_dtypes: Optional[List["np.dtype[Any]"]] = None,
284288
name: Optional[str] = None,
285289
device: Optional[Any] = None) -> None:
286290
"""
@@ -913,7 +917,11 @@ def _get_fft_backend(queue) -> FFTBackend:
913917
return FFTBackend.pyvkfft
914918

915919

916-
def get_opencl_fft_app(queue, shape, dtype, inverse):
920+
def get_opencl_fft_app(
921+
queue: "cl.CommandQueue",
922+
shape: Tuple[int, ...],
923+
dtype: "np.dtype[Any]",
924+
inverse: bool) -> Any:
917925
"""Setup an object for out-of-place FFT on with given shape and dtype
918926
on given queue.
919927
"""
@@ -932,7 +940,12 @@ def get_opencl_fft_app(queue, shape, dtype, inverse):
932940
raise RuntimeError(f"Unsupported FFT backend {backend}")
933941

934942

935-
def run_opencl_fft(fft_app, queue, input_vec, inverse=False, wait_for=None):
943+
def run_opencl_fft(
944+
fft_app: Tuple[Any, FFTBackend],
945+
queue: "cl.CommandQueue",
946+
input_vec: Array,
947+
inverse: bool = False,
948+
wait_for: List["cl.Event"] = None) -> Tuple["cl.Event", Array]:
936949
"""Runs an FFT on input_vec and returns a :class:`MarkerBasedProfilingEvent`
937950
that indicate the end and start of the operations carried out and the output
938951
vector.

0 commit comments

Comments
 (0)