Skip to content

Commit 4f3010d

Browse files
committed
Some typing improvements in examples
1 parent 4880a83 commit 4f3010d

3 files changed

Lines changed: 15 additions & 12 deletions

File tree

examples/curve-pot.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import numpy.linalg as la
3+
from numpy.typing import NDArray
34

45
import pyopencl as cl
56

@@ -120,7 +121,7 @@ def draw_pot_figure(aspect_ratio,
120121
a = 1
121122
b = 1/aspect_ratio
122123

123-
def map_to_curve(t):
124+
def map_to_curve(t: NDArray[np.floating]):
124125
t = t*(2*np.pi)
125126

126127
x = a*np.cos(t)
@@ -185,8 +186,8 @@ def map_to_curve(t):
185186

186187
from sumpy.tools import build_matrix
187188

188-
def apply_lpot(x):
189-
xovsmp = np.dot(fim, x)
189+
def apply_lpot(x: NDArray[np.inexact]) -> NDArray[np.inexact]:
190+
xovsmp = fim @ x
190191
_evt, (y,) = lpot(actx.queue,
191192
sources,
192193
ovsmp_sources,
@@ -197,7 +198,7 @@ def apply_lpot(x):
197198

198199
return actx.to_numpy(y)
199200

200-
op = LinearOperator((nsrc, nsrc), apply_lpot)
201+
op = LinearOperator((nsrc, nsrc), np.dtype(np.complex128), apply_lpot)
201202
mat = build_matrix(op, dtype=np.complex128)
202203
w, _v = la.eig(mat)
203204
plt.plot(w.real, "o-")

examples/curve.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from typing import final
2+
13
import numpy as np
24
import scipy as sp
35
import scipy.fftpack
46

57

8+
@final
69
class CurveGrid:
710
def __init__(self, x, y):
811
self.pos = np.vstack([x, y]).copy()

examples/fourier.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import numpy as np
2+
from numpy.typing import NDArray
23

34

4-
def make_fourier_vdm(n, inverse):
5+
def make_fourier_vdm(n: int, inverse: bool) -> NDArray[np.complex128]:
56
i = np.arange(n, dtype=np.float64)
67
imat = i[:, np.newaxis]*i/n
78
result = np.exp((2j*np.pi)*imat)
89

910
if inverse:
10-
result = result.T.conj()/n
11+
result = np.conj(result.T)/n
1112
return result
1213

1314

@@ -30,9 +31,7 @@ def make_fourier_mode_extender(m, n, dtype):
3031
return result
3132

3233

33-
def make_fourier_interp_matrix(m, n):
34-
return np.dot(
35-
np.dot(
36-
make_fourier_vdm(m, inverse=False),
37-
make_fourier_mode_extender(m, n, np.float64)),
38-
make_fourier_vdm(n, inverse=True))
34+
def make_fourier_interp_matrix(m: int, n: int):
35+
return (make_fourier_vdm(m, inverse=False)
36+
@ make_fourier_mode_extender(m, n, np.float64)
37+
@ make_fourier_vdm(n, inverse=True))

0 commit comments

Comments
 (0)