1616from .preprocessing import (
1717 validate_adata ,
1818 resolve_gene_symbols_column ,
19- clean_gene_names ,
2019 aggregate_expression_percentages ,
2120 extract_marker_genes ,
2221 aggregate_cluster_metadata ,
2322 extract_visualization_coordinates ,
2423)
24+ from .preprocessing .validation import materialize_canonical_gene_symbols_column
2525from .core .payload import build_annotation_payload , save_query_to_file
2626from .core .artifacts import (
2727 _is_integer_valued ,
@@ -140,136 +140,163 @@ def __init__(
140140 self .api_url = api_url
141141 self .auth_token = auth_token
142142 self ._artifact_build_errors : list [tuple [str , Exception ]] = []
143+ self ._vars_h5_path : str | None = None
144+ self ._obs_duckdb_path : str | None = None
145+ self ._original_gene_symbols_column : str | None = None
146+ self ._temporary_gene_symbols_column : str | None = None
143147
144- self .gene_symbols_column = resolve_gene_symbols_column (
145- adata , gene_symbols_column
146- )
147-
148- self .coordinates_key = validate_adata (
149- adata , group_key , rank_key , self .gene_symbols_column , coordinates_key
150- )
148+ try :
149+ self .gene_symbols_column = resolve_gene_symbols_column (
150+ adata , gene_symbols_column
151+ )
152+ self ._original_gene_symbols_column = self .gene_symbols_column
151153
152- # Use original labels as IDs if all are short (<=3 chars), otherwise enumerate
153- _unique_group_categories : list [str | int ] = natsorted (
154- adata .obs [group_key ].unique ().tolist ()
155- )
156- _short_ids = all (len (str (x )) <= 3 for x in _unique_group_categories )
157- self .cluster_map = {
158- str (x ): str (x ) if _short_ids else str (n )
159- for n , x in enumerate (_unique_group_categories )
160- }
161- self .clusters = [
162- self .cluster_map [str (x )] for x in adata .obs [group_key ].values .tolist ()
163- ]
164-
165- gene_names = (
166- adata .var [self .gene_symbols_column ].tolist ()
167- if self .gene_symbols_column is not None
168- else adata .var_names .tolist ()
169- )
170- gene_names = clean_gene_names (gene_names )
171- self .expression_percentages = aggregate_expression_percentages (
172- adata = adata ,
173- clusters = self .clusters ,
174- gene_names = gene_names ,
175- cell_batch_size = pcent_batch_size ,
176- )
154+ self .coordinates_key = validate_adata (
155+ adata , group_key , rank_key , self .gene_symbols_column , coordinates_key
156+ )
157+ (
158+ self .gene_symbols_column ,
159+ self ._original_gene_symbols_column ,
160+ ) = materialize_canonical_gene_symbols_column (
161+ adata , self .gene_symbols_column
162+ )
163+ self ._temporary_gene_symbols_column = self .gene_symbols_column
177164
178- logger .info ("Extracting marker genes..." )
179- self .marker_genes = extract_marker_genes (
180- adata = self .adata ,
181- cell_group_key = self .group_key ,
182- rank_genes_key = self .rank_key ,
183- cluster_map = self .cluster_map ,
184- n_top_genes = n_top_genes ,
185- gene_symbols_col = self .gene_symbols_column ,
186- )
165+ # Use original labels as IDs if all are short (<=3 chars), otherwise enumerate
166+ _unique_group_categories : list [str | int ] = natsorted (
167+ adata .obs [group_key ].unique ().tolist ()
168+ )
169+ _short_ids = all (len (str (x )) <= 3 for x in _unique_group_categories )
170+ self .cluster_map = {
171+ str (x ): str (x ) if _short_ids else str (n )
172+ for n , x in enumerate (_unique_group_categories )
173+ }
174+ self .clusters = [
175+ self .cluster_map [str (x )] for x in adata .obs [group_key ].values .tolist ()
176+ ]
177+
178+ gene_names = adata .var [self .gene_symbols_column ].tolist ()
179+ self .expression_percentages = aggregate_expression_percentages (
180+ adata = adata ,
181+ clusters = self .clusters ,
182+ gene_names = gene_names ,
183+ cell_batch_size = pcent_batch_size ,
184+ )
187185
188- if aggregate_metadata :
189- logger .info ("Aggregating cluster metadata..." )
190- self .group_metadata = aggregate_cluster_metadata (
186+ logger .info ("Extracting marker genes..." )
187+ self .marker_genes = extract_marker_genes (
191188 adata = self .adata ,
192- group_key = self .group_key ,
193- min_percentage = min_percentage ,
194- max_categories = max_metadata_categories ,
189+ cell_group_key = self .group_key ,
190+ rank_genes_key = self .rank_key ,
191+ cluster_map = self .cluster_map ,
192+ n_top_genes = n_top_genes ,
193+ gene_symbols_col = self .gene_symbols_column ,
195194 )
196- # Replace keys in group_metadata using cluster_map
197- self .group_metadata = {
198- self .cluster_map .get (str (key ), str (key )): value
199- for key , value in self .group_metadata .items ()
200- }
201- self .group_metadata = {
202- k : self .group_metadata [k ] for k in sorted (self .group_metadata .keys ())
195+
196+ if aggregate_metadata :
197+ logger .info ("Aggregating cluster metadata..." )
198+ self .group_metadata = aggregate_cluster_metadata (
199+ adata = self .adata ,
200+ group_key = self .group_key ,
201+ min_percentage = min_percentage ,
202+ max_categories = max_metadata_categories ,
203+ )
204+ # Replace keys in group_metadata using cluster_map
205+ self .group_metadata = {
206+ self .cluster_map .get (str (key ), str (key )): value
207+ for key , value in self .group_metadata .items ()
208+ }
209+ self .group_metadata = {
210+ k : self .group_metadata [k ]
211+ for k in sorted (self .group_metadata .keys ())
212+ }
213+ else :
214+ self .group_metadata = {}
215+
216+ # Prepare visualization data with sampling
217+ sampled_coordinates , sampled_cluster_labels = (
218+ extract_visualization_coordinates (
219+ adata = adata ,
220+ coordinates_key = self .coordinates_key ,
221+ group_key = self .group_key ,
222+ cluster_map = self .cluster_map ,
223+ max_cells_per_group = self .max_cells_per_group ,
224+ )
225+ )
226+
227+ self .visualization_data = {
228+ "coordinates" : sampled_coordinates ,
229+ "clusters" : sampled_cluster_labels ,
203230 }
204- else :
205- self .group_metadata = {}
206-
207- # Prepare visualization data with sampling
208- sampled_coordinates , sampled_cluster_labels = extract_visualization_coordinates (
209- adata = adata ,
210- coordinates_key = self .coordinates_key ,
211- group_key = self .group_key ,
212- cluster_map = self .cluster_map ,
213- max_cells_per_group = self .max_cells_per_group ,
214- )
215231
216- self .visualization_data = {
217- "coordinates" : sampled_coordinates ,
218- "clusters" : sampled_cluster_labels ,
219- }
232+ # Resolve raw counts once and cache
233+ self ._raw_counts_result = self ._resolve_raw_counts ()
234+ if self ._raw_counts_result is None :
235+ logger .warning (
236+ "No integer raw counts found in adata.layers['counts'], "
237+ "adata.raw.X, or adata.X. Skipping raw counts in vars.h5."
238+ )
220239
221- # Resolve raw counts once and cache
222- self ._raw_counts_result = self ._resolve_raw_counts ()
223- if self ._raw_counts_result is None :
224- logger .warning (
225- "No integer raw counts found in adata.layers['counts'], "
226- "adata.raw.X, or adata.X. Skipping raw counts in vars.h5."
227- )
240+ # Build vars.h5
241+ try :
242+ raw_mat , raw_col_indices = (
243+ self ._raw_counts_result
244+ if self ._raw_counts_result is not None
245+ else (None , None )
246+ )
247+ save_features_matrix (
248+ out_file = vars_h5_path ,
249+ mat = self .adata .X ,
250+ var_df = self .adata .var ,
251+ var_names = self .adata .var_names ,
252+ raw_mat = raw_mat ,
253+ raw_col_indices = raw_col_indices ,
254+ gene_symbols_column = self .gene_symbols_column ,
255+ )
256+ sys .stderr .flush ()
257+ self ._vars_h5_path = vars_h5_path
258+ except Exception as exc :
259+ logger .warning (f"vars.h5 artifact failed during build: { exc } " )
260+ self ._artifact_build_errors .append (("vars_h5" , exc ))
228261
229- # Build vars.h5
230- try :
231- raw_mat , raw_col_indices = (
232- self ._raw_counts_result
233- if self ._raw_counts_result is not None
234- else (None , None )
235- )
236- save_features_matrix (
237- out_file = vars_h5_path ,
238- mat = self .adata .X ,
239- var_df = self .adata .var ,
240- var_names = self .adata .var_names ,
241- raw_mat = raw_mat ,
242- raw_col_indices = raw_col_indices ,
243- )
244- sys .stderr .flush ()
245- self ._vars_h5_path : str | None = vars_h5_path
246- except Exception as exc :
247- logger .warning (f"vars.h5 artifact failed during build: { exc } " )
248- self ._vars_h5_path = None
249- self ._artifact_build_errors .append (("vars_h5" , exc ))
250-
251- # Build obs.duckdb
252- try :
253- logger .info ("Writing obs data to duckdb artifact..." )
254- obsm_coordinates = (
255- self .adata .obsm [self .coordinates_key ]
256- if self .coordinates_key and self .coordinates_key in self .adata .obsm
257- else None
258- )
259- save_obs_duckdb_file (
260- out_file = obs_duckdb_path ,
261- obs_df = self .adata .obs ,
262- obsm_coordinates = obsm_coordinates ,
263- coordinates_key = self .coordinates_key ,
262+ # Build obs.duckdb
263+ try :
264+ logger .info ("Writing obs data to duckdb artifact..." )
265+ obsm_coordinates = (
266+ self .adata .obsm [self .coordinates_key ]
267+ if self .coordinates_key and self .coordinates_key in self .adata .obsm
268+ else None
269+ )
270+ save_obs_duckdb_file (
271+ out_file = obs_duckdb_path ,
272+ obs_df = self .adata .obs ,
273+ obsm_coordinates = obsm_coordinates ,
274+ coordinates_key = self .coordinates_key ,
275+ )
276+ sys .stderr .flush ()
277+ self ._obs_duckdb_path = obs_duckdb_path
278+ except Exception as exc :
279+ logger .warning (f"obs.duckdb artifact failed during build: { exc } " )
280+ self ._artifact_build_errors .append (("obs_duckdb" , exc ))
281+
282+ logger .info ("Data preparation completed. Ready for submitting jobs." )
283+ except Exception :
284+ self ._cleanup_temporary_gene_symbols_column ()
285+ raise
286+
287+ def _cleanup_temporary_gene_symbols_column (self ) -> None :
288+ temp_column = self ._temporary_gene_symbols_column
289+ if temp_column is None :
290+ return
291+
292+ if temp_column in self .adata .var .columns :
293+ del self .adata .var [temp_column ]
294+ logger .info (
295+ f"Deleted temporary canonical gene-symbol column '{ temp_column } '."
264296 )
265- sys .stderr .flush ()
266- self ._obs_duckdb_path : str | None = obs_duckdb_path
267- except Exception as exc :
268- logger .warning (f"obs.duckdb artifact failed during build: { exc } " )
269- self ._obs_duckdb_path = None
270- self ._artifact_build_errors .append (("obs_duckdb" , exc ))
271297
272- logger .info ("Data preparation completed. Ready for submitting jobs." )
298+ self .gene_symbols_column = self ._original_gene_symbols_column
299+ self ._temporary_gene_symbols_column = None
273300
274301 def _resolve_raw_counts (
275302 self ,
@@ -356,7 +383,8 @@ def cleanup(self) -> None:
356383 """Delete the artifact files built during initialization.
357384
358385 Call this after run() completes to remove the vars.h5 and obs.duckdb
359- files from disk. Paths are cleared so repeated calls are safe.
386+ files from disk and drop the temporary canonical gene-symbol column.
387+ Paths are cleared so repeated calls are safe.
360388 """
361389 for attr , path in [
362390 ("_vars_h5_path" , self ._vars_h5_path ),
@@ -370,6 +398,8 @@ def cleanup(self) -> None:
370398 logger .warning (f"Failed to delete artifact { path } : { exc } " )
371399 setattr (self , attr , None )
372400
401+ self ._cleanup_temporary_gene_symbols_column ()
402+
373403 def run (
374404 self ,
375405 study_context : str ,
0 commit comments