Skip to content

Commit 4dca4c4

Browse files
committed
Add DNA-LM tokenization
1 parent 501d637 commit 4dca4c4

4 files changed

Lines changed: 269 additions & 6 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ data/databases/open/probeBase_false_genome_results.csv
1414
data/databases/open/probeBase_formatted.csv
1515
data/databases/open/probeBase_formamide.tsv
1616
data/databases/open/probeBase_genome_results.csv
17+
data/databases/open/test_ML_database_tokenized.csv
1718
genome_parse_results.csv
1819
data/genomes/
1920
data/articles2/
@@ -74,4 +75,4 @@ PROBEst.Rproj
7475
.Rproj.user
7576
.Rhistory
7677
articles.tar.bz2
77-
data/databases/articles/artificial_database_structure
78+
data/databases/articles/artificial_database_structure

scripts/generator/ML_filtration.py

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
GAILDeep, GAILWide, GAILNarrow, GAILWithDropout,
1414
GAILWideDeep, GAILWideDropout, GAILWideBatchNorm, GAILWideExtra, GAILWideBalanced
1515
)
16+
from PROBESt.tokenization import tokenize_table
1617

1718
MODELS = {
1819
"ShallowNet": lambda n: TorchClassifier(ShallowNet(n), weight_pos=5),
@@ -22,10 +23,34 @@
2223
"TabTransformer": lambda n: TorchClassifier(TabTransformer(n), weight_pos=5),
2324
}
2425

25-
def main():
26-
# Load data
27-
data_path = 'data/databases/open/test_ML_database.csv'
28-
data = pd.read_csv(data_path)
26+
def load_and_prepare_data(data_path, use_tokenized=False, add_tokens=50, force_regenerate=False):
27+
"""Load and prepare data, optionally with tokenization.
28+
29+
Args:
30+
data_path: Path to the CSV file
31+
use_tokenized: If True, tokenize sequences and add token columns
32+
add_tokens: Number of top k-mers to add as columns per sequence column (if use_tokenized=True)
33+
force_regenerate: If True, regenerate tokenized file even if it exists
34+
35+
Returns:
36+
Tuple of (train_data, val_data, test_data)
37+
"""
38+
if use_tokenized:
39+
print(f"\n{'='*60}")
40+
print("Tokenizing sequences...")
41+
print(f"{'='*60}")
42+
# Tokenize the table
43+
tokenized_path = data_path.rsplit('.', 1)[0] + '_tokenized.csv'
44+
if force_regenerate or not os.path.exists(tokenized_path):
45+
print(f"Generating tokenized file: {tokenized_path}")
46+
tokenize_table(data_path, output_csv=tokenized_path, add_tokens=add_tokens,
47+
drop_original_sequences=True)
48+
else:
49+
print(f"Using existing tokenized file: {tokenized_path}")
50+
data = pd.read_csv(tokenized_path)
51+
print(f"Loaded tokenized data from {tokenized_path}")
52+
else:
53+
data = pd.read_csv(data_path)
2954

3055
# Convert boolean 'type' column to numeric
3156
data['type'] = data['type'].astype(int)
@@ -47,6 +72,14 @@ def main():
4772
print(f"Validation set size: {len(val_data)}")
4873
print(f"Test set size: {len(test_data)}")
4974

75+
return train_data, val_data, test_data
76+
77+
78+
def main():
79+
# Load data
80+
data_path = 'data/databases/open/test_ML_database.csv'
81+
train_data, val_data, test_data = load_and_prepare_data(data_path, use_tokenized=False)
82+
5083
# Get input size
5184
input_size = train_data.shape[1] - 1
5285

