Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/rail/plotting/data_extraction_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,51 @@ def make_z_true_multi_z_point_dict(
return out_dict


def get_pz_pdf_data(
project: RailProject,
selection: str,
flavor: str,
tag: str,
algo: str,
) -> dict[str, Any] | None:
"""Get the true redshifts and point estimates
for a particualar analysis selection and flavor

Parameters
----------
project: RailProject
Object with information about the structure of the current project

selection: str
Data selection in question, e.g., 'gold', or 'blended'

flavor: str
Analysis flavor in question, e.g., 'baseline' or 'zCosmos'

algo: str
Algorithm we want the estimates for, e.g., 'knn', 'bpz'], etc...

tag: str
File tag, e.g., 'test' or 'train', or 'train_zCosmos'

Returns
-------
pz_data: dict[str, Any] | None
Data in question or None if a file is missing
"""
z_true_path = path_funcs.get_z_true_path(project, selection, flavor, tag)
z_estimate_path = path_funcs.get_ceci_pz_output_path(
project, selection, flavor, algo
)
if z_estimate_path is None: # pragma: no cover
return None
z_true_data = extract_z_true(z_true_path)
z_pdf_data = extract_z_pdf(z_estimate_path)
pz_data = dict(truth=z_true_data, pz=z_pdf_data)
return pz_data



def get_pz_point_estimate_data(
project: RailProject,
selection: str,
Expand Down
204 changes: 204 additions & 0 deletions src/rail/plotting/pz_data_holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .data_extraction_funcs import (
get_multi_pz_point_estimate_data,
get_pz_pdf_data,
get_pz_point_estimate_data,
)
from .dataset import RailDataset
Expand All @@ -18,6 +19,7 @@
RailDatasetListHolder,
RailProjectHolder,
)
from .pz_dist_plotters import RailPZDistributionDataset
from .pz_plotters import RailPZMultiPointEstimateDataset, RailPZPointEstimateDataset


Expand Down Expand Up @@ -265,3 +267,205 @@ def get_extractor_inputs(self) -> dict[str, Any]:
)
self._validate_extractor_inputs(**the_extractor_inputs)
return the_extractor_inputs


class RailPZPdfDataHolder(RailDatasetHolder):
"""Class to extract true redshifts and one p(z) pdf
from a RailProject.

This will return a dict:

truth: np.ndarray
True redshifts

pz: qp.Ensemble
Point estimates of the true redshifts
"""

config_options: dict[str, StageParameter] = dict(
name=StageParameter(str, None, fmt="%s", required=True, msg="Dataset name"),
project=StageParameter(
str, None, fmt="%s", required=True, msg="RailProject name"
),
selection=StageParameter(
str, None, fmt="%s", required=True, msg="RailProject data selection"
),
flavor=StageParameter(
str, None, fmt="%s", required=True, msg="RailProject analysis flavor"
),
tag=StageParameter(
str, None, fmt="%s", required=True, msg="RailProject file tag"
),
algo=StageParameter(
str, None, fmt="%s", required=True, msg="RailProject algorithm"
),
)

extractor_inputs: dict = {
"project": RailProject,
"selection": str,
"flavor": str,
"tag": str,
"algo": str,
}

output_type: type[RailDataset] = RailPZDistributionDataset

def __init__(self, **kwargs: Any):
RailDatasetHolder.__init__(self, **kwargs)
self._project: RailProject | None = None

def __repr__(self) -> str:
ret_str = (
f"{self.__class__.__name__} "
"( "
f"{self.config.project}, "
f"{self.config.selection}_{self.config.flavor}_{self.config.tag}_{self.config.algo}"
")"
)
return ret_str

def _get_data(self, **kwargs: Any) -> dict[str, Any] | None:
return get_pz_pdf_data(**kwargs)

def get_extractor_inputs(self) -> dict[str, Any]:
if self._project is None:
self._project = RailDatasetFactory.get_project(
self.config.project
).resolve()
the_extractor_inputs = dict(
project=self._project,
selection=self.config.selection,
flavor=self.config.flavor,
tag=self.config.tag,
algo=self.config.algo,
)
self._validate_extractor_inputs(**the_extractor_inputs)
return the_extractor_inputs

@classmethod
def generate_dataset_dict(
cls,
**kwargs: Any,
) -> tuple[
list[RailProjectHolder], list[RailDatasetHolder], list[RailDatasetListHolder]
]:
"""
Parameters
----------
**kwargs
Set Notes

Notes
-----
dataset_list_name: str
Name for the resulting DatasetList

project_file: str
Config file for project to inspect

selections: list[str]
Selections to use

flavors: list[str]
Flavors to use

Returns
-------
list[RailProjectHolder]
Underlying RailProjects

list[RailDatasetHolder]
Extracted datasets

list[RailDatasetListHolder]
Extracted dataset lists
"""
dataset_list_name: str | None = kwargs.get("dataset_list_name")
project_file = kwargs["project_file"]
project = RailProject.load_config(project_file)
selections = kwargs.get("selections")
flavors = kwargs.get("flavors")
split_mode = kwargs.get("split_mode", DatasetSplitMode.by_algo)

flavor_dict = project.get_flavors()
if flavors is None or "all" in flavors:
flavors = list(flavor_dict.keys())
if selections is None or "all" in selections:
selections = list(project.get_selections().keys())

project_name = project.name
if not dataset_list_name:
dataset_list_name = f"{project_name}_pz"

projects: list[RailProjectHolder] = []
datasets: list[RailDatasetHolder] = []
dataset_lists: list[RailDatasetListHolder] = []

projects.append(
RailProjectHolder(
name=project_name,
yaml_file=project_file,
)
)

dataset_list_dict: dict[str, list[str]] = {}
dataset_key = dataset_list_name
if split_mode == DatasetSplitMode.no_split:
dataset_list_dict[dataset_key] = []

for key in flavors:
val = flavor_dict[key]
pipelines = val["pipelines"]
if "all" not in pipelines and "pz" not in pipelines: # pragma: no cover
continue
try:
algos = val["pipeline_overrides"]["default"]["kwargs"]["algorithms"]
except KeyError:
algos = list(project.get_pzalgorithms().keys())

for selection_ in selections:
if split_mode == DatasetSplitMode.by_flavor:
dataset_key = f"{dataset_list_name}_{selection_}_{key}"
dataset_list_dict[dataset_key] = []

for algo_ in algos:
if split_mode == DatasetSplitMode.by_algo:
dataset_key = f"{dataset_list_name}_{selection_}_{algo_}"
if dataset_key not in dataset_list_dict:
dataset_list_dict[dataset_key] = []

path = path_funcs.get_ceci_pz_output_path(
project,
selection=selection_,
flavor=key,
algo=algo_,
)
if path is None:
continue
dataset_name = f"{selection_}_{key}_{algo_}"
dataset = cls(
name=dataset_name,
project=project_name,
flavor=key,
algo=algo_,
tag="test",
selection=selection_,
)
datasets.append(dataset)
dataset_list_dict[dataset_key].append(dataset_name)

for ds_name, ds_list in dataset_list_dict.items():
# Skip empty lists
if not ds_list:
continue
dataset_list = RailDatasetListHolder(
name=ds_name,
dataset_class=cls.output_type.full_class_name(),
datasets=ds_list,
)
dataset_lists.append(dataset_list)

return (projects, datasets, dataset_lists)


Loading
Loading