Skip to content

Commit 34b9405

Browse files
committed
Update: tests for OCX and QAA
1 parent c670fff commit 34b9405

2 files changed

Lines changed: 104 additions & 11 deletions

File tree

test/oceancolour/test_ocx.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from uncertaintyx.oceancolour.ocx import OCI
1313

1414

15-
def read_test_data(
15+
def read_owt_data(
1616
package: str, filename: str
1717
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int, int]:
1818
"""
@@ -39,10 +39,11 @@ class OCxTest(unittest.TestCase):
3939
"""Tests OCX model functions on optical water type classes."""
4040

4141
def test_ci(self):
42-
w, R, u, M, m = read_test_data( # noqa : N806
42+
"""Tests the chlorophyll index (CI) model function."""
43+
w, R, u, M, m = read_owt_data( # noqa : N806
4344
"test.resources", "owt.csv"
4445
)
45-
W = np.broadcast_to(w, R.shape) # noqa : N806
46+
W = np.broadcast_to(w, (M, m)) # noqa : N806
4647

4748
f = CI()
4849
x = np.stack([W[:, [1, 4, 5]], R[:, [1, 4, 5]]], axis=1)
@@ -89,7 +90,8 @@ def test_ci(self):
8990
self.assertTrue(np.isfinite(y[13]))
9091

9192
def test_oc4(self):
92-
_, R, u, M, m = read_test_data( # noqa : N806
93+
"""Tests the OC4 model function."""
94+
_, R, u, M, m = read_owt_data( # noqa : N806
9395
"test.resources", "owt.csv"
9496
)
9597

@@ -136,10 +138,11 @@ def test_oc4(self):
136138
self.assertAlmostEqual(4.362, v[13], delta=0.001)
137139

138140
def test_oci(self):
139-
w, R, u, M, m = read_test_data( # noqa : N806
141+
"""Tests the OCI model function."""
142+
w, R, u, M, m = read_owt_data( # noqa : N806
140143
"test.resources", "owt.csv"
141144
)
142-
W = np.broadcast_to(w, R.shape) # noqa : N806
145+
W = np.broadcast_to(w, (M, m)) # noqa : N806
143146

144147
f = OCI()
145148
x = np.stack([W[:, 1:], R[:, 1:]], axis=1)
@@ -164,10 +167,10 @@ def test_oci(self):
164167
self.assertAlmostEqual(3.427, y[13], delta=0.001)
165168

166169
U = np.square(u) # noqa : N806
167-
V = f.lpu_x(p, x, U) # noqa : N806
168-
self.assertEqual((M,), V.shape)
170+
U = f.lpu_x(p, x, U) # noqa : N806
171+
self.assertEqual((M,), U.shape)
169172

170-
v = np.sqrt(V)
173+
v = np.sqrt(U)
171174
self.assertAlmostEqual(0.026, v[0], delta=0.001)
172175
self.assertAlmostEqual(0.024, v[1], delta=0.001)
173176
self.assertAlmostEqual(0.042, v[2], delta=0.001)

test/oceancolour/test_qaa.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,29 @@ def matrix(result: Result, a: Any, b: Any, n: int = 1000) -> np.ndarray:
4747
return np.squeeze(result.ycov_p(np.linspace(a, b, n).reshape(1, n)))
4848

4949

50+
def read_owt_data(
51+
package: str, filename: str
52+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int, int]:
53+
"""
54+
Returns the optical water types (OWT) data table.
55+
56+
:param package: The package name.
57+
:param filename: The filename.
58+
:returns: The data table.
59+
"""
60+
with resources.path(package, filename) as resource:
61+
rows = []
62+
with open(resource) as r:
63+
df = pd.read_csv(r, sep=";", header=None, index_col=0)
64+
for name, _ in df.items():
65+
rows.append(df[name].values)
66+
data = np.stack(rows, axis=-1)
67+
wav = data[0, :6]
68+
rrs = data[1:, :6]
69+
unc = data[1:, 6:]
70+
return wav, rrs, unc, rrs.shape[0], rrs.shape[1]
71+
72+
5073
def read_plot_data(
5174
package: str, filename: str
5275
) -> tuple[np.ndarray, np.ndarray]:
@@ -108,7 +131,7 @@ def test_lee_2010_figure_2(self):
108131
self.assertAlmostEqual(0.1, result.punc[0], delta=0.1)
109132
self.assertAlmostEqual(0.1, result.punc[1], delta=0.1)
110133
self.assertAlmostEqual(0.1, result.punc[1], delta=0.1)
111-
self.assertAlmostEqual(0.3, result.yvar_r, delta=0.1)
134+
self.assertAlmostEqual(0.3, result.yvar_r.item(), delta=0.1)
112135

113136
print()
114137
print("popt = ", result.popt)
@@ -155,7 +178,7 @@ def test_lee_2010_figure_3(self):
155178
self.assertAlmostEqual(0.001, result.punc[0], delta=0.001)
156179
self.assertAlmostEqual(0.001, result.punc[0], delta=0.001)
157180
self.assertAlmostEqual(5.0, result.punc[0], delta=5.0)
158-
self.assertAlmostEqual(1.5e-05, result.yvar_r, delta=0.1e-05)
181+
self.assertAlmostEqual(1.5e-05, result.yvar_r.item(), delta=0.1e-05)
159182

160183
print()
161184
print("popt = ", result.popt)
@@ -272,6 +295,73 @@ def test_qaa_multiple_batches(self):
272295
msg=f"{i}, {j}: assertion failed",
273296
)
274297

298+
def test_qaa_with_owt(self):
299+
"""
300+
Test QAA on optical water type (OWT) classes.
301+
"""
302+
w, R, u, M, m = read_owt_data( # noqa : N806
303+
"test.resources", "owt.csv"
304+
)
305+
W = np.broadcast_to(w, (M, m)) # noqa : N806
306+
307+
f = QAA()
308+
x = np.stack(
309+
[
310+
np.broadcast_to(W, (M, m)),
311+
R,
312+
np.broadcast_to(AW, (M, m)),
313+
np.broadcast_to(BW, (M, m)),
314+
],
315+
axis=1,
316+
)
317+
u = np.stack(
318+
[
319+
np.broadcast_to(0.5, (M, m)),
320+
u,
321+
np.broadcast_to(0.1 * AW, (M, m)),
322+
np.broadcast_to(0.1 * BW, (M, m)),
323+
],
324+
axis=1,
325+
)
326+
p = f.estimate()
327+
y = f.eval(p, x)
328+
self.assertEqual((M, 4, m), y.shape)
329+
330+
a = y[:, 0, :]
331+
self.assertAlmostEqual(0.014, a[0, 0], delta=0.001)
332+
self.assertAlmostEqual(0.015, a[0, 1], delta=0.001)
333+
self.assertAlmostEqual(0.019, a[0, 2], delta=0.001)
334+
self.assertAlmostEqual(0.034, a[0, 3], delta=0.001)
335+
self.assertAlmostEqual(0.060, a[0, 4], delta=0.001)
336+
self.assertAlmostEqual(0.291, a[0, 5], delta=0.001)
337+
338+
self.assertAlmostEqual(0.515, a[13, 0], delta=0.001)
339+
self.assertAlmostEqual(0.391, a[13, 1], delta=0.001)
340+
self.assertAlmostEqual(0.249, a[13, 2], delta=0.001)
341+
self.assertAlmostEqual(0.224, a[13, 3], delta=0.001)
342+
self.assertAlmostEqual(0.184, a[13, 4], delta=0.001)
343+
self.assertAlmostEqual(0.514, a[13, 5], delta=0.001)
344+
345+
U = np.square(u) # noqa : N806
346+
U = f.lpu_x(p, x, U, True) # noqa : N806
347+
self.assertEqual((M, 4, m), U.shape)
348+
349+
u = np.sqrt(U)
350+
ua = u[:, 0, :]
351+
self.assertAlmostEqual(0.013, ua[0, 0], delta=0.001)
352+
self.assertAlmostEqual(0.016, ua[0, 1], delta=0.001)
353+
self.assertAlmostEqual(0.024, ua[0, 2], delta=0.001)
354+
self.assertAlmostEqual(0.046, ua[0, 3], delta=0.001)
355+
self.assertAlmostEqual(0.006, ua[0, 4], delta=0.001)
356+
self.assertAlmostEqual(1.569, ua[0, 5], delta=0.001)
357+
358+
self.assertAlmostEqual(0.479, ua[13, 0], delta=0.001)
359+
self.assertAlmostEqual(0.359, ua[13, 1], delta=0.001)
360+
self.assertAlmostEqual(0.235, ua[13, 2], delta=0.001)
361+
self.assertAlmostEqual(0.210, ua[13, 3], delta=0.001)
362+
self.assertAlmostEqual(0.178, ua[13, 4], delta=0.001)
363+
self.assertAlmostEqual(0.083, ua[13, 5], delta=0.001)
364+
275365

276366
if __name__ == "__main__":
277367
unittest.main()

0 commit comments

Comments
 (0)