Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/rail/cli/rail_project/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def subsample_command(
@project_options.reducer_class_name()
@project_options.input_selection()
@project_options.selection()
@project_options.sim_version()
def reduce_command(
config_file: str, run_mode: project_options.RunMode, **kwargs: Any
) -> int:
Expand Down Expand Up @@ -233,6 +234,7 @@ def run_group() -> None:
@project_options.selection()
@project_options.flavor()
@project_options.run_mode()
@project_options.convert_output()
@project_options.site()
def photmetric_errors_pipeline(config_file: str, **kwargs: Any) -> int:
"""Run the photometric errors analysis pipeline"""
Expand All @@ -257,6 +259,7 @@ def photmetric_errors_pipeline(config_file: str, **kwargs: Any) -> int:
@project_options.selection()
@project_options.flavor()
@project_options.run_mode()
@project_options.convert_output()
@project_options.site()
def prepare_pipeline(config_file: str, **kwargs: Any) -> int:
"""Run the truth-to-observed data pipeline"""
Expand All @@ -281,6 +284,7 @@ def prepare_pipeline(config_file: str, **kwargs: Any) -> int:
@project_options.selection()
@project_options.flavor()
@project_options.run_mode()
@project_options.convert_output()
@project_options.site()
def truth_to_observed_pipeline(config_file: str, **kwargs: Any) -> int:
"""Run the truth-to-observed data pipeline"""
Expand All @@ -305,6 +309,7 @@ def truth_to_observed_pipeline(config_file: str, **kwargs: Any) -> int:
@project_options.selection()
@project_options.flavor()
@project_options.run_mode()
@project_options.convert_output()
@project_options.site()
def blending_pipeline(config_file: str, **kwargs: Any) -> int:
"""Run the blending analysis pipeline"""
Expand All @@ -329,6 +334,7 @@ def blending_pipeline(config_file: str, **kwargs: Any) -> int:
@project_options.selection()
@project_options.flavor()
@project_options.run_mode()
@project_options.convert_output()
@project_options.site()
def spectroscopic_selection_pipeline(config_file: str, **kwargs: Any) -> int:
"""Run the spectroscopic selection data pipeline"""
Expand Down
15 changes: 15 additions & 0 deletions src/rail/cli/rail_project/project_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"args",
"basename",
"config_path",
"convert_output",
"catalog_template",
"file_template",
"force",
Expand All @@ -27,6 +28,7 @@
"reducer_class_name",
"run_mode",
"selection",
"sim_version",
"site",
"splitter_class_name",
"subsampler_class_name",
Expand Down Expand Up @@ -67,6 +69,11 @@
type=click.Path(),
)

convert_output = PartialOption(
"--convert-output/--no-convert-output",
help="Convert outputfiles",
default=True,
)

catalog_template = PartialOption(
"--catalog-template",
Expand Down Expand Up @@ -238,13 +245,21 @@
help="Mode to run script",
)

sim_version = PartialOption(
"--sim-version",
type=str,
help="Optional override to simulation version",
)


size = PartialOption(
"--size",
type=int,
default=100_000,
help="Number of objects in file",
)


