Skip to content

Commit c8e3917

Browse files
bump
1 parent 00b7dee commit c8e3917

File tree

2 files changed

+225
-3
lines changed

2 files changed

+225
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "grpredict"
7-
version = "26.4.7rc"
7+
version = "26.4.7"
88
description = "Estimate growth curves using an Extended Kalman Filter"
99
readme = "README.md"
1010
requires-python = ">=3.8"
@@ -31,6 +31,7 @@ classifiers = [
3131
"Programming Language :: Python :: 3.11",
3232
"Programming Language :: Python :: 3.12",
3333
"Programming Language :: Python :: 3.13",
34+
"Programming Language :: Python :: 3.14",
3435
]
3536

3637
[project.urls]
@@ -39,8 +40,6 @@ Repository = "https://github.com/Pioreactor/grpredict"
3940
[project.optional-dependencies]
4041
dev = [
4142
"pytest",
42-
"black",
43-
"mypy"
4443
]
4544

4645
[tool.setuptools.package-data]
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
#!/Users/camerondavidson-pilon/code/grpredict/.venv/bin/python3.14
2+
from __future__ import annotations
3+
4+
import os
5+
import sys
6+
from pathlib import Path
7+
8+
os.environ.setdefault("MPLCONFIGDIR", str(Path("/tmp") / "mplconfig_grpredict"))
9+
10+
import matplotlib
11+
12+
matplotlib.use("Agg")
13+
import matplotlib.pyplot as plt
14+
import numpy as np
15+
16+
ROOT = Path(__file__).resolve().parents[1]
17+
sys.path.insert(0, str(ROOT))
18+
sys.path.insert(0, str(ROOT / "src"))
19+
20+
from tests.simulation_utils import plausible_growth_rate_profiles
21+
from tests.simulation_utils import simulate_profiled_od_observations
22+
from tests.test_simulated_profiles_with_ekf import make_single_sensor_ekf
23+
from tests.test_simulated_profiles_with_ekf import run_ekf_over_observations
24+
25+
PLOTS_DIR = ROOT / "scratch" / "plots"
26+
GROWTH_RATE_OUTPUT_PATH = PLOTS_DIR / "ekf_profile_noise_grid.png"
27+
OD_OUTPUT_PATH = PLOTS_DIR / "ekf_profile_noise_od_grid.png"
28+
29+
DT_HOURS = 5.0 / 60.0 / 60.0
30+
SEED = 321
31+
PROFILE_LAYOUT: list[tuple[str, float, str]] = [
32+
("lag_log_stationary", 12.0, "lag_log_stationary"),
33+
("constant_growth", 12.0, "constant"),
34+
("washout_recovery", 12.0, "washout_recovery."),
35+
]
36+
NOISE_FAMILIES: list[tuple[str, str]] = [
37+
("nominal_near_iid", "Nominal near-iid"),
38+
("nominal_colored", "Nominal colored"),
39+
("noisy_colored", "Noisy colored"),
40+
]
41+
42+
43+
def build_panel_data(profile_name: str, total_hours: float, noise_family: str) -> dict[str, np.ndarray]:
44+
growth_rates = plausible_growth_rate_profiles(total_hours, DT_HOURS)[profile_name]
45+
simulated = simulate_profiled_od_observations(
46+
growth_rates,
47+
profile_name=noise_family,
48+
dt_hours=DT_HOURS,
49+
seed=SEED,
50+
)
51+
estimated_rates = run_ekf_over_observations(
52+
simulated["observed_od"],
53+
DT_HOURS,
54+
noise_family,
55+
)
56+
ekf = make_single_sensor_ekf(noise_family)
57+
estimated_od = np.empty_like(simulated["observed_od"])
58+
estimated_od[0] = float(ekf.state_[0])
59+
for index, observation in enumerate(simulated["observed_od"][1:], start=1):
60+
state, _ = ekf.update([float(observation)], DT_HOURS)
61+
estimated_od[index] = float(state[0])
62+
return {
63+
"time_hours": simulated["time_hours"],
64+
"rate_time_hours": simulated["time_hours"][1:],
65+
"true_rates": simulated["growth_rates"],
66+
"estimated_rates": estimated_rates,
67+
"latent_od": simulated["latent_od"],
68+
"observed_od": simulated["observed_od"],
69+
"estimated_od": estimated_od,
70+
}
71+
72+
73+
def collect_panels() -> dict[tuple[int, int], dict[str, np.ndarray]]:
74+
panels: dict[tuple[int, int], dict[str, np.ndarray]] = {}
75+
for row_index, (profile_name, total_hours, _) in enumerate(PROFILE_LAYOUT):
76+
for col_index, (noise_family, _) in enumerate(NOISE_FAMILIES):
77+
panels[(row_index, col_index)] = build_panel_data(profile_name, total_hours, noise_family)
78+
return panels
79+
80+
81+
def render_growth_rate_grid(panels: dict[tuple[int, int], dict[str, np.ndarray]]) -> None:
82+
figure, axes = plt.subplots(3, 3, figsize=(16, 10), sharex=False, sharey=True, constrained_layout=True)
83+
max_abs_rate = 0.0
84+
for panel in panels.values():
85+
panel_max = float(
86+
np.max(np.abs(np.concatenate([panel["true_rates"], panel["estimated_rates"]])))
87+
)
88+
max_abs_rate = max(max_abs_rate, panel_max)
89+
y_limit = max(0.30, np.ceil(max_abs_rate * 20.0) / 20.0)
90+
for row_index, (_, _, profile_label) in enumerate(PROFILE_LAYOUT):
91+
for col_index, (_, noise_label) in enumerate(NOISE_FAMILIES):
92+
axis = axes[row_index, col_index]
93+
panel = panels[(row_index, col_index)]
94+
95+
axis.plot(
96+
panel["rate_time_hours"],
97+
panel["true_rates"],
98+
color="#1f77b4",
99+
linewidth=2.2,
100+
label="True growth rate",
101+
)
102+
axis.plot(
103+
panel["rate_time_hours"],
104+
panel["estimated_rates"],
105+
color="#d62728",
106+
linewidth=1.5,
107+
alpha=0.9,
108+
label="EKF estimate",
109+
)
110+
axis.axhline(0.0, color="#666666", linewidth=0.8, alpha=0.5)
111+
axis.set_ylim(-0.08 if y_limit > 0.08 else -y_limit, y_limit)
112+
axis.grid(alpha=0.18)
113+
114+
if row_index == 0:
115+
axis.set_title(noise_label)
116+
if col_index == 0:
117+
axis.set_ylabel(f"{profile_label}\nGrowth rate (1/h)")
118+
if row_index == len(PROFILE_LAYOUT) - 1:
119+
axis.set_xlabel("Time (hours)")
120+
121+
rmse = float(np.sqrt(np.mean((panel["estimated_rates"] - panel["true_rates"]) ** 2)))
122+
axis.text(
123+
0.02,
124+
0.96,
125+
f"RMSE {rmse:.3f}",
126+
transform=axis.transAxes,
127+
va="top",
128+
ha="left",
129+
fontsize=9,
130+
bbox={"facecolor": "white", "edgecolor": "none", "alpha": 0.75, "pad": 1.5},
131+
)
132+
133+
handles, labels = axes[0, 0].get_legend_handles_labels()
134+
figure.legend(handles, labels, loc="upper center", ncol=2, frameon=False)
135+
figure.suptitle(
136+
"Simulated Growth Rate vs EKF Estimate Across Profiles and Noise Families\n"
137+
f"dt={DT_HOURS * 3600:.0f}s, seed={SEED}",
138+
fontsize=14,
139+
)
140+
figure.savefig(GROWTH_RATE_OUTPUT_PATH, dpi=180)
141+
plt.close(figure)
142+
143+
144+
def render_observed_od_grid(panels: dict[tuple[int, int], dict[str, np.ndarray]]) -> None:
145+
figure, axes = plt.subplots(3, 3, figsize=(16, 10), sharex=False, sharey=False, constrained_layout=True)
146+
for row_index, (_, _, profile_label) in enumerate(PROFILE_LAYOUT):
147+
for col_index, (_, noise_label) in enumerate(NOISE_FAMILIES):
148+
axis = axes[row_index, col_index]
149+
panel = panels[(row_index, col_index)]
150+
151+
axis.plot(
152+
panel["time_hours"],
153+
panel["latent_od"],
154+
color="#1f77b4",
155+
linewidth=2.0,
156+
label="Latent OD",
157+
)
158+
axis.plot(
159+
panel["time_hours"],
160+
panel["observed_od"],
161+
color="#2ca02c",
162+
linewidth=0.9,
163+
alpha=0.65,
164+
label="Observed OD",
165+
)
166+
axis.plot(
167+
panel["time_hours"],
168+
panel["estimated_od"],
169+
color="#d62728",
170+
linewidth=1.5,
171+
alpha=0.9,
172+
label="KF OD",
173+
)
174+
axis.scatter(
175+
panel["time_hours"],
176+
panel["observed_od"],
177+
color="#2ca02c",
178+
s=4,
179+
alpha=0.35,
180+
)
181+
axis.grid(alpha=0.18)
182+
183+
if row_index == 0:
184+
axis.set_title(noise_label)
185+
if col_index == 0:
186+
axis.set_ylabel(f"{profile_label}\nOD")
187+
if row_index == len(PROFILE_LAYOUT) - 1:
188+
axis.set_xlabel("Time (hours)")
189+
190+
residual_std = float(np.std(panel["observed_od"] - panel["latent_od"]))
191+
axis.text(
192+
0.02,
193+
0.96,
194+
f"resid sd {residual_std:.3f}",
195+
transform=axis.transAxes,
196+
va="top",
197+
ha="left",
198+
fontsize=9,
199+
bbox={"facecolor": "white", "edgecolor": "none", "alpha": 0.75, "pad": 1.5},
200+
)
201+
202+
handles, labels = axes[0, 0].get_legend_handles_labels()
203+
figure.legend(handles, labels, loc="upper center", ncol=3, frameon=False)
204+
figure.suptitle(
205+
"Simulated Latent OD, Observed OD, and KF OD Across Profiles and Noise Families\n"
206+
f"dt={DT_HOURS * 3600:.0f}s, seed={SEED}",
207+
fontsize=14,
208+
)
209+
figure.savefig(OD_OUTPUT_PATH, dpi=180)
210+
plt.close(figure)
211+
212+
213+
def main() -> None:
214+
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
215+
panels = collect_panels()
216+
render_growth_rate_grid(panels)
217+
render_observed_od_grid(panels)
218+
print(GROWTH_RATE_OUTPUT_PATH)
219+
print(OD_OUTPUT_PATH)
220+
221+
222+
if __name__ == "__main__":
223+
main()

0 commit comments

Comments
 (0)