|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from unittest.mock import patch |
| 4 | + |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import numpy as np |
| 7 | +from pykoopman.common.examples import advance_linear_system |
| 8 | +from pykoopman.common.examples import drss |
| 9 | +from pykoopman.common.examples import forced_duffing |
| 10 | +from pykoopman.common.examples import Linear2Ddynamics |
| 11 | +from pykoopman.common.examples import lorenz |
| 12 | +from pykoopman.common.examples import rev_dvdp |
| 13 | +from pykoopman.common.examples import rk4 |
| 14 | +from pykoopman.common.examples import sine_wave |
| 15 | +from pykoopman.common.examples import slow_manifold |
| 16 | +from pykoopman.common.examples import square_wave |
| 17 | +from pykoopman.common.examples import vdp_osc |
| 18 | + |
| 19 | + |
| 20 | +def test_drss_shapes(): |
| 21 | + n_states = 3 |
| 22 | + n_controls = 2 |
| 23 | + n_measurements = 4 |
| 24 | + A, B, C = drss(n=n_states, p=n_controls, m=n_measurements) |
| 25 | + assert A.shape == (n_states, n_states) |
| 26 | + assert B.shape == (n_states, n_controls) |
| 27 | + assert C.shape == (n_measurements, n_states) |
| 28 | + |
| 29 | + |
| 30 | +def test_drss_identity_measurement(): |
| 31 | + # If m=0, C should be identity |
| 32 | + n_states = 3 |
| 33 | + A, B, C = drss(n=n_states, m=0) |
| 34 | + assert C.shape == (n_states, n_states) |
| 35 | + np.testing.assert_array_equal(C, np.eye(n_states)) |
| 36 | + |
| 37 | + |
| 38 | +def test_advance_linear_system(): |
| 39 | + n = 2 |
| 40 | + A = np.eye(n) |
| 41 | + B = np.eye(n) |
| 42 | + C = np.eye(n) |
| 43 | + x0 = np.array([1.0, 1.0]) |
| 44 | + # consistent for 1 step |
| 45 | + # n_steps to simulate |
| 46 | + n_steps = 2 |
| 47 | + # Expanding u to match steps if needed, but the function handles 1D u as row vector |
| 48 | + # Let's provide u of shape (p, n_steps-1) |
| 49 | + u_seq = np.ones((n, n_steps - 1)) |
| 50 | + |
| 51 | + x, y = advance_linear_system(x0, u_seq, n_steps, A, B, C) |
| 52 | + |
| 53 | + # x shape should be (n, n_steps)?? Wait, let's check docstring or implementation. |
| 54 | + # Implementation: x = np.zeros([n, len(x0)]) -> Wait, len(x0) is n. |
| 55 | + # The implementation: |
| 56 | + # x = np.zeros([n, len(x0)]) ??? No, x0 is (n,). len(x0) is n. |
| 57 | + # But usually x should be (n_states, n_time_steps). |
| 58 | + # docstring says: returns x of shape (n, len(x0)). |
| 59 | + # This seems like a potential bug or confusion in docstring vs code |
| 60 | + # if n_steps != n_states. |
| 61 | + # Let's look at code: 'x = np.zeros([n, len(x0)])' |
| 62 | + # where n is passed as arg 'n' (steps). |
| 63 | + # The argument name 'n' shadows dimension 'n'. |
| 64 | + # In function def: advance_linear_system(x0, u, n, ...): |
| 65 | + # n is "Number of steps to simulate" |
| 66 | + # But inside: x = np.zeros([n, len(x0)]) |
| 67 | + # So dim 0 is n (steps), dim 1 is len(x0) (states?). |
| 68 | + # Usually states are rows or columns. |
| 69 | + # If n=steps, then x is (steps, states) or (states, steps). |
| 70 | + # Code: x[0, :] = x0. So x is (steps, states). |
| 71 | + |
| 72 | + # x has shape (n_steps, n_states) |
| 73 | + assert x.shape == (n_steps, len(x0)) |
| 74 | + assert y.shape == (n_steps, C.shape[0]) |
| 75 | + |
| 76 | + |
| 77 | +def test_vdp_osc_rk4(): |
| 78 | + t = 0 |
| 79 | + x = np.array([[1.0], [0.5]]) |
| 80 | + u = 0.0 |
| 81 | + dt = 0.01 |
| 82 | + |
| 83 | + # Check vdp_osc structure |
| 84 | + dx = vdp_osc(t, x, u) |
| 85 | + assert dx.shape == x.shape |
| 86 | + |
| 87 | + # Check rk4 integration step |
| 88 | + x_next = rk4(t, x, u, dt, vdp_osc) |
| 89 | + assert x_next.shape == x.shape |
| 90 | + assert not np.array_equal(x, x_next) |
| 91 | + |
| 92 | + |
| 93 | +def test_square_and_sine_wave(): |
| 94 | + # Just smoke tests to ensure they run/return floats |
| 95 | + val_sq = square_wave(10) |
| 96 | + assert ( |
| 97 | + isinstance(val_sq, float) |
| 98 | + or isinstance(val_sq, int) |
| 99 | + or isinstance(val_sq, np.float64) |
| 100 | + ) |
| 101 | + |
| 102 | + val_sin = sine_wave(10) |
| 103 | + assert isinstance(val_sin, float) or isinstance(val_sin, np.floating) |
| 104 | + |
| 105 | + |
| 106 | +def test_lorenz(): |
| 107 | + x = [10.0, 10.0, 10.0] |
| 108 | + t = 0.0 |
| 109 | + dx = lorenz(x, t) |
| 110 | + assert len(dx) == 3 |
| 111 | + |
| 112 | + |
| 113 | +def test_rev_dvdp(): |
| 114 | + x = np.array( |
| 115 | + [[1.0], [0.5]] |
| 116 | + ) # needs to be 2D array (2, 1) based on code usage of x[0,:]? |
| 117 | + # Code: x[0,:] - ... |
| 118 | + # So if we pass (2, 1), x[0,:] is shape (1,). |
| 119 | + t = 0 |
| 120 | + x_next = rev_dvdp(t, x) |
| 121 | + assert x_next.shape == x.shape |
| 122 | + |
| 123 | + |
| 124 | +def test_linear_2d_dynamics(): |
| 125 | + sys = Linear2Ddynamics() |
| 126 | + x = np.array([[1.0], [1.0]]) |
| 127 | + |
| 128 | + # Test linear_map |
| 129 | + y = sys.linear_map(x) |
| 130 | + assert y.shape == x.shape |
| 131 | + |
| 132 | + # Test collect_data |
| 133 | + n_int = 10 |
| 134 | + n_traj = 1 |
| 135 | + X, Y = sys.collect_data(x, n_int, n_traj) |
| 136 | + # shapes: (n_states, n_int * n_traj) |
| 137 | + assert X.shape == (2, n_int * n_traj) |
| 138 | + assert Y.shape == (2, n_int * n_traj) |
| 139 | + |
| 140 | + |
| 141 | +def test_slow_manifold(): |
| 142 | + model = slow_manifold() |
| 143 | + x = np.array([[0.1], [0.1]]) # (2, 1) to match usage of x[0, :] |
| 144 | + |
| 145 | + # Test sys |
| 146 | + t = 0 |
| 147 | + u = 0 |
| 148 | + dx = model.sys(t, x, u) |
| 149 | + assert dx.shape == x.shape |
| 150 | + |
| 151 | + # Test simulate (requires x0 to be (2, 1)) |
| 152 | + x0 = np.array([[0.1], [0.1]]) |
| 153 | + n_int = 100 |
| 154 | + X = model.simulate(x0, n_int) |
| 155 | + assert X.shape == (2, n_int * 1) # n_traj is 1 |
| 156 | + |
| 157 | + |
| 158 | +def test_forced_duffing(): |
| 159 | + dt = 0.01 |
| 160 | + d = 0.1 |
| 161 | + alpha = 1.0 |
| 162 | + beta = 1.0 |
| 163 | + model = forced_duffing(dt, d, alpha, beta) |
| 164 | + |
| 165 | + assert model.n_states == 2 |
| 166 | + |
| 167 | + # Test sys |
| 168 | + t = 0 |
| 169 | + x = np.array([[1.0], [1.0]]) |
| 170 | + u = 0.0 |
| 171 | + dx = model.sys(t, x, u) |
| 172 | + assert dx.shape == x.shape |
| 173 | + |
| 174 | + # Test simulate |
| 175 | + x0 = np.array([[1.0], [1.0]]) |
| 176 | + n_int = 10 |
| 177 | + u_seq = np.zeros((n_int, 1)) # (n_int, n_traj) ? |
| 178 | + # collect_data_discrete uses u[step, :] which implies u is (n_int, n_traj) ?? |
| 179 | + # Let's check simulate implementation: u is passed as (n_int, ...?) |
| 180 | + # simulate(x0, n_int, u) -> u[step, :] |
| 181 | + # if x0 has n_traj=1. |
| 182 | + |
| 183 | + # Wait, in forced_duffing.simulate: |
| 184 | + # u[step, :] is passed to rk4. |
| 185 | + # if u is (n_int, 1), u[step,:] is shape (1,). |
| 186 | + # sys takes u. |
| 187 | + # sys implementation: ... + u |
| 188 | + # if u is scalar or (1,) it broadcasts. |
| 189 | + |
| 190 | + X = model.simulate(x0, n_int, u_seq) |
| 191 | + assert X.shape == (2, n_int * 1) |
| 192 | + |
| 193 | + # Test collect_data_continuous |
| 194 | + u_static = 0.0 |
| 195 | + X_c, Y_c = model.collect_data_continuous(x0, u_static) |
| 196 | + assert X_c.shape == x0.shape |
| 197 | + assert Y_c.shape == x0.shape |
| 198 | + |
| 199 | + # Test collect_data_discrete |
| 200 | + X_d, Y_d = model.collect_data_discrete(x0, n_int, u_seq) |
| 201 | + assert X_d.shape == (2, n_int * 1) |
| 202 | + assert Y_d.shape == (2, n_int * 1) |
| 203 | + |
| 204 | + |
| 205 | +@patch("matplotlib.pyplot.show") |
| 206 | +def test_forced_duffing_visualize(mock_show): |
| 207 | + dt = 0.01 |
| 208 | + model = forced_duffing(dt, 0.1, 1.0, 1.0) |
| 209 | + t = np.linspace(0, 1, 100) |
| 210 | + X = np.random.rand(2, 100) |
| 211 | + |
| 212 | + model.visualize_trajectories(t, X, n_traj=1) |
| 213 | + mock_show.assert_not_called() |
| 214 | + # visualize_trajectories doesn't call show() in source?? |
| 215 | + # Let's check source. |
| 216 | + # visualize_trajectories: plt.subplots... axs.plot... axs.set... No plt.show() |
| 217 | + # It just makes plots. |
| 218 | + plt.close() # Close to avoid warning |
| 219 | + |
| 220 | + model.visualize_state_space(X, X, n_traj=1) |
| 221 | + # visualize_state_space: plt.subplots... axs.plot... No plt.show() |
| 222 | + plt.close() |
0 commit comments