|
6 | 6 | """ |
7 | 7 |
|
8 | 8 | import logging |
9 | | -from typing import Annotated, Optional |
| 9 | +from typing import Annotated, Optional, Union |
10 | 10 |
|
11 | 11 | import torch |
12 | 12 | from torch import nn |
@@ -120,18 +120,23 @@ def forward( |
120 | 120 | categorical_vars: Annotated[torch.Tensor, "batch num_cats"], |
121 | 121 | return_label_attention_matrix: bool = False, |
122 | 122 | **kwargs, |
123 | | - ) -> torch.Tensor: |
| 123 | + ) -> Union[torch.Tensor, dict[str, torch.Tensor]]: |
124 | 124 | """ |
125 | 125 | Memory-efficient forward pass implementation. |
126 | 126 |
|
127 | 127 | Args: output from dataset collate_fn |
128 | 128 | input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text |
129 | 129 | attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens |
130 | 130 | categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features) |
| 131 | + return_label_attention_matrix (bool): If True, returns a dict with logits and label_attention_matrix |
131 | 132 |
|
132 | 133 | Returns: |
133 | | - torch.Tensor: Model output scores for each class - shape (batch_size, num_classes) |
134 | | - Raw, not softmaxed. |
| 134 | + Union[torch.Tensor, dict[str, torch.Tensor]]: |
| 135 | + - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes) |
| 136 | + containing raw logits (not softmaxed) |
| 137 | + - If return_label_attention_matrix is True: dict with keys: |
| 138 | + - "logits": torch.Tensor of shape (batch_size, num_classes) |
| 139 | + - "label_attention_matrix": torch.Tensor of shape (batch_size, num_classes, seq_len) |
135 | 140 | """ |
136 | 141 | encoded_text = input_ids # clearer name |
137 | 142 | label_attention_matrix = None |
|
0 commit comments