Skip to content

Commit 6afd980

Browse files
authored
Fix NaN values in adata.obs (#74)
* Update validation.py * raise error and drop_na_cells arg * warning message * Update validation.py * docstring for drop na flag * fix to no mod adata * label_na as unknown * Update main.py
1 parent e27bd80 commit 6afd980

2 files changed

Lines changed: 71 additions & 2 deletions

File tree

cytetype/main.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
aggregate_cluster_metadata,
2222
extract_visualization_coordinates,
2323
)
24-
from .preprocessing.validation import materialize_canonical_gene_symbols_column
24+
from .preprocessing.validation import (
25+
materialize_canonical_gene_symbols_column,
26+
_generate_unique_na_label,
27+
)
2528
from .core.payload import build_annotation_payload, save_query_to_file
2629
from .core.artifacts import (
2730
_is_integer_valued,
@@ -87,6 +90,7 @@ def __init__(
8790
max_metadata_categories: int = 500,
8891
api_url: str = "https://prod.cytetype.nygen.io",
8992
auth_token: str | None = None,
93+
label_na: bool = False,
9094
) -> None:
9195
"""Initialize CyteType with AnnData object and perform data preparation.
9296
@@ -125,6 +129,11 @@ def __init__(
125129
deployment. Defaults to "https://prod.cytetype.nygen.io".
126130
auth_token (str | None, optional): Bearer token for API authentication. If provided,
127131
will be included in the Authorization header as "Bearer {auth_token}". Defaults to None.
132+
label_na (bool, optional): If True, cells with NaN values in the
133+
``group_key`` column are assigned an ``'Unknown'`` cluster label
134+
(or ``'Unknown 2'``, etc. if that label already exists). The original
135+
AnnData object is not modified. If False (default), a ``ValueError``
136+
is raised instead.
128137
129138
Raises:
130139
KeyError: If the required keys are missing in `adata.obs` or `adata.uns`
@@ -152,8 +161,40 @@ def __init__(
152161
self._original_gene_symbols_column = self.gene_symbols_column
153162

154163
self.coordinates_key = validate_adata(
155-
adata, group_key, rank_key, self.gene_symbols_column, coordinates_key
164+
adata, group_key, rank_key, self.gene_symbols_column, coordinates_key,
165+
label_na=label_na,
156166
)
167+
168+
if label_na:
169+
nan_mask = adata.obs[group_key].isna()
170+
if nan_mask.any():
171+
n_nan = int(nan_mask.sum())
172+
pct = round(100 * n_nan / adata.n_obs, 1)
173+
existing_labels = set(
174+
str(v) for v in adata.obs[group_key].dropna().unique()
175+
)
176+
na_label = _generate_unique_na_label(existing_labels)
177+
logger.warning(
178+
f"⚠️ Relabeling {n_nan} cells ({pct}%) with NaN values "
179+
f"in '{group_key}' as '{na_label}'."
180+
)
181+
adata = anndata.AnnData(
182+
X=adata.X,
183+
obs=adata.obs.copy(),
184+
var=adata.var,
185+
uns=adata.uns,
186+
obsm=adata.obsm,
187+
varm=adata.varm,
188+
layers=adata.layers,
189+
obsp=adata.obsp,
190+
varp=adata.varp,
191+
)
192+
col = adata.obs[group_key]
193+
if hasattr(col, "cat"):
194+
col = col.cat.add_categories(na_label)
195+
adata.obs[group_key] = col.fillna(na_label)
196+
self.adata = adata
197+
157198
(
158199
self.gene_symbols_column,
159200
self._original_gene_symbols_column,

cytetype/preprocessing/validation.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,43 @@ def _ur_sort_key(ur: float) -> float:
266266
return None
267267

268268

269+
def _generate_unique_na_label(existing_labels: set[str]) -> str:
270+
label = "Unknown"
271+
if label not in existing_labels:
272+
return label
273+
n = 2
274+
while f"{label} {n}" in existing_labels:
275+
n += 1
276+
return f"{label} {n}"
277+
278+
269279
def validate_adata(
270280
adata: anndata.AnnData,
271281
cell_group_key: str,
272282
rank_genes_key: str,
273283
gene_symbols_col: str | None,
274284
coordinates_key: str,
285+
label_na: bool = False,
275286
) -> str | None:
276287
if cell_group_key not in adata.obs:
277288
raise KeyError(f"Cell group key '{cell_group_key}' not found in `adata.obs`.")
289+
290+
nan_mask = adata.obs[cell_group_key].isna()
291+
n_nan = int(nan_mask.sum())
292+
if n_nan > 0:
293+
pct = round(100 * n_nan / adata.n_obs, 1)
294+
if n_nan == adata.n_obs:
295+
raise ValueError(
296+
f"All {n_nan} cells have NaN values in '{cell_group_key}'. "
297+
f"Cannot proceed with annotation."
298+
)
299+
if not label_na:
300+
raise ValueError(
301+
f"{n_nan} cells ({pct}%) have NaN values in '{cell_group_key}'. "
302+
f"Either fix the data or set label_na=True to assign these cells "
303+
f"an 'Unknown' cluster label."
304+
)
305+
278306
if adata.X is None:
279307
raise ValueError(
280308
"`adata.X` is required for ranking genes. Please ensure it contains log1p normalized data."

0 commit comments

Comments
 (0)