-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsubsample_dataset.py
More file actions
executable file
·663 lines (550 loc) · 26.6 KB
/
subsample_dataset.py
File metadata and controls
executable file
·663 lines (550 loc) · 26.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
#!/usr/bin/env python3
"""
Enhanced subsampling infrastructure for single-cell data processing.
Based on subsampling.ipynb with added features for production use.
"""
import argparse
import asyncio
import concurrent.futures
import hashlib
import logging
import os
import platform
import sys
import time
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
from tqdm.auto import tqdm
try:
import palantir
except ImportError:
palantir = None
class AsyncFileWriter:
"""Manages asynchronous file writing with process limits."""
def __init__(self, max_workers: int = 4):
self.max_workers = max_workers
self.executor = ProcessPoolExecutor(max_workers=max_workers)
self.write_tasks = []
def submit_write(self, adata: ad.AnnData, filepath: str, compression: str = "gzip"):
"""Submit an AnnData write operation to the async queue."""
future = self.executor.submit(self._write_adata, adata, filepath, compression)
self.write_tasks.append(future)
return future
@staticmethod
def _write_adata(adata: ad.AnnData, filepath: str, compression: str):
"""Static method for writing AnnData - used by ProcessPoolExecutor."""
try:
adata.write(filepath, compression=compression)
return {"success": True, "filepath": filepath, "size": os.path.getsize(filepath)}
except Exception as e:
return {"success": False, "filepath": filepath, "error": str(e)}
def wait_for_completion(self, logger: logging.Logger):
"""Wait for all pending write operations to complete."""
if not self.write_tasks:
return
logger.info(f"Waiting for {len(self.write_tasks)} write operations to complete...")
completed = 0
for future in concurrent.futures.as_completed(self.write_tasks):
result = future.result()
completed += 1
if result["success"]:
size_mb = result["size"] / (1024 * 1024)
logger.info(f"[{completed}/{len(self.write_tasks)}] Wrote {result['filepath']} ({size_mb:.1f} MB)")
else:
logger.error(f"[{completed}/{len(self.write_tasks)}] Failed to write {result['filepath']}: {result['error']}")
self.write_tasks.clear()
def __del__(self):
if hasattr(self, 'executor'):
self.executor.shutdown(wait=False)
def setup_logging(log_file: str, level: str = "INFO") -> logging.Logger:
"""Setup comprehensive logging for the subsampling process."""
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {level}')
# Create formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Setup root logger
logger = logging.getLogger('subsample_dataset')
logger.setLevel(numeric_level)
logger.handlers.clear() # Remove any existing handlers
# File handler
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
return logger
def get_runtime_info() -> Dict:
"""Collect comprehensive runtime information."""
return {
'timestamp': datetime.now().isoformat(),
'platform': platform.platform(),
'python_version': platform.python_version(),
'hostname': platform.node(),
'cpu_count': os.cpu_count(),
'memory_info': dict(os.sysconf_names) if hasattr(os, 'sysconf_names') else {},
'environment_variables': {
'SLURM_JOB_ID': os.environ.get('SLURM_JOB_ID'),
'SLURM_ARRAY_TASK_ID': os.environ.get('SLURM_ARRAY_TASK_ID'),
'SLURMD_NODENAME': os.environ.get('SLURMD_NODENAME'),
}
}
def make_seed(name: str, salt: str = "settylab") -> int:
"""Generate reproducible random seed from dataset name and salt."""
h = hashlib.sha256((name + salt).encode('utf-8')).digest()
seed = int.from_bytes(h[:4], 'little')
return seed
def subsample_counts(X, p: float, rng: np.random.Generator):
"""Subsample read counts using binomial sampling for both sparse and dense matrices."""
if sp.issparse(X):
# Sparse matrix handling
new_data = rng.binomial(X.data.astype(int), p)
# Filter out zero entries
mask = new_data > 0
X_sub = X.__class__((new_data[mask], X.indices[mask], X.indptr.copy()), shape=X.shape)
return X_sub
else:
# Dense matrix handling
X_int = X.astype(int)
X_sub = rng.binomial(X_int, p)
return X_sub
def process_adata(adata: ad.AnnData, raw_count_layer: str = "raw_counts", skip_processing: bool = False) -> None:
"""Process AnnData object: optionally normalize, log transform, and create layers."""
# Ensure raw counts are preserved
if raw_count_layer is None or raw_count_layer == "X":
# Raw counts are in adata.X, save them in raw_counts layer
if "raw_counts" not in adata.layers:
adata.layers["raw_counts"] = adata.X.copy()
else:
# Raw counts should be in the specified layer
if raw_count_layer not in adata.layers:
adata.layers[raw_count_layer] = adata.X.copy()
if skip_processing:
# Data is already processed, just ensure we have the expected layers
if "normalized_counts" not in adata.layers:
adata.layers["normalized_counts"] = adata.X.copy()
if "logged_counts" not in adata.layers:
adata.layers["logged_counts"] = adata.X.copy()
else:
# Normalize to total counts
sc.pp.normalize_total(adata)
adata.layers["normalized_counts"] = adata.X.copy()
# Log transform
sc.pp.log1p(adata)
adata.layers["logged_counts"] = adata.X.copy()
def compute_pca(adata: ad.AnnData, n_components: int = 50, layer: str = "logged_counts") -> None:
"""Compute PCA on the specified layer."""
# Temporarily set X to the desired layer for PCA computation
original_X = adata.X.copy()
adata.X = adata.layers[layer].copy()
# Compute PCA
sc.tl.pca(adata, n_comps=n_components)
# Restore original X
adata.X = original_X
def compute_diffusion_maps(adata: ad.AnnData, n_components: int = 10, knn: int = 30, alpha: float = 0.0,
layer: str = "logged_counts", use_pca: bool = True) -> None:
"""Compute diffusion maps using palantir."""
if palantir is None:
raise ImportError("palantir is required for diffusion map computation. Install with: pip install palantir")
# Ensure PCA is computed if we want to use it
if use_pca and 'X_pca' not in adata.obsm:
compute_pca(adata, n_components=n_components, layer=layer)
if use_pca and 'X_pca' in adata.obsm:
# Use PCA coordinates for diffusion map computation
palantir.utils.run_diffusion_maps(adata, n_components=n_components, knn=knn, alpha=alpha, pca_key='X_pca')
else:
# Use the specified layer directly
# Temporarily set X to the desired layer for diffusion map computation
original_X = adata.X.copy()
adata.X = adata.layers[layer].copy()
# Compute diffusion maps using palantir
palantir.utils.run_diffusion_maps(adata, n_components=n_components, knn=knn, alpha=alpha)
# Restore original X
adata.X = original_X
def subsample_dataset(
dataset_name: str,
input_path: str,
output_dir: str,
fraction_reads: float = 0.99,
fraction_cells: float = 0.99,
min_reads_per_cell: int = 1000,
min_total_cells: int = 100,
raw_count_layer: str = "raw_counts",
embedding_key: str = "X_umap",
random_state_salt: str = "settylab",
subsample_reads: bool = True,
pca_components: int = 50,
max_write_workers: int = 4,
skip_processing: bool = False,
precompute_embeddings: bool = False,
logger: logging.Logger = None
) -> Dict:
"""
Main subsampling function with comprehensive features.
Parameters:
-----------
dataset_name : str
Name of the dataset for output directory and seed generation
input_path : str
Path to input H5AD file
output_dir : str
Base output directory
fraction_reads : float
Fraction of reads to keep in each round (ignored if subsample_reads=False)
fraction_cells : float
Fraction of cells to keep in each round
min_reads_per_cell : int
Minimum reads per cell threshold
min_total_cells : int
Minimum total cells to continue subsampling
raw_count_layer : str
Layer containing raw count data
embedding_key : str
Key for UMAP embedding in obsm
random_state_salt : str
Salt for reproducible random seed generation
subsample_reads : bool
Whether to subsample reads (True) or only cells (False)
pca_components : int
Number of PCA components to compute
max_write_workers : int
Maximum number of parallel write processes
precompute_embeddings : bool
Whether to compute PCA and diffusion maps on the full dataset before subsampling
logger : logging.Logger
Logger instance
Returns:
--------
Dict with processing statistics and runtime information
"""
start_time = time.time()
if logger is None:
logger = logging.getLogger(__name__)
logger.info(f"Starting subsampling for dataset: {dataset_name}")
logger.info(f"Input path: {input_path}")
logger.info(f"Output directory: {output_dir}")
logger.info(f"Subsample reads: {subsample_reads}")
logger.info(f"Fraction reads: {fraction_reads}")
logger.info(f"Fraction cells: {fraction_cells}")
logger.info(f"Min reads per cell: {min_reads_per_cell}")
logger.info(f"Min total cells: {min_total_cells}")
logger.info(f"PCA components: {pca_components}")
logger.info(f"Precompute embeddings: {precompute_embeddings}")
# Setup random state
seed = make_seed(dataset_name, random_state_salt)
rng = np.random.default_rng(seed)
logger.info(f"Using random seed: {seed}")
# Load data
logger.info("Loading AnnData object...")
load_start = time.time()
adata = ad.read_h5ad(input_path)
load_time = time.time() - load_start
logger.info(f"Loaded data in {load_time:.2f}s: {adata.n_obs} cells × {adata.n_vars} genes")
# Validate input and get count data
if raw_count_layer is None or raw_count_layer == "X":
# Use adata.X directly
X = adata.X.copy()
logger.info("Using adata.X as count data")
else:
# Use specified layer
if raw_count_layer not in adata.layers:
raise ValueError(f"Layer '{raw_count_layer}' not found in AnnData object")
X = adata.layers[raw_count_layer].copy()
logger.info(f"Using layer '{raw_count_layer}' as count data")
if embedding_key not in adata.obsm:
raise ValueError(f"Embedding '{embedding_key}' not found in AnnData.obsm")
# Accept both sparse and dense matrices
if sp.issparse(X):
logger.info(f"Using sparse matrix format: {type(X)}")
# Only convert to integer if we're subsampling reads (need integer counts for binomial)
if subsample_reads:
X.data = X.data.astype(int)
logger.info("Converted sparse matrix data to integer (required for read subsampling)")
else:
logger.info(f"Using dense matrix format: {type(X)}")
# Only convert to integer if we're subsampling reads
if subsample_reads:
X = X.astype(int)
logger.info("Converted dense matrix to integer (required for read subsampling)")
obs = adata.obs.copy()
var = adata.var.copy()
umap = adata.obsm[embedding_key].copy()
# Store original embeddings for later use
original_pca = None
original_diffusion_maps = None
# Compute PCA and diffusion maps on full dataset if requested
if precompute_embeddings:
logger.info("Computing PCA and diffusion maps on full dataset...")
# Create a temporary AnnData object for embedding computation
temp_adata = ad.AnnData(X.copy(), obs=obs.copy(), var=var.copy())
# Set the count data in the appropriate location
if raw_count_layer is None or raw_count_layer == "X":
temp_adata.layers["raw_counts"] = X.copy()
else:
temp_adata.layers[raw_count_layer] = X.copy()
# Process data (normalize, log transform) if needed
process_adata(temp_adata, raw_count_layer if raw_count_layer != "X" else "raw_counts", skip_processing)
# Compute PCA on full dataset
logger.info(f"Computing PCA on full dataset ({temp_adata.n_obs} cells)")
compute_pca(temp_adata, n_components=pca_components)
original_pca = temp_adata.obsm['X_pca'].copy()
# Compute diffusion maps on full dataset
logger.info(f"Computing diffusion maps on full dataset ({temp_adata.n_obs} cells)")
compute_diffusion_maps(temp_adata)
original_diffusion_maps = temp_adata.obsm['DM_EigenVectors'].copy()
logger.info("Completed embedding computation on full dataset")
# Create output directory
dataset_dir = Path(output_dir) / dataset_name
dataset_dir.mkdir(parents=True, exist_ok=True)
# Initialize async writer
writer = AsyncFileWriter(max_workers=max_write_workers)
# Setup progress tracking
n_cells_original = X.shape[0]
target_n_cells = max(int(n_cells_original * fraction_cells), min_total_cells)
pbar = tqdm(total=100, desc=f"Subsampling {dataset_name}", leave=True)
target_log = np.log(min_total_cells / n_cells_original) if min_total_cells < n_cells_original else 0
percent_done = 0
round_num = 0
# Statistics tracking
stats = {
'dataset_name': dataset_name,
'original_cells': n_cells_original,
'original_genes': adata.n_vars,
'rounds_completed': 0,
'final_cells': 0,
'files_written': [],
'total_time': 0,
'load_time': load_time,
'processing_time': 0,
'write_time': 0,
'subsample_reads': subsample_reads,
'parameters': {
'fraction_reads': fraction_reads,
'fraction_cells': fraction_cells,
'min_reads_per_cell': min_reads_per_cell,
'min_total_cells': min_total_cells,
'pca_components': pca_components,
}
}
logger.info(f"Starting iterative subsampling from {n_cells_original} cells")
processing_start = time.time()
try:
while X.shape[0] > min_total_cells: # Changed >= to > to stop at min_total_cells
round_start = time.time()
round_num += 1
logger.info(f"Round {round_num}: {X.shape[0]} cells")
# Subsample reads if requested
if subsample_reads:
logger.debug(f"Subsampling reads with fraction {fraction_reads}")
X = subsample_counts(X, fraction_reads, rng)
if sp.issparse(X):
X.eliminate_zeros()
# Filter cells with too few reads
reads_per_cell = np.asarray(X.sum(axis=1)).flatten()
keep = reads_per_cell >= min_reads_per_cell
if np.sum(keep) < min_total_cells:
logger.info(f"Insufficient cells after filtering ({np.sum(keep)} < {min_total_cells}), stopping")
break
X = X[keep]
obs = obs.iloc[keep]
umap = umap[keep]
# Also filter precomputed embeddings if they exist
if precompute_embeddings and original_pca is not None:
original_pca = original_pca[keep]
if precompute_embeddings and original_diffusion_maps is not None:
original_diffusion_maps = original_diffusion_maps[keep]
# Check if we're making progress
cells_before_cell_subsampling = X.shape[0]
# Update progress
if target_log != 0:
current_log = np.log(X.shape[0] / n_cells_original)
current_percent = 100 * (current_log / target_log)
current_percent = np.clip(current_percent, 0, 100)
step = max(0, current_percent - percent_done)
pbar.update(step)
percent_done = current_percent
# Subsample cells if needed
if X.shape[0] > target_n_cells:
n_target = int(X.shape[0] * fraction_cells)
# Ensure we don't go below min_total_cells and make progress
n_target = max(n_target, min_total_cells)
if n_target >= X.shape[0]:
logger.warning(f"Cell subsampling would not reduce cell count ({n_target} >= {X.shape[0]})")
# Force a reduction to avoid infinite loop
n_target = max(min_total_cells, X.shape[0] - 1)
idx = rng.choice(X.shape[0], size=n_target, replace=False)
logger.debug(f"Subsampling cells from {X.shape[0]} to {n_target}")
X = X[idx]
obs = obs.iloc[idx]
umap = umap[idx]
# Also subsample precomputed embeddings if they exist
if precompute_embeddings and original_pca is not None:
original_pca = original_pca[idx]
if precompute_embeddings and original_diffusion_maps is not None:
original_diffusion_maps = original_diffusion_maps[idx]
# Safety check: ensure we're making progress
if X.shape[0] == cells_before_cell_subsampling:
logger.warning(f"No progress made in round {round_num} (still {X.shape[0]} cells)")
if not subsample_reads:
logger.warning("Not subsampling reads and no cell reduction - forcing exit to avoid infinite loop")
break
# Create AnnData object for this round
adata_round = ad.AnnData(X.copy(), obs=obs.copy(), var=var.copy())
# Set the count data in the appropriate location
if raw_count_layer is None or raw_count_layer == "X":
# Data is already in adata_round.X, but also save in raw_counts layer for consistency
adata_round.layers["raw_counts"] = X.copy()
else:
# Save in the specified layer
adata_round.layers[raw_count_layer] = X.copy()
adata_round.obsm[embedding_key] = umap.copy()
adata_round.obsm[f"{embedding_key}_original"] = umap.copy() # Save copy as requested
# Process data (normalize, log transform) if needed
process_adata(adata_round, raw_count_layer if raw_count_layer != "X" else "raw_counts", skip_processing)
# Add precomputed embeddings if available
if precompute_embeddings and original_pca is not None:
logger.debug(f"Adding precomputed PCA with {original_pca.shape[1]} components")
adata_round.obsm['X_pca'] = original_pca.copy()
# Also store metadata about PCA (if available from scanpy)
if hasattr(adata_round, 'varm'):
adata_round.varm['PCs'] = np.eye(adata_round.n_vars, pca_components) # Placeholder
else:
# Compute PCA normally
logger.debug(f"Computing PCA with {pca_components} components")
compute_pca(adata_round, n_components=pca_components)
# Add precomputed diffusion maps if available
if precompute_embeddings and original_diffusion_maps is not None:
logger.debug(f"Adding precomputed diffusion maps with {original_diffusion_maps.shape[1]} components")
adata_round.obsm['DM_EigenVectors'] = original_diffusion_maps.copy()
# Prepare output path
save_path = dataset_dir / f"round_{round_num:03d}_n_{adata_round.n_obs}.h5ad"
# Submit to async writer
write_future = writer.submit_write(adata_round, str(save_path), compression="gzip")
stats['files_written'].append(str(save_path))
# Update target for next iteration
target_n_cells = max(int(X.shape[0] * fraction_cells), min_total_cells)
round_time = time.time() - round_start
logger.info(f"Round {round_num} completed in {round_time:.2f}s: {adata_round.n_obs} cells")
stats['rounds_completed'] = round_num
stats['final_cells'] = adata_round.n_obs
except Exception as e:
logger.error(f"Error during subsampling: {str(e)}", exc_info=True)
raise
finally:
# Ensure progress bar reaches 100%
pbar.n = 100
pbar.refresh()
pbar.close()
# Wait for all writes to complete
write_start = time.time()
writer.wait_for_completion(logger)
write_time = time.time() - write_start
processing_time = time.time() - processing_start
total_time = time.time() - start_time
# Update final statistics
stats.update({
'processing_time': processing_time,
'write_time': write_time,
'total_time': total_time
})
logger.info(f"Subsampling completed: {round_num} rounds, {stats['final_cells']} final cells")
logger.info(f"Total time: {total_time:.2f}s (load: {load_time:.2f}s, processing: {processing_time:.2f}s, write: {write_time:.2f}s)")
logger.info(f"Files written: {len(stats['files_written'])}")
return stats
def main():
"""Main entry point with argument parsing."""
parser = argparse.ArgumentParser(
description="Enhanced subsampling infrastructure for single-cell data",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# Required arguments
parser.add_argument("dataset_name", help="Name of the dataset")
parser.add_argument("input_path", help="Path to input H5AD file")
# Optional arguments with defaults from original notebook
parser.add_argument("--output-dir", default="data/subsampled",
help="Output directory for subsampled data")
parser.add_argument("--fraction-reads", type=float, default=0.99,
help="Fraction of reads to keep in each round")
parser.add_argument("--fraction-cells", type=float, default=0.99,
help="Fraction of cells to keep in each round")
parser.add_argument("--min-reads-per-cell", type=int, default=1000,
help="Minimum reads per cell threshold")
parser.add_argument("--min-total-cells", type=int, default=100,
help="Minimum total cells to continue subsampling")
parser.add_argument("--raw-count-layer", default="raw_counts",
help="Layer containing raw count data")
parser.add_argument("--embedding-key", default="X_umap",
help="Key for UMAP embedding in obsm")
parser.add_argument("--random-state-salt", default="settylab",
help="Salt for reproducible random seed generation")
parser.add_argument("--no-subsample-reads", action="store_true",
help="Only subsample cells, not reads")
parser.add_argument("--pca-components", type=int, default=50,
help="Number of PCA components to compute")
parser.add_argument("--max-write-workers", type=int, default=4,
help="Maximum number of parallel write processes")
parser.add_argument("--skip-processing", action="store_true",
help="Skip normalization and log transformation (data already processed)")
parser.add_argument("--precompute-embeddings", action="store_true",
help="Compute PCA and diffusion maps on full dataset before subsampling")
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO", help="Logging level")
args = parser.parse_args()
# Setup output directory and logging
output_dir = Path(args.output_dir)
dataset_dir = output_dir / args.dataset_name
dataset_dir.mkdir(parents=True, exist_ok=True)
log_file = dataset_dir / f"subsampling_{args.dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logger = setup_logging(str(log_file), args.log_level)
# Log runtime information
runtime_info = get_runtime_info()
logger.info("=== Runtime Information ===")
for key, value in runtime_info.items():
logger.info(f"{key}: {value}")
# Log all arguments
logger.info("=== Configuration ===")
for arg, value in vars(args).items():
logger.info(f"{arg}: {value}")
try:
# Run subsampling
stats = subsample_dataset(
dataset_name=args.dataset_name,
input_path=args.input_path,
output_dir=args.output_dir,
fraction_reads=args.fraction_reads,
fraction_cells=args.fraction_cells,
min_reads_per_cell=args.min_reads_per_cell,
min_total_cells=args.min_total_cells,
raw_count_layer=args.raw_count_layer,
embedding_key=args.embedding_key,
random_state_salt=args.random_state_salt,
subsample_reads=not args.no_subsample_reads,
pca_components=args.pca_components,
max_write_workers=args.max_write_workers,
skip_processing=args.skip_processing,
precompute_embeddings=args.precompute_embeddings,
logger=logger
)
# Log final statistics
logger.info("=== Final Statistics ===")
for key, value in stats.items():
if key != 'files_written': # Don't log the full file list
logger.info(f"{key}: {value}")
logger.info("Subsampling completed successfully!")
except Exception as e:
logger.error(f"Subsampling failed: {str(e)}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()