File tree Expand file tree Collapse file tree
torchTextClassifiers/model Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -132,7 +132,7 @@ def forward(
132132
133133 Returns:
134134 Union[torch.Tensor, dict[str, torch.Tensor]]:
135- - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes)
135+ - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes)
136136 containing raw logits (not softmaxed)
137137 - If return_label_attention_matrix is True: dict with keys:
138138 - "logits": torch.Tensor of shape (batch_size, num_classes)
@@ -142,6 +142,10 @@ def forward(
142142 label_attention_matrix = None
143143 if self .text_embedder is None :
144144 x_text = encoded_text .float ()
145+ if return_label_attention_matrix :
146+ raise ValueError (
147+ "return_label_attention_matrix=True requires a text_embedder with label attention enabled"
148+ )
145149 else :
146150 text_embed_output = self .text_embedder (
147151 input_ids = encoded_text ,
You can’t perform that action at this time.
0 commit comments