Skip to content

Commit 704fe14

Browse files
micedremeilame-tayebjee
authored andcommitted
Adapt examples to new package architecture
1 parent 1b62eee commit 704fe14

5 files changed

Lines changed: 575 additions & 330 deletions

File tree

examples/advanced_training.py

Lines changed: 139 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,44 @@
66
and training monitoring.
77
"""
88

9+
import os
10+
import random
11+
import warnings
12+
913
import numpy as np
10-
from torchTextClassifiers import create_fasttext
14+
import torch
15+
from pytorch_lightning import seed_everything
16+
17+
from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
18+
from torchTextClassifiers.tokenizers import WordPieceTokenizer
1119

1220
def main():
21+
# Set seed for reproducibility
22+
SEED = 42
23+
24+
# Set environment variables for full reproducibility
25+
os.environ['PYTHONHASHSEED'] = str(SEED)
26+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
27+
28+
# Use PyTorch Lightning's seed_everything for comprehensive seeding
29+
seed_everything(SEED, workers=True)
30+
31+
# Make PyTorch operations deterministic
32+
torch.backends.cudnn.deterministic = True
33+
torch.backends.cudnn.benchmark = False
34+
torch.use_deterministic_algorithms(True, warn_only=True)
35+
36+
# Suppress PyTorch Lightning warnings for cleaner output
37+
warnings.filterwarnings(
38+
'ignore',
39+
message='.*',
40+
category=UserWarning,
41+
module='pytorch_lightning'
42+
)
43+
1344
print("⚙️ Advanced Training Configuration Example")
1445
print("=" * 50)
15-
46+
1647
# Create a larger dataset for demonstrating advanced training
1748
print("📝 Creating training dataset...")
1849

@@ -67,54 +98,63 @@ def main():
6798
print(f"Training samples: {len(X_train)}")
6899
print(f"Validation samples: {len(X_val)}")
69100
print(f"Test samples: {len(X_test)}")
70-
71-
# Create FastText classifier
72-
print("\n🏗️ Creating FastText classifier...")
73-
classifier = create_fasttext(
101+
102+
# Create and train tokenizer (shared across all examples)
103+
print("\n🏗️ Creating and training WordPiece tokenizer...")
104+
tokenizer = WordPieceTokenizer(vocab_size=5000, output_dim=128)
105+
training_corpus = X_train.tolist()
106+
tokenizer.train(training_corpus)
107+
print("✅ Tokenizer trained successfully!")
108+
109+
# Example 1: Basic training with default settings
110+
print("\n🎯 Example 1: Basic training with default settings...")
111+
112+
model_config = ModelConfig(
74113
embedding_dim=100,
75-
sparse=False,
76-
num_tokens=10000,
77-
min_count=1,
78-
min_n=3,
79-
max_n=6,
80-
len_word_ngrams=2,
81114
num_classes=2
82115
)
83-
84-
# Build the model
85-
print("\n🔨 Building model...")
86-
classifier.build(X_train, y_train)
87-
print("✅ Model built successfully!")
88-
89-
# Example 1: Basic training with default settings
90-
print("\n🎯 Example 1: Basic training with default settings...")
91-
classifier.train(
92-
X_train, y_train, X_val, y_val,
116+
117+
classifier = torchTextClassifiers(
118+
tokenizer=tokenizer,
119+
model_config=model_config
120+
)
121+
print("✅ Classifier created successfully!")
122+
123+
training_config = TrainingConfig(
93124
num_epochs=15,
94125
batch_size=8,
95-
patience_train=5,
126+
lr=1e-3,
127+
patience_early_stopping=5,
128+
num_workers=0,
129+
trainer_params={'deterministic': True}
130+
)
131+
132+
classifier.train(
133+
X_train, y_train, X_val, y_val,
134+
training_config=training_config,
96135
verbose=True
97136
)
98-
99-
basic_accuracy = classifier.validate(X_test, y_test)
137+
138+
result = classifier.predict(X_test)
139+
basic_predictions = result["prediction"].squeeze().numpy()
140+
basic_accuracy = (basic_predictions == y_test).mean()
100141
print(f"✅ Basic training completed! Accuracy: {basic_accuracy:.3f}")
101142

102143
# Example 2: Advanced training with custom Lightning trainer parameters
103144
print("\n🚀 Example 2: Advanced training with custom parameters...")
104-
145+
105146
# Create a new classifier for comparison
106-
advanced_classifier = create_fasttext(
147+
advanced_model_config = ModelConfig(
107148
embedding_dim=100,
108-
sparse=False,
109-
num_tokens=10000,
110-
min_count=1,
111-
min_n=3,
112-
max_n=6,
113-
len_word_ngrams=2,
114149
num_classes=2
115150
)
116-
advanced_classifier.build(X_train, y_train)
117-
151+
152+
advanced_classifier = torchTextClassifiers(
153+
tokenizer=tokenizer,
154+
model_config=advanced_model_config
155+
)
156+
print("✅ Advanced classifier created successfully!")
157+
118158
# Custom trainer parameters for advanced features
119159
advanced_trainer_params = {
120160
'accelerator': 'auto', # Use GPU if available, else CPU
@@ -125,62 +165,77 @@ def main():
125165
'enable_progress_bar': True, # Show progress bar
126166
'log_every_n_steps': 5, # Log every 5 steps
127167
}
128-
129-
advanced_classifier.train(
130-
X_train, y_train, X_val, y_val,
168+
169+
advanced_training_config = TrainingConfig(
131170
num_epochs=20,
132171
batch_size=4, # Smaller batch size with grad accumulation
133-
patience_train=7,
134-
trainer_params=advanced_trainer_params,
172+
lr=1e-3,
173+
patience_early_stopping=7,
174+
num_workers=0,
175+
cpu_run=False, # Don't override accelerator from trainer_params
176+
trainer_params=advanced_trainer_params
177+
)
178+
179+
advanced_classifier.train(
180+
X_train, y_train, X_val, y_val,
181+
training_config=advanced_training_config,
135182
verbose=True
136183
)
137-
138-
advanced_accuracy = advanced_classifier.validate(X_test, y_test)
184+
185+
advanced_result = advanced_classifier.predict(X_test)
186+
advanced_predictions = advanced_result["prediction"].squeeze().numpy()
187+
advanced_accuracy = (advanced_predictions == y_test).mean()
139188
print(f"✅ Advanced training completed! Accuracy: {advanced_accuracy:.3f}")
140189

141190
# Example 3: Training with CPU-only (useful for small datasets or debugging)
142191
print("\n💻 Example 3: CPU-only training...")
143-
144-
cpu_classifier = create_fasttext(
192+
193+
cpu_model_config = ModelConfig(
145194
embedding_dim=64, # Smaller embedding for faster CPU training
146-
sparse=True, # Sparse embeddings for efficiency
147-
num_tokens=5000,
148-
min_count=1,
149-
min_n=3,
150-
max_n=6,
151-
len_word_ngrams=2,
152195
num_classes=2
153196
)
154-
cpu_classifier.build(X_train, y_train)
155-
156-
cpu_classifier.train(
157-
X_train, y_train, X_val, y_val,
197+
198+
cpu_classifier = torchTextClassifiers(
199+
tokenizer=tokenizer,
200+
model_config=cpu_model_config
201+
)
202+
print("✅ CPU classifier created successfully!")
203+
204+
cpu_training_config = TrainingConfig(
158205
num_epochs=10,
159206
batch_size=16, # Larger batch size for CPU
160-
cpu_run=True, # Force CPU usage
207+
lr=1e-3,
208+
patience_early_stopping=3,
209+
cpu_run=False, # Don't override accelerator from trainer_params
161210
num_workers=0, # No multiprocessing for CPU
162-
patience_train=3,
211+
trainer_params={'deterministic': True, 'accelerator': 'cpu'}
212+
)
213+
214+
cpu_classifier.train(
215+
X_train, y_train, X_val, y_val,
216+
training_config=cpu_training_config,
163217
verbose=True
164218
)
165-
166-
cpu_accuracy = cpu_classifier.validate(X_test, y_test)
219+
220+
cpu_result = cpu_classifier.predict(X_test)
221+
cpu_predictions = cpu_result["prediction"].squeeze().numpy()
222+
cpu_accuracy = (cpu_predictions == y_test).mean()
167223
print(f"✅ CPU training completed! Accuracy: {cpu_accuracy:.3f}")
168224

169225
# Example 4: Custom training with specific Lightning callbacks
170226
print("\n🔧 Example 4: Training with custom callbacks...")
171-
172-
custom_classifier = create_fasttext(
227+
228+
custom_model_config = ModelConfig(
173229
embedding_dim=128,
174-
sparse=False,
175-
num_tokens=8000,
176-
min_count=1,
177-
min_n=3,
178-
max_n=6,
179-
len_word_ngrams=2,
180230
num_classes=2
181231
)
182-
custom_classifier.build(X_train, y_train)
183-
232+
233+
custom_classifier = torchTextClassifiers(
234+
tokenizer=tokenizer,
235+
model_config=custom_model_config
236+
)
237+
print("✅ Custom classifier created successfully!")
238+
184239
# Custom trainer with specific monitoring and checkpointing
185240
custom_trainer_params = {
186241
'max_epochs': 25,
@@ -189,18 +244,27 @@ def main():
189244
'check_val_every_n_epoch': 2, # Validate every 2 epochs
190245
'enable_checkpointing': True,
191246
'enable_model_summary': True,
247+
'deterministic': True,
192248
}
193-
194-
custom_classifier.train(
195-
X_train, y_train, X_val, y_val,
249+
250+
custom_training_config = TrainingConfig(
196251
num_epochs=25,
197252
batch_size=6,
198-
patience_train=8,
199-
trainer_params=custom_trainer_params,
253+
lr=1e-3,
254+
patience_early_stopping=8,
255+
num_workers=0,
256+
trainer_params=custom_trainer_params
257+
)
258+
259+
custom_classifier.train(
260+
X_train, y_train, X_val, y_val,
261+
training_config=custom_training_config,
200262
verbose=True
201263
)
202-
203-
custom_accuracy = custom_classifier.validate(X_test, y_test)
264+
265+
custom_result = custom_classifier.predict(X_test)
266+
custom_predictions = custom_result["prediction"].squeeze().numpy()
267+
custom_accuracy = (custom_predictions == y_test).mean()
204268
print(f"✅ Custom training completed! Accuracy: {custom_accuracy:.3f}")
205269

206270
# Compare all training approaches

0 commit comments

Comments
 (0)