Skip to content

Commit e0f8f5e

Browse files
docs: include label attention and value encoder
1 parent 22dfcff commit e0f8f5e

6 files changed

Lines changed: 800 additions & 570 deletions

File tree

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ A unified, extensible framework for text classification with categorical variabl
77
## 🚀 Features
88

99
- **Complex input support**: Handle text data alongside categorical variables seamlessly.
10+
- **ValueEncoder**: Pass raw string categorical values and labels directly — no manual integer encoding required. Build a `ValueEncoder` from `DictEncoder` or sklearn `LabelEncoder` instances once, and the wrapper handles encoding at train time and label decoding after prediction automatically.
1011
- **Unified yet highly customizable**:
1112
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
1213
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
@@ -15,7 +16,9 @@ A unified, extensible framework for text classification with categorical variabl
1516
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
1617
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
1718
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
18-
- **Additional features**: explainability using Captum
19+
- **Explainability**:
20+
- **Captum integration**: gradient-based token attribution via integrated gradients (`explain_with_captum=True`).
21+
- **Label attention**: class-specific cross-attention that produces one sentence embedding per class, enabling token-level explanations for each label (`explain_with_label_attention=True`). Enable it by setting `n_heads_label_attention` in `ModelConfig`.
1922

2023

2124
## 📦 Installation
@@ -57,5 +60,3 @@ See the [examples/](examples/) directory for:
5760
## 📄 License
5861

5962
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
60-
61-

docs/source/architecture/overview.md

