Skip to content

Commit 2374df8

Browse files
chore: better dict handling in TextEmbedder forward output
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 525b482 commit 2374df8

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,14 @@ def forward(
180180

181181
token_embeddings = norm(token_embeddings)
182182

183-
text_embedding, label_attention_matrix = self._get_sentence_embedding(
183+
out = self._get_sentence_embedding(
184184
token_embeddings=token_embeddings,
185185
attention_mask=attention_mask,
186186
return_label_attention_matrix=return_label_attention_matrix,
187-
).values()
187+
)
188188

189+
text_embedding = out["sentence_embedding"]
190+
label_attention_matrix = out["label_attention_matrix"]
189191
return {
190192
"sentence_embedding": text_embedding,
191193
"label_attention_matrix": label_attention_matrix,

0 commit comments

Comments
 (0)