Skip to content

Commit 965c1bf

Browse files
committed
Make cell-eval trust explicit log1p scale
1 parent 44f9009 commit 965c1bf

5 files changed

Lines changed: 68 additions & 33 deletions

File tree

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,13 @@ cell-eval run \
5858
-ap <your/path/to/pred>.h5ad \
5959
-ar <your/path/to/real>.h5ad \
6060
--num-threads 64 \
61+
--is-log1p \
6162
--profile full
6263
```
6364

65+
`cell-eval run` trusts the declared input scale. It does not guess whether
66+
matrices are log1p-normalized and does not transform `.X` internally.
67+
6468
To run this as a python module you will need to use the `MetricsEvaluator` class.
6569

6670
```python
@@ -75,6 +79,7 @@ evaluator = MetricsEvaluator(
7579
control_pert="control",
7680
pert_col="perturbation",
7781
num_threads=64,
82+
is_log1p=True,
7883
)
7984
(results, agg_results) = evaluator.compute()
8085
```

src/cell_eval/_baseline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def build_base_mean_adata(
8585
groupby=pert_col,
8686
reference=control_pert,
8787
threads=num_threads,
88-
allow_discrete=allow_discrete,
88+
is_log1p=not allow_discrete,
8989
pdex_kwargs=pdex_kwargs,
9090
)
9191
frame = pdex(

src/cell_eval/_cli/_run.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,13 @@ def parse_args_run(parser: ap.ArgumentParser):
8282
parser.add_argument(
8383
"--allow-discrete",
8484
action="store_true",
85-
help="Allow discrete data to be evaluated (usually expected to be norm-logged inputs)",
85+
help="Deprecated; inputs are trusted as provided. Use --no-is-log1p to configure pdex for non-log1p inputs.",
86+
)
87+
parser.add_argument(
88+
"--is-log1p",
89+
action=ap.BooleanOptionalAction,
90+
default=True,
91+
help="Whether input AnnData .X is already log1p-normalized. cell-eval trusts this flag and does not transform inputs [default: %(default)s]",
8692
)
8793
parser.add_argument(
8894
"--profile",
@@ -159,6 +165,7 @@ def run_evaluation(args: ap.Namespace):
159165
allow_discrete=args.allow_discrete,
160166
prefix=ct,
161167
skip_de=args.profile == "pds",
168+
is_log1p=args.is_log1p,
162169
)
163170
evaluator.compute(
164171
profile=args.profile,
@@ -179,6 +186,7 @@ def run_evaluation(args: ap.Namespace):
179186
outdir=args.outdir,
180187
allow_discrete=args.allow_discrete,
181188
skip_de=args.profile == "pds",
189+
is_log1p=args.is_log1p,
182190
)
183191
evaluator.compute(
184192
profile=args.profile,

src/cell_eval/_evaluator.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,15 @@ class MetricsEvaluator:
5757
outdir: str = "./cell-eval-outdir"
5858
Output directory.
5959
allow_discrete: bool = False
60-
Allow discrete data.
60+
Deprecated; cell-eval now trusts the caller-provided scale.
6161
prefix: str | None = None
6262
Prefix for output files.
6363
pdex_kwargs: dict[str, Any] | None = None
6464
Keyword arguments for parallel_differential_expression.
65-
These will overwrite arguments passed to MetricsEvaluator.__init__ if they conflict.
65+
These may not conflict with ``is_log1p``.
66+
is_log1p: bool = True
67+
Whether input AnnData matrices are already log1p-normalized. cell-eval trusts
68+
this flag and does not inspect or transform input matrices.
6669
"""
6770

6871
def __init__(
@@ -79,6 +82,7 @@ def __init__(
7982
prefix: str | None = None,
8083
pdex_kwargs: dict[str, Any] | None = None,
8184
skip_de: bool = False,
85+
is_log1p: bool = True,
8286
):
8387
# Enable a global string cache for categorical columns
8488
pl.enable_string_cache()
@@ -97,7 +101,6 @@ def __init__(
97101
pred=adata_pred,
98102
control_pert=control_pert,
99103
pert_col=pert_col,
100-
allow_discrete=allow_discrete,
101104
)
102105

103106
if skip_de:
@@ -108,10 +111,10 @@ def __init__(
108111
de_pred=de_pred,
109112
de_real=de_real,
110113
num_threads=num_threads,
111-
allow_discrete=allow_discrete,
112114
outdir=outdir,
113115
prefix=prefix,
114116
pdex_kwargs=pdex_kwargs or {},
117+
is_log1p=is_log1p,
115118
)
116119

117120
self.outdir = outdir
@@ -170,7 +173,6 @@ def _build_anndata_pair(
170173
pred: ad.AnnData | str,
171174
control_pert: str,
172175
pert_col: str,
173-
allow_discrete: bool = False,
174176
):
175177
if isinstance(real, str):
176178
logger.info(f"Reading real anndata from {real}")
@@ -183,10 +185,6 @@ def _build_anndata_pair(
183185
_cast_float16_to_float32(real, which="real")
184186
_cast_float16_to_float32(pred, which="pred")
185187

186-
# Validate that the input is normalized and log-transformed
187-
_convert_to_normlog(real, which="real", allow_discrete=allow_discrete)
188-
_convert_to_normlog(pred, which="pred", allow_discrete=allow_discrete)
189-
190188
# Build the anndata pair
191189
return PerturbationAnndataPair(
192190
real=real, pred=pred, control_pert=control_pert, pert_col=pert_col
@@ -234,31 +232,31 @@ def _build_de_comparison(
234232
de_pred: pl.DataFrame | str | None = None,
235233
de_real: pl.DataFrame | str | None = None,
236234
num_threads: int = 1,
237-
allow_discrete: bool = False,
238235
outdir: str | None = None,
239236
prefix: str | None = None,
240237
pdex_kwargs: dict[str, Any] | None = None,
238+
is_log1p: bool = True,
241239
):
242240
return initialize_de_comparison(
243241
real=_load_or_build_de(
244242
mode="real",
245243
de_path=de_real,
246244
anndata_pair=anndata_pair,
247245
num_threads=num_threads,
248-
allow_discrete=allow_discrete,
249246
outdir=outdir,
250247
prefix=prefix,
251248
pdex_kwargs=pdex_kwargs or {},
249+
is_log1p=is_log1p,
252250
),
253251
pred=_load_or_build_de(
254252
mode="pred",
255253
de_path=de_pred,
256254
anndata_pair=anndata_pair,
257255
num_threads=num_threads,
258-
allow_discrete=allow_discrete,
259256
outdir=outdir,
260257
prefix=prefix,
261258
pdex_kwargs=pdex_kwargs or {},
259+
is_log1p=is_log1p,
262260
),
263261
)
264262

@@ -267,21 +265,23 @@ def _build_pdex_kwargs(
267265
reference: str,
268266
groupby: str,
269267
threads: int,
270-
allow_discrete: bool,
268+
is_log1p: bool,
271269
pdex_kwargs: dict[str, Any] | None = None,
272270
) -> dict[str, Any]:
273-
pdex_kwargs = pdex_kwargs or {}
271+
pdex_kwargs = dict(pdex_kwargs or {})
274272
if "reference" not in pdex_kwargs:
275273
pdex_kwargs["reference"] = reference
276274
if "groupby" not in pdex_kwargs:
277275
pdex_kwargs["groupby"] = groupby
278276
if "threads" not in pdex_kwargs:
279277
pdex_kwargs["threads"] = threads
280278
if "is_log1p" not in pdex_kwargs:
281-
if allow_discrete:
282-
pdex_kwargs["is_log1p"] = False
283-
else:
284-
pdex_kwargs["is_log1p"] = True
279+
pdex_kwargs["is_log1p"] = is_log1p
280+
elif bool(pdex_kwargs["is_log1p"]) != is_log1p:
281+
raise ValueError(
282+
"Conflicting log1p configuration: MetricsEvaluator(is_log1p="
283+
f"{is_log1p}) but pdex_kwargs['is_log1p']={pdex_kwargs['is_log1p']!r}"
284+
)
285285
return pdex_kwargs
286286

287287

@@ -292,8 +292,8 @@ def _load_or_build_de(
292292
num_threads: int = 1,
293293
outdir: str | None = None,
294294
prefix: str | None = None,
295-
allow_discrete: bool = False,
296295
pdex_kwargs: dict[str, Any] | None = None,
296+
is_log1p: bool = True,
297297
) -> pl.DataFrame:
298298
if de_path is None:
299299
if anndata_pair is None:
@@ -303,8 +303,8 @@ def _load_or_build_de(
303303
reference=anndata_pair.control_pert,
304304
groupby=anndata_pair.pert_col,
305305
threads=num_threads,
306-
allow_discrete=allow_discrete,
307306
pdex_kwargs=pdex_kwargs or {},
307+
is_log1p=is_log1p,
308308
)
309309
logger.info(f"Using the following pdex kwargs: {pdex_kwargs}")
310310
frame = pdex(

tests/test_eval.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def test_broken_adata_not_normlog():
7070
control_pert=CONTROL_VAR,
7171
pert_col=PERT_COL,
7272
outdir=OUTDIR,
73+
is_log1p=False,
7374
)
7475
evaluator.compute(
7576
break_on_error=True,
@@ -86,32 +87,37 @@ def test_broken_adata_not_normlog_skip_check():
8687
pert_col=PERT_COL,
8788
outdir=OUTDIR,
8889
allow_discrete=True,
90+
is_log1p=False,
8991
)
9092
evaluator.compute(
9193
break_on_error=True,
9294
)
9395

9496

95-
def test_broken_adata_invalid_pred_scale():
96-
"""Test that predicted data with invalid scale is rejected."""
97+
def test_eval_trusts_declared_log1p_without_scale_scan():
98+
"""MetricsEvaluator should not scan or transform .X when is_log1p is set."""
9799
adata_real = build_random_anndata(normlog=True)
98100
adata_pred = adata_real.copy()
99101

100-
# Create invalid predicted data: mix of raw counts and log1p
102+
# Values above the old log1p threshold used to fail during scale guessing.
101103
adata_pred.X = np.random.uniform(
102104
0,
103105
5000,
104106
size=adata_pred.X.shape, # type: ignore
105107
)
108+
before = np.asarray(adata_pred.X).copy()
106109

107-
with pytest.raises(ValueError, match="Invalid scale.*exceeds log1p threshold"):
108-
MetricsEvaluator(
109-
adata_pred=adata_pred,
110-
adata_real=adata_real,
111-
control_pert=CONTROL_VAR,
112-
pert_col=PERT_COL,
113-
outdir=OUTDIR,
114-
)
110+
evaluator = MetricsEvaluator(
111+
adata_pred=adata_pred,
112+
adata_real=adata_real,
113+
control_pert=CONTROL_VAR,
114+
pert_col=PERT_COL,
115+
outdir=OUTDIR,
116+
skip_de=True,
117+
is_log1p=True,
118+
)
119+
120+
np.testing.assert_array_equal(evaluator.anndata_pair.pred.X, before)
115121

116122

117123
def test_broken_adata_missing_pertcol_in_real():
@@ -292,6 +298,22 @@ def test_eval_pdex_kwargs_duplicated():
292298
)
293299

294300

301+
def test_eval_pdex_kwargs_is_log1p_conflict():
302+
adata_real = build_random_anndata()
303+
adata_pred = downsample_cells(adata_real, fraction=0.5)
304+
with pytest.raises(ValueError, match="Conflicting log1p configuration"):
305+
MetricsEvaluator(
306+
adata_pred=adata_pred,
307+
adata_real=adata_real,
308+
control_pert="control",
309+
pert_col="perturbation",
310+
is_log1p=True,
311+
pdex_kwargs={
312+
"is_log1p": False,
313+
},
314+
)
315+
316+
295317
def validate_expected_files(
296318
outdir: str, prefix: str | None = None, remove: bool = True
297319
):

0 commit comments

Comments
 (0)