1212@dataclass
1313class LabelAttentionConfig :
1414 n_head : int
15- n_kv_head : int
1615 num_classes : int
1716
1817
@@ -306,34 +305,29 @@ def __init__(self, config: TextEmbedderConfig):
306305 self .embedding_dim = config .embedding_dim
307306 self .num_classes = label_attention_config .num_classes
308307 self .n_head = label_attention_config .n_head
309- self .n_kv_head = label_attention_config .n_kv_head
310- self .enable_gqa = (
311- self .n_head != self .n_kv_head
312- ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
313-
308+
314309 # Validate head configuration
315310 self .head_dim = self .embedding_dim // self .n_head
316-
311+
317312 if self .head_dim * self .n_head != self .embedding_dim :
318313 raise ValueError (
319314 f"embedding_dim ({ self .embedding_dim } ) must be divisible by n_head ({ self .n_head } ). "
320315 f"Got head_dim = { self .head_dim } with remainder { self .embedding_dim % self .n_head } "
321316 )
322-
323- if self .n_head % self .n_kv_head != 0 :
324- raise ValueError (
325- f"n_head ({ self .n_head } ) must be divisible by n_kv_head ({ self .n_kv_head } ) for Group Query Attention. "
326- f"Got remainder { self .n_head % self .n_kv_head } "
327- )
328317
329318 self .label_embeds = nn .Embedding (self .num_classes , self .embedding_dim )
330319
331320 self .c_q = nn .Linear (self .embedding_dim , self .n_head * self .head_dim , bias = False )
332- self .c_k = nn .Linear (self .embedding_dim , self .n_kv_head * self .head_dim , bias = False )
333- self .c_v = nn .Linear (self .embedding_dim , self .n_kv_head * self .head_dim , bias = False )
321+ self .c_k = nn .Linear (self .embedding_dim , self .n_head * self .head_dim , bias = False )
322+ self .c_v = nn .Linear (self .embedding_dim , self .n_head * self .head_dim , bias = False )
334323 self .c_proj = nn .Linear (self .embedding_dim , self .embedding_dim , bias = False )
335324
336- def forward (self , token_embeddings , attention_mask : Optional [torch .Tensor ] = None , compute_attention_matrix : Optional [bool ] = False ):
325+ def forward (
326+ self ,
327+ token_embeddings ,
328+ attention_mask : Optional [torch .Tensor ] = None ,
329+ compute_attention_matrix : Optional [bool ] = False ,
330+ ):
337331 """
338332 Args:
339333 token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input.
@@ -362,8 +356,8 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non
362356 all_label_embeddings = norm (all_label_embeddings )
363357
364358 q = self .c_q (all_label_embeddings ).view (B , self .num_classes , self .n_head , self .head_dim )
365- k = self .c_k (token_embeddings ).view (B , T , self .n_kv_head , self .head_dim )
366- v = self .c_v (token_embeddings ).view (B , T , self .n_kv_head , self .head_dim )
359+ k = self .c_k (token_embeddings ).view (B , T , self .n_head , self .head_dim )
360+ v = self .c_v (token_embeddings ).view (B , T , self .n_head , self .head_dim )
367361
368362 q , k = norm (q ), norm (k ) # QK norm
369363 q , k , v = (
@@ -379,11 +373,11 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non
379373 attn_mask = None
380374 if attention_mask is not None :
381375 # Convert: 0 (padding) -> True (mask out), 1 (real) -> False (attend to)
382- attn_mask = ( attention_mask == 0 ) # (B, T)
376+ attn_mask = attention_mask == 0 # (B, T)
383377 # Expand to (B, 1, 1, T) for broadcasting across heads and queries
384378 attn_mask = attn_mask .unsqueeze (1 ).unsqueeze (2 ) # (B, 1, 1, T)
385379
386- y = F .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , is_causal = False , enable_gqa = self . enable_gqa )
380+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , is_causal = False )
387381
388382 # Re-assemble the heads side by side and project back to residual stream
389383 y = y .transpose (1 , 2 ).contiguous ().view (B , self .num_classes , - 1 ) # (bs, n_labels, d_model)
@@ -400,7 +394,7 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non
400394 # attn_mask is already in the right shape: (B, 1, 1, T)
401395 # We need to apply it to scores of shape (B, n_head, n_labels, T)
402396 # Set masked positions to -inf so they become 0 after softmax
403- attention_scores = attention_scores .masked_fill (attn_mask , float (' -inf' ))
397+ attention_scores = attention_scores .masked_fill (attn_mask , float (" -inf" ))
404398
405399 attention_matrix = torch .softmax (attention_scores , dim = - 1 )
406400
0 commit comments