Skip to content

Commit e71ab0f

Browse files
chore: disable gqa for label attention for simpl.
1 parent 819ffbc commit e71ab0f

2 files changed

Lines changed: 21 additions & 29 deletions

File tree

tests/test_pipeline.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def run_full_pipeline(
9494
label_attention_config=(
9595
LabelAttentionConfig(
9696
n_head=attention_config.n_head,
97-
n_kv_head=attention_config.n_kv_head,
9897
num_classes=model_params["num_classes"],
9998
)
10099
if label_attention_enabled
@@ -156,7 +155,6 @@ def run_full_pipeline(
156155
label_attention_config=(
157156
LabelAttentionConfig(
158157
n_head=attention_config.n_head,
159-
n_kv_head=attention_config.n_kv_head,
160158
num_classes=model_params["num_classes"],
161159
)
162160
if label_attention_enabled
@@ -199,9 +197,9 @@ def run_full_pipeline(
199197

200198
# Test label attention assertions
201199
if label_attention_enabled:
202-
assert predictions["label_attention_attributions"] is not None, (
203-
"Label attention attributions should not be None when label_attention_enabled is True"
204-
)
200+
assert (
201+
predictions["label_attention_attributions"] is not None
202+
), "Label attention attributions should not be None when label_attention_enabled is True"
205203
label_attention_attributions = predictions["label_attention_attributions"]
206204
expected_shape = (
207205
len(sample_text_data), # batch_size
@@ -215,9 +213,9 @@ def run_full_pipeline(
215213
)
216214
else:
217215
# When label attention is not enabled, the attributions should be None
218-
assert predictions.get("label_attention_attributions") is None, (
219-
"Label attention attributions should be None when label_attention_enabled is False"
220-
)
216+
assert (
217+
predictions.get("label_attention_attributions") is None
218+
), "Label attention attributions should be None when label_attention_enabled is False"
221219

222220
# Test explainability functions
223221
text_idx = 0

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
@dataclass
1313
class LabelAttentionConfig:
1414
n_head: int
15-
n_kv_head: int
1615
num_classes: int
1716

1817

@@ -306,34 +305,29 @@ def __init__(self, config: TextEmbedderConfig):
306305
self.embedding_dim = config.embedding_dim
307306
self.num_classes = label_attention_config.num_classes
308307
self.n_head = label_attention_config.n_head
309-
self.n_kv_head = label_attention_config.n_kv_head
310-
self.enable_gqa = (
311-
self.n_head != self.n_kv_head
312-
) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
313-
308+
314309
# Validate head configuration
315310
self.head_dim = self.embedding_dim // self.n_head
316-
311+
317312
if self.head_dim * self.n_head != self.embedding_dim:
318313
raise ValueError(
319314
f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). "
320315
f"Got head_dim = {self.head_dim} with remainder {self.embedding_dim % self.n_head}"
321316
)
322-
323-
if self.n_head % self.n_kv_head != 0:
324-
raise ValueError(
325-
f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head}) for Group Query Attention. "
326-
f"Got remainder {self.n_head % self.n_kv_head}"
327-
)
328317

329318
self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim)
330319

331320
self.c_q = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False)
332-
self.c_k = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False)
333-
self.c_v = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False)
321+
self.c_k = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False)
322+
self.c_v = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False)
334323
self.c_proj = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
335324

336-
def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = None, compute_attention_matrix: Optional[bool] = False):
325+
def forward(
326+
self,
327+
token_embeddings,
328+
attention_mask: Optional[torch.Tensor] = None,
329+
compute_attention_matrix: Optional[bool] = False,
330+
):
337331
"""
338332
Args:
339333
token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input.
@@ -362,8 +356,8 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non
362356
all_label_embeddings = norm(all_label_embeddings)
363357

364358
q = self.c_q(all_label_embeddings).view(B, self.num_classes, self.n_head, self.head_dim)
365-
k = self.c_k(token_embeddings).view(B, T, self.n_kv_head, self.head_dim)
366-
v = self.c_v(token_embeddings).view(B, T, self.n_kv_head, self.head_dim)
359+
k = self.c_k(token_embeddings).view(B, T, self.n_head, self.head_dim)
360+
v = self.c_v(token_embeddings).view(B, T, self.n_head, self.head_dim)
367361

368362
q, k = norm(q), norm(k) # QK norm
369363
q, k, v = (
@@ -379,11 +373,11 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non
379373
attn_mask = None
380374
if attention_mask is not None:
381375
# Convert: 0 (padding) -> True (mask out), 1 (real) -> False (attend to)
382-
attn_mask = (attention_mask == 0) # (B, T)
376+
attn_mask = attention_mask == 0 # (B, T)
383377
# Expand to (B, 1, 1, T) for broadcasting across heads and queries
384378
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
385379

386-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False, enable_gqa=self.enable_gqa)
380+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False)
387381

388382
# Re-assemble the heads side by side and project back to residual stream
389383
y = y.transpose(1, 2).contiguous().view(B, self.num_classes, -1) # (bs, n_labels, d_model)
@@ -400,7 +394,7 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non
400394
# attn_mask is already in the right shape: (B, 1, 1, T)
401395
# We need to apply it to scores of shape (B, n_head, n_labels, T)
402396
# Set masked positions to -inf so they become 0 after softmax
403-
attention_scores = attention_scores.masked_fill(attn_mask, float('-inf'))
397+
attention_scores = attention_scores.masked_fill(attn_mask, float("-inf"))
404398

405399
attention_matrix = torch.softmax(attention_scores, dim=-1)
406400

0 commit comments

Comments
 (0)