|
1 | 1 | import math |
2 | 2 | from dataclasses import dataclass |
3 | | -from typing import Optional |
| 3 | +from typing import Dict, Optional |
4 | 4 |
|
5 | 5 | import torch |
6 | 6 | import torch.nn as nn |
@@ -200,15 +200,18 @@ def _get_sentence_embedding( |
200 | 200 | token_embeddings: torch.Tensor, |
201 | 201 | attention_mask: torch.Tensor, |
202 | 202 | return_label_attention_matrix: bool = False, |
203 | | - ) -> torch.Tensor: |
| 203 | + ) -> Dict[str, Optional[torch.Tensor]]: |
204 | 204 | """ |
205 | 205 | Compute sentence embedding from embedded tokens - "remove" second dimension. |
206 | 206 |
|
207 | 207 | Args (output from dataset collate_fn): |
208 | 208 | token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text |
209 | 209 | attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens |
| 210 | + return_label_attention_matrix (bool): Whether to compute and return the label attention matrix |
210 | 211 | Returns: |
211 | | - torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim) |
| 212 | + Dict[str, Optional[torch.Tensor]]: A dictionary containing: |
| 213 | + - 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled |
| 214 | + - 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None |
212 | 215 | """ |
213 | 216 |
|
214 | 217 | # average over non-pad token embeddings |
|
0 commit comments