Skip to content

Commit 86a0715

Browse files
Fix early returns to match dictionary return type
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
1 parent 44d9345 commit 86a0715

1 file changed

Lines changed: 13 additions & 7 deletions

File tree

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def forward(
129129
input_ids: torch.Tensor,
130130
attention_mask: torch.Tensor,
131131
return_label_attention_matrix: bool = False,
132-
) -> dict[str, Optional[torch.Tensor]]:
132+
) -> Dict[str, Optional[torch.Tensor]]:
133133
"""Converts input token IDs to their corresponding embeddings.
134134
135135
Args:
@@ -222,14 +222,20 @@ def _get_sentence_embedding(
222222
if self.attention_config is not None:
223223
if self.attention_config.aggregation_method is not None: # default is "mean"
224224
if self.attention_config.aggregation_method == "first":
225-
return token_embeddings[:, 0, :]
225+
return {
226+
"sentence_embedding": token_embeddings[:, 0, :],
227+
"label_attention_matrix": None,
228+
}
226229
elif self.attention_config.aggregation_method == "last":
227230
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
228-
return token_embeddings[
229-
torch.arange(token_embeddings.size(0)),
230-
lengths - 1,
231-
:,
232-
]
231+
return {
232+
"sentence_embedding": token_embeddings[
233+
torch.arange(token_embeddings.size(0)),
234+
lengths - 1,
235+
:,
236+
],
237+
"label_attention_matrix": None,
238+
}
233239
else:
234240
if self.attention_config.aggregation_method != "mean":
235241
raise ValueError(

0 commit comments

Comments
 (0)