-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
68 lines (56 loc) · 1.96 KB
/
main.py
File metadata and controls
68 lines (56 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from typing import Callable, TypeAlias
from functools import partial
import models
import solvers
RHS: TypeAlias = Callable[[float, np.ndarray], np.ndarray]
Solver: TypeAlias = Callable[[RHS, tuple[float, float], npt.ArrayLike, float], tuple[np.ndarray, np.ndarray]]
def main():
# change settings here
system: str = "sir"
solver: str = "ie"
plot_states: tuple[int, ...] = (0, 1, 2) # plot states against time
plot_pairs: tuple[tuple[int, int], ...] = () # plot states against each other
params = dict(
beta=0.40,
gamma=0.05
)
t_span: tuple[float, float] = (0.0, 100.0)
x0: list[float] = [0.99, 0.01, 0.0]
h: float = 0.05
# solver, don't change anything here
labels: tuple[str, ...] = models.labels[system]
base_func = getattr(models, system)
func: RHS = partial(base_func, **params)
solve: Solver = getattr(solvers, solver)
t, y = solve(func, t_span, x0, h)
# Plot time-series (states vs. time)
if plot_states:
plt.figure()
for state in plot_states:
plt.plot(t, y[:, state], label=labels[state])
plt.xlabel("t")
if len(plot_states) == 1:
plt.ylabel(labels[plot_states[0]])
else:
plt.ylabel("states")
plt.title(f"{system}: time-series")
plt.legend()
plt.grid(True)
plt.tight_layout()
# Plot phase-space (state pairs)
if plot_pairs:
for i, j in plot_pairs:
plt.figure()
plt.plot(y[:, i], y[:, j], label=f"{labels[i]} vs {labels[j]}")
plt.xlabel(labels[i])
plt.ylabel(labels[j])
plt.title(f"{system}: {labels[i]} vs {labels[j]}")
plt.grid(True)
plt.axis('equal')
plt.tight_layout()
plt.show()
if __name__ == '__main__':
main()