Skip to content

Commit 15fe25a

Browse files
Merge branch 'main' into renovate/actions-upload-pages-artifact-5.x
2 parents 3842411 + f90d687 commit 15fe25a

16 files changed

Lines changed: 723 additions & 485 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ A unified, extensible framework for text classification with categorical variabl
1010
- **ValueEncoder**: Pass raw string categorical values and labels directly — no manual integer encoding required. Build a `ValueEncoder` from `DictEncoder` or sklearn `LabelEncoder` instances once, and the wrapper handles encoding at train time and label decoding after prediction automatically.
1111
- **Unified yet highly customizable**:
1212
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
13-
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
14-
- The `TextClassificationModel` class combines these components and can be extended for custom behavior.
13+
- Text embedding is split into two composable stages: **`TokenEmbedder`** (token → per-token vectors, with optional self-attention) and **`SentenceEmbedder`** (aggregation: mean / first / last / label attention). Combine them with `CategoricalVariableNet` and `ClassificationHead` — all are `torch.nn.Module`.
14+
- The `TextClassificationModel` class assembles these components and can be extended for custom behavior.
1515
- **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks.
1616
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
1717
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:

docs/source/api/components.rst

Lines changed: 99 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,27 @@ Modular torch.nn.Module components for building custom architectures.
88
Text Embedding
99
--------------
1010

11-
TextEmbedder
12-
~~~~~~~~~~~~
11+
Text embedding is split into two composable stages:
1312

14-
Embeds text tokens with optional self-attention.
13+
1. **TokenEmbedder** — maps each token to a dense vector (with optional self-attention). Output: ``(batch, seq_len, embedding_dim)``.
14+
2. **SentenceEmbedder** — aggregates token vectors into a sentence embedding. Output: ``(batch, embedding_dim)`` or ``(batch, num_classes, embedding_dim)`` with label attention.
1515

16-
.. autoclass:: torchTextClassifiers.model.components.text_embedder.TextEmbedder
16+
TokenEmbedder
17+
~~~~~~~~~~~~~
18+
19+
Embeds tokenized text with optional self-attention.
20+
21+
.. autoclass:: torchTextClassifiers.model.components.text_embedder.TokenEmbedder
1722
:members:
1823
:undoc-members:
1924
:show-inheritance:
2025

21-
TextEmbedderConfig
22-
~~~~~~~~~~~~~~~~~~
26+
TokenEmbedderConfig
27+
~~~~~~~~~~~~~~~~~~~
2328

24-
Configuration for TextEmbedder.
29+
Configuration for TokenEmbedder.
2530

26-
.. autoclass:: torchTextClassifiers.model.components.text_embedder.TextEmbedderConfig
31+
.. autoclass:: torchTextClassifiers.model.components.text_embedder.TokenEmbedderConfig
2732
:members:
2833
:undoc-members:
2934
:show-inheritance:
@@ -32,31 +37,90 @@ Example:
3237

