Skip to content

Commit 9295e31

Browse files
fix: fixed to base-grid (#72)
1 parent 5392920 commit 9295e31

2 files changed

Lines changed: 28 additions & 80 deletions

File tree

src/waveresponse/_core.py

Lines changed: 14 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,9 @@ def __init__(
278278
clockwise=False,
279279
waves_coming_from=True,
280280
):
281-
self._freq = np.asarray_chkfinite(freq).ravel().copy()
282-
self._dirs = np.asarray_chkfinite(dirs).ravel().copy()
283-
self._vals = (
284-
np.asarray_chkfinite(vals)
285-
.reshape((len(self._freq), len(self._dirs)))
286-
.copy()
287-
)
281+
self._freq = np.asarray_chkfinite(freq).copy()
282+
self._dirs = np.asarray_chkfinite(dirs).copy()
283+
self._vals = np.asarray_chkfinite(vals).copy()
288284
self._clockwise = clockwise
289285
self._waves_coming_from = waves_coming_from
290286
self._freq_hz = freq_hz
@@ -298,6 +294,11 @@ def __init__(
298294

299295
self._check_freq(self._freq)
300296
self._check_dirs(self._dirs)
297+
if self._vals.shape != (len(self._freq), len(self._dirs)):
298+
raise ValueError(
299+
"Values must have shape shape (N, M), such that ``N=len(freq)`` "
300+
"and ``M=len(dirs)``."
301+
)
301302

302303
def __repr__(self):
303304
return "_BaseGrid"
@@ -328,6 +329,9 @@ def _check_freq(self, freq):
328329
"""
329330
Check frequency bins.
330331
"""
332+
if freq.ndim != 1:
333+
raise ValueError("`freq` must be 1 dimensional.")
334+
331335
if np.any(freq[:-1] >= freq[1:]) or freq[0] < 0:
332336
raise ValueError(
333337
"Frequencies must be positive and monotonically increasing."
@@ -337,6 +341,9 @@ def _check_dirs(self, dirs):
337341
"""
338342
Check direction bins.
339343
"""
344+
if dirs.ndim != 1:
345+
raise ValueError("`dirs` must be 1 dimensional.")
346+
340347
if np.any(dirs[:-1] >= dirs[1:]) or dirs[0] < 0 or dirs[-1] >= 2.0 * np.pi:
341348
raise ValueError(
342349
"Directions must be positive, monotonically increasing, and "
@@ -531,47 +538,6 @@ def rotate(self, angle, degrees=False):
531538
new._dirs, new._vals = _sort(dirs_new, new._vals)
532539
return new
533540

534-
def _interpolate_function(self, complex_convert="rectangular", **kw):
535-
"""
536-
Interpolation function based on ``scipy.interpolate.RegularGridInterpolator``.
537-
"""
538-
xp = np.concatenate(
539-
(self._dirs[-1:] - 2 * np.pi, self._dirs, self._dirs[:1] + 2.0 * np.pi)
540-
)
541-
542-
yp = self._freq
543-
zp = np.concatenate(
544-
(
545-
self._vals[:, -1:],
546-
self._vals,
547-
self._vals[:, :1],
548-
),
549-
axis=1,
550-
)
551-
552-
if np.all(np.isreal(zp)):
553-
return RGI((xp, yp), zp.T, **kw)
554-
elif complex_convert.lower() == "polar":
555-
amp, phase = complex_to_polar(zp, phase_degrees=False)
556-
phase_complex = np.cos(phase) + 1j * np.sin(phase)
557-
interp_amp = RGI((xp, yp), amp.T, **kw)
558-
interp_phase = RGI((xp, yp), phase_complex.T, **kw)
559-
return lambda *args, **kwargs: (
560-
polar_to_complex(
561-
interp_amp(*args, **kwargs),
562-
np.angle(interp_phase(*args, **kwargs)),
563-
phase_degrees=False,
564-
)
565-
)
566-
elif complex_convert.lower() == "rectangular":
567-
interp_real = RGI((xp, yp), np.real(zp.T), **kw)
568-
interp_imag = RGI((xp, yp), np.imag(zp.T), **kw)
569-
return lambda *args, **kwargs: (
570-
interp_real(*args, **kwargs) + 1j * interp_imag(*args, **kwargs)
571-
)
572-
else:
573-
raise ValueError("Unknown 'complex_convert' type")
574-
575541
def __mul__(self, other):
576542
"""
577543
Multiply values (element-wise).
@@ -717,38 +683,6 @@ class Grid(_BaseGrid):
717683
convention is assumed.
718684
"""
719685

720-
def __init__(
721-
self,
722-
freq,
723-
dirs,
724-
vals,
725-
freq_hz=False,
726-
degrees=False,
727-
clockwise=False,
728-
waves_coming_from=True,
729-
):
730-
self._freq = np.asarray_chkfinite(freq).copy()
731-
self._dirs = np.asarray_chkfinite(dirs).copy()
732-
self._vals = np.asarray_chkfinite(vals).copy()
733-
self._clockwise = clockwise
734-
self._waves_coming_from = waves_coming_from
735-
self._freq_hz = freq_hz
736-
self._degrees = degrees
737-
738-
if freq_hz:
739-
self._freq = 2.0 * np.pi * self._freq
740-
741-
if degrees:
742-
self._dirs = (np.pi / 180.0) * self._dirs
743-
744-
self._check_freq(self._freq)
745-
self._check_dirs(self._dirs)
746-
if self._vals.shape != (len(self._freq), len(self._dirs)):
747-
raise ValueError(
748-
"Values must have shape shape (N, M), such that ``N=len(freq)`` "
749-
"and ``M=len(dirs)``."
750-
)
751-
752686
def __repr__(self):
753687
return "Grid"
754688

tests/test_core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,20 @@ def test__init__raises_dirs_greater_than_2pi(self):
831831
vals = np.zeros((3, 4))
832832
_BaseGrid(freq, dirs, vals, degrees=False)
833833

834+
def test__init__raises_freq_not_1d(self):
835+
with pytest.raises(ValueError, match="`freq` must be 1 dimensional."):
836+
freq = np.array([[0, 1], [2, 3]])
837+
dirs = np.array([0.0, 1.0, 1.5, 2.0])
838+
vals = np.zeros((4, 4))
839+
_BaseGrid(freq, dirs, vals)
840+
841+
def test__init__raises_dirs_not_1d(self):
842+
with pytest.raises(ValueError, match="`dirs` must be 1 dimensional."):
843+
freq = np.array([0, 1, 2, 3])
844+
dirs = np.array([[0.0, 1.0], [1.5, 2.0]])
845+
vals = np.zeros((4, 4))
846+
_BaseGrid(freq, dirs, vals)
847+
834848
def test__init__raises_vals_shape(self):
835849
with pytest.raises(ValueError):
836850
freq = np.array([0, 1, 2])

0 commit comments

Comments
 (0)