test_file_template = PartialOption(
"--test-file-template",
type=str,
Expand Down
49 changes: 31 additions & 18 deletions src/rail/projects/pipeline_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,15 +540,15 @@ def tomography_input_callback(
return input_files


def truth_to_observed_convert_commands(
sink_dir: str, **kwargs: Any
) -> list[list[str]]:
phot_errors = kwargs.get("error_models", [])
def truth_to_observed_convert_commands(sink_dir: str, **kwargs: Any) -> list[list[str]]:
phot_errors = kwargs.get("error_models", {})
if phot_errors is not None:
assert isinstance(phot_errors, dict)
spec_selections = kwargs.get("selectors", [])
spec_selections = kwargs.get("selectors", {})
if spec_selections is not None:
assert isinstance(spec_selections, dict)
models_to_run_select = kwargs.get("models_to_run_select", [])

convert_commands = []

for phot_error_ in phot_errors:
Expand All @@ -561,8 +561,9 @@ def truth_to_observed_convert_commands(
f"{sink_dir}/output_error_model_{phot_error_}.hdf5",
]
convert_commands += [convert_command]

for spec_selection_ in spec_selections:
if spec_selection_ not in models_to_run_select:
continue
convert_command = [
"tables-io",
"convert",
Expand Down Expand Up @@ -590,7 +591,7 @@ def prepare_convert_commands(sink_dir: str, **_kwargs: Any) -> list[list[str]]:

def photometric_errors_convert_commands(
sink_dir: str, **_kwargs: Any
) -> list[list[str]]:
) -> list[list[str]]:
convert_command = [
"tables-io",
"convert",
Expand Down Expand Up @@ -787,7 +788,9 @@ def __init__(self, **kwargs: Any):
def __repr__(self) -> str:
return f"{self.config.pipeline_template} {self.config.path}"

def _parse_pipeline_kwargs(self, project: RailProject, **kwargs: Any) -> dict[str, Any]:
def _parse_pipeline_kwargs(
self, project: RailProject, **kwargs: Any
) -> dict[str, Any]:
"""Parse the set of kwargs to expand out 'all'"""
overrides: dict[str, Any] = {}
for key, val in kwargs.items():
Expand All @@ -801,6 +804,9 @@ def _parse_pipeline_kwargs(self, project: RailProject, **kwargs: Any) -> dict[st
temp_dict = project.get_summarizers()
elif key == "error_models":
temp_dict = project.get_error_models()
elif key == "models_to_run_select":
overrides[key] = val
continue
else:
continue
if "all" in val:
Expand All @@ -810,7 +816,7 @@ def _parse_pipeline_kwargs(self, project: RailProject, **kwargs: Any) -> dict[st
algo_name_: temp_dict[algo_name_] for algo_name_ in val
}
return overrides

def build(
self,
project: RailProject,
Expand Down Expand Up @@ -843,7 +849,7 @@ def build(
stages_config = None

parsed_overrides = self._parse_pipeline_kwargs(project, **pipeline_kwargs)
pipeline_kwargs.update(**parsed_overrides)
pipeline_kwargs.update(**parsed_overrides)

catalog_tag = project.get_flavor(self.config.flavor).get("catalog_tag", None)
if catalog_tag:
Expand Down Expand Up @@ -965,6 +971,8 @@ def make_pipeline_catalog_commands(
"""
pipeline_name = self.config.pipeline_template
pipeline_info = project.get_pipeline(pipeline_name)
convert_output = kwargs.pop("convert_output", False)

flavor = self.config.flavor
pipeline_path = project.get_path(
"pipeline_path", pipeline=pipeline_name, flavor=flavor, **kwargs
Expand All @@ -988,9 +996,11 @@ def make_pipeline_catalog_commands(
all_commands: list[tuple[list[list[str]], str]] = []

pipeline_config_kwargs = pipeline_info.config.kwargs.copy()
parsed_overrides = self._parse_pipeline_kwargs(project, **pipeline_config_kwargs)
pipeline_config_kwargs.update(**parsed_overrides)

parsed_overrides = self._parse_pipeline_kwargs(
project, **pipeline_config_kwargs
)
pipeline_config_kwargs.update(**parsed_overrides)

selection = kwargs["selection"]

for source_catalog, sink_catalog in zip(
Expand All @@ -1008,11 +1018,14 @@ def make_pipeline_catalog_commands(
output_dir=sink_dir,
log_dir=sink_dir,
)
convert_commands = catalog_convert_commands_function(
sink_dir,
**kwargs,
**pipeline_config_kwargs,
)
if convert_output:
convert_commands = catalog_convert_commands_function(
sink_dir,
**kwargs,
**pipeline_config_kwargs,
)
else:
convert_commands = []
iter_commands = [
["mkdir", "-p", f"{sink_dir}"],
ceci_commands,
Expand Down
22 changes: 16 additions & 6 deletions src/rail/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class RailFlavor(Configurable):
pipelines=StageParameter(list, ["all"], fmt="%s", msg="pipelines being used"),
file_aliases=StageParameter(dict, {}, fmt="%s", msg="file aliases used"),
pipeline_overrides=StageParameter(dict, {}, fmt="%s", msg="file aliases used"),
path_overrides=StageParameter(
dict, {}, fmt="%s", required=False, msg="Overrieds to common paths"
),
)

def __init__(self, **kwargs: Any) -> None:
Expand Down Expand Up @@ -523,10 +526,9 @@ def subsample_data(
for key, val in subsampler_args.config.to_dict().items()
if key in subsampler_config_keys
}

subsampler = subsampler_class(**use_pairs)

basename_dict: dict[str, str] = subsampler.get_basename_dict()
basename_dict: dict[str, str] = subsampler.get_basename_dict(**kwargs)
sources_dict: dict[str, list[str]] = {}

for key, val in basename_dict.items():
Expand Down Expand Up @@ -620,9 +622,9 @@ def build_pipelines(
if self.config.CatalogLib:
for catalog_lib in self.config.CatalogLib:
catalog_utils.load_yaml(catalog_lib)
else:
else:
catalog_utils.load_yaml(catalog_utils.DEFAULT_CATAlOG_TAG_FILE)

flavor_dict = self.get_flavor(flavor)
pipelines_to_build = flavor_dict["pipelines"]
all_flavor_overrides = flavor_dict.get("pipeline_overrides", {}).copy()
Expand Down Expand Up @@ -677,7 +679,11 @@ def make_pipeline_single_input_command(
"""
pipeline_template = self.get_pipeline(pipeline_name)
pipeline_instance = pipeline_template.make_instance(self, flavor, {})
return pipeline_instance.make_pipeline_single_input_command(self, **kwargs)
flavor_dict = self.get_flavor(flavor)
path_overrides = flavor_dict.config.path_overrides
return pipeline_instance.make_pipeline_single_input_command(
self, **kwargs, **path_overrides
)

def make_pipeline_catalog_commands(
self,
Expand Down Expand Up @@ -705,7 +711,11 @@ def make_pipeline_catalog_commands(
"""
pipeline_template = self.get_pipeline(pipeline_name)
pipeline_instance = pipeline_template.make_instance(self, flavor, {})
return pipeline_instance.make_pipeline_catalog_commands(self, **kwargs)
flavor_dict = self.get_flavor(flavor)
path_overrides = flavor_dict.config.path_overrides
return pipeline_instance.make_pipeline_catalog_commands(
self, **kwargs, **path_overrides
)

def run_pipeline_single(
self,
Expand Down
Loading
Loading