We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 525b482 commit 2374df8Copy full SHA for 2374df8
1 file changed
torchTextClassifiers/model/components/text_embedder.py
@@ -180,12 +180,14 @@ def forward(
180
181
token_embeddings = norm(token_embeddings)
182
183
- text_embedding, label_attention_matrix = self._get_sentence_embedding(
+ out = self._get_sentence_embedding(
184
token_embeddings=token_embeddings,
185
attention_mask=attention_mask,
186
return_label_attention_matrix=return_label_attention_matrix,
187
- ).values()
+ )
188
189
+ text_embedding = out["sentence_embedding"]
190
+ label_attention_matrix = out["label_attention_matrix"]
191
return {
192
"sentence_embedding": text_embedding,
193
"label_attention_matrix": label_attention_matrix,
0 commit comments