Skip to content

Commit ec6742c

Browse files
fix docstring in TextEmbedder forward
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 3d034f4 commit ec6742c

1 file changed

Lines changed: 13 additions & 5 deletions

File tree

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,25 @@ def forward(
121121
input_ids: torch.Tensor,
122122
attention_mask: torch.Tensor,
123123
return_label_attention_matrix: bool = False,
124-
) -> torch.Tensor:
124+
) -> dict[str, Optional[torch.Tensor]]:
125125
"""Converts input token IDs to their corresponding embeddings.
126126
127127
Args:
128128
input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized
129129
attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
130-
return_label_attention_matrix (bool): Whether to return the label attention matrix
130+
return_label_attention_matrix (bool): Whether to return the label attention matrix.
131+
131132
Returns:
132-
torch.Tensor: Text embeddings, shape (batch_size, embedding_dim) if self.enable_label_attention is False, else (batch_size, num_labels, embedding_dim)
133-
torch.Tensor: Label attention matrix, shape (batch_size, num_labels, seq_len) if return_label_attention_matrix is True, else None.
134-
Also None if label attention is disabled (even if return_label_attention_matrix is True)
133+
dict: A dictionary with the following keys:
134+
135+
- "sentence_embedding" (torch.Tensor): Text embeddings of shape
136+
(batch_size, embedding_dim) if ``self.enable_label_attention`` is False,
137+
else (batch_size, num_labels, embedding_dim).
138+
139+
- "label_attention_matrix" (Optional[torch.Tensor]): Label attention
140+
matrix of shape (batch_size, num_labels, seq_len) if
141+
``return_label_attention_matrix`` is True and label attention is
142+
enabled, otherwise ``None``.
135143
"""
136144

137145
encoded_text = input_ids # clearer name

0 commit comments

Comments
 (0)