Lines changed: 111 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,80 @@ At its core, torchTextClassifiers processes data through a simple pipeline:
1111
```
1212

1313
**Data Flow:**
14-
1. **Text** is tokenized into numerical tokens
15-
2. **Tokens** are embedded into dense vectors (with optional attention)
16-
3. **Categorical variables** (optional) are embedded separately
17-
4. **All embeddings** are combined
18-
5. **Classification head** produces final predictions
14+
1. **ValueEncoder** (optional) converts raw string categorical values and labels into integers
15+
2. **Text** is tokenized into numerical tokens
16+
3. **Tokens** are embedded into dense vectors (with optional self-attention)
17+
— or into one embedding *per class* if **label attention** is enabled
18+
4. **Categorical variables** (optional) are embedded separately
19+
5. **All embeddings** are combined
20+
6. **Classification head** produces final predictions
21+
— if a `ValueEncoder` was provided, integer predictions are decoded back to original labels
22+
23+
## Component 0: ValueEncoder (optional preprocessing)
24+
25+
**Purpose:** Encode raw string (or mixed-type) categorical values and labels into
26+
integer indices, and decode predicted integers back to original label values after
27+
inference.
28+
29+
### When to Use
30+
31+
Use `ValueEncoder` whenever your categorical features or labels are stored as strings
32+
(e.g. `"Electronics"`, `"positive"`) rather than integers. Without it, you must
33+
integer-encode inputs manually before passing them to `train` / `predict`.
34+
35+
### Building a ValueEncoder
36+
37+
```python
38+
from sklearn.preprocessing import LabelEncoder
39+
from torchTextClassifiers.value_encoder import DictEncoder, ValueEncoder
40+
41+
# Option A: sklearn LabelEncoder (fit on train data)
42+
cat_encoder = LabelEncoder().fit(X_train_categories)
43+
44+
# Option B: explicit dict mapping
45+
cat_encoder = DictEncoder({"Electronics": 0, "Audio": 1, "Books": 2})
46+
47+
value_encoder = ValueEncoder(
48+
label_encoder=LabelEncoder().fit(y_train), # encodes/decodes labels
49+
categorical_encoders={
50+
"category": cat_encoder, # one entry per categorical column
51+
# "brand": brand_encoder, # add more as needed
52+
},
53+
)
54+
```
55+
56+
### What It Provides
57+
58+
```python
59+
value_encoder.vocabulary_sizes # [3, ...] – inferred from each encoder
60+
value_encoder.num_classes # 2 – inferred from label encoder
61+
```
62+
63+
These are read automatically by `torchTextClassifiers` when constructing the model,
64+
so you don't need to set `num_classes` or `categorical_vocabulary_sizes` in `ModelConfig`
65+
manually.
66+
67+
### Integration with the Wrapper
68+
69+
```python
70+
classifier = torchTextClassifiers(
71+
tokenizer=tokenizer,
72+
model_config=ModelConfig(embedding_dim=64), # num_classes inferred from encoder
73+
value_encoder=value_encoder,
74+
)
75+
76+
# Train with raw string inputs (default: raw_categorical_inputs=True, raw_labels=True)
77+
classifier.train(X_train, y_train, training_config)
78+
79+
# Predict — output labels are decoded back to original strings automatically
80+
result = classifier.predict(X_test)
81+
print(result["prediction"]) # ["positive", "negative", ...]
82+
```
83+
84+
The `ValueEncoder` is saved and reloaded with the model via `classifier.save()` /
85+
`torchTextClassifiers.load()`.
86+
87+
---
1988

2089
## Component 1: Tokenizer
2190

@@ -144,6 +213,39 @@ embedder = TextEmbedder(config)
144213
- `n_head`: Number of attention heads (typically 4, 8, or 16)
145214
- `n_layer`: Depth of transformer (start with 2-3)
146215

216+
### With Label Attention (Optional Explainability Layer)
217+
218+
Label attention replaces mean-pooling with a **cross-attention mechanism** where each
219+
class has a learnable embedding that attends over the token sequence:
220+
221+
```
222+
Token embeddings (batch, seq_len, d)
223+
↓ cross-attention (labels as queries, tokens as keys/values)
224+
Sentence embeddings (batch, num_classes, d) ← one per class
225+
226+
ClassificationHead (d → 1) ← shared, applied per class
227+
228+
Logits (batch, num_classes)
229+
```
230+
231+
Enable it by setting `n_heads_label_attention` in `ModelConfig`:
232+
233+
```python
234+
model_config = ModelConfig(
235+
embedding_dim=96,
236+
num_classes=6,
237+
n_heads_label_attention=4, # number of attention heads for label attention
238+
)
239+
```
240+
241+
**Benefits:**
242+
- Free explainability at inference time (`explain_with_label_attention=True` in `predict`)
243+
- The returned attention matrix `(batch, n_head, num_classes, seq_len)` shows which
244+
tokens each class focuses on
245+
- Can be combined with self-attention (`attention_config`)
246+
247+
**Constraint:** `embedding_dim` must be divisible by `n_heads_label_attention`.
248+
147249
## Component 3: Categorical Variable Handler
148250

149251
**Purpose:** Process categorical features (like user demographics, product categories) alongside text.
@@ -276,7 +378,7 @@ head = ClassificationHead(net=custom_head)
276378
## Complete Architecture
277379

278380
```{thumbnail} diagrams/NN.drawio.png
279-
:alt:
381+
:alt:
280382
```
281383

282384
### Full Model Assembly
@@ -592,9 +694,10 @@ categorical_embedding_dim = min(50, 10 // 2) = 5
592694

593695
torchTextClassifiers provides a **component-based pipeline** for text classification:
594696

697+
0. **ValueEncoder** (optional) → Encodes raw string inputs; decodes predictions back to original labels
595698
1. **Tokenizer** → Converts text to tokens
596-
2. **Text Embedder** → Creates semantic embeddings (with optional attention)
597-
3. **Categorical Handler** → Processes additional features (optional)
699+
2. **Text Embedder** → Creates semantic embeddings (with optional self-attention and/or label attention)
700+
3. **Categorical Handler** (optional) → Processes additional categorical features
598701
4. **Classification Head** → Produces predictions
599702

600703
**Key Benefits:**
@@ -610,5 +713,3 @@ torchTextClassifiers provides a **component-based pipeline** for text classifica
610713
- **Examples**: Explore complete examples in the repository
611714

612715
Ready to build your classifier? Start with {doc}`../getting_started/quickstart`!
613-
614-

0 commit comments

Comments
 (0)