@@ -8,22 +8,27 @@ Modular torch.nn.Module components for building custom architectures.
88Text Embedding
99--------------
1010
11- TextEmbedder
12- ~~~~~~~~~~~~
11+ Text embedding is split into two composable stages:
1312
14- Embeds text tokens with optional self-attention.
13+ 1. **TokenEmbedder ** — maps each token to a dense vector (with optional self-attention). Output: ``(batch, seq_len, embedding_dim) ``.
14+ 2. **SentenceEmbedder ** — aggregates token vectors into a sentence embedding. Output: ``(batch, embedding_dim) `` or ``(batch, num_classes, embedding_dim) `` with label attention.
1515
16- .. autoclass :: torchTextClassifiers.model.components.text_embedder.TextEmbedder
16+ TokenEmbedder
17+ ~~~~~~~~~~~~~
18+
19+ Embeds tokenized text with optional self-attention.
20+
21+ .. autoclass :: torchTextClassifiers.model.components.text_embedder.TokenEmbedder
1722 :members:
1823 :undoc-members:
1924 :show-inheritance:
2025
21- TextEmbedderConfig
22- ~~~~~~~~~~~~~~~~~~
26+ TokenEmbedderConfig
27+ ~~~~~~~~~~~~~~~~~~~
2328
24- Configuration for TextEmbedder .
29+ Configuration for TokenEmbedder .
2530
26- .. autoclass :: torchTextClassifiers.model.components.text_embedder.TextEmbedderConfig
31+ .. autoclass :: torchTextClassifiers.model.components.text_embedder.TokenEmbedderConfig
2732 :members:
2833 :undoc-members:
2934 :show-inheritance:
@@ -32,31 +37,90 @@ Example:
3237
3338.. code-block :: python
3439
35- from torchTextClassifiers.model.components import TextEmbedder, TextEmbedderConfig
40+ from torchTextClassifiers.model.components import (
41+ TokenEmbedder, TokenEmbedderConfig, AttentionConfig,
42+ )
3643
37- # Simple text embedder
38- config = TextEmbedderConfig (
44+ # Simple token embedder (no self-attention)
45+ config = TokenEmbedderConfig (
3946 vocab_size = 5000 ,
4047 embedding_dim = 128 ,
41- attention_config = None
48+ padding_idx = 0 ,
4249 )
43- embedder = TextEmbedder(config)
50+ token_embedder = TokenEmbedder(config)
51+ out = token_embedder(input_ids, attention_mask)
52+ # out["token_embeddings"]: (batch, seq_len, 128)
4453
4554 # With self-attention
46- from torchTextClassifiers.model.components import AttentionConfig
47-
4855 attention_config = AttentionConfig(
49- n_embd = 128 ,
56+ n_layers = 2 ,
5057 n_head = 4 ,
51- n_layer = 2 ,
52- dropout = 0.1
58+ n_kv_head = 4 ,
59+ positional_encoding = False ,
5360 )
54- config = TextEmbedderConfig (
61+ config = TokenEmbedderConfig (
5562 vocab_size = 5000 ,
5663 embedding_dim = 128 ,
57- attention_config = attention_config
64+ padding_idx = 0 ,
65+ attention_config = attention_config,
66+ )
67+ token_embedder = TokenEmbedder(config)
68+
69+ SentenceEmbedder
70+ ~~~~~~~~~~~~~~~~
71+
72+ Aggregates per-token embeddings into a single sentence embedding.
73+
74+ .. autoclass :: torchTextClassifiers.model.components.text_embedder.SentenceEmbedder
75+ :members:
76+ :undoc-members:
77+ :show-inheritance:
78+
79+ SentenceEmbedderConfig
80+ ~~~~~~~~~~~~~~~~~~~~~~
81+
82+ Configuration for SentenceEmbedder.
83+
84+ .. autoclass :: torchTextClassifiers.model.components.text_embedder.SentenceEmbedderConfig
85+ :members:
86+ :undoc-members:
87+ :show-inheritance:
88+
89+ LabelAttentionConfig
90+ ~~~~~~~~~~~~~~~~~~~~
91+
92+ Configuration for the label-attention aggregation mode.
93+
94+ .. autoclass :: torchTextClassifiers.model.components.text_embedder.LabelAttentionConfig
95+ :members:
96+ :undoc-members:
97+ :show-inheritance:
98+
99+ Example:
100+
101+ .. code-block :: python
102+
103+ from torchTextClassifiers.model.components import (
104+ SentenceEmbedder, SentenceEmbedderConfig,
105+ LabelAttentionConfig,
58106 )
59- embedder = TextEmbedder(config)
107+
108+ # Mean-pooling (default)
109+ sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method = " mean" ))
110+ out = sentence_embedder(token_embeddings, attention_mask)
111+ # out["sentence_embedding"]: (batch, 128)
112+
113+ # Label attention — one embedding per class
114+ sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(
115+ aggregation_method = None ,
116+ label_attention_config = LabelAttentionConfig(
117+ n_head = 4 ,
118+ num_classes = 6 ,
119+ embedding_dim = 128 ,
120+ ),
121+ ))
122+ out = sentence_embedder(token_embeddings, attention_mask)
123+ # out["sentence_embedding"]: (batch, num_classes, 128)
60124
61125 Categorical Features
62126--------------------
@@ -246,22 +310,31 @@ Components can be composed to create custom architectures:
246310
247311.. code-block :: python
248312
313+ import torch
249314 import torch.nn as nn
250315 from torchTextClassifiers.model.components import (
251- TextEmbedder, CategoricalVariableNet, ClassificationHead
316+ TokenEmbedder, TokenEmbedderConfig,
317+ SentenceEmbedder, SentenceEmbedderConfig,
318+ CategoricalVariableNet, ClassificationHead,
252319 )
253320
254321 class CustomModel (nn .Module ):
255322 def __init__ (self ):
256323 super ().__init__ ()
257- self .text_embedder = TextEmbedder(text_config)
324+ self .token_embedder = TokenEmbedder(TokenEmbedderConfig(
325+ vocab_size = 5000 , embedding_dim = 128 , padding_idx = 0 ,
326+ ))
327+ self .sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig())
258328 self .cat_net = CategoricalVariableNet(... )
259329 self .head = ClassificationHead(... )
260330
261- def forward (self , input_ids , categorical_data ):
262- text_features = self .text_embedder(input_ids)
331+ def forward (self , input_ids , attention_mask , categorical_data ):
332+ token_out = self .token_embedder(input_ids, attention_mask)
333+ sent_out = self .sentence_embedder(
334+ token_out[" token_embeddings" ], token_out[" attention_mask" ]
335+ )
263336 cat_features = self .cat_net(categorical_data)
264- combined = torch.cat([text_features , cat_features], dim = 1 )
337+ combined = torch.cat([sent_out[ " sentence_embedding " ] , cat_features], dim = 1 )
265338 return self .head(combined)
266339
267340 See Also
@@ -270,4 +343,3 @@ See Also
270343* :doc: `model ` - How components are used in models
271344* :doc: `../architecture/overview ` - Architecture explanation
272345* :doc: `configs ` - ModelConfig for component configuration
273-
0 commit comments