Skip to content

Commit d516e6b

Browse files
Improve validation to follow TextEmbedder pattern and clarify error messages
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
1 parent 4aada37 commit d516e6b

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -299,19 +299,19 @@ def __init__(self, config: TextEmbedderConfig):
299299
) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
300300

301301
# Validate head configuration
302-
if self.embedding_dim % self.n_head != 0:
302+
self.head_dim = self.embedding_dim // self.n_head
303+
304+
if self.head_dim * self.n_head != self.embedding_dim:
303305
raise ValueError(
304306
f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). "
305-
f"Got head_dim = {self.embedding_dim / self.n_head}"
307+
f"Got head_dim = {self.head_dim} with remainder {self.embedding_dim % self.n_head}"
306308
)
307309

308310
if self.n_head % self.n_kv_head != 0:
309311
raise ValueError(
310312
f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head}) for Group Query Attention. "
311-
f"Got n_head / n_kv_head = {self.n_head / self.n_kv_head}"
313+
f"Got remainder {self.n_head % self.n_kv_head}"
312314
)
313-
314-
self.head_dim = self.embedding_dim // self.n_head
315315

316316
self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim)
317317

0 commit comments

Comments
 (0)