Skip to content

Commit 70b79c9

Browse files
doc: fix docstring on shape of label attention matrix
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 7a988c3 commit 70b79c9

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,14 @@ def forward(
142142
143143
- "sentence_embedding" (torch.Tensor): Text embeddings of shape
144144
(batch_size, embedding_dim) if ``self.enable_label_attention`` is False,
145-
else (batch_size, num_labels, embedding_dim).
145+
else (batch_size, num_classes, embedding_dim), where ``num_classes``
146+
is the number of label classes.
146147
147148
- "label_attention_matrix" (Optional[torch.Tensor]): Label attention
148-
matrix of shape (batch_size, num_labels, seq_len) if
149+
matrix of shape (batch_size, n_head, num_classes, seq_len) if
149150
``return_label_attention_matrix`` is True and label attention is
150-
enabled, otherwise ``None``.
151+
enabled, otherwise ``None``. The dimensions correspond to
152+
(batch_size, attention heads, label classes, sequence length).
151153
"""
152154

153155
encoded_text = input_ids # clearer name

0 commit comments

Comments
 (0)