Skip to content

Commit 66c2823

Browse files
committed
20% speedups
1 parent 81a50ba commit 66c2823

2 files changed

Lines changed: 128 additions & 56 deletions

File tree

miepython/mie_jit.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,8 @@ def _S1_S2_nb(m, x, mu, n_pole):
321321
N = len(a)
322322
pi = np.zeros(N)
323323
tau = np.zeros(N)
324-
n = np.arange(1, N + 1)
325-
scale = (2 * n + 1) / ((n + 1) * n)
324+
n = np.arange(1, N + 1, dtype=np.float64)
325+
scale = (2.0 * n + 1.0) / ((n + 1.0) * n)
326326

327327
nangles = len(mu)
328328
S1 = np.zeros(nangles, dtype=np.complex128)
@@ -331,8 +331,14 @@ def _S1_S2_nb(m, x, mu, n_pole):
331331
for k in range(nangles):
332332
_pi_tau_nb(mu[k], pi, tau)
333333
if n_pole == 0:
334-
S1[k] = np.sum(scale * (pi * a + tau * b))
335-
S2[k] = np.sum(scale * (tau * a + pi * b))
334+
s1 = 0.0 + 0.0j
335+
s2 = 0.0 + 0.0j
336+
for i in range(N):
337+
si = scale[i]
338+
s1 += si * (pi[i] * a[i] + tau[i] * b[i])
339+
s2 += si * (tau[i] * a[i] + pi[i] * b[i])
340+
S1[k] = s1
341+
S2[k] = s2
336342
else:
337343
S1[k] = scale[n_pole] * (pi[n_pole] * a[n_pole] + tau[n_pole] * b[n_pole])
338344
S2[k] = scale[n_pole] * (tau[n_pole] * a[n_pole] + pi[n_pole] * b[n_pole])
@@ -443,7 +449,7 @@ def _single_sphere_nb(m, x, n_pole, e_field):
443449
qback: the backscatter efficiency
444450
g: the average cosine of the scattering phase function
445451
"""
446-
e_field = not e_field # unused
452+
_ = e_field # unused
447453

448454
# case when sphere matches its environment
449455
if abs(m.real - 1) <= 1e-8 and abs(m.imag) < 1e-8:
@@ -461,33 +467,60 @@ def _single_sphere_nb(m, x, n_pole, e_field):
461467
m = 1 - 10000j
462468

463469
a, b = _an_bn_nb(m, x, n_pole)
470+
x2 = x * x
464471

465472
if n_pole == 0:
466-
n = np.arange(1, len(a) + 1)
467-
cn = 2.0 * n + 1.0
468-
469-
qext = 2 * np.sum(cn * (a.real + b.real)) / x**2
470-
471-
if m.imag == 0:
473+
n_terms = len(a)
474+
qext_acc = 0.0
475+
qsca_acc = 0.0
476+
qback_acc = 0.0 + 0.0j
477+
g_acc = 0.0
478+
479+
for i in range(n_terms):
480+
ni = i + 1
481+
cn = 2.0 * ni + 1.0
482+
483+
ai = a[i]
484+
bi = b[i]
485+
ai_re = ai.real
486+
ai_im = ai.imag
487+
bi_re = bi.real
488+
bi_im = bi.imag
489+
490+
qext_acc += cn * (ai_re + bi_re)
491+
492+
if m.imag != 0.0:
493+
qsca_acc += cn * (ai_re * ai_re + ai_im * ai_im + bi_re * bi_re + bi_im * bi_im)
494+
495+
# (-1)^n with n starting from 1.
496+
sign = -1.0 if (ni % 2) == 1 else 1.0
497+
qback_acc += sign * cn * (ai - bi)
498+
499+
if i < n_terms - 1:
500+
aip1 = a[i + 1]
501+
bip1 = b[i + 1]
502+
c1n = ni * (ni + 2.0) / (ni + 1.0)
503+
c2n = cn / ni / (ni + 1.0)
504+
g_acc += c1n * (ai * np.conjugate(aip1) + bi * np.conjugate(bip1)).real
505+
g_acc += c2n * (ai * np.conjugate(bi)).real
506+
507+
qext = 2.0 * qext_acc / x2
508+
509+
if m.imag == 0.0:
472510
qsca = qext
473511
else:
474-
qsca = 2 * np.sum(cn * (np.abs(a) ** 2 + np.abs(b) ** 2)) / x**2
475-
476-
qback = np.abs(np.sum((-1) ** n * cn * (a - b))) ** 2 / x**2
512+
qsca = 2.0 * qsca_acc / x2
477513

478-
c1n = n * (n + 2) / (n + 1)
479-
c2n = cn / n / (n + 1)
480-
asy1 = c1n[:-1] * (a[:-1] * a[1:].conjugate() + b[:-1] * b[1:].conjugate()).real
481-
asy2 = c2n[:-1] * (a[:-1] * b[:-1].conjugate()).real
482-
g = 4 * np.sum(asy1 + asy2) / qsca / x**2
514+
qback = np.abs(qback_acc) ** 2 / x2
515+
g = 4.0 * g_acc / qsca / x2
483516

484517
else:
485518
cn = 2.0 * n_pole + 1
486-
qext = 2 * cn * (a[-1].real + b[-1].real) / x**2
487-
qback = np.abs((-1) ** n_pole * cn * (a[-1] - b[-1])) ** 2 / x**2
519+
qext = 2.0 * cn * (a[-1].real + b[-1].real) / x2
520+
qback = np.abs((-1) ** n_pole * cn * (a[-1] - b[-1])) ** 2 / x2
488521
qsca = qext
489522
if m.imag < 0:
490-
qsca = 2 * cn * (np.abs(a[-1]) ** 2 + np.abs(b[-1]) ** 2) / x**2
523+
qsca = 2.0 * cn * (np.abs(a[-1]) ** 2 + np.abs(b[-1]) ** 2) / x2
491524
g = 0
492525

493526
return qext, qsca, qback, g

miepython/mie_nojit.py

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Low-level Mie calculations that do not use numba.
33
"""
44

