Skip to content

Commit c1d0d05

Browse files
authored
Merge pull request #234 from LeonHafner/leonhafner/fix_log2fc
fix: consume pdex's log2_fold_change directly (stop double-logging)
2 parents 2db6b9c + 1acc943 commit c1d0d05

11 files changed

Lines changed: 66 additions & 42 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "cell-eval"
3-
version = "0.7.0"
3+
version = "0.7.1"
44
description = "Evaluation metrics for single-cell perturbation predictions"
55
readme = "README.md"
66
authors = [
@@ -11,7 +11,7 @@ authors = [
1111
requires-python = ">=3.11"
1212
dependencies = [
1313
"igraph>=0.11.8",
14-
"pdex>=0.2.0",
14+
"pdex>=0.2.2",
1515
"polars>=1.30.0",
1616
"pyyaml>=6.0.2",
1717
"scanpy>=1.10.3",

src/cell_eval/_baseline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def build_base_mean_adata(
7777

7878
if output_path is not None:
7979
logger.info(f"Saving baseline data to {output_path}")
80-
baseline_adata.write_h5ad(output_path) # type: ignore[invalid-argument-type]
80+
baseline_adata.write_h5ad(output_path) # ty: ignore[invalid-argument-type]
8181

8282
if output_de_path is not None:
8383
logger.info("Calculating differential expression")

src/cell_eval/_cli/_prep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def strip_anndata(
226226

227227
# Write the h5ad file
228228
logger.info(f"Writing h5ad output to {tmp_h5ad}")
229-
minimal.write_h5ad(tmp_h5ad) # type: ignore[invalid-argument-type]
229+
minimal.write_h5ad(tmp_h5ad) # ty: ignore[invalid-argument-type]
230230

231231
# Zstd compress the h5ad file (will create pred.h5ad.zst)
232232
logger.info(f"Zstd compressing {tmp_h5ad}")

src/cell_eval/_types/_anndata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def _bulk_anndata(
125125

126126
# Create a polars dataframe with the groupby key
127127
frame = pl.DataFrame(
128-
matrix,
128+
matrix, # ty: ignore[invalid-argument-type]
129129
).with_columns(
130-
groupby_key=adata.obs[groupby_key].to_numpy(str), # type: ignore
130+
groupby_key=adata.obs[groupby_key].to_numpy(str),
131131
)
132132

133133
# Pseudobulk (mean) the dataframe by the groupby key

src/cell_eval/_types/_de.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def initialize_de_comparison(
1717
pred: pl.DataFrame,
1818
target_col: str = "target",
1919
feature_col: str = "feature",
20-
fold_change_col: str = "fold_change",
2120
log2_fold_change_col: str = "log2_fold_change",
2221
abs_log2_fold_change_col: str = "abs_log2_fold_change",
2322
pvalue_col: str = "p_value",
@@ -27,7 +26,6 @@ def initialize_de_comparison(
2726
DEResults,
2827
target_col=target_col,
2928
feature_col=feature_col,
30-
fold_change_col=fold_change_col,
3129
log2_fold_change_col=log2_fold_change_col,
3230
abs_log2_fold_change_col=abs_log2_fold_change_col,
3331
)
@@ -47,7 +45,6 @@ class DEResults:
4745
# Column names configuration
4846
target_col: str = "target"
4947
feature_col: str = "feature"
50-
fold_change_col: str = "fold_change"
5148
log2_fold_change_col: str = "log2_fold_change"
5249
abs_log2_fold_change_col: str = "abs_log2_fold_change"
5350
pvalue_col: str = "p_value"
@@ -58,7 +55,7 @@ def __post_init__(self) -> None:
5855
required_cols = {
5956
self.target_col,
6057
self.feature_col,
61-
self.fold_change_col,
58+
self.log2_fold_change_col,
6259
self.pvalue_col,
6360
self.fdr_col,
6461
}
@@ -67,7 +64,6 @@ def __post_init__(self) -> None:
6764
raise ValueError(f"Missing required columns: {missing}")
6865

6966
numeric_cols = [
70-
self.fold_change_col,
7167
self.pvalue_col,
7268
self.fdr_col,
7369
self.log2_fold_change_col,
@@ -80,31 +76,32 @@ def __post_init__(self) -> None:
8076
]
8177

8278
logger.info(f"Checking DE data integrity... ({self.name})")
83-
fc_num_null = self.data.filter(pl.col(self.fold_change_col).is_null()).height
84-
fc_num_inf = self.data.filter(pl.col(self.fold_change_col).is_infinite()).height
85-
fc_num_nan = self.data.filter(pl.col(self.fold_change_col).is_nan()).height
86-
if fc_num_null > 0:
79+
lfc_num_null = self.data.filter(
80+
pl.col(self.log2_fold_change_col).is_null()
81+
).height
82+
lfc_num_inf = self.data.filter(
83+
pl.col(self.log2_fold_change_col).is_infinite()
84+
).height
85+
lfc_num_nan = self.data.filter(
86+
pl.col(self.log2_fold_change_col).is_nan()
87+
).height
88+
if lfc_num_null > 0:
8789
logger.warning(
88-
f"Identified {fc_num_null} missing fold change values ({self.name})"
90+
f"Identified {lfc_num_null} missing log2 fold change values ({self.name})"
8991
)
90-
if fc_num_inf > 0:
92+
if lfc_num_inf > 0:
9193
logger.warning(
92-
f"Identified {fc_num_inf} infinite fold change values ({self.name})"
94+
f"Identified {lfc_num_inf} infinite log2 fold change values ({self.name})"
9395
)
94-
if fc_num_nan > 0:
96+
if lfc_num_nan > 0:
9597
logger.warning(
96-
f"Identified {fc_num_nan} NaN fold change values ({self.name})"
98+
f"Identified {lfc_num_nan} NaN log2 fold change values ({self.name})"
9799
)
98100
logger.info(f"DE data integrity check complete. ({self.name})")
99101

100-
# Add log2 fold change columns if not present
101-
if self.log2_fold_change_col not in self.data.columns:
102+
# Derive abs(log2_fold_change) if not already provided.
103+
if self.abs_log2_fold_change_col not in self.data.columns:
102104
self.data = self.data.with_columns(
103-
pl.col(self.fold_change_col)
104-
.log(base=2)
105-
.alias(self.log2_fold_change_col)
106-
.fill_nan(0.0)
107-
).with_columns(
108105
pl.col(self.log2_fold_change_col)
109106
.abs()
110107
.alias(self.abs_log2_fold_change_col)
@@ -153,7 +150,10 @@ def get_top_genes(
153150
# Set FDR threshold if not provided
154151
fdr_threshold = fdr_threshold if fdr_threshold is not None else 0.05
155152

156-
descending = sort_by in {DESortBy.FOLD_CHANGE, DESortBy.ABS_FOLD_CHANGE}
153+
descending = sort_by in {
154+
DESortBy.LOG2_FOLD_CHANGE,
155+
DESortBy.ABS_LOG2_FOLD_CHANGE,
156+
}
157157

158158
# Create a rank matrix where each row is the ordinal rank of a gene and each column is a perturbation.
159159
# The rank is sensitive to the sort-by column and is computed post-filtering for FDR.
@@ -219,7 +219,7 @@ def compute_overlap(
219219
k: int | None,
220220
metric: Literal["overlap", "precision"] = "overlap",
221221
fdr_threshold: float | None = None,
222-
sort_by: DESortBy = DESortBy.ABS_FOLD_CHANGE,
222+
sort_by: DESortBy = DESortBy.ABS_LOG2_FOLD_CHANGE,
223223
) -> dict[str, float]:
224224
"""
225225
Compute overlap metrics across perturbations.

src/cell_eval/_types/_enums.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
class DESortBy(enum.Enum):
55
"""Sorting options for differential expression results."""
66

7-
FOLD_CHANGE = "log2_fold_change"
8-
ABS_FOLD_CHANGE = "abs_log2_fold_change"
7+
LOG2_FOLD_CHANGE = "log2_fold_change"
8+
ABS_LOG2_FOLD_CHANGE = "abs_log2_fold_change"
99
PVALUE = "p_value"
1010
FDR = "fdr"
1111

src/cell_eval/metrics/_de.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def de_overlap_metric(
1313
k: int | None,
1414
metric: Literal["precision", "overlap"] = "overlap",
1515
fdr_threshold: float = 0.05,
16-
sort_by: DESortBy = DESortBy.ABS_FOLD_CHANGE,
16+
sort_by: DESortBy = DESortBy.ABS_LOG2_FOLD_CHANGE,
1717
) -> dict[str, float]:
1818
"""Compute overlap between real and predicted DE genes.
1919
@@ -124,7 +124,9 @@ def __call__(self, data: DEComparison) -> dict[str, float]:
124124
suffix="_pred",
125125
how="left",
126126
)
127-
.with_columns(pl.col(f"{data.real.fold_change_col}_pred").fill_null(0.0))
127+
.with_columns(
128+
pl.col(f"{data.real.log2_fold_change_col}_pred").fill_null(0.0)
129+
)
128130
)
129131

130132
for row in (
@@ -133,8 +135,8 @@ def __call__(self, data: DEComparison) -> dict[str, float]:
133135
)
134136
.agg(
135137
pl.corr(
136-
pl.col(data.real.fold_change_col).cast(pl.Float64),
137-
pl.col(f"{data.real.fold_change_col}_pred").cast(pl.Float64),
138+
pl.col(data.real.log2_fold_change_col).cast(pl.Float64),
139+
pl.col(f"{data.real.log2_fold_change_col}_pred").cast(pl.Float64),
138140
method="spearman",
139141
).alias("spearman_corr"),
140142
)

src/cell_eval/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def guess_is_lognorm(
4141
if isinstance(adata.X, csr_matrix) or isinstance(adata.X, csc_matrix):
4242
frac, _ = np.modf(adata.X.data)
4343
elif adata.is_view:
44-
frac, _ = np.modf(adata.X.toarray()) # type: ignore[unresolved-attribute]
44+
frac, _ = np.modf(adata.X.toarray()) # ty: ignore[unresolved-attribute]
4545
elif adata.X is None:
4646
raise ValueError("adata.X is None")
4747
else:
@@ -60,8 +60,8 @@ def guess_is_lognorm(
6060
max_val = adata.X.max()
6161
min_val = adata.X.min()
6262
else:
63-
max_val = float(np.max(adata.X)) # type: ignore[no-matching-overload]
64-
min_val = float(np.min(adata.X)) # type: ignore[no-matching-overload]
63+
max_val = float(np.max(adata.X)) # ty: ignore[no-matching-overload]
64+
min_val = float(np.min(adata.X)) # ty: ignore[no-matching-overload]
6565

6666
# Validate range
6767
if min_val < 0:

tests/test_de_float_types.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ def test_de_spearman_lfc_mixed_float_types() -> None:
1010
{
1111
"target": ["pert1", "pert1", "pert2", "pert2"],
1212
"feature": ["gene1", "gene2", "gene1", "gene2"],
13-
"fold_change": [1.5, 2.0, 0.5, 1.2],
13+
"log2_fold_change": [1.5, 2.0, 0.5, 1.2],
1414
"p_value": [0.01, 0.02, 0.03, 0.04],
1515
"fdr": [0.01, 0.02, 0.03, 0.04],
1616
}
1717
)
1818

19-
pred_df = real_df.with_columns(pl.col("fold_change").cast(pl.Float64))
19+
pred_df = real_df.with_columns(pl.col("log2_fold_change").cast(pl.Float64))
2020

2121
comparison = DEComparison(
2222
real=DEResults(real_df, name="real"),
@@ -27,3 +27,25 @@ def test_de_spearman_lfc_mixed_float_types() -> None:
2727

2828
assert isinstance(result, dict)
2929
assert all(isinstance(value, (int, float)) for value in result.values())
30+
31+
32+
def test_de_results_preserves_negative_log2_fold_change() -> None:
33+
"""`abs_log2_fold_change` must be |log2_fold_change|; negatives must not be zeroed.
34+
35+
Pre-fix, `__post_init__` applied `log(base=2).fill_nan(0.0)` to the values, which
36+
silently zeroed every downregulated gene.
37+
"""
38+
df = pl.DataFrame(
39+
{
40+
"target": ["p", "p", "p", "p"],
41+
"feature": ["g1", "g2", "g3", "g4"],
42+
"log2_fold_change": [-1.0, 0.0, 1.0, 2.0],
43+
"p_value": [0.01, 0.01, 0.01, 0.01],
44+
"fdr": [0.01, 0.01, 0.01, 0.01],
45+
}
46+
)
47+
de = DEResults(df, name="real")
48+
lfc = de.data["log2_fold_change"].cast(pl.Float64).to_list()
49+
abs_lfc = de.data["abs_log2_fold_change"].cast(pl.Float64).to_list()
50+
assert lfc == [-1.0, 0.0, 1.0, 2.0]
51+
assert abs_lfc == [1.0, 0.0, 1.0, 2.0]

tests/test_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def test_unknown_alternative_de_metric():
196196
control_pert=CONTROL_VAR,
197197
pert_col=PERT_COL,
198198
outdir=OUTDIR,
199-
de_method="unknown", # type: ignore[unknown-argument]
199+
de_method="unknown", # ty: ignore[unknown-argument]
200200
).compute()
201201

202202

0 commit comments

Comments
 (0)