Skip to content

Commit 89cc8fe

Browse files
micedremeilame-tayebjee
authored andcommitted
examples : fix basic_classification after refactor
1 parent ea26799 commit 89cc8fe

1 file changed

Lines changed: 37 additions & 21 deletions

File tree

examples/basic_classification.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
"""
77

88
import numpy as np
9-
from torchTextClassifiers import create_fasttext
9+
from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
10+
from torchTextClassifiers.tokenizers import WordPieceTokenizer
11+
1012

1113
def main():
1214
print("🚀 Basic Text Classification Example")
@@ -48,43 +50,57 @@ def main():
4850
print(f"Validation samples: {len(X_val)}")
4951
print(f"Test samples: {len(X_test)}")
5052

51-
# Create FastText classifier
52-
print("\n🏗️ Creating FastText classifier...")
53-
classifier = create_fasttext(
53+
# Create and train tokenizer
54+
print("\n🏗️ Creating and training WordPiece tokenizer...")
55+
tokenizer = WordPieceTokenizer(vocab_size=5000, output_dim=128)
56+
57+
# Train tokenizer on the training corpus
58+
training_corpus = X_train.tolist()
59+
tokenizer.train(training_corpus)
60+
print("✅ Tokenizer trained successfully!")
61+
62+
# Create model configuration
63+
print("\n🔧 Creating model configuration...")
64+
model_config = ModelConfig(
5465
embedding_dim=50,
55-
sparse=False,
56-
num_tokens=5000,
57-
min_count=1,
58-
min_n=3,
59-
max_n=6,
60-
len_word_ngrams=2,
6166
num_classes=2
6267
)
63-
64-
# Build the model
65-
print("\n🔨 Building model...")
66-
classifier.build(X_train, y_train)
67-
print("✅ Model built successfully!")
68+
69+
# Create classifier
70+
print("\n🔨 Creating classifier...")
71+
classifier = torchTextClassifiers(
72+
tokenizer=tokenizer,
73+
model_config=model_config
74+
)
75+
print("✅ Classifier created successfully!")
6876

6977
# Train the model
7078
print("\n🎯 Training model...")
71-
classifier.train(
72-
X_train, y_train, X_val, y_val,
79+
training_config = TrainingConfig(
7380
num_epochs=20,
7481
batch_size=4,
75-
patience_train=5,
82+
lr=1e-3,
83+
patience_early_stopping=5,
84+
num_workers=0 # Use 0 for simple examples to avoid multiprocessing issues
85+
)
86+
classifier.train(
87+
X_train, y_train, X_val, y_val,
88+
training_config=training_config,
7689
verbose=True
7790
)
7891
print("✅ Training completed!")
7992

8093
# Make predictions
8194
print("\n🔮 Making predictions...")
82-
predictions = classifier.predict(X_test)
95+
result = classifier.predict(X_test)
96+
predictions = result["prediction"].squeeze().numpy() # Extract predictions from dictionary
97+
confidence = result["confidence"].squeeze().numpy() # Extract confidence scores
8398
print(f"Predictions: {predictions}")
99+
print(f"Confidence: {confidence}")
84100
print(f"True labels: {y_test}")
85-
101+
86102
# Calculate accuracy
87-
accuracy = classifier.validate(X_test, y_test)
103+
accuracy = (predictions == y_test).mean()
88104
print(f"Test accuracy: {accuracy:.3f}")
89105

90106
# Show detailed results

0 commit comments

Comments
 (0)