5+
from functools import lru_cache
6+
57
import numpy as np
68

79
__all__ = (
@@ -16,6 +18,31 @@
1618
)
1719

1820

21+
@lru_cache(maxsize=128)
22+
def _series_scale_factors(n_terms):
23+
"""Return cached per-order scale factors for Mie series summations."""
24+
n = np.arange(1, n_terms + 1, dtype=np.float64)
25+
scale = (2.0 * n + 1.0) / ((n + 1.0) * n)
26+
scale.setflags(write=False)
27+
return scale
28+
29+
30+
@lru_cache(maxsize=128)
31+
def _single_sphere_factors(n_terms):
32+
"""Return cached per-order factors used by ``_single_sphere_py``."""
33+
n_int = np.arange(1, n_terms + 1, dtype=np.int64)
34+
n = n_int.astype(np.float64)
35+
cn = 2.0 * n + 1.0
36+
alt = np.where((n_int % 2) == 0, 1.0, -1.0) # (-1)^n with n starting at 1
37+
c1n = n * (n + 2.0) / (n + 1.0)
38+
c2n = cn / n / (n + 1.0)
39+
cn.setflags(write=False)
40+
alt.setflags(write=False)
41+
c1n.setflags(write=False)
42+
c2n.setflags(write=False)
43+
return cn, alt, c1n, c2n
44+
45+
1946
def _Lentz_Dn(z, N):
2047
"""
2148
Compute the logarithmic derivative of the Ricatti-Bessel function.
@@ -139,7 +166,7 @@ def _an_bn_py(m, x, n_pole=0):
139166
Returns:
140167
a, b: arrays of Mie coefficents An and Bn
141168
"""
142-
if np.imag(m) > 0: # ensure imaginary part of refractive index is negative
169+
if m.imag > 0: # ensure imaginary part of refractive index is negative
143170
m = np.conj(m)
144171

145172
if n_pole == 0:
@@ -152,31 +179,35 @@ def _an_bn_py(m, x, n_pole=0):
152179
if x <= 0:
153180
return a, b
154181

182+
inv_x = 1.0 / x
155183
psi_nm1 = np.sin(x) # nm1 = n-1 = 0
156-
psi_n = psi_nm1 / x - np.cos(x)
184+
psi_n = psi_nm1 * inv_x - np.cos(x)
157185
xi_nm1 = np.complex128(psi_nm1 + 1j * np.cos(x))
158-
xi_n = np.complex128(psi_n + 1j * (np.cos(x) / x + np.sin(x)))
186+
xi_n = np.complex128(psi_n + 1j * (np.cos(x) * inv_x + np.sin(x)))
159187

160188
if m.real > 0.0:
161189
D = _D_calc_py(m, x, nstop + 1)
162190

