@@ -121,17 +121,25 @@ def forward(
121121 input_ids : torch .Tensor ,
122122 attention_mask : torch .Tensor ,
123123 return_label_attention_matrix : bool = False ,
124- ) -> torch .Tensor :
124+ ) -> dict [ str , Optional [ torch .Tensor ]] :
125125 """Converts input token IDs to their corresponding embeddings.
126126
127127 Args:
128128 input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized
129129 attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
130- return_label_attention_matrix (bool): Whether to return the label attention matrix
130+ return_label_attention_matrix (bool): Whether to return the label attention matrix.
131+
131132 Returns:
132- torch.Tensor: Text embeddings, shape (batch_size, embedding_dim) if self.enable_label_attention is False, else (batch_size, num_labels, embedding_dim)
133- torch.Tensor: Label attention matrix, shape (batch_size, num_labels, seq_len) if return_label_attention_matrix is True, else None.
134- Also None if label attention is disabled (even if return_label_attention_matrix is True)
133+ dict: A dictionary with the following keys:
134+
135+ - "sentence_embedding" (torch.Tensor): Text embeddings of shape
136+ (batch_size, embedding_dim) if ``self.enable_label_attention`` is False,
137+ else (batch_size, num_labels, embedding_dim).
138+
139+ - "label_attention_matrix" (Optional[torch.Tensor]): Label attention
140+ matrix of shape (batch_size, num_labels, seq_len) if
141+ ``return_label_attention_matrix`` is True and label attention is
142+ enabled, otherwise ``None``.
135143 """
136144
137145 encoded_text = input_ids # clearer name
0 commit comments