Skip to content

Commit 44d9345

Browse files
Update _get_sentence_embedding return type annotation and docstring
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
1 parent 0558f97 commit 44d9345

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from dataclasses import dataclass
3-
from typing import Optional
3+
from typing import Dict, Optional
44

55
import torch
66
import torch.nn as nn
@@ -200,15 +200,18 @@ def _get_sentence_embedding(
200200
token_embeddings: torch.Tensor,
201201
attention_mask: torch.Tensor,
202202
return_label_attention_matrix: bool = False,
203-
) -> torch.Tensor:
203+
) -> Dict[str, Optional[torch.Tensor]]:
204204
"""
205205
Compute sentence embedding from embedded tokens - "remove" second dimension.
206206
207207
Args (output from dataset collate_fn):
208208
token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
209209
attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
210+
return_label_attention_matrix (bool): Whether to compute and return the label attention matrix
210211
Returns:
211-
torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim)
212+
Dict[str, Optional[torch.Tensor]]: A dictionary containing:
213+
- 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled
214+
- 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None
212215
"""
213216

214217
# average over non-pad token embeddings

0 commit comments

Comments
 (0)