Skip to content

Commit 934b041

Browse files
feat!(attention): enable attention logic
from nanochat attention_config to be passed to model_config init weights etc also update _get_sentence_embedding
1 parent a5b3e4d commit 934b041

5 files changed

Lines changed: 288 additions & 8 deletions

File tree

torchTextClassifiers/model/components/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from .attention import (
2+
AttentionConfig as AttentionConfig,
3+
)
14
from .categorical_var_net import (
25
CategoricalForwardType as CategoricalForwardType,
36
)
@@ -6,3 +9,4 @@
69
)
710
from .classification_head import ClassificationHead as ClassificationHead
811
from .text_embedder import TextEmbedder as TextEmbedder
12+
from .text_embedder import TextEmbedderConfig as TextEmbedderConfig
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Largely inspired from Andrej Karpathy's nanochat, see here https://github.com/karpathy/nanochat/blob/master/nanochat/gpt.py"""
2+
3+
from dataclasses import dataclass
4+
from typing import Optional
5+
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
10+
### Some utils used in text_embedder.py for the attention blocks ###
11+
12+
13+
def apply_rotary_emb(x, cos, sin):
14+
assert x.ndim == 4 # multihead attention
15+
16+
d = x.shape[3] // 2
17+
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
18+
y1 = x1 * cos + x2 * sin # rotate pairs of dims
19+
y2 = x1 * (-sin) + x2 * cos
20+
out = torch.cat([y1, y2], 3) # re-assemble
21+
out = out.to(x.dtype) # ensure input/output dtypes match
22+
return out
23+
24+
25+
def norm(x):
26+
# Purely functional rmsnorm with no learnable params
27+
return F.rms_norm(x, (x.size(-1),))
28+
29+
30+
#### Config #####
31+
@dataclass
32+
class AttentionConfig:
33+
n_layers: int
34+
n_head: int
35+
n_kv_head: int
36+
sequence_len: Optional[int] = None
37+
positional_encoding: bool = True
38+
aggregation_method: str = "mean" # or 'last', or 'first'
39+
40+
41+
#### Attention Block #####
42+
43+
# Composed of SelfAttentionLayer and MLP with residual connections
44+
45+
46+
class Block(nn.Module):
47+
def __init__(self, config: AttentionConfig, layer_idx: int):
48+
super().__init__()
49+
50+
self.layer_idx = layer_idx
51+
self.attn = SelfAttentionLayer(config, layer_idx)
52+
self.mlp = MLP(config)
53+
54+
def forward(self, x, cos_sin):
55+
x = x + self.attn(norm(x), cos_sin)
56+
x = x + self.mlp(norm(x))
57+
return x
58+
59+
60+
##### Components of the Block #####
61+
62+
63+
class SelfAttentionLayer(nn.Module):
64+
def __init__(self, config: AttentionConfig, layer_idx):
65+
super().__init__()
66+
self.layer_idx = layer_idx
67+
self.n_head = config.n_head
68+
self.n_kv_head = config.n_kv_head
69+
self.enable_gqa = (
70+
self.n_head != self.n_kv_head
71+
) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
72+
self.n_embd = config.n_embd
73+
self.head_dim = self.n_embd // self.n_head
74+
assert self.n_embd % self.n_head == 0
75+
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
76+
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
77+
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
78+
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
79+
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
80+
81+
self.apply_positional_encoding = config.positional_encoding
82+
83+
def forward(self, x, cos_sin=None):
84+
B, T, C = x.size()
85+
86+
# Project the input to get queries, keys, and values
87+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
88+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
89+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
90+
91+
if self.apply_positional_encoding:
92+
assert cos_sin is not None, "Rotary embeddings require precomputed cos/sin tensors"
93+
cos, sin = cos_sin
94+
q, k = (
95+
apply_rotary_emb(q, cos, sin),
96+
apply_rotary_emb(k, cos, sin),
97+
) # QK rotary embedding
98+
99+
q, k = norm(q), norm(k) # QK norm
100+
q, k, v = (
101+
q.transpose(1, 2),
102+
k.transpose(1, 2),
103+
v.transpose(1, 2),
104+
) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
105+
106+
# is_causal=False for non-autoregressive models (BERT-like)
107+
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=self.enable_gqa)
108+
109+
# Re-assemble the heads side by side and project back to residual stream
110+
y = y.transpose(1, 2).contiguous().view(B, T, -1)
111+
y = self.c_proj(y)
112+
113+
return y
114+
115+
116+
class MLP(nn.Module):
117+
def __init__(self, config):
118+
super().__init__()
119+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
120+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
121+
122+
def forward(self, x):
123+
x = self.c_fc(x)
124+
x = F.relu(x).square()
125+
x = self.c_proj(x)
126+
return x

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 146 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,107 @@
1+
import math
2+
from dataclasses import dataclass
3+
from typing import Optional
4+
15
import torch
26
from torch import nn
37

