@@ -57,12 +57,15 @@ class MetricsEvaluator:
5757 outdir: str = "./cell-eval-outdir"
5858 Output directory.
5959 allow_discrete: bool = False
60- Allow discrete data .
60+ Deprecated; cell-eval now trusts the caller-provided scale .
6161 prefix: str | None = None
6262 Prefix for output files.
6363 pdex_kwargs: dict[str, Any] | None = None
6464 Keyword arguments for parallel_differential_expression.
65- These will overwrite arguments passed to MetricsEvaluator.__init__ if they conflict.
65+ These may not conflict with ``is_log1p``.
66+ is_log1p: bool = True
67+ Whether input AnnData matrices are already log1p-normalized. cell-eval trusts
68+ this flag and does not inspect or transform input matrices.
6669 """
6770
6871 def __init__ (
@@ -79,6 +82,7 @@ def __init__(
7982 prefix : str | None = None ,
8083 pdex_kwargs : dict [str , Any ] | None = None ,
8184 skip_de : bool = False ,
85+ is_log1p : bool = True ,
8286 ):
8387 # Enable a global string cache for categorical columns
8488 pl .enable_string_cache ()
@@ -97,7 +101,6 @@ def __init__(
97101 pred = adata_pred ,
98102 control_pert = control_pert ,
99103 pert_col = pert_col ,
100- allow_discrete = allow_discrete ,
101104 )
102105
103106 if skip_de :
@@ -108,10 +111,10 @@ def __init__(
108111 de_pred = de_pred ,
109112 de_real = de_real ,
110113 num_threads = num_threads ,
111- allow_discrete = allow_discrete ,
112114 outdir = outdir ,
113115 prefix = prefix ,
114116 pdex_kwargs = pdex_kwargs or {},
117+ is_log1p = is_log1p ,
115118 )
116119
117120 self .outdir = outdir
@@ -170,7 +173,6 @@ def _build_anndata_pair(
170173 pred : ad .AnnData | str ,
171174 control_pert : str ,
172175 pert_col : str ,
173- allow_discrete : bool = False ,
174176):
175177 if isinstance (real , str ):
176178 logger .info (f"Reading real anndata from { real } " )
@@ -183,10 +185,6 @@ def _build_anndata_pair(
183185 _cast_float16_to_float32 (real , which = "real" )
184186 _cast_float16_to_float32 (pred , which = "pred" )
185187
186- # Validate that the input is normalized and log-transformed
187- _convert_to_normlog (real , which = "real" , allow_discrete = allow_discrete )
188- _convert_to_normlog (pred , which = "pred" , allow_discrete = allow_discrete )
189-
190188 # Build the anndata pair
191189 return PerturbationAnndataPair (
192190 real = real , pred = pred , control_pert = control_pert , pert_col = pert_col
@@ -234,31 +232,31 @@ def _build_de_comparison(
234232 de_pred : pl .DataFrame | str | None = None ,
235233 de_real : pl .DataFrame | str | None = None ,
236234 num_threads : int = 1 ,
237- allow_discrete : bool = False ,
238235 outdir : str | None = None ,
239236 prefix : str | None = None ,
240237 pdex_kwargs : dict [str , Any ] | None = None ,
238+ is_log1p : bool = True ,
241239):
242240 return initialize_de_comparison (
243241 real = _load_or_build_de (
244242 mode = "real" ,
245243 de_path = de_real ,
246244 anndata_pair = anndata_pair ,
247245 num_threads = num_threads ,
248- allow_discrete = allow_discrete ,
249246 outdir = outdir ,
250247 prefix = prefix ,
251248 pdex_kwargs = pdex_kwargs or {},
249+ is_log1p = is_log1p ,
252250 ),
253251 pred = _load_or_build_de (
254252 mode = "pred" ,
255253 de_path = de_pred ,
256254 anndata_pair = anndata_pair ,
257255 num_threads = num_threads ,
258- allow_discrete = allow_discrete ,
259256 outdir = outdir ,
260257 prefix = prefix ,
261258 pdex_kwargs = pdex_kwargs or {},
259+ is_log1p = is_log1p ,
262260 ),
263261 )
264262
@@ -267,21 +265,23 @@ def _build_pdex_kwargs(
267265 reference : str ,
268266 groupby : str ,
269267 threads : int ,
270- allow_discrete : bool ,
268+ is_log1p : bool ,
271269 pdex_kwargs : dict [str , Any ] | None = None ,
272270) -> dict [str , Any ]:
273- pdex_kwargs = pdex_kwargs or {}
271+ pdex_kwargs = dict ( pdex_kwargs or {})
274272 if "reference" not in pdex_kwargs :
275273 pdex_kwargs ["reference" ] = reference
276274 if "groupby" not in pdex_kwargs :
277275 pdex_kwargs ["groupby" ] = groupby
278276 if "threads" not in pdex_kwargs :
279277 pdex_kwargs ["threads" ] = threads
280278 if "is_log1p" not in pdex_kwargs :
281- if allow_discrete :
282- pdex_kwargs ["is_log1p" ] = False
283- else :
284- pdex_kwargs ["is_log1p" ] = True
279+ pdex_kwargs ["is_log1p" ] = is_log1p
280+ elif bool (pdex_kwargs ["is_log1p" ]) != is_log1p :
281+ raise ValueError (
282+ "Conflicting log1p configuration: MetricsEvaluator(is_log1p="
283+ f"{ is_log1p } ) but pdex_kwargs['is_log1p']={ pdex_kwargs ['is_log1p' ]!r} "
284+ )
285285 return pdex_kwargs
286286
287287
@@ -292,8 +292,8 @@ def _load_or_build_de(
292292 num_threads : int = 1 ,
293293 outdir : str | None = None ,
294294 prefix : str | None = None ,
295- allow_discrete : bool = False ,
296295 pdex_kwargs : dict [str , Any ] | None = None ,
296+ is_log1p : bool = True ,
297297) -> pl .DataFrame :
298298 if de_path is None :
299299 if anndata_pair is None :
@@ -303,8 +303,8 @@ def _load_or_build_de(
303303 reference = anndata_pair .control_pert ,
304304 groupby = anndata_pair .pert_col ,
305305 threads = num_threads ,
306- allow_discrete = allow_discrete ,
307306 pdex_kwargs = pdex_kwargs or {},
307+ is_log1p = is_log1p ,
308308 )
309309 logger .info (f"Using the following pdex kwargs: { pdex_kwargs } " )
310310 frame = pdex (
0 commit comments