Skip to content

Commit d266572

Browse files
fix: ensure label_indices uses correct device and dtype
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
1 parent 2374df8 commit d266572

2 files changed

Lines changed: 4 additions & 1 deletion

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,4 @@ example_files/
183183
_site/
184184
.quarto/
185185
**/*.quarto_ipynb
186+
my_ttc/

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,9 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F
324324
compute_attention_matrix = bool(compute_attention_matrix)
325325

326326
# 1. Create label indices [0, 1, ..., C-1] for the whole batch
327-
label_indices = torch.arange(self.num_classes).expand(B, -1)
327+
label_indices = torch.arange(
328+
self.num_classes, dtype=torch.long, device=token_embeddings.device
329+
).expand(B, -1)
328330

329331
all_label_embeddings = self.label_embeds(
330332
label_indices

0 commit comments

Comments
 (0)