Skip to content

Commit b2e797b

Browse files
doc: fix readme
1 parent 45ace28 commit b2e797b

1 file changed

Lines changed: 12 additions & 130 deletions

File tree

README.md

Lines changed: 12 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
# torchTextClassifiers
22

3-
A unified, extensible framework for text classification built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
3+
A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
44

55
## 🚀 Features
66

7-
- **Unified API**: Consistent interface for different classifier wrappers
8-
- **Extensible**: Easy to add new classifier implementations through wrapper pattern
9-
- **FastText Support**: Built-in FastText classifier with n-gram tokenization
10-
- **Flexible Preprocessing**: Each classifier can implement its own text preprocessing approach
7+
- **Mixed input support**: Handle text data alongside categorical variables seamlessly.
8+
- **Unified yet highly customizable**:
9+
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
10+
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
11+
- The `TextClassificationModel` class combines these components and can be extended for custom behavior.
1112
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
13+
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
14+
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
15+
- **Additional features**: explainability using Captum
1216

1317

1418
## 📦 Installation
@@ -25,140 +29,18 @@ uv sync
2529
pip install -e .
2630
```
2731

28-
## 🎯 Quick Start
29-
30-
### Basic FastText Classification
31-
32-
```python
33-
import numpy as np
34-
from torchTextClassifiers import create_fasttext
35-
36-
# Create a FastText classifier
37-
classifier = create_fasttext(
38-
embedding_dim=100,
39-
sparse=False,
40-
num_tokens=10000,
41-
min_count=2,
42-
min_n=3,
43-
max_n=6,
44-
len_word_ngrams=2,
45-
num_classes=2
46-
)
47-
48-
# Prepare your data
49-
X_train = np.array([
50-
"This is a positive example",
51-
"This is a negative example",
52-
"Another positive case",
53-
"Another negative case"
54-
])
55-
y_train = np.array([1, 0, 1, 0])
56-
57-
X_val = np.array([
58-
"Validation positive",
59-
"Validation negative"
60-
])
61-
y_val = np.array([1, 0])
62-
63-
# Build the model
64-
classifier.build(X_train, y_train)
65-
66-
# Train the model
67-
classifier.train(
68-
X_train, y_train, X_val, y_val,
69-
num_epochs=50,
70-
batch_size=32,
71-
patience_train=5,
72-
verbose=True
73-
)
74-
75-
# Make predictions
76-
X_test = np.array(["This is a test sentence"])
77-
predictions = classifier.predict(X_test)
78-
print(f"Predictions: {predictions}")
79-
80-
# Validate on test set
81-
accuracy = classifier.validate(X_test, np.array([1]))
82-
print(f"Accuracy: {accuracy:.3f}")
83-
```
84-
85-
### Custom Classifier Implementation
86-
87-
```python
88-
import numpy as np
89-
from torchTextClassifiers import torchTextClassifiers
90-
from torchTextClassifiers.classifiers.simple_text_classifier import SimpleTextWrapper, SimpleTextConfig
91-
92-
# Example: TF-IDF based classifier (alternative to tokenization)
93-
config = SimpleTextConfig(
94-
hidden_dim=128,
95-
num_classes=2,
96-
max_features=5000,
97-
learning_rate=1e-3,
98-
dropout_rate=0.2
99-
)
100-
101-
# Create classifier with TF-IDF preprocessing
102-
wrapper = SimpleTextWrapper(config)
103-
classifier = torchTextClassifiers(wrapper)
104-
105-
# Text data
106-
X_train = np.array(["Great product!", "Terrible service", "Love it!"])
107-
y_train = np.array([1, 0, 1])
108-
109-
# Build and train
110-
classifier.build(X_train, y_train)
111-
# ... continue with training
112-
```
113-
114-
115-
### Training Customization
116-
117-
```python
118-
# Custom PyTorch Lightning trainer parameters
119-
trainer_params = {
120-
'accelerator': 'gpu',
121-
'devices': 1,
122-
'precision': 16, # Mixed precision training
123-
'gradient_clip_val': 1.0,
124-
}
125-
126-
classifier.train(
127-
X_train, y_train, X_val, y_val,
128-
num_epochs=100,
129-
batch_size=64,
130-
patience_train=10,
131-
trainer_params=trainer_params,
132-
verbose=True
133-
)
134-
```
135-
136-
## 🔬 Testing
137-
138-
Run the test suite:
139-
140-
```bash
141-
# Run all tests
142-
uv run pytest
143-
144-
# Run with coverage
145-
uv run pytest --cov=torchTextClassifiers
146-
147-
# Run specific test file
148-
uv run pytest tests/test_torchTextClassifiers.py -v
149-
```
32+
## 📝 Usage
15033

34+
Checkout the [notebook](notebooks/example.ipynb) for a quick start.
15135

15236
## 📚 Examples
15337

15438
See the [examples/](examples/) directory for:
15539
- Basic text classification
15640
- Multi-class classification
15741
- Mixed features (text + categorical)
158-
- Custom classifier implementation
15942
- Advanced training configurations
160-
161-
43+
- Prediction and explainability
16244

16345
## 📄 License
16446

0 commit comments

Comments
 (0)