107107
108108# {{{ multi_index helpers
109109
110- def add_mi (mi1 , mi2 ) :
110+ def add_mi (mi1 : Sequence [ int ] , mi2 : Sequence [ int ]) -> Tuple [ int , ...] :
111111 return tuple ([mi1i + mi2i for mi1i , mi2i in zip (mi1 , mi2 )])
112112
113113
114- def mi_factorial (mi ) :
114+ def mi_factorial (mi : Sequence [ int ]) -> int :
115115 import math
116116 result = 1
117117 for mi_i in mi :
118118 result *= math .factorial (mi_i )
119119 return result
120120
121121
122- def mi_increment_axis (mi , axis , increment ):
122+ def mi_increment_axis (
123+ mi : Sequence [int ], axis : int , increment : int
124+ ) -> Tuple [int , ...]:
123125 new_mi = list (mi )
124126 new_mi [axis ] += increment
125127 return tuple (new_mi )
126128
127129
128- def mi_set_axis (mi , axis , value ) :
130+ def mi_set_axis (mi : Sequence [ int ] , axis : int , value : int ) -> Tuple [ int , ...] :
129131 new_mi = list (mi )
130132 new_mi [axis ] = value
131133 return tuple (new_mi )
132134
133135
134- def mi_power (vector , mi , evaluate = True ):
136+ def mi_power (
137+ vector : Sequence [T ], mi : Sequence [int ],
138+ evaluate : bool = True ) -> T :
135139 result = 1
136140 for mi_i , vec_i in zip (mi , vector ):
137141 if mi_i == 1 :
@@ -147,8 +151,8 @@ def add_to_sac(sac, expr):
147151 if sac is None :
148152 return expr
149153
150- if isinstance ( expr , ( numbers . Number , sym . Number , int ,
151- float , complex , sym .Symbol )):
154+ from numbers import Number
155+ if isinstance ( expr , ( Number , sym . Number , sym .Symbol )):
152156 return expr
153157
154158 name = sac .assign_temp ("temp" , expr )
@@ -280,7 +284,7 @@ def __init__(self, ctx: Any,
280284 target_kernels : List ["Kernel" ],
281285 source_kernels : List ["Kernel" ],
282286 strength_usage : Optional [List [int ]] = None ,
283- value_dtypes : Optional [List ["np.dtype" ]] = None ,
287+ value_dtypes : Optional [List ["np.dtype[Any] " ]] = None ,
284288 name : Optional [str ] = None ,
285289 device : Optional [Any ] = None ) -> None :
286290 """
@@ -913,7 +917,11 @@ def _get_fft_backend(queue) -> FFTBackend:
913917 return FFTBackend .pyvkfft
914918
915919
916- def get_opencl_fft_app (queue , shape , dtype , inverse ):
920+ def get_opencl_fft_app (
921+ queue : "cl.CommandQueue" ,
922+ shape : Tuple [int , ...],
923+ dtype : "np.dtype[Any]" ,
924+ inverse : bool ) -> Any :
917925 """Setup an object for out-of-place FFT on with given shape and dtype
918926 on given queue.
919927 """
@@ -932,7 +940,12 @@ def get_opencl_fft_app(queue, shape, dtype, inverse):
932940 raise RuntimeError (f"Unsupported FFT backend { backend } " )
933941
934942
935- def run_opencl_fft (fft_app , queue , input_vec , inverse = False , wait_for = None ):
943+ def run_opencl_fft (
944+ fft_app : Tuple [Any , FFTBackend ],
945+ queue : "cl.CommandQueue" ,
946+ input_vec : Array ,
947+ inverse : bool = False ,
948+ wait_for : List ["cl.Event" ] = None ) -> Tuple ["cl.Event" , Array ]:
936949 """Runs an FFT on input_vec and returns a :class:`MarkerBasedProfilingEvent`
937950 that indicate the end and start of the operations carried out and the output
938951 vector.
0 commit comments