Skip to content

Commit 2c4f2b4

Browse files
Fix return type annotation for TextClassificationModel.forward
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
1 parent 33dc833 commit 2c4f2b4

1 file changed

Lines changed: 9 additions & 4 deletions

File tree

torchTextClassifiers/model/model.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import logging
9-
from typing import Annotated, Optional
9+
from typing import Annotated, Optional, Union
1010

1111
import torch
1212
from torch import nn
@@ -120,18 +120,23 @@ def forward(
120120
categorical_vars: Annotated[torch.Tensor, "batch num_cats"],
121121
return_label_attention_matrix: bool = False,
122122
**kwargs,
123-
) -> torch.Tensor:
123+
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
124124
"""
125125
Memory-efficient forward pass implementation.
126126
127127
Args: output from dataset collate_fn
128128
input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text
129129
attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
130130
categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features)
131+
return_label_attention_matrix (bool): If True, returns a dict with logits and label_attention_matrix
131132
132133
Returns:
133-
torch.Tensor: Model output scores for each class - shape (batch_size, num_classes)
134-
Raw, not softmaxed.
134+
Union[torch.Tensor, dict[str, torch.Tensor]]:
135+
- If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes)
136+
containing raw logits (not softmaxed)
137+
- If return_label_attention_matrix is True: dict with keys:
138+
- "logits": torch.Tensor of shape (batch_size, num_classes)
139+
- "label_attention_matrix": torch.Tensor of shape (batch_size, num_classes, seq_len)
135140
"""
136141
encoded_text = input_ids # clearer name
137142
label_attention_matrix = None

0 commit comments

Comments
 (0)