Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/rail/projects/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,26 @@ def setup_project_area() -> int: # pragma: no cover
"tests/temp_data/data/test/ci_test_blend_baseline_100k.hdf5"
):
return 3

if not os.path.exists("tests/temp_data/data/ci_test_subsample_gold_baseline.tgz"):
urllib.request.urlretrieve(
# "https://portal.nersc.gov/cfs/lsst/PZ/test_data/ci_test.tgz",
"http://s3df.slac.stanford.edu/people/echarles/xfer/ci_test_subsample_gold_baseline.tgz",
"tests/temp_data/data/ci_test_subsample_gold_baseline.tgz",
)
if not os.path.exists("tests/temp_data/data/ci_test_subsample_gold_baseline.tgz"):
return 4

if not os.path.exists("tests/temp_data/data/ci_test_subsample_gold_baseline"):
status = subprocess.run(
["tar", "zxvf", "tests/temp_data/data/ci_test_subsample_gold_baseline.tgz", "-C", "tests/temp_data/data"], check=False
)
if status.returncode != 0:
return status.returncode

if not os.path.exists("tests/temp_data/data/ci_test_subsample_gold_baseline/28/output_error_model_roman_medium.pq"):
return 2

return 0


Expand Down
25 changes: 25 additions & 0 deletions src/rail/projects/subsample_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ class RailSubsample(Configurable):
inputs=StageParameter(
dict, None, fmt="%s", required=False, msg="Input catalog detatils"
),
# Fields used by SpecAreaSubsampler (ignored by other subsamplers)
ra_col=StageParameter(
str, "ra", required=False, fmt="%s", msg="RA column name for area cuts"
),
dec_col=StageParameter(
str, "dec", required=False, fmt="%s", msg="Dec column name for area cuts"
),
spec_inputs=StageParameter(
dict,
None,
required=False,
fmt="%s",
msg="Per-survey spec inputs with optional area_cut (SpecAreaSubsampler)",
),
photometric_inputs=StageParameter(
dict,
None,
required=False,
fmt="%s",
msg="Photometric inputs to inner-join with spec union (SpecAreaSubsampler)",
),
)
yaml_tag = "Subsample"

Expand All @@ -48,6 +69,10 @@ def __init__(self, **kwargs: Any):
def __repr__(self) -> str:
return f"N={self.config.num_objects} seed={self.config.seed}"

def is_spec_area_subsample(self) -> bool:
"""Return True if this subsample is configured for SpecAreaSubsampler."""
return self.config.spec_inputs is not None


class RailSubsampleFactory(RailFactoryMixin):
"""Factory class to make subsamples
Expand Down
286 changes: 286 additions & 0 deletions src/rail/projects/subsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, cast

import numpy as np
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq
from ceci.config import StageParameter
Expand Down Expand Up @@ -208,3 +209,288 @@ def run(
output,
)
print("done")


class SpecAreaSubsampler(RailSubsampler):
"""Combine spectroscopic surveys with per-survey area cuts, then join photometric data.

Unlike MultiCatalogSubsampler (which uses a single pre-merged spec file),
this class takes individual per-survey spec selection files, applies optional
area cuts (square RA/Dec boxes) to specified surveys, unions the results,
then inner-joins with photometric inputs (e.g., Roman). All selected
objects are written to the output — no random sub-sampling.

Area cut convention
-------------------
For a survey entry that contains an ``area_cut`` key with sub-keys
``ra_center``, ``dec_center``, and ``area_sq_deg``, the selection is::

|RA - ra_center| <= sqrt(area_sq_deg)/2 / cos(dec_center)
|Dec - dec_center| <= sqrt(area_sq_deg)/2

This produces a box that is square in projected (RA*cos(dec), Dec) space
and covers approximately ``area_sq_deg`` square degrees.
Surveys without an ``area_cut`` entry are included in their entirety.

Configuration
-------------
spec_inputs : dict
Mapping of survey label to input parameters dict. Each entry must
contain ``basename``, ``bands``, ``mag_band_name_template``, and
``mag_err_band_name_template``. Optional keys: ``extra_cols``
(list), ``cuts`` (list of 3-element filter specs), and ``area_cut``
(dict with ``ra_center``, ``dec_center``, ``area_sq_deg``).
photometric_inputs : dict
Mapping of catalog label to input parameters dict (same structure as
``spec_inputs`` but without ``area_cut``). These catalogs are
inner-joined with the unioned spec catalog on ``object_id_col``.
object_id_col : str
Column used as the join key (default: ``"object_id"``).
ra_col : str
RA column name used for area cuts (default: ``"ra"``).
dec_col : str
Dec column name used for area cuts (default: ``"dec"``).

Example YAML (Subsample entry)
-------------------------------
.. code-block:: yaml

