|
| 1 | +# Sharded Eden DataLoader Implementation |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +The `sharded_eden_dataloader.py` implements a dataloader for genomic sequences that uses pre-computed data structures and SQLite databases for efficient data access. This implementation is designed to significantly reduce the computational overhead during training by moving expensive operations to a pre-processing phase. |
| 6 | + |
| 7 | +## Key Features |
| 8 | + |
| 9 | +### 1. Split-Specific Window Databases |
| 10 | + |
| 11 | +- **Sharded**: Uses separate pre-computed window databases for each split: |
| 12 | + - `train_window_db_path`: SQLite database containing window mappings for training data |
| 13 | + - `val_window_db_path`: SQLite database containing window mappings for validation data |
| 14 | + - `test_window_db_path`: SQLite database containing window mappings for test data |
| 15 | + |
| 16 | +### 2. SQLite Database Storage |
| 17 | + |
| 18 | +- **Sharded**: Uses SQLite databases organized by sample: |
| 19 | + - **Per-Sample Sequence Databases**: Each sample has its own SQLite file at `sequence_db_dir/<sample_id>/glm_dataset_<sample_id>.sqlite` |
| 20 | + - **Split-Specific Window Databases**: Pre-computed window mappings stored in separate databases for each data split |
| 21 | + |
| 22 | +### 3. Virtual Window Pre-computation |
| 23 | + |
| 24 | +- **Sharded**: Window mappings are pre-computed from Parquet files and stored in split-specific databases |
| 25 | + |
| 26 | +## Sequence ID Format |
| 27 | + |
| 28 | +Sequence IDs follow a specific format: `BCR__ECT-SAMPLE1__CT1-1` |
| 29 | + |
| 30 | +The sample ID can be extracted using: `extract_sample_id(sequence_id)` which implements `".".join(sequence_id.split("__")[1].split("-")[1:])` (returns `SAMPLE1`) |
| 31 | + |
| 32 | +## Database Schema |
| 33 | + |
| 34 | +### Per-Sample Sequence Database |
| 35 | + |
| 36 | +Each sample has its own SQLite file with the following schema: |
| 37 | + |
| 38 | +```sql |
| 39 | +CREATE TABLE sequences ( |
| 40 | + contig_id TEXT PRIMARY KEY, |
| 41 | + nt_sequence TEXT NOT NULL |
| 42 | +); |
| 43 | +``` |
| 44 | + |
| 45 | +### Split-Specific Window Database |
| 46 | + |
| 47 | +Each split (train/validation/test) has its own window database: |
| 48 | + |
| 49 | +```sql |
| 50 | +CREATE TABLE metadata ( |
| 51 | + key TEXT PRIMARY KEY, |
| 52 | + value INTEGER NOT NULL |
| 53 | +); |
| 54 | + |
| 55 | +CREATE TABLE window_mappings ( |
| 56 | + window_idx INTEGER PRIMARY KEY, |
| 57 | + sequence_id TEXT NOT NULL, |
| 58 | + window_in_seq_idx INTEGER NOT NULL |
| 59 | +); |
| 60 | +CREATE INDEX idx_sequence_id ON window_mappings(sequence_id); |
| 61 | +``` |
| 62 | + |
| 63 | +The metadata table stores the `window_size` and `stride` parameters used during pre-computation. |
| 64 | + |
| 65 | +## Directory Structure |
| 66 | + |
| 67 | +``` |
| 68 | +sequence_db_dir/ |
| 69 | +├── SAMPLE1/ |
| 70 | +│ └── glm_dataset_SAMPLE1.sqlite |
| 71 | +├── SAMPLE2/ |
| 72 | +│ └── glm_dataset_SAMPLE2.sqlite |
| 73 | +├── SAMPLE3/ |
| 74 | +│ └── glm_dataset_SAMPLE3.sqlite |
| 75 | +└── ... |
| 76 | +
|
| 77 | +Window databases (separate files): |
| 78 | +├── train_windows.db |
| 79 | +├── val_windows.db |
| 80 | +└── test_windows.db |
| 81 | +``` |
| 82 | + |
| 83 | +## Usage Example |
| 84 | + |
| 85 | +```python |
| 86 | +from bionemo.evo2.run.sharded_eden_dataloader import ShardedEdenDataModule |
| 87 | + |
| 88 | +# Create the data module |
| 89 | +data_module = ShardedEdenDataModule( |
| 90 | + sequence_db_dir="path/to/sequence_db_dir", # Directory containing sample folders |
| 91 | + train_window_db_path="path/to/train_windows.db", |
| 92 | + val_window_db_path="path/to/val_windows.db", |
| 93 | + test_window_db_path="path/to/test_windows.db", |
| 94 | + seq_length=8192, |
| 95 | + micro_batch_size=1, |
| 96 | + global_batch_size=4, |
| 97 | + num_workers=8, |
| 98 | + rc_aug=True, |
| 99 | + use_control_tags=True, |
| 100 | +) |
| 101 | + |
| 102 | +# Use with PyTorch Lightning trainer |
| 103 | +trainer = pl.Trainer(...) |
| 104 | +trainer.fit(model, data_module) |
| 105 | +``` |
| 106 | + |
| 107 | +## Pre-processing Workflow |
| 108 | + |
| 109 | +### 1. Create Sample Sequence Databases |
| 110 | + |
| 111 | +For each sample, create its SQLite database: |
| 112 | + |
| 113 | +```python |
| 114 | +import sqlite3 |
| 115 | +import os |
| 116 | + |
| 117 | + |
| 118 | +def create_sample_database(sample_id, sequences, output_dir): |
| 119 | + """Create SQLite database for a single sample.""" |
| 120 | + # Create sample directory |
| 121 | + sample_dir = os.path.join(output_dir, sample_id) |
| 122 | + os.makedirs(sample_dir, exist_ok=True) |
| 123 | + |
| 124 | + # Create database |
| 125 | + db_path = os.path.join(sample_dir, f"glm_dataset_{sample_id}.sqlite") |
| 126 | + conn = sqlite3.connect(db_path) |
| 127 | + cursor = conn.cursor() |
| 128 | + |
| 129 | + # Create table |
| 130 | + cursor.execute( |
| 131 | + """ |
| 132 | + CREATE TABLE sequences ( |
| 133 | + contig_id TEXT PRIMARY KEY, |
| 134 | + nt_sequence TEXT NOT NULL |
| 135 | + ) |
| 136 | + """ |
| 137 | + ) |
| 138 | + |
| 139 | + # Insert sequences for this sample |
| 140 | + for seq_id, sequence in sequences: |
| 141 | + cursor.execute( |
| 142 | + "INSERT INTO sequences (contig_id, nt_sequence) VALUES (?, ?)", |
| 143 | + (seq_id, sequence), |
| 144 | + ) |
| 145 | + |
| 146 | + conn.commit() |
| 147 | + conn.close() |
| 148 | + |
| 149 | + |
| 150 | +# Example usage |
| 151 | +# Group sequences by sample_id |
| 152 | +from collections import defaultdict |
| 153 | + |
| 154 | +sequences_by_sample = defaultdict(list) |
| 155 | +for seq_id, sequence in all_sequences: # all_sequences is your data |
| 156 | + sample_id = extract_sample_id(seq_id) |
| 157 | + sequences_by_sample[sample_id].append((seq_id, sequence)) |
| 158 | + |
| 159 | +# Create database for each sample |
| 160 | +for sample_id, sequences in sequences_by_sample.items(): |
| 161 | + create_sample_database(sample_id, sequences, "path/to/sequence_db_dir") |
| 162 | +``` |
| 163 | + |
| 164 | +### 2. Create Split Data Files |
| 165 | + |
| 166 | +Create Parquet files for each split containing sequence metadata: |
| 167 | + |
| 168 | +```python |
| 169 | +import polars as pl |
| 170 | + |
| 171 | +# Create train split Parquet file |
| 172 | +train_data = pl.DataFrame( |
| 173 | + { |
| 174 | + "contig_id": ["BCR__ECT-SAMPLE1__CT1-1", "BCR__ECT-SAMPLE1__CT1-2", ...], |
| 175 | + "length": [1500, 2000, ...], # sequence lengths |
| 176 | + } |
| 177 | +) |
| 178 | +train_data.write_parquet("train_split.parquet") |
| 179 | + |
| 180 | +# Similarly for validation and test splits |
| 181 | +val_data = pl.DataFrame( |
| 182 | + {"contig_id": ["BCR__ECT-SAMPLE2__CT1-1", ...], "length": [1800, ...]} |
| 183 | +) |
| 184 | +val_data.write_parquet("val_split.parquet") |
| 185 | + |
| 186 | +test_data = pl.DataFrame( |
| 187 | + {"contig_id": ["BCR__ECT-SAMPLE3__CT1-1", ...], "length": [1600, ...]} |
| 188 | +) |
| 189 | +test_data.write_parquet("test_split.parquet") |
| 190 | +``` |
| 191 | + |
| 192 | +### 3. Create Window Mappings Databases using CLI |
| 193 | + |
| 194 | +The package includes a CLI tool for pre-computing the window databases: |
| 195 | + |
| 196 | +```bash |
| 197 | +# Pre-compute window mappings for training split |
| 198 | +python -m bionemo.evo2.run.sharded_eden_dataloader precompute \ |
| 199 | + train_split.parquet \ |
| 200 | + train_windows.db \ |
| 201 | + --window-size 8192 \ |
| 202 | + --stride 7992 |
| 203 | + |
| 204 | +# Pre-compute window mappings for validation split |
| 205 | +python -m bionemo.evo2.run.sharded_eden_dataloader precompute \ |
| 206 | + val_split.parquet \ |
| 207 | + val_windows.db \ |
| 208 | + --window-size 8192 \ |
| 209 | + --stride 7992 |
| 210 | + |
| 211 | +# Pre-compute window mappings for test split |
| 212 | +python -m bionemo.evo2.run.sharded_eden_dataloader precompute \ |
| 213 | + test_split.parquet \ |
| 214 | + test_windows.db \ |
| 215 | + --window-size 8192 \ |
| 216 | + --stride 7992 |
| 217 | +``` |
| 218 | + |
| 219 | +## Implementation Details |
| 220 | + |
| 221 | +### Key Components |
| 222 | + |
| 223 | +1. **ShardedEdenDataModule**: |
| 224 | + |
| 225 | + - Uses separate window databases for each split (train/val/test) |
| 226 | + - Manages per-sample SQLite file paths |
| 227 | + - Creates datasets with directory and database paths |
| 228 | + - Handles distributed training setup with Megatron integration |
| 229 | + |
| 230 | +2. **ShardedEdenDataset**: |
| 231 | + |
| 232 | + - Automatically discovers sample SQLite files from directory structure |
| 233 | + - Maps sequence IDs to appropriate sample databases using `extract_sample_id()` |
| 234 | + - Pre-opens all database connections for performance |
| 235 | + - Attaches window database to each sequence connection for efficient JOINs |
| 236 | + - Implements sequence caching with connection pooling |
| 237 | + - Maintains compatibility with original tokenization and formatting logic |
| 238 | + - Optional window access logging for performance analysis |
| 239 | + |
| 240 | +3. **CLI Tool**: |
| 241 | + |
| 242 | + - `precompute`: Creates window databases from Parquet files |
| 243 | + |
| 244 | +### Advanced Features |
| 245 | + |
| 246 | +#### Window Access Logging |
| 247 | + |
| 248 | +Enable detailed logging of window access patterns: |
| 249 | + |
| 250 | +```python |
| 251 | +dataset = ShardedEdenDataset( |
| 252 | + # ... other parameters ... |
| 253 | + log_windows=True, |
| 254 | + log_dir="sequence_logs", |
| 255 | +) |
| 256 | +``` |
| 257 | + |
| 258 | +This creates CSV logs tracking which windows are accessed, useful for analyzing data loading patterns. |
| 259 | + |
| 260 | +#### Connection Management |
| 261 | + |
| 262 | +- All database connections are pre-opened during initialization for performance |
| 263 | +- Database connections are pooled and reused per sample |
| 264 | +- Sequence data is fetched on-demand using SQL SUBSTR for memory efficiency |
| 265 | +- Position IDs are shared across instances to reduce memory usage |
| 266 | +- Connections are properly closed when dataset is destroyed |
| 267 | + |
| 268 | +#### Metadata Validation |
| 269 | + |
| 270 | +The implementation validates that window databases were created with compatible parameters: |
| 271 | + |
| 272 | +- Checks stored `window_size` matches dataset `seq_length` |
| 273 | +- Checks stored `stride` matches dataset `stride` |
| 274 | +- Provides clear error messages for mismatches |
| 275 | + |
| 276 | +### Error Handling |
| 277 | + |
| 278 | +- Validates sample SQLite files exist during initialization |
| 279 | +- Handles missing sequences gracefully with informative error messages |
| 280 | +- Ensures proper cleanup of database connections |
| 281 | +- Provides detailed debugging information for database issues |
| 282 | +- Validates Parquet file schema during pre-computation |
0 commit comments