diff --git a/bbconf/db_utils.py b/bbconf/db_utils.py index 8261154..e456642 100644 --- a/bbconf/db_utils.py +++ b/bbconf/db_utils.py @@ -14,7 +14,7 @@ event, select, ) -from sqlalchemy.dialects.postgresql import ARRAY, JSON +from sqlalchemy.dialects.postgresql import ARRAY, JSON, JSONB from sqlalchemy.engine import URL, Engine, create_engine from sqlalchemy.event import listens_for from sqlalchemy.exc import IntegrityError, ProgrammingError @@ -253,10 +253,10 @@ class BedStats(Base): intron_percentage: Mapped[Optional[float]] intergenic_percentage: Mapped[Optional[float]] promotercore_percentage: Mapped[Optional[float]] - tssdist: Mapped[Optional[float]] + median_neighbor_distance: Mapped[Optional[float]] distributions: Mapped[Optional[dict]] = mapped_column( - JSON, + JSONB, nullable=True, comment="Full distribution arrays from gtars genomicdist (JSONB)", ) @@ -344,7 +344,7 @@ class BedSets(Base): JSON, comment="Median values of the bedset" ) bedset_stats: Mapped[Optional[dict]] = mapped_column( - JSON, + JSONB, nullable=True, comment="Pre-aggregated distribution statistics from gtars (JSONB)", ) diff --git a/bbconf/models/bed_models.py b/bbconf/models/bed_models.py index fec4a11..3f0f3b2 100644 --- a/bbconf/models/bed_models.py +++ b/bbconf/models/bed_models.py @@ -51,6 +51,7 @@ class BedStatsModel(BaseModel): number_of_regions: float | None = None gc_content: float | None = None median_tss_dist: float | None = None + median_neighbor_distance: float | None = None mean_region_width: float | None = None exon_frequency: float | None = None @@ -209,6 +210,13 @@ class BedListResult(BaseModel): results: list[BedMetadataBasic] +class BedBatchResult(BaseModel): + count: int + limit: int + offset: int + results: list[BedMetadataAll] + + class QdrantSearchResult(BaseModel): id: str payload: dict = None diff --git a/bbconf/models/bedset_models.py b/bbconf/models/bedset_models.py index 45dae37..804ecb6 100644 --- a/bbconf/models/bedset_models.py +++ b/bbconf/models/bedset_models.py @@ -22,18 +22,31 @@ class BedSetDistributions(BaseModel): Stored in the bedset_stats JSONB database column. Populated when member bed files have been processed with the gtars analysis backend. + + Only distributions that are meaningful at collection-level are kept: + - scalar_summaries: mean ± sd + 25-bin histograms of per-file scalar values + - tss_histogram: summed per-bin counts across files (fixed ±100 kb axis) + - region_distribution: per-chrom bin-wise mean ± sd across files + (requires gtars ≥ PR #248 for reference-aligned bin widths) + - partitions: mean ± sd of per-file partition percentages + + Dropped (retained in per-file distributions blob, not aggregated): + - widths_histogram: per-file variable-range bins aren't summable; use + scalar_summaries.mean_region_width histogram instead + - neighbor_distances KDE: per-file within-bedset variance is low; use + scalar_summaries.median_neighbor_distance instead + - gc_content KDE: per-file distribution is unimodal; use + scalar_summaries.gc_content mean instead + - chromosome_summaries: redundant with region_distribution + - expected_partitions: per-file null hypothesis, not collection property """ n_files: int = 0 composition: Optional[dict] = None scalar_summaries: Optional[dict] = None tss_histogram: Optional[dict] = None - widths_histogram: Optional[dict] = None - neighbor_distances: Optional[dict] = None - gc_content: Optional[dict] = None region_distribution: Optional[dict] = None partitions: Optional[dict] = None - chromosome_summaries: Optional[dict] = None class BedSetPlots(BaseModel): diff --git a/bbconf/modules/aggregation.py b/bbconf/modules/aggregation.py new file mode 100644 index 0000000..b461c90 --- /dev/null +++ b/bbconf/modules/aggregation.py @@ -0,0 +1,384 @@ +"""Collection-level aggregation of per-file genomic distributions. + +All aggregation is pushed to SQL (PostgreSQL). No per-row Python loops. +Used by both BedAgentBedSet.create() and BedAgentBedFile.aggregate_collection(). +""" + +import logging +import math +from typing import Dict, List, Optional + +from sqlalchemy import text +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session + +from bbconf.const import PKG_NAME +from bbconf.models.bedset_models import BedSetDistributions + +_LOGGER = logging.getLogger(PKG_NAME) + +# Number of bins when building histograms of per-file scalar means +_SCALAR_HIST_BINS = 25 +# Default decimal precision for stored floats +DEFAULT_PRECISION = 3 + + +def round_floats(obj, ndigits: int = DEFAULT_PRECISION): + """Recursively round floats in nested dicts/lists.""" + if isinstance(obj, float): + return round(obj, ndigits) + if isinstance(obj, dict): + return {k: round_floats(v, ndigits) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [round_floats(v, ndigits) for v in obj] + return obj + + +def aggregate_collection( + engine: Engine, + bed_ids: List[str], + precision: int = DEFAULT_PRECISION, +) -> BedSetDistributions: + """Aggregate per-file distributions into collection-level stats. + + All aggregation is done in SQL. Python only reshapes query results + into the BedSetDistributions model. + + :param engine: SQLAlchemy engine + :param bed_ids: list of bed file identifiers + :param precision: decimal places for stored floats (default 3) + :return: BedSetDistributions with aggregated distributions + """ + if not bed_ids: + return BedSetDistributions(n_files=0) + + n = len(bed_ids) + + with Session(engine) as session: + composition = _sql_aggregate_composition(session, bed_ids) + scalar_summaries = _sql_aggregate_scalars(session, bed_ids) + region_distribution = _sql_aggregate_region_distribution(session, bed_ids) + tss_histogram = _sql_aggregate_tss_histogram(session, bed_ids) + partitions = _sql_aggregate_partitions(session, bed_ids) + + stats = BedSetDistributions( + n_files=n, + composition=composition, + scalar_summaries=scalar_summaries, + tss_histogram=tss_histogram, + region_distribution=region_distribution, + partitions=partitions, + ) + + if precision is not None: + stats = BedSetDistributions(**round_floats(stats.model_dump(), precision)) + + return stats + + +# --------------------------------------------------------------------------- +# SQL aggregation helpers +# --------------------------------------------------------------------------- + + +def _sql_aggregate_composition(session: Session, bed_ids: List[str]) -> Optional[dict]: + """Count distinct values per metadata column via SQL GROUP BY.""" + fields = ["genome_alias", "assay", "cell_type", "tissue", "target"] + result = {} + + for field in fields: + if field == "genome_alias": + sql = text( + """ + SELECT genome_alias AS val, COUNT(*) AS cnt + FROM bed + WHERE id = ANY(:bed_ids) AND genome_alias IS NOT NULL + GROUP BY genome_alias + ORDER BY cnt DESC + """ + ) + else: + sql = text( + f""" + SELECT m.{field} AS val, COUNT(*) AS cnt + FROM bed b + JOIN bed_metadata m ON m.id = b.id + WHERE b.id = ANY(:bed_ids) AND m.{field} IS NOT NULL + GROUP BY m.{field} + ORDER BY cnt DESC + """ + ) + rows = session.execute(sql, {"bed_ids": bed_ids}).all() + if rows: + result[field] = {row.val: row.cnt for row in rows} + + return result if result else None + + +def _sql_aggregate_scalars(session: Session, bed_ids: List[str]) -> Optional[dict]: + """Compute mean, sd, and histogram for scalar columns in SQL. + + Uses a single query for mean/sd/min/max/count, then width_bucket + for histogram binning. + """ + scalar_columns = [ + "number_of_regions", + "mean_region_width", + "median_tss_dist", + "gc_content", + "median_neighbor_distance", + ] + + # 1. Mean, sd, min, max, count in one query + agg_exprs = ", ".join( + f"AVG({col}) AS {col}_mean, " + f"STDDEV({col}) AS {col}_sd, " + f"MIN({col}) AS {col}_min, " + f"MAX({col}) AS {col}_max, " + f"COUNT({col}) AS {col}_n" + for col in scalar_columns + ) + sql = text(f"SELECT {agg_exprs} FROM bed_stats WHERE id = ANY(:bed_ids)") + row = session.execute(sql, {"bed_ids": bed_ids}).one() + + result = {} + for col in scalar_columns: + n = getattr(row, f"{col}_n") + if not n: + continue + mean_val = float(getattr(row, f"{col}_mean")) + sd_val = float(getattr(row, f"{col}_sd") or 0.0) + col_min = float(getattr(row, f"{col}_min")) + col_max = float(getattr(row, f"{col}_max")) + + # 2. Histogram via width_bucket (PostgreSQL) + histogram = _sql_histogram(session, bed_ids, col, col_min, col_max, n) + + result[col] = { + "mean": mean_val, + "sd": sd_val, + "n": n, + "histogram": histogram, + } + + return result if result else None + + +def _sql_histogram( + session: Session, + bed_ids: List[str], + column: str, + col_min: float, + col_max: float, + n: int, +) -> dict: + """Build a histogram for a single scalar column using width_bucket.""" + num_bins = min(_SCALAR_HIST_BINS, max(3, math.ceil(math.sqrt(n)))) + + if col_min == col_max: + # All values identical — single bin + return { + "counts": [n], + "edges": [col_min, col_max], + } + + sql = text( + f""" + SELECT + width_bucket({column}, :lo, :hi, :bins) AS bucket, + COUNT(*) AS cnt + FROM bed_stats + WHERE id = ANY(:bed_ids) AND {column} IS NOT NULL + GROUP BY bucket + ORDER BY bucket + """ + ) + rows = session.execute( + sql, + {"bed_ids": bed_ids, "lo": col_min, "hi": col_max, "bins": num_bins}, + ).all() + + # width_bucket returns 1..num_bins (in-range) plus 0 (below) and num_bins+1 (above/equal to hi) + counts = [0] * num_bins + for bucket, cnt in rows: + if bucket == 0: + counts[0] += cnt + elif bucket > num_bins: + counts[-1] += cnt + else: + counts[bucket - 1] += cnt + + # Compute edges + step = (col_max - col_min) / num_bins + edges = [col_min + i * step for i in range(num_bins + 1)] + + return {"counts": counts, "edges": edges} + + +def _sql_aggregate_region_distribution( + session: Session, bed_ids: List[str] +) -> Optional[dict]: + """Aggregate per-chromosome region_distribution via SQL JSONB unnest. + + Requires that member files used gtars >= PR #248 with --chrom-sizes so that + bin widths are reference-aligned across files (same bin_idx -> same bp + range on a given chromosome, regardless of file). + + Returns {chrom: {mean: [...], sd: [...], n: int}} or None if no data. + """ + sql = text( + """ + WITH per_file AS ( + SELECT distributions->'distributions'->'region_distribution' AS rd + FROM bed_stats + WHERE id = ANY(:bed_ids) + AND distributions IS NOT NULL + AND distributions->'distributions'->'region_distribution' IS NOT NULL + ), + unnested AS ( + SELECT + chrom, + ordinality - 1 AS bin_idx, + (val)::float AS count + FROM per_file, + jsonb_each(rd) AS per_chrom(chrom, counts), + jsonb_array_elements_text(counts) WITH ORDINALITY AS t(val, ordinality) + ) + SELECT + chrom, + bin_idx, + AVG(count) AS mean, + COALESCE(STDDEV(count), 0.0) AS sd, + COUNT(*) AS n + FROM unnested + GROUP BY chrom, bin_idx + ORDER BY chrom, bin_idx + """ + ) + + rows = session.execute(sql, {"bed_ids": bed_ids}).all() + if not rows: + return None + + result: Dict[str, dict] = {} + for chrom, bin_idx, mean_val, sd_val, n_val in rows: + if chrom not in result: + result[chrom] = {"mean": [], "sd": [], "n": int(n_val)} + while len(result[chrom]["mean"]) <= bin_idx: + result[chrom]["mean"].append(0.0) + result[chrom]["sd"].append(0.0) + result[chrom]["mean"][bin_idx] = float(mean_val) + result[chrom]["sd"][bin_idx] = float(sd_val) + + return result if result else None + + +def _sql_aggregate_tss_histogram( + session: Session, bed_ids: List[str] +) -> Optional[dict]: + """Aggregate fixed-axis tss_distances histogram via SQL. + + TSS distances use a fixed 100-bin axis (+/-100 kb), so element-wise + AVG/STDDEV across files is valid without re-binning. + + Returns {mean: [...], sd: [...], n: int, x_min, x_max, bins} or None. + """ + sql = text( + """ + WITH per_file AS ( + SELECT + distributions->'distributions'->'tss_distances'->'counts' AS counts, + distributions->'distributions'->'tss_distances'->>'x_min' AS x_min, + distributions->'distributions'->'tss_distances'->>'x_max' AS x_max, + distributions->'distributions'->'tss_distances'->>'bins' AS bins + FROM bed_stats + WHERE id = ANY(:bed_ids) + AND distributions IS NOT NULL + AND distributions->'distributions'->'tss_distances'->'counts' IS NOT NULL + ), + unnested AS ( + SELECT + ordinality - 1 AS bin_idx, + (val)::float AS count, + x_min, x_max, bins + FROM per_file, + jsonb_array_elements_text(counts) WITH ORDINALITY AS t(val, ordinality) + ) + SELECT + bin_idx, + AVG(count) AS mean, + COALESCE(STDDEV(count), 0.0) AS sd, + COUNT(*) AS n, + MAX(x_min) AS x_min, + MAX(x_max) AS x_max, + MAX(bins) AS bins + FROM unnested + GROUP BY bin_idx + ORDER BY bin_idx + """ + ) + + rows = session.execute(sql, {"bed_ids": bed_ids}).all() + if not rows: + return None + + n_bins = len(rows) + result = { + "mean": [0.0] * n_bins, + "sd": [0.0] * n_bins, + "n": int(rows[0][3]), + } + x_min, x_max, bins_str = rows[0][4], rows[0][5], rows[0][6] + if x_min is not None: + try: + result["x_min"] = float(x_min) + result["x_max"] = float(x_max) + result["bins"] = int(bins_str) if bins_str else n_bins + except (ValueError, TypeError): + pass + + for bin_idx, mean_val, sd_val, _n, _xmin, _xmax, _bins in rows: + result["mean"][bin_idx] = float(mean_val) + result["sd"][bin_idx] = float(sd_val) + + return result + + +def _sql_aggregate_partitions(session: Session, bed_ids: List[str]) -> Optional[dict]: + """Aggregate genomic partitions from flat percentage columns. + + Uses the pre-computed *_percentage columns on bed_stats, which are + populated by both R and gtars backends for all beds. + """ + partition_columns = [ + ("exon", "exon_percentage"), + ("intron", "intron_percentage"), + ("intergenic", "intergenic_percentage"), + ("promoterprox", "promoterprox_percentage"), + ("promotercore", "promotercore_percentage"), + ("fiveutr", "fiveutr_percentage"), + ("threeutr", "threeutr_percentage"), + ] + + agg_exprs = ", ".join( + f"AVG({col}) * 100 AS {name}_mean, " + f"COALESCE(STDDEV({col}) * 100, 0.0) AS {name}_sd, " + f"COUNT({col}) AS {name}_n" + for name, col in partition_columns + ) + sql = text(f"SELECT {agg_exprs} FROM bed_stats WHERE id = ANY(:bed_ids)") + + row = session.execute(sql, {"bed_ids": bed_ids}).one() + + result = {} + for name, _col in partition_columns: + n = getattr(row, f"{name}_n") + if not n: + continue + result[name] = { + "mean_pct": float(getattr(row, f"{name}_mean")), + "sd_pct": float(getattr(row, f"{name}_sd")), + "n": int(n), + } + + return result if result else None diff --git a/bbconf/modules/bedfiles.py b/bbconf/modules/bedfiles.py index 0d3fc9b..4542e54 100644 --- a/bbconf/modules/bedfiles.py +++ b/bbconf/modules/bedfiles.py @@ -45,6 +45,7 @@ UniverseNotFoundError, ) from bbconf.models.bed_models import ( + BedBatchResult, BedClassification, BedEmbeddingResult, BedFiles, @@ -67,6 +68,7 @@ UniverseMetadata, VectorMetadata, ) +from bbconf.models.bedset_models import BedSetDistributions _LOGGER = getLogger(PKG_NAME) @@ -214,12 +216,13 @@ def get(self, identifier: str, full: bool = False) -> BedMetadataAll: ), ) - def get_stats(self, identifier: str) -> BedStatsModel: + def get_stats(self, identifier: str, distributions: bool = True) -> BedStatsModel: """ Get file statistics by identifier. Args: identifier: Bed file identifier. + distributions: include distribution arrays in the result. Returns: Project statistics as BedStats object. @@ -232,8 +235,85 @@ def get_stats(self, identifier: str) -> BedStatsModel: raise BEDFileNotFoundError(f"Bed file with id: {identifier} not found.") bed_stats = BedStatsModel(**bed_object.__dict__) + if not distributions: + bed_stats.distributions = None + return bed_stats + def get_batch( + self, + identifiers: list, + full: bool = False, + distributions: bool = False, + ) -> "BedBatchResult": + """ + Get multiple bed file records by identifiers in a single DB round-trip. + + :param identifiers: list of bed file identifiers + :param full: if True, include scalar stats for each record + :param distributions: if True, include distribution arrays in stats + :return: BedListResult with matching records + """ + statement = select(Bed).where(Bed.id.in_(identifiers)) + + with Session(self._sa_engine) as session: + beds = session.scalars(statement) + results = [] + for bed_object in beds: + annotation = StandardMeta( + **( + bed_object.annotations.__dict__ + if bed_object.annotations + else {} + ) + ) + if full and bed_object.stats: + bed_stats = BedStatsModel(**bed_object.stats.__dict__) + if not distributions: + bed_stats.distributions = None + else: + bed_stats = None + + results.append( + BedMetadataAll( + id=bed_object.id, + name=bed_object.name, + description=bed_object.description, + submission_date=bed_object.submission_date, + last_update_date=bed_object.last_update_date, + genome_alias=bed_object.genome_alias, + genome_digest=bed_object.genome_digest, + bed_compliance=bed_object.bed_compliance, + data_format=bed_object.data_format, + is_universe=bed_object.is_universe, + license_id=bed_object.license_id or DEFAULT_LICENSE, + processed=bed_object.processed, + annotation=annotation, + stats=bed_stats, + compliant_columns=bed_object.compliant_columns, + non_compliant_columns=bed_object.non_compliant_columns, + ) + ) + + return BedBatchResult( + count=len(results), + limit=len(identifiers), + offset=0, + results=results, + ) + + def aggregate_collection(self, bed_ids: list) -> BedSetDistributions: + """Aggregate per-file distributions into collection-level stats. + + Thin wrapper around the standalone aggregate_collection() function. + + :param bed_ids: list of bed file identifiers + :return: BedSetDistributions with aggregated distributions + """ + from bbconf.modules.aggregation import aggregate_collection + + return aggregate_collection(self._sa_engine, bed_ids) + def get_plots(self, identifier: str) -> BedPlots: """ Get file plots by identifier. diff --git a/bbconf/modules/bedsets.py b/bbconf/modules/bedsets.py index cad38dc..18df72c 100644 --- a/bbconf/modules/bedsets.py +++ b/bbconf/modules/bedsets.py @@ -19,6 +19,7 @@ from bbconf.models.bedset_models import ( BedMetadataBasic, BedSetBedFiles, + BedSetDistributions, BedSetListResult, BedSetMetadata, BedSetPEP, @@ -26,6 +27,7 @@ BedSetStats, FileModel, ) +from bbconf.modules.aggregation import aggregate_collection _LOGGER = logging.getLogger(PKG_NAME) @@ -77,9 +79,18 @@ def get(self, identifier: str, full: bool = False) -> BedSetMetadata: mean=BedStatsModel(**bedset_obj.bedset_means), sd=BedStatsModel(**bedset_obj.bedset_standard_deviation), ).model_dump() + + # Include new distribution stats if available + if bedset_obj.bedset_stats: + distributions = BedSetDistributions( + **bedset_obj.bedset_stats + ).model_dump() + else: + distributions = None else: plots = None stats = None + distributions = None bedset_metadata = BedSetMetadata( id=bedset_obj.id, @@ -87,6 +98,7 @@ def get(self, identifier: str, full: bool = False) -> BedSetMetadata: description=bedset_obj.description, md5sum=bedset_obj.md5sum, statistics=stats, + distributions=distributions, plots=plots, bed_ids=list_of_bedfiles, submission_date=bedset_obj.submission_date, @@ -177,6 +189,35 @@ def get_statistics(self, identifier: str) -> BedSetStats: sd=BedStatsModel(**bedset_object.bedset_standard_deviation), ) + def get_distributions(self, identifier: str) -> BedSetDistributions: + """ + Get distribution statistics for bedset by identifier. + + Returns aggregated distribution data from the JSONB column. + Falls back to wrapping old scalar stats if JSONB is not populated. + + Args: + identifier: Bedset identifier. + + Returns: + BedSetDistributions with aggregated distributions. + """ + statement = select(BedSets).where(BedSets.id == identifier) + with Session(self._db_engine.engine) as session: + bedset_object = session.scalar(statement) + if not bedset_object: + raise BedSetNotFoundError(f"Bedset with id: {identifier} not found.") + if bedset_object.bedset_stats: + return BedSetDistributions(**bedset_object.bedset_stats) + # Fallback: wrap old scalar columns + return BedSetDistributions( + n_files=0, + scalar_summaries=_old_stats_to_scalar_summaries( + bedset_object.bedset_means, + bedset_object.bedset_standard_deviation, + ), + ) + def get_bedset_pep(self, identifier: str) -> dict: """ Create pep file for a bedset. @@ -337,8 +378,17 @@ def create( if statistics: stats = self._calculate_statistics(bedid_list) + # Also compute distribution-level aggregation (for gtars-processed beds) + try: + dist_stats = aggregate_collection(self._db_engine.engine, bedid_list) + except Exception as e: + _LOGGER.warning( + f"Distribution aggregation failed (beds may lack distributions): {e}" + ) + dist_stats = None else: stats = None + dist_stats = None if self.exists(identifier): if no_fail and not overwrite: _LOGGER.warning( @@ -371,6 +421,7 @@ def create( summary=annotation.get("summary"), bedset_means=stats.mean.model_dump() if stats else None, bedset_standard_deviation=stats.sd.model_dump() if stats else None, + bedset_stats=dist_stats.model_dump() if dist_stats else None, md5sum=compute_md5sum_bedset(bedid_list), author=annotation.get("author"), source=annotation.get("source"), @@ -711,3 +762,36 @@ def add_bedfile(self, identifier: str, bedfile: str) -> None: def delete_bedfile(self, identifier: str, bedfile: str) -> None: raise NotImplementedError + + +def _old_stats_to_scalar_summaries( + bedset_means: dict | None, + bedset_sd: dict | None, +) -> dict | None: + """Convert old bedset_means/bedset_standard_deviation to scalar_summaries format. + + Maps the 4 key scalar fields from old-style BedSetStats(mean, sd) to the + new BedSetDistributions.scalar_summaries format. + """ + if not bedset_means: + return None + + scalar_keys = [ + "number_of_regions", + "mean_region_width", + "median_tss_dist", + "gc_content", + "median_neighbor_distance", + ] + result = {} + for k in scalar_keys: + mean_val = bedset_means.get(k) + sd_val = (bedset_sd or {}).get(k) + if mean_val is not None: + result[k] = { + "mean": mean_val, + "sd": sd_val or 0.0, + "n": 0, + } + + return result if result else None