- Subsample:
name: "taskset_2_spec_train_10yr"
object_id_col: "object_id"
ra_col: "ra"
dec_col: "dec"
spec_inputs:
zCOSMOS:
basename: output_select_lsst_obs_cond_10yr_zCOSMOS.pq
bands: ['u', 'g', 'r', 'i', 'z', 'y']
mag_band_name_template: "mag_{band}_lsst"
mag_err_band_name_template: "mag_{band}_lsst_err"
extra_cols: ['redshift', 'ra', 'dec']
cuts:
- ['mag_i_lsst', '<', 25.4]
area_cut:
ra_center: 9.0
dec_center: -42.0
area_sq_deg: 2.0
VVDSf02:
basename: output_select_lsst_obs_cond_10yr_VVDSf02.pq
bands: ['u', 'g', 'r', 'i', 'z', 'y']
mag_band_name_template: "mag_{band}_lsst"
mag_err_band_name_template: "mag_{band}_lsst_err"
extra_cols: ['redshift', 'ra', 'dec']
cuts:
- ['mag_i_lsst', '<', 25.4]
area_cut:
ra_center: 14.0
dec_center: -42.0
area_sq_deg: 0.6
DEEP2_LSST:
basename: output_select_lsst_obs_cond_10yr_DEEP2_LSST.pq
bands: ['u', 'g', 'r', 'i', 'z', 'y']
mag_band_name_template: "mag_{band}_lsst"
mag_err_band_name_template: "mag_{band}_lsst_err"
extra_cols: ['redshift', 'ra', 'dec']
cuts:
- ['mag_i_lsst', '<', 25.4]
area_cut:
ra_center: 9.0
dec_center: -46.0
area_sq_deg: 2.0
DESI_BGS:
basename: output_select_lsst_obs_cond_10yr_DESI_BGS_color.pq
bands: ['u', 'g', 'r', 'i', 'z', 'y']
mag_band_name_template: "mag_{band}_lsst"
mag_err_band_name_template: "mag_{band}_lsst_err"
extra_cols: ['redshift', 'ra', 'dec']
cuts:
- ['mag_i_lsst', '<', 25.4]
# DESI_LRG, DESI_ELG_LOP etc. follow same pattern (no area_cut = full sky)
photometric_inputs:
roman:
basename: output_deredden_roman_medium.pq
bands: ['Y', 'J', 'H']
mag_band_name_template: "mag_{band}_roman"
mag_err_band_name_template: "mag_{band}_roman_err"
extra_cols: []
"""

config_options: dict[str, StageParameter] = dict(
name=StageParameter(str, None, fmt="%s", required=True, msg="Subsampler Name"),
object_id_col=StageParameter(
str, "object_id", fmt="%s", msg="Object Id column name"
),
ra_col=StageParameter(str, "ra", fmt="%s", msg="RA column name for area cuts"),
dec_col=StageParameter(
str, "dec", fmt="%s", msg="Dec column name for area cuts"
),
spec_inputs=StageParameter(
dict,
None,
fmt="%s",
msg="Per-survey spec selection inputs (with optional area_cut)",
),
photometric_inputs=StageParameter(
dict,
None,
fmt="%s",
msg="Photometric catalog inputs to inner-join with the spec union",
),
)

def get_basename_dict(self, **kwargs: Any) -> dict[str, str]:
out_dict: dict[str, str] = {}
for key, val in (self.config.spec_inputs or {}).items():
assert "basename" in val, f"spec_inputs['{key}'] missing 'basename'"
out_dict[key] = val["basename"]
for key, val in (self.config.photometric_inputs or {}).items():
assert "basename" in val, f"photometric_inputs['{key}'] missing 'basename'"
out_dict[key] = val["basename"]
return out_dict

@staticmethod
def _get_mag_columns(input_params: dict[str, Any]) -> list[str]:
try:
bands = input_params["bands"]
except KeyError: # pragma: no cover
raise KeyError(
f"Input parameters does not include 'bands'"
f" {list(input_params.keys())}"
) from None
try:
mag_band_name_template = input_params["mag_band_name_template"]
except KeyError: # pragma: no cover
raise KeyError(
f"Input parameters does not include 'mag_band_name_template'"
f" {list(input_params.keys())}"
) from None
try:
mag_err_band_name_template = input_params["mag_err_band_name_template"]
except KeyError: # pragma: no cover
raise KeyError(
f"Input parameters does not include 'mag_err_band_name_template'"
f" {list(input_params.keys())}"
) from None
out_list: list[str] = [
mag_band_name_template.format(band=band_) for band_ in bands
]
out_list += [mag_err_band_name_template.format(band=band_) for band_ in bands]
return out_list

def _make_area_cut_filters(self, input_params: dict[str, Any]) -> list[list]:
"""Return RA/Dec box filter entries for this input, or [] if no area_cut.

