@@ -129,7 +129,7 @@ def forward(
129129 input_ids : torch .Tensor ,
130130 attention_mask : torch .Tensor ,
131131 return_label_attention_matrix : bool = False ,
132- ) -> dict [str , Optional [torch .Tensor ]]:
132+ ) -> Dict [str , Optional [torch .Tensor ]]:
133133 """Converts input token IDs to their corresponding embeddings.
134134
135135 Args:
@@ -222,14 +222,20 @@ def _get_sentence_embedding(
222222 if self .attention_config is not None :
223223 if self .attention_config .aggregation_method is not None : # default is "mean"
224224 if self .attention_config .aggregation_method == "first" :
225- return token_embeddings [:, 0 , :]
225+ return {
226+ "sentence_embedding" : token_embeddings [:, 0 , :],
227+ "label_attention_matrix" : None ,
228+ }
226229 elif self .attention_config .aggregation_method == "last" :
227230 lengths = attention_mask .sum (dim = 1 ).clamp (min = 1 ) # last non-pad token index + 1
228- return token_embeddings [
229- torch .arange (token_embeddings .size (0 )),
230- lengths - 1 ,
231- :,
232- ]
231+ return {
232+ "sentence_embedding" : token_embeddings [
233+ torch .arange (token_embeddings .size (0 )),
234+ lengths - 1 ,
235+ :,
236+ ],
237+ "label_attention_matrix" : None ,
238+ }
233239 else :
234240 if self .attention_config .aggregation_method != "mean" :
235241 raise ValueError (
0 commit comments