Skip to content

Commit 81bf68a

Browse files
feat: extend available output_types for multiply with binned grid types (#80)
1 parent 476da3d commit 81bf68a

2 files changed

Lines changed: 122 additions & 25 deletions

File tree

src/waveresponse/_core.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def check_type(grid, grid_type):
4343
raise ValueError("Grid objects have different wave conventions.")
4444

4545

46-
def multiply(grid1, grid2, output_type="grid"):
46+
def multiply(grid1, grid2, output_type="Grid"):
4747
"""
4848
Multiply values (element-wise).
4949
@@ -53,23 +53,29 @@ def multiply(grid1, grid2, output_type="grid"):
5353
Grid object.
5454
grid2 : obj
5555
Grid object.
56-
output_type : str {"grid", "rao", "directional_spectrum", "wave_spectrum"}
56+
output_type : {'Grid', 'RAO', 'DirectionalSpectrum', 'WaveSpectrum', 'DirectionalBinSpectrum', 'WaveBinSpectrum'}
5757
Output grid type.
5858
"""
5959

6060
TYPE_MAP = {
61-
"grid": Grid,
62-
"rao": RAO,
63-
"directional_spectrum": DirectionalSpectrum,
64-
"wave_spectrum": WaveSpectrum,
61+
"Grid": Grid,
62+
"RAO": RAO,
63+
"DirectionalSpectrum": DirectionalSpectrum,
64+
"DirectionalBinSpectrum": DirectionalBinSpectrum,
65+
"WaveSpectrum": WaveSpectrum,
66+
"WaveBinSpectrum": WaveBinSpectrum,
67+
"grid": Grid, # for backward compatibility
68+
"rao": RAO, # for backward compatibility
69+
"directional_spectrum": DirectionalSpectrum, # for backward compatibility
70+
"wave_spectrum": WaveSpectrum, # for backward compatibility
6571
}
6672

67-
if output_type not in TYPE_MAP:
68-
raise ValueError("The given `output_type` is not valid.")
73+
output_type_ = TYPE_MAP.get(output_type, output_type)
6974

70-
_check_is_similar(grid1, grid2, exact_type=False)
75+
if not (isinstance(output_type_, type) and issubclass(output_type_, Grid)):
76+
raise ValueError(f"Invalid `output_type`: {output_type_!r}")
7177

72-
type_ = TYPE_MAP.get(output_type)
78+
_check_is_similar(grid1, grid2, exact_type=False)
7379

7480
freq = grid1._freq
7581
dirs = grid1._dirs
@@ -85,7 +91,7 @@ def multiply(grid1, grid2, output_type="grid"):
8591
**convention,
8692
)
8793

88-
return type_.from_grid(new)
94+
return output_type_.from_grid(new)
8995

9096

9197
def _cast_to_grid(grid):

tests/test_core.py

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,9 @@ def test_rao_and_rao_to_default_grid(self, rao):
189189
np.testing.assert_array_almost_equal(out._dirs, rao._dirs)
190190
np.testing.assert_array_almost_equal(out._vals, vals_expect)
191191

192-
def test_grid_and_grid_to_grid(self, grid):
193-
out = wr.multiply(grid, grid.copy(), output_type="grid")
192+
@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
193+
def test_grid_and_grid_to_grid(self, output_type, grid):
194+
out = wr.multiply(grid, grid.copy(), output_type=output_type)
194195

195196
vals_expect = grid._vals * grid._vals
196197

@@ -203,8 +204,9 @@ def test_grid_and_grid_to_grid(self, grid):
203204
np.testing.assert_array_almost_equal(out._dirs, grid._dirs)
204205
np.testing.assert_array_almost_equal(out._vals, vals_expect)
205206

206-
def test_rao_and_wave_to_grid(self, rao, wave):
207-
out = wr.multiply(rao, wave, output_type="grid")
207+
@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
208+
def test_rao_and_wave_to_grid(self, output_type, rao, wave):
209+
out = wr.multiply(rao, wave, output_type=output_type)
208210

