Skip to content

Commit f744f61

Browse files
authored
Merge pull request #20 from John-Ragland/add_tests
added tests
2 parents daa659d + 569c6c6 commit f744f61

8 files changed

Lines changed: 872 additions & 1 deletion

File tree

.github/workflows/tests.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
8+
jobs:
9+
test:
10+
runs-on: ubuntu-latest
11+
12+
steps:
13+
- uses: actions/checkout@v4
14+
15+
- name: Install uv
16+
uses: astral-sh/setup-uv@v5
17+
18+
- name: Set up Python
19+
uses: actions/setup-python@v5
20+
with:
21+
python-version: "3.12"
22+
23+
- name: Install dependencies
24+
run: uv sync --group dev
25+
26+
- name: Run tests
27+
run: uv run pytest tests/

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ __pycache__/
33
.DS_Store
44
docs/_build/
55
dist/
6-
.vscode/
6+
.vscode/
7+
uv.lock

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,9 @@ exclude_lines = [
7474
"if __name__ == .__main__.:",
7575
"if TYPE_CHECKING:",
7676
]
77+
78+
[dependency-groups]
79+
dev = [
80+
"pytest>=8.3.5",
81+
"pytest-cov>=5.0.0",
82+
]

tests/conftest.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
Shared test fixtures for pygenray tests.
3+
"""
4+
import matplotlib
5+
matplotlib.use('Agg')
6+
7+
8+
def pytest_addoption(parser):
9+
"""Register --regenerate-physics CLI flag for physics regression tests."""
10+
parser.addoption(
11+
'--regenerate-physics', action='store_true', default=False,
12+
help='Regenerate physics regression fixture and skip comparison.',
13+
)
14+
15+
import numpy as np
16+
import pytest
17+
import xarray as xr
18+
19+
from pygenray.ray_objects import Ray, RayFan
20+
21+
22+
def _make_ray(launch_angle: float, source_depth: float, n_bottom: int = 0,
23+
n_surface: int = 0, N: int = 10, R: float = 10000.0) -> Ray:
24+
"""Helper: build a synthetic Ray without running the ODE solver.
25+
26+
The y array uses the positive-z convention expected by Ray.__init__:
27+
y[0,:] = travel times
28+
y[1,:] = depth (positive = deeper)
29+
y[2,:] = ray parameter sin(θ)/c (positive for downward ray in ODE)
30+
"""
31+
r = np.linspace(0.0, R, N)
32+
t = r / 1500.0
33+
# Depth increases linearly (simulating a shallow downward ray)
34+
z_ode = np.linspace(source_depth, source_depth + R * 0.01, N)
35+
p_ode = np.ones(N) * np.sin(np.radians(abs(launch_angle) + 1e-3)) / 1500.0
36+
y = np.vstack([t, z_ode, p_ode]) # shape (3, N)
37+
return Ray(r=r, y=y, n_bottom=n_bottom, n_surface=n_surface,
38+
launch_angle=launch_angle, source_depth=source_depth)
39+
40+
41+
@pytest.fixture
42+
def simple_ray():
43+
"""Single synthetic Ray with 10 range steps."""
44+
return _make_ray(launch_angle=-10.0, source_depth=100.0)
45+
46+
47+
@pytest.fixture
48+
def simple_rayfan():
49+
"""Small RayFan of 3 synthetic Rays — no ray tracing needed."""
50+
rays = [
51+
_make_ray(launch_angle=-5.0, source_depth=100.0, n_bottom=0),
52+
_make_ray(launch_angle=5.0, source_depth=150.0, n_bottom=1),
53+
_make_ray(launch_angle=-10.0, source_depth=200.0, n_bottom=0),
54+
]
55+
return RayFan(rays)

tests/fixtures/munk_regression.npz

7.41 KB
Binary file not shown.

tests/test_environment.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""
2+
Tests for pygenray.environment: munk_ssp, OceanEnvironment2D, eflat, eflatinv.
3+
"""
4+
import numpy as np
5+
import pytest
6+
import xarray as xr
7+
from matplotlib import pyplot as plt
8+
9+
from pygenray.environment import (
10+
OceanEnvironment2D,
11+
eflat,
12+
eflatinv,
13+
munk_ssp,
14+
)
15+
16+
17+
# ---------------------------------------------------------------------------
18+
# munk_ssp
19+
# ---------------------------------------------------------------------------
20+
21+
class TestMunkSSP:
22+
def test_output_shape_matches_input(self):
23+
z = np.arange(0, 5000, 10)
24+
c = munk_ssp(z)
25+
assert c.shape == z.shape
26+
27+
def test_minimum_at_sofar_depth(self):
28+
sofar = 1300.0
29+
z = np.arange(0, 6000, 1)
30+
c = munk_ssp(z, sofar_depth=sofar)
31+
# Minimum should be at the SOFAR depth
32+
assert z[np.argmin(c)] == pytest.approx(sofar, abs=2.0)
33+
34+
def test_default_params_near_1500_at_sofar(self):
35+
sofar = 1300.0
36+
c_sofar = munk_ssp(np.array([sofar]))
37+
assert c_sofar[0] == pytest.approx(1500.0, abs=5.0)
38+
39+
def test_scalar_input(self):
40+
c = munk_ssp(np.array([0.0]))
41+
assert c.shape == (1,)
42+
43+
44+
# ---------------------------------------------------------------------------
45+
# OceanEnvironment2D construction
46+
# ---------------------------------------------------------------------------
47+
48+
class TestOceanEnvironment2DConstruction:
49+
def test_default_init_attributes_exist(self):
50+
env = OceanEnvironment2D()
51+
for attr in ('sound_speed', 'bathymetry', 'dcdz', 'bottom_angle',
52+
'bottom_angle_interp'):
53+
assert hasattr(env, attr), f"Missing attribute: {attr}"
54+
55+
def test_default_sound_speed_is_2d(self):
56+
env = OceanEnvironment2D()
57+
assert env.sound_speed.ndim == 2
58+
assert set(env.sound_speed.dims) == {'range', 'depth'}
59+
60+
def test_default_flat_earth_attributes_exist(self):
61+
env = OceanEnvironment2D(flat_earth_transform=True)
62+
assert hasattr(env, 'sound_speed_fe')
63+
assert hasattr(env, 'bathymetry_fe')
64+
65+
def test_flat_earth_false_no_fe_attributes(self):
66+
env = OceanEnvironment2D(flat_earth_transform=False)
67+
assert not hasattr(env, 'sound_speed_fe')
68+
assert not hasattr(env, 'bathymetry_fe')
69+
70+
def test_custom_1d_sound_speed(self):
71+
z = np.arange(0.0, 3000.0, 10.0)
72+
c_vals = munk_ssp(z)
73+
ssp = xr.DataArray(c_vals, dims=['depth'], coords={'depth': z})
74+
bathy = xr.DataArray(
75+
np.ones(20) * 4000.0, dims=['range'],
76+
coords={'range': np.linspace(0, 50e3, 20)}
77+
)
78+
env = OceanEnvironment2D(sound_speed=ssp, bathymetry=bathy,
79+
flat_earth_transform=False)
80+
assert env.sound_speed.ndim == 1
81+
assert 'depth' in env.sound_speed.dims
82+
83+
def test_custom_2d_sound_speed(self):
84+
z = np.arange(0.0, 3000.0, 50.0)
85+
r = np.linspace(0.0, 50e3, 20)
86+
c_2d = np.outer(np.ones(len(r)), munk_ssp(z))
87+
ssp = xr.DataArray(c_2d, dims=['range', 'depth'],
88+
coords={'range': r, 'depth': z})
89+
env = OceanEnvironment2D(sound_speed=ssp, flat_earth_transform=False)
90+
assert env.sound_speed.ndim == 2
91+
92+
def test_custom_bathymetry_stored(self):
93+
bathy_vals = np.ones(20) * 3500.0
94+
r = np.linspace(0.0, 50e3, 20)
95+
bathy = xr.DataArray(bathy_vals, dims=['range'], coords={'range': r})
96+
env = OceanEnvironment2D(bathymetry=bathy, flat_earth_transform=False)
97+
np.testing.assert_array_equal(env.bathymetry.values, bathy_vals)
98+
99+
# --- invalid inputs ---
100+
101+
def test_sound_speed_not_dataarray_raises_type_error(self):
102+
with pytest.raises(TypeError):
103+
OceanEnvironment2D(sound_speed=np.ones(100))
104+
105+
def test_sound_speed_3d_raises_value_error(self):
106+
da = xr.DataArray(
107+
np.ones((5, 10, 20)),
108+
dims=['range', 'depth', 'extra'],
109+
coords={'range': np.arange(5), 'depth': np.arange(10),
110+
'extra': np.arange(20)}
111+
)
112+
with pytest.raises(ValueError):
113+
OceanEnvironment2D(sound_speed=da)
114+
115+
def test_sound_speed_missing_depth_dim_raises_value_error(self):
116+
da = xr.DataArray(np.ones(50), dims=['range'],
117+
coords={'range': np.arange(50)})
118+
with pytest.raises(ValueError):
119+
OceanEnvironment2D(sound_speed=da)
120+
121+
def test_2d_sound_speed_missing_range_dim_raises_value_error(self):
122+
da = xr.DataArray(
123+
np.ones((10, 20)),
124+
dims=['depth', 'extra'],
125+
coords={'depth': np.arange(10), 'extra': np.arange(20)}
126+
)
127+
with pytest.raises(ValueError):
128+
OceanEnvironment2D(sound_speed=da)
129+
130+
def test_bathymetry_not_dataarray_raises_type_error(self):
131+
with pytest.raises(TypeError):
132+
OceanEnvironment2D(bathymetry=np.ones(50))
133+
134+
def test_bathymetry_missing_range_dim_raises_value_error(self):
135+
da = xr.DataArray(np.ones(50), dims=['depth'],
136+
coords={'depth': np.arange(50)})
137+
with pytest.raises(ValueError):
138+
OceanEnvironment2D(bathymetry=da)
139+
140+
141+
# ---------------------------------------------------------------------------
142+
# eflat / eflatinv round-trip
143+
# ---------------------------------------------------------------------------
144+
145+
class TestEflat:
146+
LAT = 35.0
147+
148+
def test_depth_roundtrip(self):
149+
dep = np.array([100.0, 500.0, 1000.0, 2000.0, 4000.0])
150+
depf, _ = eflat(dep, self.LAT)
151+
dep_rec, _ = eflatinv(depf, np.array([self.LAT]))
152+
np.testing.assert_allclose(dep_rec, dep, atol=1.0,
153+
err_msg="Depth round-trip outside 1 m tolerance")
154+
155+
def test_sound_speed_roundtrip(self):
156+
dep = np.array([100.0, 500.0, 1000.0, 2000.0])
157+
cs = np.array([1500.0, 1490.0, 1480.0, 1510.0])
158+
depf, csf = eflat(dep, self.LAT, cs)
159+
_, cs_rec = eflatinv(depf, np.array([self.LAT]), csf)
160+
np.testing.assert_allclose(cs_rec, cs, rtol=1e-4,
161+
err_msg="Sound speed round-trip outside 0.01% tolerance")
162+
163+
def test_eflat_increases_depth(self):
164+
"""Flat-earth transformation should increase effective depths."""
165+
dep = np.array([100.0, 1000.0, 3000.0])
166+
depf, _ = eflat(dep, self.LAT)
167+
assert np.all(depf > dep)
168+
169+
170+
# ---------------------------------------------------------------------------
171+
# OceanEnvironment2D.plot smoke test
172+
# ---------------------------------------------------------------------------
173+
174+
class TestOceanEnvironment2DPlot:
175+
def test_plot_runs_without_error(self):
176+
env = OceanEnvironment2D()
177+
fig, ax = plt.subplots()
178+
plt.sca(ax)
179+
env.plot()
180+
plt.close('all')

0 commit comments

Comments
 (0)