Skip to content

Commit e27bd80

Browse files
Update version to 0.19.3 and enhance gene symbol handling in CyteType (#72)
- Bump package version to 0.19.3. - Introduce materialization of a canonical gene symbols column in AnnData, improving gene symbol management. - Refactor CyteType initialization to handle gene symbols more flexibly, including support for temporary columns. - Update save_features_matrix to conditionally include gene symbols metadata in output files. - Enhance tests to validate new gene symbol handling and ensure proper functionality.
1 parent 7a9b4a1 commit e27bd80

6 files changed

Lines changed: 313 additions & 122 deletions

File tree

cytetype/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.19.2"
1+
__version__ = "0.19.3"
22

33
import requests
44

cytetype/core/artifacts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _write_var_metadata(
4040
n_cols: int,
4141
var_df: pd.DataFrame,
4242
var_names: pd.Index | Sequence[Any] | None,
43+
gene_symbols_column: str | None = None,
4344
) -> None:
4445
if len(var_df) != n_cols:
4546
raise ValueError(
@@ -68,6 +69,8 @@ def _write_var_metadata(
6869
data=_as_string_values(var_df.index),
6970
dtype=text_dtype,
7071
)
72+
if gene_symbols_column is not None:
73+
var_group.attrs["gene_symbols_column"] = gene_symbols_column
7174

7275
columns_group = var_group.create_group("columns")
7376
for i, col_name in enumerate(var_df.columns):
@@ -414,6 +417,7 @@ def save_features_matrix(
414417
mat: Any,
415418
var_df: pd.DataFrame | None = None,
416419
var_names: pd.Index | Sequence[Any] | None = None,
420+
gene_symbols_column: str | None = None,
417421
raw_mat: Any | None = None,
418422
raw_col_indices: "np.ndarray | None" = None,
419423
raw_cell_batch: int = 2000,
@@ -454,6 +458,7 @@ def save_features_matrix(
454458
n_cols=n_cols,
455459
var_df=var_df,
456460
var_names=var_names,
461+
gene_symbols_column=gene_symbols_column,
457462
)
458463

459464
if raw_mat is not None:

cytetype/main.py

Lines changed: 150 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from .preprocessing import (
1717
validate_adata,
1818
resolve_gene_symbols_column,
19-
clean_gene_names,
2019
aggregate_expression_percentages,
2120
extract_marker_genes,
2221
aggregate_cluster_metadata,
2322
extract_visualization_coordinates,
2423
)
24+
from .preprocessing.validation import materialize_canonical_gene_symbols_column
2525
from .core.payload import build_annotation_payload, save_query_to_file
2626
from .core.artifacts import (
2727
_is_integer_valued,
@@ -140,136 +140,163 @@ def __init__(
140140
self.api_url = api_url
141141
self.auth_token = auth_token
142142
self._artifact_build_errors: list[tuple[str, Exception]] = []
143+
self._vars_h5_path: str | None = None
144+
self._obs_duckdb_path: str | None = None
145+
self._original_gene_symbols_column: str | None = None
146+
self._temporary_gene_symbols_column: str | None = None
143147

144-
self.gene_symbols_column = resolve_gene_symbols_column(
145-
adata, gene_symbols_column
146-
)
147-
148-
self.coordinates_key = validate_adata(
149-
adata, group_key, rank_key, self.gene_symbols_column, coordinates_key
150-
)
148+
try:
149+
self.gene_symbols_column = resolve_gene_symbols_column(
150+
adata, gene_symbols_column
151+
)
152+
self._original_gene_symbols_column = self.gene_symbols_column
151153

152-
# Use original labels as IDs if all are short (<=3 chars), otherwise enumerate
153-
_unique_group_categories: list[str | int] = natsorted(
154-
adata.obs[group_key].unique().tolist()
155-
)
156-
_short_ids = all(len(str(x)) <= 3 for x in _unique_group_categories)
157-
self.cluster_map = {
158-
str(x): str(x) if _short_ids else str(n)
159-
for n, x in enumerate(_unique_group_categories)
160-
}
161-
self.clusters = [
162-
self.cluster_map[str(x)] for x in adata.obs[group_key].values.tolist()
163-
]
164-
165-
gene_names = (
166-
adata.var[self.gene_symbols_column].tolist()
167-
if self.gene_symbols_column is not None
168-
else adata.var_names.tolist()
169-
)
170-
gene_names = clean_gene_names(gene_names)
171-
self.expression_percentages = aggregate_expression_percentages(
172-
adata=adata,
173-
clusters=self.clusters,
174-
gene_names=gene_names,
175-
cell_batch_size=pcent_batch_size,
176-
)
154+
self.coordinates_key = validate_adata(
155+
adata, group_key, rank_key, self.gene_symbols_column, coordinates_key
156+
)
157+
(
158+
self.gene_symbols_column,
159+
self._original_gene_symbols_column,
160+
) = materialize_canonical_gene_symbols_column(
161+
adata, self.gene_symbols_column
162+
)
163+
self._temporary_gene_symbols_column = self.gene_symbols_column
177164

178-
logger.info("Extracting marker genes...")
179-
self.marker_genes = extract_marker_genes(
180-
adata=self.adata,
181-
cell_group_key=self.group_key,
182-
rank_genes_key=self.rank_key,
183-
cluster_map=self.cluster_map,
184-
n_top_genes=n_top_genes,
185-
gene_symbols_col=self.gene_symbols_column,
186-
)
165+
# Use original labels as IDs if all are short (<=3 chars), otherwise enumerate
166+
_unique_group_categories: list[str | int] = natsorted(
167+
adata.obs[group_key].unique().tolist()
168+
)
169+
_short_ids = all(len(str(x)) <= 3 for x in _unique_group_categories)
170+
self.cluster_map = {
171+
str(x): str(x) if _short_ids else str(n)
172+
for n, x in enumerate(_unique_group_categories)
173+
}
174+
self.clusters = [
175+
self.cluster_map[str(x)] for x in adata.obs[group_key].values.tolist()
176+
]
177+
178+
gene_names = adata.var[self.gene_symbols_column].tolist()
179+
self.expression_percentages = aggregate_expression_percentages(
180+
adata=adata,
181+
clusters=self.clusters,
182+
gene_names=gene_names,
183+
cell_batch_size=pcent_batch_size,
184+
)
187185

188-
if aggregate_metadata:
189-
logger.info("Aggregating cluster metadata...")
190-
self.group_metadata = aggregate_cluster_metadata(
186+
logger.info("Extracting marker genes...")
187+
self.marker_genes = extract_marker_genes(
191188
adata=self.adata,
192-
group_key=self.group_key,
193-
min_percentage=min_percentage,
194-
max_categories=max_metadata_categories,
189+
cell_group_key=self.group_key,
190+
rank_genes_key=self.rank_key,
191+
cluster_map=self.cluster_map,
192+
n_top_genes=n_top_genes,
193+
gene_symbols_col=self.gene_symbols_column,
195194
)
196-
# Replace keys in group_metadata using cluster_map
197-
self.group_metadata = {
198-
self.cluster_map.get(str(key), str(key)): value
199-
for key, value in self.group_metadata.items()
200-
}
201-
self.group_metadata = {
202-
k: self.group_metadata[k] for k in sorted(self.group_metadata.keys())
195+
196+
if aggregate_metadata:
197+
logger.info("Aggregating cluster metadata...")
198+
self.group_metadata = aggregate_cluster_metadata(
199+
adata=self.adata,
200+
group_key=self.group_key,
201+
min_percentage=min_percentage,
202+
max_categories=max_metadata_categories,
203+
)
204+
# Replace keys in group_metadata using cluster_map
205+
self.group_metadata = {
206+
self.cluster_map.get(str(key), str(key)): value
207+
for key, value in self.group_metadata.items()
208+
}
209+
self.group_metadata = {
210+
k: self.group_metadata[k]
211+
for k in sorted(self.group_metadata.keys())
212+
}
213+
else:
214+
self.group_metadata = {}
215+
216+
# Prepare visualization data with sampling
217+
sampled_coordinates, sampled_cluster_labels = (
218+
extract_visualization_coordinates(
219+
adata=adata,
220+
coordinates_key=self.coordinates_key,
221+
group_key=self.group_key,
222+
cluster_map=self.cluster_map,
223+
max_cells_per_group=self.max_cells_per_group,
224+
)
225+
)
226+
227+
self.visualization_data = {
228+
"coordinates": sampled_coordinates,
229+
"clusters": sampled_cluster_labels,
203230
}
204-
else:
205-
self.group_metadata = {}
206-
207-
# Prepare visualization data with sampling
208-
sampled_coordinates, sampled_cluster_labels = extract_visualization_coordinates(
209-
adata=adata,
210-
coordinates_key=self.coordinates_key,
211-
group_key=self.group_key,
212-
cluster_map=self.cluster_map,
213-
max_cells_per_group=self.max_cells_per_group,
214-
)
215231

216-
self.visualization_data = {
217-
"coordinates": sampled_coordinates,
218-
"clusters": sampled_cluster_labels,
219-
}
232+
# Resolve raw counts once and cache
233+
self._raw_counts_result = self._resolve_raw_counts()
234+
if self._raw_counts_result is None:
235+
logger.warning(
236+
"No integer raw counts found in adata.layers['counts'], "
237+
"adata.raw.X, or adata.X. Skipping raw counts in vars.h5."
238+
)
220239

221-
# Resolve raw counts once and cache
222-
self._raw_counts_result = self._resolve_raw_counts()
223-
if self._raw_counts_result is None:
224-
logger.warning(
225-
"No integer raw counts found in adata.layers['counts'], "
226-
"adata.raw.X, or adata.X. Skipping raw counts in vars.h5."
227-
)
240+
# Build vars.h5
241+
try:
242+
raw_mat, raw_col_indices = (
243+
self._raw_counts_result
244+
if self._raw_counts_result is not None
245+
else (None, None)
246+
)
247+
save_features_matrix(
248+
out_file=vars_h5_path,
249+
mat=self.adata.X,
250+
var_df=self.adata.var,
251+
var_names=self.adata.var_names,
252+
raw_mat=raw_mat,
253+
raw_col_indices=raw_col_indices,
254+
gene_symbols_column=self.gene_symbols_column,
255+
)
256+
sys.stderr.flush()
257+
self._vars_h5_path = vars_h5_path
258+
except Exception as exc:
259+
logger.warning(f"vars.h5 artifact failed during build: {exc}")
260+
self._artifact_build_errors.append(("vars_h5", exc))
228261

229-
# Build vars.h5
230-
try:
231-
raw_mat, raw_col_indices = (
232-
self._raw_counts_result
233-
if self._raw_counts_result is not None
234-
else (None, None)
235-
)
236-
save_features_matrix(
237-
out_file=vars_h5_path,
238-
mat=self.adata.X,
239-
var_df=self.adata.var,
240-
var_names=self.adata.var_names,
241-
raw_mat=raw_mat,
242-
raw_col_indices=raw_col_indices,
243-
)
244-
sys.stderr.flush()
245-
self._vars_h5_path: str | None = vars_h5_path
246-
except Exception as exc:
247-
logger.warning(f"vars.h5 artifact failed during build: {exc}")
248-
self._vars_h5_path = None
249-
self._artifact_build_errors.append(("vars_h5", exc))
250-
251-
# Build obs.duckdb
252-
try:
253-
logger.info("Writing obs data to duckdb artifact...")
254-
obsm_coordinates = (
255-
self.adata.obsm[self.coordinates_key]
256-
if self.coordinates_key and self.coordinates_key in self.adata.obsm
257-
else None
258-
)
259-
save_obs_duckdb_file(
260-
out_file=obs_duckdb_path,
261-
obs_df=self.adata.obs,
262-
obsm_coordinates=obsm_coordinates,
263-
coordinates_key=self.coordinates_key,
262+
# Build obs.duckdb
263+
try:
264+
logger.info("Writing obs data to duckdb artifact...")
265+
obsm_coordinates = (
266+
self.adata.obsm[self.coordinates_key]
267+
if self.coordinates_key and self.coordinates_key in self.adata.obsm
268+
else None
269+
)
270+
save_obs_duckdb_file(
271+
out_file=obs_duckdb_path,
272+
obs_df=self.adata.obs,
273+
obsm_coordinates=obsm_coordinates,
274+
coordinates_key=self.coordinates_key,
275+
)
276+
sys.stderr.flush()
277+
self._obs_duckdb_path = obs_duckdb_path
278+
except Exception as exc:
279+
logger.warning(f"obs.duckdb artifact failed during build: {exc}")
280+
self._artifact_build_errors.append(("obs_duckdb", exc))
281+
282+
logger.info("Data preparation completed. Ready for submitting jobs.")
283+
except Exception:
284+
self._cleanup_temporary_gene_symbols_column()
285+
raise
286+
287+
def _cleanup_temporary_gene_symbols_column(self) -> None:
288+
temp_column = self._temporary_gene_symbols_column
289+
if temp_column is None:
290+
return
291+
292+
if temp_column in self.adata.var.columns:
293+
del self.adata.var[temp_column]
294+
logger.info(
295+
f"Deleted temporary canonical gene-symbol column '{temp_column}'."
264296
)
265-
sys.stderr.flush()
266-
self._obs_duckdb_path: str | None = obs_duckdb_path
267-
except Exception as exc:
268-
logger.warning(f"obs.duckdb artifact failed during build: {exc}")
269-
self._obs_duckdb_path = None
270-
self._artifact_build_errors.append(("obs_duckdb", exc))
271297

272-
logger.info("Data preparation completed. Ready for submitting jobs.")
298+
self.gene_symbols_column = self._original_gene_symbols_column
299+
self._temporary_gene_symbols_column = None
273300

274301
def _resolve_raw_counts(
275302
self,
@@ -356,7 +383,8 @@ def cleanup(self) -> None:
356383
"""Delete the artifact files built during initialization.
357384
358385
Call this after run() completes to remove the vars.h5 and obs.duckdb
359-
files from disk. Paths are cleared so repeated calls are safe.
386+
files from disk and drop the temporary canonical gene-symbol column.
387+
Paths are cleared so repeated calls are safe.
360388
"""
361389
for attr, path in [
362390
("_vars_h5_path", self._vars_h5_path),
@@ -370,6 +398,8 @@ def cleanup(self) -> None:
370398
logger.warning(f"Failed to delete artifact {path}: {exc}")
371399
setattr(self, attr, None)
372400

401+
self._cleanup_temporary_gene_symbols_column()
402+
373403
def run(
374404
self,
375405
study_context: str,

0 commit comments

Comments
 (0)