Skip to content

Commit f52e8e8

Browse files
feat: improve phase interpolation when using "polar" interpolation method (#69)
1 parent bbeac62 commit f52e8e8

2 files changed

Lines changed: 22 additions & 21 deletions

File tree

src/waveresponse/_core.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -623,12 +623,13 @@ def _interpolate_function(self, complex_convert="rectangular", **kw):
623623
return RGI((xp, yp), zp.T, **kw)
624624
elif complex_convert.lower() == "polar":
625625
amp, phase = complex_to_polar(zp, phase_degrees=False)
626+
phase_complex = np.cos(phase) + 1j * np.sin(phase)
626627
interp_amp = RGI((xp, yp), amp.T, **kw)
627-
interp_phase = RGI((xp, yp), phase.T, **kw)
628+
interp_phase = RGI((xp, yp), phase_complex.T, **kw)
628629
return lambda *args, **kwargs: (
629630
polar_to_complex(
630631
interp_amp(*args, **kwargs),
631-
interp_phase(*args, **kwargs),
632+
np.angle(interp_phase(*args, **kwargs)),
632633
phase_degrees=False,
633634
)
634635
)
@@ -684,13 +685,6 @@ def interpolate(
684685
-------
685686
array :
686687
Interpolated grid values.
687-
688-
Notes
689-
-----
690-
Apply 'polar' interpolation with caution as phase values are not "unwraped"
691-
before interpolation. This may lead to some unexpected artifacts in the
692-
results.
693-
694688
"""
695689
freq = np.asarray_chkfinite(freq).reshape(-1)
696690
dirs = np.asarray_chkfinite(dirs).reshape(-1)

tests/test_core.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import pytest
88
from scipy.integrate import quad
9+
from scipy.interpolate import RegularGridInterpolator as RGI
910

1011
import waveresponse as wr
1112
from waveresponse import (
@@ -1498,16 +1499,19 @@ def test_interpolate_complex_polar(self):
14981499
vp = vp_amp * (np.cos(vp_phase) + 1j * np.sin(vp_phase))
14991500
grid = Grid(yp, xp, vp, freq_hz=True, degrees=True)
15001501

1501-
y = np.linspace(0.5, 1.0, 20)
1502-
x = np.linspace(5.0, 15.0, 10)
1502+
y = np.linspace(0.0, 2.0, 200)
1503+
x = np.linspace(0.0, 359.0, 100)
15031504
vals_amp_expect = np.array(
15041505
[[a_amp * x_i + b_amp * y_i for x_i in x] for y_i in y]
15051506
)
1506-
vals_phase_expect = np.array(
1507-
[[a_phase * x_i + b_phase * y_i for x_i in x] for y_i in y]
1508-
)
1509-
vals_expect = vals_amp_expect * (
1510-
np.cos(vals_phase_expect) + 1j * np.sin(vals_phase_expect)
1507+
x_, y_ = np.meshgrid(x, y, indexing="ij", sparse=True)
1508+
vals_phase_cos_expect = RGI((xp, yp), np.cos(vp_phase).T)((x_, y_)).T
1509+
vals_phase_sin_expect = RGI((xp, yp), np.sin(vp_phase).T)((x_, y_)).T
1510+
1511+
vals_expect = (
1512+
vals_amp_expect
1513+
* (vals_phase_cos_expect + 1j * vals_phase_sin_expect)
1514+
/ np.abs(vals_phase_cos_expect + 1j * vals_phase_sin_expect)
15111515
)
15121516

15131517
vals_out = grid.interpolate(
@@ -1692,11 +1696,14 @@ def test_reshape_complex_polar(self):
16921696
vals_amp_expect = np.array(
16931697
[[a_amp * x_i + b_amp * y_i for x_i in x] for y_i in y]
16941698
)
1695-
vals_phase_expect = np.array(
1696-
[[a_phase * x_i + b_phase * y_i for x_i in x] for y_i in y]
1697-
)
1698-
vals_expect = vals_amp_expect * (
1699-
np.cos(vals_phase_expect) + 1j * np.sin(vals_phase_expect)
1699+
x_, y_ = np.meshgrid(x, y, indexing="ij", sparse=True)
1700+
vals_phase_cos_expect = RGI((xp, yp), np.cos(vp_phase).T)((x_, y_)).T
1701+
vals_phase_sin_expect = RGI((xp, yp), np.sin(vp_phase).T)((x_, y_)).T
1702+
1703+
vals_expect = (
1704+
vals_amp_expect
1705+
* (vals_phase_cos_expect + 1j * vals_phase_sin_expect)
1706+
/ np.abs(vals_phase_cos_expect + 1j * vals_phase_sin_expect)
17001707
)
17011708

17021709
np.testing.assert_array_almost_equal(freq_out, freq_expect)

0 commit comments

Comments
 (0)