Skip to content

Commit 12299bd

Browse files
committed
add test reporter
1 parent de8990a commit 12299bd

6 files changed

Lines changed: 775 additions & 0 deletions

File tree

fermentools/reporter/__init__.py

Whitespace-only changes.
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
from typing import Optional, Union
2+
3+
from numpy.typing import ArrayLike
4+
import numpy as np
5+
6+
from sklearn.decomposition import PCA
7+
from sklearn.pipeline import Pipeline
8+
9+
from chemotools.outliers import QResiduals, HotellingT2
10+
11+
import matplotlib.pyplot as plt
12+
from matplotlib.patches import Ellipse
13+
import matplotlib.transforms as transforms
14+
15+
from reporter.plot.plot_spectra import plot_spectra
16+
from reporter._utils import (
17+
extract_and_validate_model,
18+
get_cut_wavenumbers_from_pipeline,
19+
)
20+
21+
22+
class PCAReport:
23+
def __init__(
24+
self,
25+
model: Union[Pipeline, PCA],
26+
X_train: ArrayLike,
27+
y_train: ArrayLike,
28+
X_test: Optional[ArrayLike] = None,
29+
y_test: Optional[ArrayLike] = None,
30+
wavenumbers: Optional[ArrayLike] = None,
31+
):
32+
self.model = model
33+
self.X_train = X_train
34+
self.y_train = y_train
35+
self.X_test = X_test
36+
self.y_test = y_test
37+
self.wavenumbers = wavenumbers
38+
self.transformer, self.estimator = extract_and_validate_model(
39+
model, model_type="pca"
40+
)
41+
self.cut_wavenumbers = get_cut_wavenumbers_from_pipeline(self.transformer)
42+
43+
def plot_data(
44+
self,
45+
color_by: Optional[ArrayLike] = None,
46+
title: str = "Spectra",
47+
x_label: str = "Wavenumber",
48+
y_label: str = "Intensity",
49+
):
50+
"""
51+
Plot spectra with optional color coding.
52+
53+
Parameters
54+
----------
55+
color_by : Optional[ArrayLike]
56+
Optional array for color coding the points.
57+
"""
58+
59+
if self.wavenumbers is None:
60+
raise ValueError("Wavenumbers are not provided.")
61+
62+
if color_by is None:
63+
color_by = self.y_train
64+
65+
plot_spectra(
66+
x=self.wavenumbers,
67+
y=self.X_train,
68+
color_by=color_by,
69+
title=title,
70+
x_label=x_label,
71+
y_label=y_label,
72+
)
73+
74+
def plot_preprocessed_data(
75+
self,
76+
color_by: Optional[ArrayLike] = None,
77+
title: str = "Spectra",
78+
x_label: str = "Wavenumber",
79+
y_label: str = "Intensity",
80+
):
81+
"""
82+
Plot preprocessed spectra with optional color coding.
83+
84+
Parameters
85+
----------
86+
color_by : Optional[ArrayLike]
87+
Optional array for color coding the points.
88+
"""
89+
90+
if self.wavenumbers is None:
91+
raise ValueError("Wavenumbers are not provided.")
92+
93+
if self.cut_wavenumbers is not None:
94+
wavenumbers = self.cut_wavenumbers
95+
else:
96+
wavenumbers = self.wavenumbers
97+
98+
if color_by is None:
99+
color_by = self.y_train
100+
101+
plot_spectra(
102+
x=wavenumbers,
103+
y=self.transformer.transform(self.X_train),
104+
color_by=color_by,
105+
title=title,
106+
x_label=x_label,
107+
y_label=y_label,
108+
)
109+
110+
def _get_ellipse(self, x, y, ax, n_std: int = 3, edgecolor: str = "red"):
111+
cov = np.cov(x, y)
112+
if cov.shape != (2, 2) or np.any(np.isnan(cov)):
113+
return # skip invalid data
114+
pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1])
115+
ell_radius_x = np.sqrt(1 + pearson)
116+
ell_radius_y = np.sqrt(1 - pearson)
117+
ellipse = Ellipse(
118+
(0, 0),
119+
width=ell_radius_x * 2,
120+
height=ell_radius_y * 2,
121+
edgecolor=edgecolor,
122+
facecolor="none",
123+
linewidth=2,
124+
alpha=0.5,
125+
)
126+
scale_x = np.std(x) * n_std
127+
scale_y = np.std(y) * n_std
128+
mean_x = np.mean(x)
129+
mean_y = np.mean(y)
130+
transf = (
131+
transforms.Affine2D()
132+
.rotate_deg(0)
133+
.scale(scale_x, scale_y)
134+
.translate(mean_x, mean_y)
135+
)
136+
ellipse.set_transform(transf + ax.transData)
137+
ax.add_patch(ellipse)
138+
139+
def plot_scores(
140+
self, color_by=None, title="PCA Scores", x_axis=1, y_axis=2, n_std=3
141+
):
142+
X_train_pca = self.model.transform(self.X_train)
143+
144+
if color_by is None:
145+
color_by = self.y_train
146+
147+
fig, ax = plt.subplots(figsize=(5, 4))
148+
149+
# Plot training data
150+
scatter = ax.scatter(
151+
X_train_pca[:, x_axis - 1],
152+
X_train_pca[:, y_axis - 1],
153+
c=color_by,
154+
cmap="turbo",
155+
edgecolor="k",
156+
s=50,
157+
label="Train",
158+
)
159+
160+
# Add overall ellipse for training data
161+
self._get_ellipse(
162+
X_train_pca[:, x_axis - 1],
163+
X_train_pca[:, y_axis - 1],
164+
ax,
165+
n_std=n_std,
166+
edgecolor="red",
167+
)
168+
169+
if self.X_test is None:
170+
pass
171+
else:
172+
X_test_pca = self.model.transform(self.X_test)
173+
# Plot test data
174+
ax.scatter(
175+
X_test_pca[:, x_axis - 1],
176+
X_test_pca[:, y_axis - 1],
177+
c="blue",
178+
edgecolor="k",
179+
s=50,
180+
marker="s",
181+
label="Test",
182+
)
183+
184+
# Add overall ellipse for test data
185+
self._get_ellipse(
186+
X_test_pca[:, x_axis - 1],
187+
X_test_pca[:, y_axis - 1],
188+
ax,
189+
edgecolor="blue",
190+
)
191+
192+
ax.set_title(title)
193+
ax.set_xlabel(f"PC{x_axis}")
194+
ax.set_ylabel(f"PC{y_axis}")
195+
plt.colorbar(scatter, ax=ax, label="Color by (train)")
196+
ax.legend()
197+
ax.grid()
198+
plt.tight_layout()
199+
plt.show()
200+
return self
201+
202+
def plot_loadings(self, title="PCA Loadings"):
203+
loadings = self.estimator.components_.T
204+
205+
if self.cut_wavenumbers is not None:
206+
wavenumbers = self.cut_wavenumbers
207+
else:
208+
wavenumbers = self.wavenumbers
209+
210+
fig, ax = plt.subplots(figsize=(10, 4))
211+
ax.plot(wavenumbers, loadings)
212+
ax.set_title(title)
213+
ax.set_xlabel("Features")
214+
ax.set_ylabel("Loadings")
215+
ax.legend([f"PC{i + 1}" for i in range(loadings.shape[1])])
216+
ax.grid()
217+
plt.tight_layout()
218+
plt.show()
219+
return self
220+
221+
def plot_scree(self, title="Cumulative Explained Variance"):
222+
explained_variance = self.estimator.explained_variance_ratio_
223+
cumulative_variance = np.cumsum(explained_variance)
224+
225+
fig, ax = plt.subplots(figsize=(10, 4))
226+
ax.plot(
227+
range(1, len(explained_variance) + 1),
228+
explained_variance,
229+
marker="o",
230+
color="blue",
231+
)
232+
ax.bar(
233+
range(1, len(cumulative_variance) + 1),
234+
cumulative_variance,
235+
alpha=0.5,
236+
color="orange",
237+
)
238+
ax.set_title(title)
239+
ax.set_xlabel("Principal Component")
240+
ax.set_ylabel("Cumulative Explained Variance Ratio")
241+
ax.grid()
242+
plt.tight_layout()
243+
plt.show()
244+
return self
245+
246+
def plot_residuals(
247+
self,
248+
color_by=None,
249+
label_by=None,
250+
title="Residuals",
251+
):
252+
q_residuals = QResiduals(self.model)
253+
h_residuals = HotellingT2(self.model)
254+
255+
q_train = q_residuals.fit_predict_residuals(self.X_train)
256+
h_train = h_residuals.fit_predict_residuals(self.X_train, y=None)
257+
258+
if color_by is None:
259+
color_by = self.y_train
260+
261+
fig, ax = plt.subplots(figsize=(5, 4))
262+
# Plot training data
263+
scatter = ax.scatter(
264+
h_train,
265+
q_train,
266+
c=color_by,
267+
cmap="turbo",
268+
edgecolor="k",
269+
s=50,
270+
label="Train",
271+
)
272+
273+
# Plot test data
274+
if self.X_test is None:
275+
pass
276+
else:
277+
q_test = q_residuals.predict_residuals(self.X_test)
278+
h_test = h_residuals.predict_residuals(self.X_test, y=None)
279+
280+
ax.scatter(
281+
h_test,
282+
q_test,
283+
c="blue",
284+
edgecolor="k",
285+
s=50,
286+
marker="s",
287+
label="Test",
288+
)
289+
290+
ax.axhline(
291+
q_residuals.critical_value_,
292+
color="red",
293+
linestyle="--",
294+
label="Q Residuals Threshold",
295+
)
296+
ax.axvline(
297+
h_residuals.critical_value_,
298+
color="red",
299+
linestyle="--",
300+
label="Hotelling T2 Threshold",
301+
)
302+
303+
# Add sample labels
304+
if label_by is None:
305+
pass
306+
else:
307+
for i, txt in enumerate(label_by):
308+
ax.annotate(
309+
txt,
310+
(h_train[i], q_train[i]),
311+
textcoords="offset points",
312+
xytext=(0, 5),
313+
ha="center",
314+
)
315+
316+
ax.set_title(title)
317+
ax.set_xlabel("Hotelling T2")
318+
ax.set_ylabel("Q Residuals")
319+
plt.colorbar(scatter, ax=ax, label="Color by (train)")
320+
ax.grid()
321+
plt.tight_layout()
322+
plt.show()
323+
return self

0 commit comments

Comments
 (0)