|
21 | 21 | aggregate_cluster_metadata, |
22 | 22 | extract_visualization_coordinates, |
23 | 23 | ) |
24 | | -from .preprocessing.validation import materialize_canonical_gene_symbols_column |
| 24 | +from .preprocessing.validation import ( |
| 25 | + materialize_canonical_gene_symbols_column, |
| 26 | + _generate_unique_na_label, |
| 27 | +) |
25 | 28 | from .core.payload import build_annotation_payload, save_query_to_file |
26 | 29 | from .core.artifacts import ( |
27 | 30 | _is_integer_valued, |
@@ -87,6 +90,7 @@ def __init__( |
87 | 90 | max_metadata_categories: int = 500, |
88 | 91 | api_url: str = "https://prod.cytetype.nygen.io", |
89 | 92 | auth_token: str | None = None, |
| 93 | + label_na: bool = False, |
90 | 94 | ) -> None: |
91 | 95 | """Initialize CyteType with AnnData object and perform data preparation. |
92 | 96 |
|
@@ -125,6 +129,11 @@ def __init__( |
125 | 129 | deployment. Defaults to "https://prod.cytetype.nygen.io". |
126 | 130 | auth_token (str | None, optional): Bearer token for API authentication. If provided, |
127 | 131 | will be included in the Authorization header as "Bearer {auth_token}". Defaults to None. |
| 132 | + label_na (bool, optional): If True, cells with NaN values in the |
| 133 | + ``group_key`` column are assigned an ``'Unknown'`` cluster label |
| 134 | + (or ``'Unknown 2'``, etc. if that label already exists). The original |
| 135 | + AnnData object is not modified. If False (default), a ``ValueError`` |
| 136 | + is raised instead. |
128 | 137 |
|
129 | 138 | Raises: |
130 | 139 | KeyError: If the required keys are missing in `adata.obs` or `adata.uns` |
@@ -152,8 +161,40 @@ def __init__( |
152 | 161 | self._original_gene_symbols_column = self.gene_symbols_column |
153 | 162 |
|
154 | 163 | self.coordinates_key = validate_adata( |
155 | | - adata, group_key, rank_key, self.gene_symbols_column, coordinates_key |
| 164 | + adata, group_key, rank_key, self.gene_symbols_column, coordinates_key, |
| 165 | + label_na=label_na, |
156 | 166 | ) |
| 167 | + |
| 168 | + if label_na: |
| 169 | + nan_mask = adata.obs[group_key].isna() |
| 170 | + if nan_mask.any(): |
| 171 | + n_nan = int(nan_mask.sum()) |
| 172 | + pct = round(100 * n_nan / adata.n_obs, 1) |
| 173 | + existing_labels = set( |
| 174 | + str(v) for v in adata.obs[group_key].dropna().unique() |
| 175 | + ) |
| 176 | + na_label = _generate_unique_na_label(existing_labels) |
| 177 | + logger.warning( |
| 178 | + f"⚠️ Relabeling {n_nan} cells ({pct}%) with NaN values " |
| 179 | + f"in '{group_key}' as '{na_label}'." |
| 180 | + ) |
| 181 | + adata = anndata.AnnData( |
| 182 | + X=adata.X, |
| 183 | + obs=adata.obs.copy(), |
| 184 | + var=adata.var, |
| 185 | + uns=adata.uns, |
| 186 | + obsm=adata.obsm, |
| 187 | + varm=adata.varm, |
| 188 | + layers=adata.layers, |
| 189 | + obsp=adata.obsp, |
| 190 | + varp=adata.varp, |
| 191 | + ) |
| 192 | + col = adata.obs[group_key] |
| 193 | + if hasattr(col, "cat"): |
| 194 | + col = col.cat.add_categories(na_label) |
| 195 | + adata.obs[group_key] = col.fillna(na_label) |
| 196 | + self.adata = adata |
| 197 | + |
157 | 198 | ( |
158 | 199 | self.gene_symbols_column, |
159 | 200 | self._original_gene_symbols_column, |
|
0 commit comments