@@ -352,6 +385,95 @@ def main():
352385
os.makedirs(output_dir, exist_ok=True)
353386
test_predictions.to_csv(os.path.join(output_dir, 'test_predictions.csv'), index=False)
354387
print(f"\nPredictions saved to {os.path.join(output_dir, 'test_predictions.csv')}")
388+
389+
# Test GAIL_Wide_Custom2 on both tokenized and non-tokenized data
390+
print("\n" + "="*60)
391+
print("Testing GAIL_Wide_Custom2 on tokenized vs non-tokenized data")
392+
print("="*60)
393+
394+
# Test on non-tokenized data (already loaded)
395+
print("\n" + "-"*60)
396+
print("Testing GAIL_Wide_Custom2 on NON-TOKENIZED data")
397+
print("-"*60)
398+
gail_wide_custom2_non_tokenized = TorchClassifier(
399+
GAILWide(input_size, hidden1=384, hidden2=192), weight_pos=5
400+
)
401+
X_train = train_data.drop(columns=['type'])
402+
y_train = train_data['type']
403+
gail_wide_custom2_non_tokenized.train(
404+
X_train, y_train, epochs=150, batch_size=32,
405+
val_data=val_data, track_curves=True
406+
)
407+
non_tokenized_metrics = validate_filtration_AI(
408+
gail_wide_custom2_non_tokenized, val_data,
409+
output_name='GAIL_Wide_Custom2_non_tokenized.png'
410+
)
411+
print("\nGAIL_Wide_Custom2 (non-tokenized) validation metrics:")
412+
for metric, value in non_tokenized_metrics.items():
413+
print(f" {metric}: {value:.4f}")
414+
415+
# Plot learning curves for non-tokenized
416+
if hasattr(gail_wide_custom2_non_tokenized, 'train_losses') and len(gail_wide_custom2_non_tokenized.train_losses) > 0:
417+
plot_learning_curves(
418+
gail_wide_custom2_non_tokenized, output_dir=output_dir,
419+
output_name='learning_curves_GAIL_Wide_Custom2_non_tokenized.png'
420+
)
421+
print(f"Learning curves saved to {os.path.join(output_dir, 'learning_curves_GAIL_Wide_Custom2_non_tokenized.png')}")
422+
423+
# Test on tokenized data
424+
print("\n" + "-"*60)
425+
print("Testing GAIL_Wide_Custom2 on TOKENIZED data")
426+
print("-"*60)
427+
train_data_tokenized, val_data_tokenized, test_data_tokenized = load_and_prepare_data(
428+
data_path, use_tokenized=True, add_tokens=50, force_regenerate=True
429+
)
430+
input_size_tokenized = train_data_tokenized.shape[1] - 1
431+
print(f"Tokenized data input size: {input_size_tokenized} (vs {input_size} for non-tokenized)")
432+
433+
gail_wide_custom2_tokenized = TorchClassifier(
434+
GAILWide(input_size_tokenized, hidden1=384, hidden2=192), weight_pos=5
435+
)
436+
X_train_tokenized = train_data_tokenized.drop(columns=['type'])
437+
y_train_tokenized = train_data_tokenized['type']
438+
gail_wide_custom2_tokenized.train(
439+
X_train_tokenized, y_train_tokenized, epochs=150, batch_size=32,
440+
val_data=val_data_tokenized, track_curves=True
441+
)
442+
tokenized_metrics = validate_filtration_AI(
443+
gail_wide_custom2_tokenized, val_data_tokenized,
444+
output_name='GAIL_Wide_Custom2_tokenized.png'
445+
)
446+
print("\nGAIL_Wide_Custom2 (tokenized) validation metrics:")
447+
for metric, value in tokenized_metrics.items():
448+
print(f" {metric}: {value:.4f}")
449+
450+
# Plot learning curves for tokenized
451+
if hasattr(gail_wide_custom2_tokenized, 'train_losses') and len(gail_wide_custom2_tokenized.train_losses) > 0:
452+
plot_learning_curves(
453+
gail_wide_custom2_tokenized, output_dir=output_dir,
454+
output_name='learning_curves_GAIL_Wide_Custom2_tokenized.png'
455+
)
456+
print(f"Learning curves saved to {os.path.join(output_dir, 'learning_curves_GAIL_Wide_Custom2_tokenized.png')}")
457+
458+
# Compare results
459+
print("\n" + "="*60)
460+
print("COMPARISON: GAIL_Wide_Custom2 - Tokenized vs Non-Tokenized")
461+
print("="*60)
462+
print(f"{'Metric':<20} {'Non-Tokenized':<15} {'Tokenized':<15} {'Difference':<15}")
463+
print("-"*60)
464+
for metric in non_tokenized_metrics.keys():
465+
non_val = non_tokenized_metrics[metric]
466+
tok_val = tokenized_metrics[metric]
467+
diff = tok_val - non_val
468+
print(f"{metric:<20} {non_val:<15.4f} {tok_val:<15.4f} {diff:+.4f}")
469+
470+
# Determine winner
471+
if tokenized_metrics['f1'] > non_tokenized_metrics['f1']:
472+
print(f"\n✓ Tokenized version performs better (F1: {tokenized_metrics['f1']:.4f} vs {non_tokenized_metrics['f1']:.4f})")
473+
elif non_tokenized_metrics['f1'] > tokenized_metrics['f1']:
474+
print(f"\n✓ Non-tokenized version performs better (F1: {non_tokenized_metrics['f1']:.4f} vs {tokenized_metrics['f1']:.4f})")
475+
else:
476+
print(f"\n= Both versions perform equally (F1: {non_tokenized_metrics['f1']:.4f})")
355477

356478
if __name__ == '__main__':
357479
main()

src/PROBESt/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
from . import bash_wrappers
1111
from . import models_registry
1212
from . import AI
13-
from . import filtration
13+
from . import filtration
14+
from . import tokenization

