File tree Expand file tree Collapse file tree
torchTextClassifiers/model/components Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -299,19 +299,19 @@ def __init__(self, config: TextEmbedderConfig):
299299 ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
300300
301301 # Validate head configuration
302- if self .embedding_dim % self .n_head != 0 :
302+ self .head_dim = self .embedding_dim // self .n_head
303+
304+ if self .head_dim * self .n_head != self .embedding_dim :
303305 raise ValueError (
304306 f"embedding_dim ({ self .embedding_dim } ) must be divisible by n_head ({ self .n_head } ). "
305- f"Got head_dim = { self .embedding_dim / self .n_head } "
307+ f"Got head_dim = { self .head_dim } with remainder { self . embedding_dim % self .n_head } "
306308 )
307309
308310 if self .n_head % self .n_kv_head != 0 :
309311 raise ValueError (
310312 f"n_head ({ self .n_head } ) must be divisible by n_kv_head ({ self .n_kv_head } ) for Group Query Attention. "
311- f"Got n_head / n_kv_head = { self .n_head / self .n_kv_head } "
313+ f"Got remainder { self .n_head % self .n_kv_head } "
312314 )
313-
314- self .head_dim = self .embedding_dim // self .n_head
315315
316316 self .label_embeds = nn .Embedding (self .num_classes , self .embedding_dim )
317317
You can’t perform that action at this time.
0 commit comments