209211
vals_expect = rao._vals * wave._vals
210212

@@ -217,8 +219,9 @@ def test_rao_and_wave_to_grid(self, rao, wave):
217219
np.testing.assert_array_almost_equal(out._dirs, rao._dirs)
218220
np.testing.assert_array_almost_equal(out._vals, vals_expect)
219221

220-
def test_rao_and_rao_to_rao(self, rao):
221-
out = wr.multiply(rao, rao.copy(), output_type="rao")
222+
@pytest.mark.parametrize("output_type", ("rao", "RAO", RAO))
223+
def test_rao_and_rao_to_rao(self, output_type, rao):
224+
out = wr.multiply(rao, rao.copy(), output_type=output_type)
222225

223226
vals_expect = rao._vals * rao._vals
224227

@@ -231,8 +234,9 @@ def test_rao_and_rao_to_rao(self, rao):
231234
np.testing.assert_array_almost_equal(out._dirs, rao._dirs)
232235
np.testing.assert_array_almost_equal(out._vals, vals_expect)
233236

234-
def test_rao_and_rao_to_grid(self, rao):
235-
out = wr.multiply(rao, rao.copy(), output_type="grid")
237+
@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
238+
def test_rao_and_rao_to_grid(self, output_type, rao):
239+
out = wr.multiply(rao, rao.copy(), output_type=output_type)
236240

237241
vals_expect = rao._vals * rao._vals
238242

@@ -246,8 +250,11 @@ def test_rao_and_rao_to_grid(self, rao):
246250
np.testing.assert_array_almost_equal(out._dirs, rao._dirs)
247251
np.testing.assert_array_almost_equal(out._vals, vals_expect)
248252

249-
def test_wave_and_wave_to_wave(self, wave):
250-
out = wr.multiply(wave, wave.copy(), output_type="wave_spectrum")
253+
@pytest.mark.parametrize(
254+
"output_type", ("wave_spectrum", "WaveSpectrum", WaveSpectrum)
255+
)
256+
def test_wave_and_wave_to_wave(self, output_type, wave):
257+
out = wr.multiply(wave, wave.copy(), output_type=output_type)
251258

252259
vals_expect = wave._vals * wave._vals
253260

@@ -260,8 +267,12 @@ def test_wave_and_wave_to_wave(self, wave):
260267
np.testing.assert_array_almost_equal(out._dirs, wave._dirs)
261268
np.testing.assert_array_almost_equal(out._vals, vals_expect)
262269

263-
def test_wave_and_wave_to_dir_spectrum(self, wave):
264-
out = wr.multiply(wave, wave.copy(), output_type="directional_spectrum")
270+
@pytest.mark.parametrize(
271+
"output_type",
272+
("directional_spectrum", "DirectionalSpectrum", DirectionalSpectrum),
273+
)
274+
def test_wave_and_wave_to_dir_spectrum(self, output_type, wave):
275+
out = wr.multiply(wave, wave.copy(), output_type=output_type)
265276

266277
vals_expect = wave._vals * wave._vals
267278

@@ -275,8 +286,9 @@ def test_wave_and_wave_to_dir_spectrum(self, wave):
275286
np.testing.assert_array_almost_equal(out._dirs, wave._dirs)
276287
np.testing.assert_array_almost_equal(out._vals, vals_expect)
277288

278-
def test_wave_and_wave_to_grid(self, wave):
279-
out = wr.multiply(wave, wave.copy(), output_type="grid")
289+
@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
290+
def test_wave_and_wave_to_grid(self, output_type, wave):
291+
out = wr.multiply(wave, wave.copy(), output_type=output_type)
280292

281293
vals_expect = wave._vals * wave._vals
282294

@@ -290,6 +302,85 @@ def test_wave_and_wave_to_grid(self, wave):
290302
np.testing.assert_array_almost_equal(out._dirs, wave._dirs)
291303
np.testing.assert_array_almost_equal(out._vals, vals_expect)
292304

