|
6 | 6 | """ |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | | -from torchTextClassifiers import create_fasttext |
| 9 | +from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers |
| 10 | +from torchTextClassifiers.tokenizers import WordPieceTokenizer |
| 11 | + |
10 | 12 |
|
11 | 13 | def main(): |
12 | 14 | print("🚀 Basic Text Classification Example") |
@@ -48,43 +50,57 @@ def main(): |
48 | 50 | print(f"Validation samples: {len(X_val)}") |
49 | 51 | print(f"Test samples: {len(X_test)}") |
50 | 52 |
|
51 | | - # Create FastText classifier |
52 | | - print("\n🏗️ Creating FastText classifier...") |
53 | | - classifier = create_fasttext( |
| 53 | + # Create and train tokenizer |
| 54 | + print("\n🏗️ Creating and training WordPiece tokenizer...") |
| 55 | + tokenizer = WordPieceTokenizer(vocab_size=5000, output_dim=128) |
| 56 | + |
| 57 | + # Train tokenizer on the training corpus |
| 58 | + training_corpus = X_train.tolist() |
| 59 | + tokenizer.train(training_corpus) |
| 60 | + print("✅ Tokenizer trained successfully!") |
| 61 | + |
| 62 | + # Create model configuration |
| 63 | + print("\n🔧 Creating model configuration...") |
| 64 | + model_config = ModelConfig( |
54 | 65 | embedding_dim=50, |
55 | | - sparse=False, |
56 | | - num_tokens=5000, |
57 | | - min_count=1, |
58 | | - min_n=3, |
59 | | - max_n=6, |
60 | | - len_word_ngrams=2, |
61 | 66 | num_classes=2 |
62 | 67 | ) |
63 | | - |
64 | | - # Build the model |
65 | | - print("\n🔨 Building model...") |
66 | | - classifier.build(X_train, y_train) |
67 | | - print("✅ Model built successfully!") |
| 68 | + |
| 69 | + # Create classifier |
| 70 | + print("\n🔨 Creating classifier...") |
| 71 | + classifier = torchTextClassifiers( |
| 72 | + tokenizer=tokenizer, |
| 73 | + model_config=model_config |
| 74 | + ) |
| 75 | + print("✅ Classifier created successfully!") |
68 | 76 |
|
69 | 77 | # Train the model |
70 | 78 | print("\n🎯 Training model...") |
71 | | - classifier.train( |
72 | | - X_train, y_train, X_val, y_val, |
| 79 | + training_config = TrainingConfig( |
73 | 80 | num_epochs=20, |
74 | 81 | batch_size=4, |
75 | | - patience_train=5, |
| 82 | + lr=1e-3, |
| 83 | + patience_early_stopping=5, |
| 84 | + num_workers=0 # Use 0 for simple examples to avoid multiprocessing issues |
| 85 | + ) |
| 86 | + classifier.train( |
| 87 | + X_train, y_train, X_val, y_val, |
| 88 | + training_config=training_config, |
76 | 89 | verbose=True |
77 | 90 | ) |
78 | 91 | print("✅ Training completed!") |
79 | 92 |
|
80 | 93 | # Make predictions |
81 | 94 | print("\n🔮 Making predictions...") |
82 | | - predictions = classifier.predict(X_test) |
| 95 | + result = classifier.predict(X_test) |
| 96 | + predictions = result["prediction"].squeeze().numpy() # Extract predictions from dictionary |
| 97 | + confidence = result["confidence"].squeeze().numpy() # Extract confidence scores |
83 | 98 | print(f"Predictions: {predictions}") |
| 99 | + print(f"Confidence: {confidence}") |
84 | 100 | print(f"True labels: {y_test}") |
85 | | - |
| 101 | + |
86 | 102 | # Calculate accuracy |
87 | | - accuracy = classifier.validate(X_test, y_test) |
| 103 | + accuracy = (predictions == y_test).mean() |
88 | 104 | print(f"Test accuracy: {accuracy:.3f}") |
89 | 105 |
|
90 | 106 | # Show detailed results |
|
0 commit comments