163191
for n in range(1, nstop):
164-
temp = D[n - 1] / m + n / x
192+
n_over_x = n * inv_x
193+
temp = D[n - 1] / m + n_over_x
165194
a[n - 1] = (temp * psi_n - psi_nm1) / (temp * xi_n - xi_nm1)
166-
temp = D[n - 1] * m + n / x
195+
temp = D[n - 1] * m + n_over_x
167196
b[n - 1] = (temp * psi_n - psi_nm1) / (temp * xi_n - xi_nm1)
168-
psi = (2 * n + 1) * psi_n / x - psi_nm1
169-
xi = (2 * n + 1) * xi_n / x - xi_nm1
197+
two_np1_over_x = (2 * n + 1) * inv_x
198+
psi = two_np1_over_x * psi_n - psi_nm1
199+
xi = two_np1_over_x * xi_n - xi_nm1
170200
xi_nm1 = xi_n
171201
xi_n = xi
172202
psi_nm1 = psi_n
173203
psi_n = psi
174204

175205
else:
176206
for n in range(1, nstop):
177-
a[n - 1] = (n * psi_n / x - psi_nm1) / (n * xi_n / x - xi_nm1)
207+
n_over_x = n * inv_x
208+
a[n - 1] = (n_over_x * psi_n - psi_nm1) / (n_over_x * xi_n - xi_nm1)
178209
b[n - 1] = psi_n / xi_n
179-
xi = (2 * n + 1) * xi_n / x - xi_nm1
210+
xi = (2 * n + 1) * inv_x * xi_n - xi_nm1
180211
xi_nm1 = xi_n
181212
xi_n = xi
182213
psi_nm1 = psi_n
@@ -216,35 +247,39 @@ def _cn_dn_py(m, x, n_pole):
216247
if x <= 0:
217248
return c, d
218249

250+
inv_x = 1.0 / x
219251
# no need to calculate anything when sphere is perfectly conducting
220252
if m.real > 0.0 and not np.isinf(m.real) or not np.isinf(m.imag):
221253
psi_nm1 = np.sin(x) # nm1 = n-1 = 0
222-
psi_n = psi_nm1 / x - np.cos(x)
254+
psi_n = psi_nm1 * inv_x - np.cos(x)
223255

256+
inv_mx = 1.0 / mx
224257
psi_nm1_mx = np.sin(mx) # nm1 = n-1 = 0
225-
psi_n_mx = psi_nm1_mx / mx - np.cos(mx)
258+
psi_n_mx = psi_nm1_mx * inv_mx - np.cos(mx)
226259

227260
xi_nm1 = np.complex128(psi_nm1 + 1j * np.cos(x))
228-
xi_n = np.complex128(psi_n + 1j * (np.cos(x) / x + np.sin(x)))
261+
xi_n = np.complex128(psi_n + 1j * (np.cos(x) * inv_x + np.sin(x)))
229262

230263
Dmx = _D_calc_py(np.complex128(m), x, nstop + 1)
231264
Dx = _D_calc_py(np.complex128(1), x, nstop + 1)
232265

233266
for n in range(1, nstop + 1):
234-
common = (psi_n / psi_n_mx) * ((Dx[n - 1] + n / x) * xi_n - xi_nm1)
267+
n_over_x = n * inv_x
268+
common = (psi_n / psi_n_mx) * ((Dx[n - 1] + n_over_x) * xi_n - xi_nm1)
235269

236-
c[n - 1] = m * common / ((m * Dmx[n - 1] + n / x) * xi_n - xi_nm1)
237-
d[n - 1] = common / ((Dmx[n - 1] / m + n / x) * xi_n - xi_nm1)
270+
c[n - 1] = m * common / ((m * Dmx[n - 1] + n_over_x) * xi_n - xi_nm1)
271+
d[n - 1] = common / ((Dmx[n - 1] / m + n_over_x) * xi_n - xi_nm1)
238272

239-
psi = (2 * n + 1) * psi_n / x - psi_nm1
273+
two_np1 = 2 * n + 1
274+
psi = two_np1 * inv_x * psi_n - psi_nm1
240275
psi_nm1 = psi_n
241276
psi_n = psi
242277

243-
psi_mx = (2 * n + 1) * psi_n_mx / mx - psi_nm1_mx
278+
psi_mx = two_np1 * inv_mx * psi_n_mx - psi_nm1_mx
244279
psi_nm1_mx = psi_n_mx
245280
psi_n_mx = psi_mx
246281

247-
xi = (2 * n + 1) * xi_n / x - xi_nm1
282+
xi = two_np1 * inv_x * xi_n - xi_nm1
248283
xi_nm1 = xi_n
249284
xi_n = xi
250285

