Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/rail/plotting/data_extraction_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def get_pz_pdf_data(
return pz_data



def get_pz_point_estimate_data(
project: RailProject,
selection: str,
Expand Down
49 changes: 48 additions & 1 deletion src/rail/plotting/plot_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import os
from typing import TYPE_CHECKING, Any
import yaml

from matplotlib.figure import Figure
import numpy as np

if TYPE_CHECKING:
from .dataset_holder import RailDatasetHolder
Expand All @@ -23,13 +25,15 @@ def __init__(
figure: Figure | None = None,
plotter: RailPlotter | None = None,
dataset_holder: RailDatasetHolder | None = None,
data: dict[str, Any] | None = None,
):
"""C'tor"""
self._name = name
self._path = path
self._figure = figure
self._plotter = plotter
self._dataset_holder = dataset_holder
self._data = data

@property
def name(self) -> str:
Expand All @@ -56,6 +60,11 @@ def dataset_holder(self) -> RailDatasetHolder | None:
"""Return the dataset used to make the plot"""
return self._dataset_holder

@property
def data(self) -> dict[str, Any] | None:
"""Return the data used to make the plot"""
return self._data

def set_path(
self,
path: str | None = None,
Expand Down Expand Up @@ -83,6 +92,29 @@ def savefig(
fullpath = os.path.join(outdir, relpath)
self.figure.savefig(fullpath, **kwargs)

def savedata(
self,
outdir: str = ".",
) -> None:
if self.data is None: # pragma: no cover
raise ValueError(f"Tried to savedata missing data {self.name}")

fullpath = os.path.join(outdir, f"{self.path}.yaml")

yaml_data: dict[str, Any] = {}
for k, v in self.data.items():
if isinstance(v, (np.floating, float)):
yaml_data[k] = float(v)
elif isinstance(v, np.ndarray) and v.ndim == 0:
yaml_data[k] = float(v)
elif isinstance(v, np.ndarray):
yaml_data[k] = v.tolist()
else:
yaml_data[k] = v

with open(fullpath, "w", encoding="utf-8") as fout:
yaml.dump(yaml_data, fout, default_flow_style=False, sort_keys=False)


class RailPlotDict:
"""Simple class for dicts of matplotlib Figures
Expand Down Expand Up @@ -116,7 +148,7 @@ def savefigs(
os.makedirs(outpath)
for _key, val in self._plots.items():
if val.path: # pragma: no cover
val.savefig(val.path, outpath, **kwargs)
val.savefig(val.path, os.path.dirname(outpath), **kwargs)
else:
val.savefig(
os.path.join(os.path.basename(outpath), f"{val.name}.{figtype}"),
Expand All @@ -125,3 +157,18 @@ def savefigs(
)
if purge:
val.set_figure(None)

def savedata(
self,
outpath: str = ".",
) -> None:

if not os.path.exists(outpath): # pragma: no cover
os.makedirs(outpath)

for _key, val in self._plots.items():
if val.path: # pragma: no cover
val.savedata(outpath)
else:
val.set_path(val.name)
val.savedata(outpath)
2 changes: 0 additions & 2 deletions src/rail/plotting/pz_data_holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,5 +467,3 @@ def generate_dataset_dict(
dataset_lists.append(dataset_list)

return (projects, datasets, dataset_lists)


49 changes: 30 additions & 19 deletions src/rail/plotting/pz_dist_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ class RailPZDistributionDataset(RailDataset):
"""Dataet to hold a vector p(z) point estimates and corresponding
true redshifts
"""

data_types = dict(truth=np.ndarray, pz=qp.Ensemble)


class PZPlotterPITProb(RailPlotter):
"""Class to plot the p(z_true)
"""
"""Class to plot the p(z_true)"""

config_options: dict[str, StageParameter] = RailPlotter.config_options.copy()
config_options.update(
Expand All @@ -44,22 +44,27 @@ def _make_prob_plot(
figure, axes = plt.subplots(figsize=(7, 6))

pit = qp.metrics.PIT(pz, truth)
bin_edges = np.linspace(
0., 1., self.config.n_prob_bins + 1
)
bin_edges = np.linspace(0.0, 1.0, self.config.n_prob_bins + 1)

pdf_vals = pit.pit.pdf(np.linspace(0, 1))
mean = 1./pdf_vals.mean()
_ = axes.plot(np.linspace(0, 1), pit.pit.pdf(np.linspace(0, 1)))
_ = axes.plot([0, 1], [mean, mean], "--", color='black')
pdf_vals = np.squeeze(pit.pit.pdf(bin_edges))
mean = 1.0 / pdf_vals.mean()
_ = axes.plot(bin_edges, pdf_vals)
_ = axes.plot([0, 1], [mean, mean], "--", color="black")
_ = axes.set_xlabel("Q")
_ = axes.set_ylabel(r"$P(z_{\rm ref})$")
_ = plt.xlim(0, 1)

plot_name = self._make_full_plot_name(prefix, "")

return RailPlotHolder(
name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder
name=plot_name,
figure=figure,
plotter=self,
dataset_holder=dataset_holder,
data=dict(
x_vals=bin_edges,
y_vals=pdf_vals,
),
)

def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]:
Expand Down Expand Up @@ -89,10 +94,8 @@ def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]:
return out_dict



class PZPlotterPITQQ(RailPlotter):
"""Class to plot the p(z_true > z(Q))
"""
"""Class to plot the p(z_true > z(Q))"""

config_options: dict[str, StageParameter] = RailPlotter.config_options.copy()
config_options.update(
Expand All @@ -111,17 +114,16 @@ def _make_pit_qq_plot(
figure, axes = plt.subplots(figsize=(7, 6))

pit = qp.metrics.PIT(pz, truth)
bin_edges = np.linspace(
0., 1., self.config.n_prob_bins + 1
)
bin_edges = np.linspace(0.0, 1.0, self.config.n_prob_bins + 1)

ks = pit.evaluate_PIT_KS().statistic
outlier = pit.evaluate_PIT_outlier_rate()
CvM = pit.evaluate_PIT_CvM().statistic
ksamp = pit.evaluate_PIT_anderson_ksamp().statistic
cdf_vals = pit.pit.cdf(bin_edges)

_ = axes.plot(np.linspace(0, 1, 101), pit.pit.cdf(np.linspace(0, 1, 101)))
_ = axes.plot([0, 1], [0, 1], "--", color='black')
_ = axes.plot(bin_edges, cdf_vals)
_ = axes.plot([0, 1], [0, 1], "--", color="black")
_ = axes.plot(
[],
[],
Expand All @@ -144,7 +146,16 @@ def _make_pit_qq_plot(
plot_name = self._make_full_plot_name(prefix, "")

return RailPlotHolder(
name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder
name=plot_name,
figure=figure,
plotter=self,
dataset_holder=dataset_holder,
data=dict(
ks=ks,
outlier=outlier,
CvM=CvM,
ksamp=ksamp,
),
)

def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]:
Expand Down
69 changes: 42 additions & 27 deletions src/rail/plotting/pz_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _make_2d_hist_plot(
self.config.z_min, self.config.z_max, self.config.n_zbins + 1
)
dz = (pointEstimate - truth) / (1 + truth)
mean, _mean_err, std, outlier_rate, abs_outlier_rate = (
mean, mean_err, std, outlier_rate, abs_outlier_rate = (
self.get_biweight_mean_sigma_outlier(dz, nclip=self.config.n_clip)
)
mean, std, outlier_rate, abs_outlier_rate = (
Expand Down Expand Up @@ -125,7 +125,17 @@ def _make_2d_hist_plot(
plot_name = self._make_full_plot_name(prefix, "")

return RailPlotHolder(
name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder
name=plot_name,
figure=figure,
plotter=self,
dataset_holder=dataset_holder,
data=dict(
mean=mean,
mean_err=mean_err,
std=std,
outlier_rate=outlier_rate,
abs_outlier_rate=abs_outlier_rate,
),
)

def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]:
Expand Down Expand Up @@ -421,7 +431,11 @@ def _make_biweight_stats_plot(
axes[1].set_ylabel(r"$(z_{phot} - z_{spec})/(1+z_{spec})$")
plot_name = self._make_full_plot_name(prefix, "")
return RailPlotHolder(
name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder
name=plot_name,
figure=figure,
plotter=self,
dataset_holder=dataset_holder,
data=results,
)

def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]:
Expand Down Expand Up @@ -505,16 +519,16 @@ def process_data(
z_mean.append(np.mean(zx[bin_indices == i]))

return {
"z_mean": z_mean,
"biweight_mean": biweight_mean,
"biweight_std": biweight_std,
"biweight_sigma": biweight_sigma,
"biweight_outlier": biweight_outlier,
"qt_95_low": qt_95_low,
"qt_68_low": qt_68_low,
"median": median,
"qt_68_high": qt_68_high,
"qt_95_high": qt_95_high,
"z_mean": np.array(z_mean),
"biweight_mean": np.array(biweight_mean),
"biweight_std": np.array(biweight_std),
"biweight_sigma": np.array(biweight_sigma),
"biweight_outlier": np.array(biweight_outlier),
"qt_95_low": np.array(qt_95_low),
"qt_68_low": np.array(qt_68_low),
"median": np.array(median),
"qt_68_high": np.array(qt_68_high),
"qt_95_high": np.array(qt_95_high),
}


Expand Down Expand Up @@ -604,7 +618,11 @@ def _make_biweight_stats_plot(
plot_name = self._make_full_plot_name(prefix, "")

return RailPlotHolder(
name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder
name=plot_name,
figure=figure,
plotter=self,
dataset_holder=dataset_holder,
data=results,
)

def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]:
Expand Down Expand Up @@ -697,17 +715,14 @@ def process_data( # pylint: disable=too-many-arguments
mag_mean = (mag_bins[:-1] + mag_bins[1:]) / 2

return {
"mag_mean": mag_mean,
"biweight_mean": biweight_mean,
"biweight_std": biweight_std,
"biweight_sigma": biweight_sigma,
"biweight_outlier": biweight_outlier,
"qt_95_low": qt_95_low,
"qt_68_low": qt_68_low,
"median": median,
"qt_68_high": qt_68_high,
"qt_95_high": qt_95_high,
"mag_mean": np.array(mag_mean),
"biweight_mean": np.array(biweight_mean),
"biweight_std": np.array(biweight_std),
"biweight_sigma": np.array(biweight_sigma),
"biweight_outlier": np.array(biweight_outlier),
"qt_95_low": np.array(qt_95_low),
"qt_68_low": np.array(qt_68_low),
"median": np.array(median),
"qt_68_high": np.array(qt_68_high),
"qt_95_high": np.array(qt_95_high),
}



2 changes: 2 additions & 0 deletions src/rail/plotting/utility_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any

import numpy as np
from pathlib import Path

from rail.utils.path_utils import find_rail_file


Expand Down
Loading