3338
.. code-block:: python
3439
35-
from torchTextClassifiers.model.components import TextEmbedder, TextEmbedderConfig
40+
from torchTextClassifiers.model.components import (
41+
TokenEmbedder, TokenEmbedderConfig, AttentionConfig,
42+
)
3643
37-
# Simple text embedder
38-
config = TextEmbedderConfig(
44+
# Simple token embedder (no self-attention)
45+
config = TokenEmbedderConfig(
3946
vocab_size=5000,
4047
embedding_dim=128,
41-
attention_config=None
48+
padding_idx=0,
4249
)
43-
embedder = TextEmbedder(config)
50+
token_embedder = TokenEmbedder(config)
51+
out = token_embedder(input_ids, attention_mask)
52+
# out["token_embeddings"]: (batch, seq_len, 128)
4453
4554
# With self-attention
46-
from torchTextClassifiers.model.components import AttentionConfig
47-
4855
attention_config = AttentionConfig(
49-
n_embd=128,
56+
n_layers=2,
5057
n_head=4,
51-
n_layer=2,
52-
dropout=0.1
58+
n_kv_head=4,
59+
positional_encoding=False,
5360
)
54-
config = TextEmbedderConfig(
61+
config = TokenEmbedderConfig(
5562
vocab_size=5000,
5663
embedding_dim=128,
57-
attention_config=attention_config
64+
padding_idx=0,
65+
attention_config=attention_config,
66+
)
67+
token_embedder = TokenEmbedder(config)
68+
69+
SentenceEmbedder
70+
~~~~~~~~~~~~~~~~
71+
72+
Aggregates per-token embeddings into a single sentence embedding.
73+
74+
.. autoclass:: torchTextClassifiers.model.components.text_embedder.SentenceEmbedder
75+
:members:
76+
:undoc-members:
77+
:show-inheritance:
78+
79+
SentenceEmbedderConfig
80+
~~~~~~~~~~~~~~~~~~~~~~
81+
82+
Configuration for SentenceEmbedder.
83+
84+
.. autoclass:: torchTextClassifiers.model.components.text_embedder.SentenceEmbedderConfig
85+
:members:
86+
:undoc-members:
87+
:show-inheritance:
88+
89+
LabelAttentionConfig
90+
~~~~~~~~~~~~~~~~~~~~
91+
92+
Configuration for the label-attention aggregation mode.
93+
94+
.. autoclass:: torchTextClassifiers.model.components.text_embedder.LabelAttentionConfig
95+
:members:
96+
:undoc-members:
97+
:show-inheritance:
98+
99+
Example:
100+
101+
.. code-block:: python
102+
103+
from torchTextClassifiers.model.components import (
104+
SentenceEmbedder, SentenceEmbedderConfig,
105+
LabelAttentionConfig,
58106
)
59-
embedder = TextEmbedder(config)
107+
108+
# Mean-pooling (default)
109+
sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean"))
110+
out = sentence_embedder(token_embeddings, attention_mask)
111+
# out["sentence_embedding"]: (batch, 128)
112+
113+
# Label attention — one embedding per class
114+
sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(
115+
aggregation_method=None,
116+
label_attention_config=LabelAttentionConfig(
117+
n_head=4,
118+
num_classes=6,
119+
embedding_dim=128,
120+
),
121+
))
122+
out = sentence_embedder(token_embeddings, attention_mask)
123+
# out["sentence_embedding"]: (batch, num_classes, 128)
60124
61125
Categorical Features
62126
--------------------
@@ -246,22 +310,31 @@ Components can be composed to create custom architectures:
246310

247311
.. code-block:: python
248312
313+
import torch
249314
import torch.nn as nn
250315
from torchTextClassifiers.model.components import (
251-
TextEmbedder, CategoricalVariableNet, ClassificationHead
316+
TokenEmbedder, TokenEmbedderConfig,
317+
SentenceEmbedder, SentenceEmbedderConfig,
318+
CategoricalVariableNet, ClassificationHead,
252319
)
253320
254321
class CustomModel(nn.Module):
255322
def __init__(self):
256323
super().__init__()
257-
self.text_embedder = TextEmbedder(text_config)
324+
self.token_embedder = TokenEmbedder(TokenEmbedderConfig(
325+
vocab_size=5000, embedding_dim=128, padding_idx=0,
326+
))
327+
self.sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig())
258328
self.cat_net = CategoricalVariableNet(...)
259329
self.head = ClassificationHead(...)
260330
261-
def forward(self, input_ids, categorical_data):
262-
text_features = self.text_embedder(input_ids)
331+
def forward(self, input_ids, attention_mask, categorical_data):
332+
token_out = self.token_embedder(input_ids, attention_mask)
333+
sent_out = self.sentence_embedder(
334+
token_out["token_embeddings"], token_out["attention_mask"]
335+
)
263336
cat_features = self.cat_net(categorical_data)
264-
combined = torch.cat([text_features, cat_features], dim=1)
337+
combined = torch.cat([sent_out["sentence_embedding"], cat_features], dim=1)
265338
return self.head(combined)
266339
267340
See Also
@@ -270,4 +343,3 @@ See Also
270343
* :doc:`model` - How components are used in models
271344
* :doc:`../architecture/overview` - Architecture explanation
272345
* :doc:`configs` - ModelConfig for component configuration
273-