8+
from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm
9+
10+
11+
@dataclass
12+
class TextEmbedderConfig:
13+
vocab_size: int
14+
embedding_dim: int
15+
padding_idx: int
16+
attention_config: Optional[AttentionConfig] = None
17+
418

519
class TextEmbedder(nn.Module):
6-
def __init__(self, vocab_size: int, embedding_dim: int, padding_idx: int):
20+
def __init__(self, text_embedder_config: TextEmbedderConfig):
721
super().__init__()
822

9-
self.vocab_size = vocab_size
10-
self.embedding_dim = embedding_dim
11-
self.padding_idx = padding_idx
23+
self.config = text_embedder_config
24+
25+
self.attention_config = text_embedder_config.attention_config
26+
if self.attention_config is not None:
27+
self.attention_config.n_embd = text_embedder_config.embedding_dim
28+
29+
self.vocab_size = text_embedder_config.vocab_size
30+
self.embedding_dim = text_embedder_config.embedding_dim
31+
self.padding_idx = text_embedder_config.padding_idx
1232

1333
self.embedding_layer = nn.Embedding(
14-
embedding_dim=embedding_dim,
15-
num_embeddings=vocab_size,
34+
embedding_dim=self.embedding_dim,
35+
num_embeddings=self.vocab_size,
1636
padding_idx=self.padding_idx,
1737
)
1838

39+
if self.attention_config is not None:
40+
self.transformer = nn.ModuleDict(
41+
{
42+
"h": nn.ModuleList(
43+
[
44+
Block(self.attention_config, layer_idx)
45+
for layer_idx in range(self.attention_config.n_layers)
46+
]
47+
),
48+
}
49+
)
50+
51+
head_dim = self.attention_config.n_embd // self.attention_config.n_head
52+
53+
if head_dim * self.attention_config.n_head != self.attention_config.n_embd:
54+
raise ValueError("embedding_dim must be divisible by n_head.")
55+
56+
if self.attention_config.positional_encoding:
57+
if head_dim % 2 != 0:
58+
raise ValueError(
59+
"embedding_dim / n_head must be even for rotary positional embeddings."
60+
)
61+
62+
if self.attention_config.sequence_len is None:
63+
raise ValueError(
64+
"sequence_len must be specified in AttentionConfig when positional_encoding is True."
65+
)
66+
67+
self.rotary_seq_len = self.attention_config.sequence_len * 10
68+
cos, sin = self._precompute_rotary_embeddings(
69+
seq_len=self.rotary_seq_len, head_dim=head_dim
70+
)
71+
72+
self.register_buffer(
73+
"cos", cos, persistent=False
74+
) # persistent=False means it's not saved to the checkpoint
75+
self.register_buffer("sin", sin, persistent=False)
76+
77+
def init_weights(self):
78+
self.apply(self._init_weights)
79+
80+
# zero out c_proj weights in all blocks
81+
if self.attention_config is not None:
82+
for block in self.transformer.h:
83+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
84+
torch.nn.init.zeros_(block.attn.c_proj.weight)
85+
# init the rotary embeddings
86+
head_dim = self.attention_config.n_embd // self.attention_config.n_head
87+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
88+
self.cos, self.sin = cos, sin
89+
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
90+
if self.embedding_layer.weight.device.type == "cuda":
91+
self.embedding_layer.to(dtype=torch.bfloat16)
92+
93+
def _init_weights(self, module):
94+
if isinstance(module, nn.Linear):
95+
# https://arxiv.org/pdf/2310.17813
96+
fan_out = module.weight.size(0)
97+
fan_in = module.weight.size(1)
98+
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
99+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
100+
if module.bias is not None:
101+
torch.nn.init.zeros_(module.bias)
102+
elif isinstance(module, nn.Embedding):
103+
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
104+
19105
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
20106
"""Converts input token IDs to their corresponding embeddings."""
21107

@@ -36,6 +122,19 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torc
36122
encoded_text
37123
) # (batch_size, seq_len, embedding_dim)
38124