The box is square in projected (RA*cos(dec), Dec) space:
RA half-width = sqrt(area) / 2 / cos(dec_center)
Dec half-width = sqrt(area) / 2
so that the covered sky area is approximately area_sq_deg.
"""
area_cut = input_params.get("area_cut")
if not area_cut:
return []
ra0 = float(area_cut["ra_center"])
dec0 = float(area_cut["dec_center"])
area = float(area_cut["area_sq_deg"])
half_side_dec = np.sqrt(area) / 2.0
half_side_ra = half_side_dec / np.cos(np.radians(dec0))
ra_col = self.config.ra_col
dec_col = self.config.dec_col
return [
[ra_col, ">=", ra0 - half_side_ra],
[ra_col, "<=", ra0 + half_side_ra],
[dec_col, ">=", dec0 - half_side_dec],
[dec_col, "<=", dec0 + half_side_dec],
]

def _select_input(
self,
input_params: dict[str, Any],
file_list: list[str],
apply_area_cut: bool = False,
) -> pa.Table:
"""Filter dataset, apply optional area cut, project to required columns."""
all_cuts: list = []
sub_sel_cuts = input_params.get("cuts", [])
if sub_sel_cuts:
all_cuts += list(sub_sel_cuts)
if apply_area_cut:
all_cuts += self._make_area_cut_filters(input_params)

parsed_cuts = parse_item(all_cuts) if all_cuts else []
dataset = ds.dataset(file_list)

save_cols: list[str] = [self.config.object_id_col]
save_cols += self._get_mag_columns(input_params)
save_cols += input_params.get("extra_cols", [])

filtered = filter_dataset(dataset, parsed_cuts, save_cols) # type: ignore
return filtered.to_table()

def run(
self,
input_files: dict[str, list[str]],
output: str,
) -> None:
spec_inputs = self.config.spec_inputs or {}
photo_inputs = self.config.photometric_inputs or {}

# 1. Process each spec survey, applying area cut where configured
spec_tables: list[pa.Table] = []
for key, input_params in spec_inputs.items():
if key not in input_files:
print(f"Warning: spec key '{key}' not found in input_files, skipping")
continue
table = self._select_input(
input_params, input_files[key], apply_area_cut=True
)
has_area_cut = bool(input_params.get("area_cut"))
print(
f"{key}: {table.num_rows} rows"
f" (area cut {'applied' if has_area_cut else 'not applied'})"
)
spec_tables.append(table)

if not spec_tables:
raise ValueError("No spec survey inputs produced any data")

# 2. Union all spec tables; deduplicate by object_id
# (an object may pass color cuts for multiple surveys)
spec_combined = pa.concat_tables(spec_tables, promote_options="default")
print(f"Combined spec rows before dedup: {spec_combined.num_rows}")

object_id_col = self.config.object_id_col
seen: set = set()
keep_indices: list[int] = []
for i, oid in enumerate(spec_combined.column(object_id_col).to_pylist()):
if oid not in seen:
seen.add(oid)
keep_indices.append(i)
if len(keep_indices) < spec_combined.num_rows:
spec_combined = spec_combined.take(keep_indices)
print(f"Combined spec rows after dedup: {spec_combined.num_rows}")

# 3. Inner-join with each photometric input on object_id
result: pa.Table = spec_combined
for key, input_params in photo_inputs.items():
if key not in input_files:
print(
f"Warning: photo key '{key}' not found in input_files, skipping"
)
continue
photo_table = self._select_input(
input_params, input_files[key], apply_area_cut=False
)
print(f"{key}: {photo_table.num_rows} rows")
result = result.join(
photo_table,
keys=object_id_col,
join_type="inner",
)

print(f"Total objects after join: {result.num_rows}")
print(f"writing {output}")

output_dir = os.path.dirname(output)
os.makedirs(output_dir, exist_ok=True)
pq.write_table(result, output)
print("done")
30 changes: 30 additions & 0 deletions tests/ci_subsample.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
Project:

Name: ci_test

# Include other configuration files
Includes:
- tests/ci_subsample_library.yaml

PathTemplates: {}

CommonPaths:
root: tests/temp_data
scratch_root: "{root}"
catalogs_dir: "{root}/data"
project: ci_test_subsample
sim_version: v1.1.3

# Baseline configuraiton, included in others by default
Baseline:
catalog_tag: flagship
pipelines: ['all']
file_aliases: # Set the training and test files
area_1k: area_1k
multi_1k: multi_1k

# These are variables that we iterate over when running over entire catalogs
IterationVars:
healpix:
- 27
- 28
Loading
Loading