@@ -195,7 +195,13 @@ def evaluation_step(
195195 query_embeddings /= query_embeddings .norm (p = 2 , dim = - 1 , keepdim = True )
196196 query_embeddings = query_embeddings [matching_indices ]
197197
198- logits = 100.0 * _safe_matmul (query_embeddings , class_embeddings )
198+ if self .all_dataset_info [dataset_index ]["num_classes" ] == 2 :
199+ softmax_output = _safe_matmul (
200+ query_embeddings , class_embeddings
201+ ).softmax (dim = - 1 )
202+ logits = softmax_output [:, 1 ] - softmax_output [:, 0 ]
203+ else :
204+ logits = 100.0 * _safe_matmul (query_embeddings , class_embeddings )
199205 targets = batch [Modalities .get_modality (query_modality ).target ][
200206 matching_indices
201207 ]
@@ -233,27 +239,36 @@ def _create_metrics(
233239 num_classes : int , top_k : List [int ], prefix : str , postfix : str
234240 ) -> MetricCollection :
235241 """Create a collection of classification metrics."""
242+ task_type = "binary" if num_classes == 2 else "multiclass"
243+ acc_metrics = (
244+ {
245+ f"top{ k } _accuracy" : Accuracy (
246+ task = task_type , num_classes = num_classes , top_k = k , average = "micro"
247+ )
248+ for k in top_k
249+ }
250+ if num_classes > 2
251+ else {"accuracy" : Accuracy (task = task_type , num_classes = num_classes )}
252+ )
236253 return MetricCollection (
237254 {
238255 "precision" : Precision (
239- task = "multiclass" , num_classes = num_classes , average = "macro"
256+ task = task_type ,
257+ num_classes = num_classes ,
258+ average = "macro" if num_classes > 2 else "micro" ,
240259 ),
241260 "recall" : Recall (
242- task = "multiclass" , num_classes = num_classes , average = "macro"
261+ task = task_type ,
262+ num_classes = num_classes ,
263+ average = "macro" if num_classes > 2 else "micro" ,
243264 ),
244265 "f1_score_macro" : F1Score (
245- task = "multiclass" , num_classes = num_classes , average = "macro"
266+ task = task_type ,
267+ num_classes = num_classes ,
268+ average = "macro" if num_classes > 2 else "micro" ,
246269 ),
247- "aucroc" : AUROC (task = "multiclass" , num_classes = num_classes ),
248- ** {
249- f"top{ k } _accuracy" : Accuracy (
250- task = "multiclass" ,
251- num_classes = num_classes ,
252- top_k = k ,
253- average = "micro" ,
254- )
255- for k in top_k
256- },
270+ "aucroc" : AUROC (task = task_type , num_classes = num_classes ),
271+ ** acc_metrics ,
257272 },
258273 prefix = prefix ,
259274 postfix = postfix ,
0 commit comments