125+
token_embeddings = norm(token_embeddings)
126+
127+
if self.attention_config is not None:
128+
if self.attention_config.positional_encoding:
129+
cos_sin = self.cos[:, :seq_len], self.sin[:, :seq_len]
130+
else:
131+
cos_sin = None
132+
133+
for block in self.transformer.h:
134+
token_embeddings = block(token_embeddings, cos_sin)
135+
136+
token_embeddings = norm(token_embeddings)
137+
39138
text_embedding = self._get_sentence_embedding(
40139
token_embeddings=token_embeddings, attention_mask=attention_mask
41140
)
@@ -57,9 +156,28 @@ def _get_sentence_embedding(
57156

58157
# average over non-pad token embeddings
59158
# attention mask has 1 for non-pad tokens and 0 for pad token positions
60-
# TODO: add attention logic at some point
61159

62160
# mask pad-tokens
161+
162+
if self.attention_config is not None:
163+
if self.attention_config.aggregation_method is not None:
164+
if self.attention_config.aggregation_method == "first":
165+
return token_embeddings[:, 0, :]
166+
elif self.attention_config.aggregation_method == "last":
167+
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
168+
return token_embeddings[
169+
torch.arange(token_embeddings.size(0)),
170+
lengths - 1,
171+
:,
172+
]
173+
else:
174+
if self.attention_config.aggregation_method != "mean":
175+
raise ValueError(
176+
f"Unknown aggregation method: {self.attention_config.aggregation_method}. Supported methods are 'mean', 'first', 'last'."
177+
)
178+
179+
assert self.attention_config is None or self.attention_config.aggregation_method == "mean"
180+
63181
mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
64182
masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
65183

@@ -79,3 +197,24 @@ def __call__(self, *args, **kwargs):
79197
f"(got shape {tuple(out.shape)})"
80198
)
81199
return out
200+
201+
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
202+
# autodetect the device from model embeddings
203+
if device is None:
204+
device = next(self.parameters()).device
205+
206+
# stride the channels
207+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
208+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
209+
# stride the time steps
210+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
211+
# calculate the rotation frequencies at each (time, channel) pair
212+
freqs = torch.outer(t, inv_freq)
213+
cos, sin = freqs.cos(), freqs.sin()
214+
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
215+
cos, sin = (
216+
cos[None, :, None, :],
217+
sin[None, :, None, :],
218+
) # add batch and head dims for later broadcasting
219+
220+
return cos, sin

torchTextClassifiers/model/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def __init__(
6969

7070
self.num_classes = self.classification_head.num_classes
7171

72+
torch.nn.init.zeros_(self.classification_head.net.weight)
73+
if self.text_embedder is not None:
74+
self.text_embedder.init_weights()
75+
7276
def _validate_component_connections(self):
7377
def _check_text_categorical_connection(self, text_embedder, cat_var_net):
7478
if cat_var_net.forward_type == CategoricalForwardType.SUM_TO_TEXT:

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
from torchTextClassifiers.dataset import TextClassificationDataset
2424
from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule
2525
from torchTextClassifiers.model.components import (
26+
AttentionConfig,
2627
CategoricalForwardType,
2728
CategoricalVariableNet,
2829
ClassificationHead,
2930
TextEmbedder,
31+
TextEmbedderConfig,
3032
)
3133
from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput
3234

@@ -48,6 +50,7 @@ class ModelConfig:
4850
categorical_vocabulary_sizes: Optional[List[int]] = None
4951
categorical_embedding_dims: Optional[Union[List[int], int]] = None
5052
num_classes: Optional[int] = None
53+
attention_config: Optional[AttentionConfig] = None
5154

5255
def to_dict(self) -> Dict[str, Any]:
5356
return asdict(self)
@@ -139,10 +142,14 @@ def __init__(
139142
)
140143
self.embedding_dim = self.tokenizer.output_dim
141144
else:
142-
self.text_embedder = TextEmbedder(
145+
text_embedder_config = TextEmbedderConfig(
143146
vocab_size=self.vocab_size,
144147
embedding_dim=self.embedding_dim,
145148
padding_idx=tokenizer.padding_idx,
149+
attention_config=model_config.attention_config,
150+
)
151+
self.text_embedder = TextEmbedder(
152+
text_embedder_config=text_embedder_config,
146153
)
147154

148155
classif_head_input_dim = self.embedding_dim

0 commit comments

Comments
 (0)