Skip to content

Commit 819ffbc

Browse files
Address code review feedback: fix trailing whitespace and NameError
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
1 parent 2c4f2b4 commit 819ffbc

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

torchTextClassifiers/model/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def forward(
132132
133133
Returns:
134134
Union[torch.Tensor, dict[str, torch.Tensor]]:
135-
- If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes)
135+
- If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes)
136136
containing raw logits (not softmaxed)
137137
- If return_label_attention_matrix is True: dict with keys:
138138
- "logits": torch.Tensor of shape (batch_size, num_classes)
@@ -142,6 +142,10 @@ def forward(
142142
label_attention_matrix = None
143143
if self.text_embedder is None:
144144
x_text = encoded_text.float()
145+
if return_label_attention_matrix:
146+
raise ValueError(
147+
"return_label_attention_matrix=True requires a text_embedder with label attention enabled"
148+
)
145149
else:
146150
text_embed_output = self.text_embedder(
147151
input_ids=encoded_text,

0 commit comments

Comments
 (0)