docs/source/api/index.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The API is organized into several modules:
1111
* :doc:`wrapper` - High-level torchTextClassifiers wrapper class
1212
* :doc:`configs` - Configuration classes (ModelConfig, TrainingConfig)
1313
* :doc:`tokenizers` - Text tokenization (NGram, WordPiece, HuggingFace)
14-
* :doc:`components` - Model components (TextEmbedder, CategoricalVariableNet, etc.)
14+
* :doc:`components` - Model components (TokenEmbedder, SentenceEmbedder, CategoricalVariableNet, etc.)
1515
* :doc:`model` - Core PyTorch models
1616
* :doc:`dataset` - Dataset classes for data loading
1717

@@ -30,7 +30,8 @@ Most Used Classes
3030
Architecture Components
3131
~~~~~~~~~~~~~~~~~~~~~~~
3232

33-
* :class:`torchTextClassifiers.model.components.TextEmbedder` - Text embedding layer
33+
* :class:`torchTextClassifiers.model.components.text_embedder.TokenEmbedder` - Token embedding layer
34+
* :class:`torchTextClassifiers.model.components.text_embedder.SentenceEmbedder` - Sentence aggregation layer
3435
* :class:`torchTextClassifiers.model.components.CategoricalVariableNet` - Categorical features
3536
* :class:`torchTextClassifiers.model.components.ClassificationHead` - Classification layer
3637
* :class:`torchTextClassifiers.model.components.Attention.AttentionConfig` - Attention configuration

docs/source/api/model.rst

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,49 +20,53 @@ Core PyTorch nn.Module combining all components.
2020

2121
**Architecture:**
2222

23-
The model combines three main components:
23+
The model combines four main components:
2424

25-
1. **TextEmbedder**: Converts tokens to embeddings
26-
2. **CategoricalVariableNet** (optional): Handles categorical features
27-
3. **ClassificationHead**: Produces class logits
25+
1. **TokenEmbedder**: Maps each token to a dense vector (with optional self-attention)
26+
2. **SentenceEmbedder**: Aggregates token vectors into a sentence representation
27+
3. **CategoricalVariableNet** (optional): Handles categorical features
28+
4. **ClassificationHead**: Produces class logits
2829

2930
Example:
3031

3132
.. code-block:: python
3233
3334
from torchTextClassifiers.model import TextClassificationModel
3435
from torchTextClassifiers.model.components import (
35-
TextEmbedder, TextEmbedderConfig,
36-
CategoricalVariableNet, CategoricalForwardType,
37-
ClassificationHead
36+
TokenEmbedder, TokenEmbedderConfig,
37+
SentenceEmbedder, SentenceEmbedderConfig,
38+
CategoricalVariableNet,
39+
ClassificationHead,
3840
)
3941
4042
# Create components
41-
text_embedder = TextEmbedder(TextEmbedderConfig(
43+
token_embedder = TokenEmbedder(TokenEmbedderConfig(
4244
vocab_size=5000,
43-
embedding_dim=128
45+
embedding_dim=128,
46+
padding_idx=0,
4447
))
48+
sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean"))
4549
4650
cat_net = CategoricalVariableNet(
47-
vocabulary_sizes=[10, 20],
48-
embedding_dims=[8, 16],
49-
forward_type=CategoricalForwardType.AVERAGE_AND_CONCAT
51+
categorical_vocabulary_sizes=[10, 20],
52+
categorical_embedding_dims=[8, 16],
5053
)
5154
5255
classification_head = ClassificationHead(
5356
input_dim=128 + 24, # text_dim + cat_dim
54-
num_classes=5
57+
num_classes=5,
5558
)
5659
5760
# Combine into model
5861
model = TextClassificationModel(
59-
text_embedder=text_embedder,
62+
token_embedder=token_embedder,
63+
sentence_embedder=sentence_embedder,
6064
categorical_variable_net=cat_net,
61-
classification_head=classification_head
65+
classification_head=classification_head,
6266
)
6367
6468
# Forward pass
65-
logits = model(input_ids, categorical_data)
69+
logits = model(input_ids, attention_mask, categorical_data)
6670
6771
PyTorch Lightning Module
6872
-------------------------

0 commit comments

Comments
 (0)