src/PROBESt/tokenization.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Tokenization module for DNA sequences.
2+
3+
This module provides functions to tokenize DNA sequences into k-mers,
4+
similar to how DNA language models process sequences.
5+
"""
6+
7+
import pandas as pd
8+
import numpy as np
9+
from typing import List, Dict, Optional
10+
from collections import Counter
11+
12+
13+
def tokenize_seq(sequence: str, k: int = 3) -> List[str]:
14+
"""Tokenize a DNA sequence into k-mers.
15+
16+
Args:
17+
sequence: DNA sequence string (e.g., "ATCGATCG")
18+
k: Size of k-mers (default: 3, producing 3-mers like "ATC", "TCG")
19+
20+
Returns:
21+
List of k-mer tokens extracted from the sequence.
22+
For example, "ATCGATCG" with k=3 returns ["ATC", "TCG", "CGA", "GAT", "ATC", "TCG"]
23+
24+
Example:
25+
>>> tokenize_seq("ATCGATCG", k=3)
26+
['ATC', 'TCG', 'CGA', 'GAT', 'ATC', 'TCG']
27+
"""
28+
if not sequence or pd.isna(sequence):
29+
return []
30+
31+
sequence = str(sequence).upper().strip()
32+
if len(sequence) < k:
33+
return []
34+
35+
tokens = []
36+
for i in range(len(sequence) - k + 1):
37+
kmer = sequence[i:i+k]
38+
# Only include valid DNA k-mers (containing only A, T, C, G)
39+
if all(base in 'ATCG' for base in kmer):
40+
tokens.append(kmer)
41+
42+
return tokens
43+
44+
45+
def tokenize_table(input_csv: str, output_csv: Optional[str] = None,
46+
add_tokens: int = 100, k: int = 3,
47+
sequence_columns: Optional[List[str]] = None,
48+
drop_original_sequences: bool = True) -> pd.DataFrame:
49+
"""Tokenize DNA sequences in a CSV table and add token count columns.
50+
51+
This function reads a CSV file, tokenizes sequences in specified columns
52+
(default: 'sseq' and 'qseq'), and adds new columns with k-mer counts.
53+
The new columns are named like 'sseq_token_TTC', 'qseq_token_AAA', etc.
54+
55+
Args:
56+
input_csv: Path to input CSV file
57+
output_csv: Path to output CSV file (if None, appends '_tokenized' to input name)
58+
add_tokens: Number of top k-mers to add as columns per sequence column (default: 100)
59+
k: Size of k-mers (default: 3)
60+
sequence_columns: List of column names to tokenize (default: ['sseq', 'qseq'])
61+
drop_original_sequences: If True, drop original sequence columns after tokenization (default: True)
62+
63+
Returns:
64+
DataFrame with original columns plus new token count columns (or without original sequences if dropped)
65+
66+
Example:
67+
>>> df = tokenize_table('data.csv', add_tokens=50)
68+
>>> # Adds columns like 'sseq_token_AAA', 'sseq_token_TTC', etc.
69+
"""
70+
if sequence_columns is None:
71+
sequence_columns = ['sseq', 'qseq']
72+
73+
# Read the CSV
74+
df = pd.read_csv(input_csv)
75+
76+
# Collect all k-mers from all sequences to find the most frequent ones
77+
all_kmers = Counter()
78+
79+
for col in sequence_columns:
80+
if col not in df.columns:
81+
print(f"Warning: Column '{col}' not found in CSV. Skipping.")
82+
continue
83+
84+
for seq in df[col]:
85+
if pd.notna(seq):
86+
tokens = tokenize_seq(str(seq), k=k)
87+
all_kmers.update(tokens)
88+
89+
# Get top k-mers to add as columns
90+
top_kmers = [kmer for kmer, _ in all_kmers.most_common(add_tokens)]
91+
92+
print(f"Found {len(all_kmers)} unique {k}-mers. Adding top {len(top_kmers)} as columns.")
93+
94+
# Build all token count columns at once to avoid DataFrame fragmentation
95+
new_columns = {}
96+
97+
for col in sequence_columns:
98+
if col not in df.columns:
99+
continue
100+
101+
# Pre-compute tokens for all sequences to avoid repeated computation
102+
print(f"Tokenizing {col} column...")
103+
all_tokens = df[col].apply(
104+
lambda seq: tokenize_seq(str(seq), k=k) if pd.notna(seq) else []
105+
)
106+
107+
# Count k-mers for each row and each top k-mer
108+
for kmer in top_kmers:
109+
col_name = f"{col}_token_{kmer}"
110+
new_columns[col_name] = all_tokens.apply(
111+
lambda tokens: tokens.count(kmer) if isinstance(tokens, list) else 0
112+
)
113+
114+
# Add all new columns at once using pd.concat to avoid fragmentation
115+
if new_columns:
116+
new_df = pd.DataFrame(new_columns, index=df.index)
117+
df = pd.concat([df, new_df], axis=1)
118+
119+
# Count added columns before dropping original sequences
120+
num_added_columns = len(new_columns) if new_columns else 0
121+
122+
# Optionally drop original sequence columns after tokenization
123+
if drop_original_sequences:
124+
for col in sequence_columns:
125+
if col in df.columns:
126+
df = df.drop(columns=[col])
127+
print(f"Dropped original sequence column: {col}")
128+
129+
# Save to output file
130+
if output_csv is None:
131+
base_name = input_csv.rsplit('.', 1)[0]
132+
output_csv = f"{base_name}_tokenized.csv"
133+
134+
df.to_csv(output_csv, index=False)
135+
print(f"Tokenized table saved to {output_csv}")
136+
print(f"Added {num_added_columns} new token columns.")
137+
138+
return df
139+

0 commit comments

Comments
 (0)