@@ -310,8 +345,9 @@ def _S1_S2_py(m, x, mu, n_pole):
310345
N = len(a)
311346
pi = np.zeros(N)
312347
tau = np.zeros(N)
313-
n = np.arange(1, N + 1)
314-
scale = (2 * n + 1) / ((n + 1) * n)
348+
scale = _series_scale_factors(N)
349+
scale_a = scale * a
350+
scale_b = scale * b
315351

316352
nangles = len(mu)
317353
S1 = np.zeros(nangles, dtype=np.complex128)
@@ -320,8 +356,8 @@ def _S1_S2_py(m, x, mu, n_pole):
320356
for k in range(nangles):
321357
_pi_tau_py(mu[k], pi, tau)
322358
if n_pole == 0:
323-
S1[k] = np.sum(scale * (pi * a + tau * b))
324-
S2[k] = np.sum(scale * (tau * a + pi * b))
359+
S1[k] = np.dot(pi, scale_a) + np.dot(tau, scale_b)
360+
S2[k] = np.dot(tau, scale_a) + np.dot(pi, scale_b)
325361
else:
326362
S1[k] = scale[n_pole] * (pi[n_pole] * a[n_pole] + tau[n_pole] * b[n_pole])
327363
S2[k] = scale[n_pole] * (tau[n_pole] * a[n_pole] + pi[n_pole] * b[n_pole])
@@ -429,7 +465,7 @@ def _single_sphere_py(m, x, n_pole, e_field):
429465
qback: the backscatter efficiency
430466
g: the average cosine of the scattering phase function
431467
"""
432-
e_field = not e_field # unused
468+
_ = e_field # currently unused in scalar aggregate efficiencies
433469

434470
# case when sphere matches its environment
435471
if abs(m.real - 1) <= 1e-8 and abs(m.imag) < 1e-8:
@@ -447,33 +483,36 @@ def _single_sphere_py(m, x, n_pole, e_field):
447483
m = 1 - 10000j
448484

449485
a, b = _an_bn_py(m, x, n_pole)
486+
x2 = x * x
450487

451488
if n_pole == 0:
452-
n = np.arange(1, len(a) + 1)
453-
cn = 2.0 * n + 1.0
489+
n_terms = len(a)
490+
cn, alt, c1n, c2n = _single_sphere_factors(n_terms)
491+
a_re = a.real
492+
b_re = b.real
493+
a_abs2 = a_re * a_re + a.imag * a.imag
494+
b_abs2 = b_re * b_re + b.imag * b.imag
454495

455-
qext = 2 * np.sum(cn * (a.real + b.real)) / x**2
496+
qext = 2.0 * np.dot(cn, a_re + b_re) / x2
456497

457498
if m.imag == 0:
458499
qsca = qext
459500
else:
460-
qsca = 2 * np.sum(cn * (np.abs(a) ** 2 + np.abs(b) ** 2)) / x**2
501+
qsca = 2.0 * np.dot(cn, a_abs2 + b_abs2) / x2
461502

462-
qback = np.abs(np.sum((-1) ** n * cn * (a - b))) ** 2 / x**2
503+
qback = np.abs(np.dot(alt * cn, a - b)) ** 2 / x2
463504

464-
c1n = n * (n + 2) / (n + 1)
465-
c2n = cn / n / (n + 1)
466505
asy1 = c1n[:-1] * (a[:-1] * a[1:].conjugate() + b[:-1] * b[1:].conjugate()).real
467506
asy2 = c2n[:-1] * (a[:-1] * b[:-1].conjugate()).real
468-
g = 4 * np.sum(asy1 + asy2) / qsca / x**2
507+
g = 4.0 * np.sum(asy1 + asy2) / qsca / x2
469508

470509
else:
471510
cn = 2.0 * n_pole + 1
472-
qback = np.abs((-1) ** n_pole * cn * (a[-1] - b[-1])) ** 2 / x**2
473-
qext = 2 * cn * (a[-1].real + b[-1].real) / x**2
511+
qback = np.abs((-1) ** n_pole * cn * (a[-1] - b[-1])) ** 2 / x2
512+
qext = 2.0 * cn * (a[-1].real + b[-1].real) / x2
474513
qsca = qext
475514
if m.imag < 0:
476-
qsca = 2 * cn * (np.abs(a[-1]) ** 2 + np.abs(b[-1]) ** 2) / x**2
515+
qsca = 2.0 * cn * (np.abs(a[-1]) ** 2 + np.abs(b[-1]) ** 2) / x2
477516
g = None
478517

479518
return qext, qsca, qback, g

0 commit comments

Comments
 (0)