305+
@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
306+
def test_wavebin_and_wavebin_to_grid(self, output_type, wavebin):
307+
grid1 = wavebin.copy()
308+
grid2 = wavebin.copy()
309+
out = wr.multiply(grid1, grid2, output_type=output_type)
310+
311+
vals_expect = wavebin._vals * wavebin._vals
312+
313+
assert isinstance(out, Grid)
314+
assert not isinstance(out, WaveSpectrum)
315+
assert out._freq_hz is False
316+
assert out._degrees is False
317+
assert out._clockwise == grid1._clockwise
318+
assert out._waves_coming_from == grid1._waves_coming_from
319+
np.testing.assert_array_almost_equal(out._freq, grid1._freq)
320+
np.testing.assert_array_almost_equal(out._dirs, grid1._dirs)
321+
np.testing.assert_array_almost_equal(out._vals, vals_expect)
322+
323+
@pytest.mark.parametrize("output_type", ("WaveBinSpectrum", WaveBinSpectrum))
324+
def test_wavebin_and_wavebin_to_wavebin(self, output_type, wavebin):
325+
grid1 = wavebin.copy()
326+
grid2 = wavebin.copy()
327+
out = wr.multiply(grid1, grid2, output_type=output_type)
328+
329+
vals_expect = wavebin._vals * wavebin._vals
330+
331+
assert isinstance(out, Grid)
332+
assert not isinstance(out, WaveSpectrum)
333+
assert out._freq_hz is False
334+
assert out._degrees is False
335+
assert out._clockwise == grid1._clockwise
336+
assert out._waves_coming_from == grid1._waves_coming_from
337+
np.testing.assert_array_almost_equal(out._freq, grid1._freq)
338+
np.testing.assert_array_almost_equal(out._dirs, grid1._dirs)
339+
np.testing.assert_array_almost_equal(out._vals, vals_expect)
340+
341+
@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
342+
def test_binspectrum_and_binspectrum_to_grid(
343+
self, output_type, directional_bin_spectrum
344+
):
345+
grid1 = directional_bin_spectrum.copy()
346+
grid2 = directional_bin_spectrum.copy()
347+
out = wr.multiply(grid1, grid2, output_type=output_type)
348+
349+
vals_expect = grid1._vals * grid2._vals
350+
351+
assert isinstance(out, Grid)
352+
assert not isinstance(out, WaveSpectrum)
353+
assert out._freq_hz is False
354+
assert out._degrees is False
355+
assert out._clockwise == grid1._clockwise
356+
assert out._waves_coming_from == grid1._waves_coming_from
357+
np.testing.assert_array_almost_equal(out._freq, grid1._freq)
358+
np.testing.assert_array_almost_equal(out._dirs, grid1._dirs)
359+
np.testing.assert_array_almost_equal(out._vals, vals_expect)
360+
361+
@pytest.mark.parametrize(
362+
"output_type",
363+
("DirectionalBinSpectrum", DirectionalBinSpectrum),
364+
)
365+
def test_binspectrum_and_binspectrum_to_binspectrum(
366+
self, output_type, directional_bin_spectrum
367+
):
368+
grid1 = directional_bin_spectrum.copy()
369+
grid2 = directional_bin_spectrum.copy()
370+
out = wr.multiply(grid1, grid2, output_type=output_type)
371+
372+
vals_expect = grid1._vals * grid2._vals
373+
374+
assert isinstance(out, Grid)
375+
assert not isinstance(out, WaveSpectrum)
376+
assert out._freq_hz is False
377+
assert out._degrees is False
378+
assert out._clockwise == grid1._clockwise
379+
assert out._waves_coming_from == grid1._waves_coming_from
380+
np.testing.assert_array_almost_equal(out._freq, grid1._freq)
381+
np.testing.assert_array_almost_equal(out._dirs, grid1._dirs)
382+
np.testing.assert_array_almost_equal(out._vals, vals_expect)
383+
293384
def test_raises_output_type(self, grid):
294385
with pytest.raises(ValueError):
295386
wr.multiply(grid, grid.copy(), output_type="invalid-type")

0 commit comments

Comments
 (0)