2626from ..utils .constants import UtilsConstants
2727from ..utils .calibration_metrics import compute_calibration_metrics
2828from ..utils .data import PatchDataset
29- from ..utils .downstream_metrics import compute_metrics
29+ from ..utils .downstream_metrics import compute_metric , compute_metrics
3030from ..utils .utils import (
3131 get_hyperaparams_dict ,
3232 local_seed ,
@@ -272,13 +272,13 @@ def train_probe(
272272 # Updating best ckpt
273273 if (
274274 task_type == "linear_probing"
275- and metrics [i ]["f1" ] > best_val_perf
275+ and metrics [i ]["f1" ][ "metric_score" ] > best_val_perf
276276 ) or (
277277 task_type == "segmentation"
278278 and np .array (losses [i ]).mean ().item () < best_val_perf
279279 ):
280280 if task_type == "linear_probing" :
281- best_val_perf = metrics [i ]["f1" ]
281+ best_val_perf = metrics [i ]["f1" ][ "metric_score" ]
282282 elif task_type == "segmentation" :
283283 best_val_perf = np .array (losses [i ]).mean ().item ()
284284 best_ckpt_hyperparam_id = i
@@ -538,7 +538,9 @@ def train_eval(
538538 tot_loss = []
539539 all_out = []
540540 all_label = []
541- for batch_id , batch in tqdm (enumerate (dataloader ), total = len (dataloader )):
541+ for batch_id , batch in tqdm (
542+ enumerate (dataloader ), total = len (dataloader ), disable = hyperparam_search
543+ ):
542544 # Batch data
543545 if "emb" in batch .keys ():
544546 emb = batch ["emb" ].to (device )
@@ -591,8 +593,14 @@ def train_eval(
591593 out = []
592594 if task_type == "segmentation" :
593595 for o , m in zip (output , unmasked_label ):
594- out .append (torch .cat ([o [c ][m ].unsqueeze (- 1 ) for c in range (o .shape [0 ])], dim = - 1 ))
595- curr_loss = sum ([criterion (o , l ) for o , l in zip (out , label )]) / len (out )
596+ out .append (
597+ torch .cat (
598+ [o [c ][m ].unsqueeze (- 1 ) for c in range (o .shape [0 ])], dim = - 1
599+ )
600+ )
601+ curr_loss = sum ([criterion (o , l ) for o , l in zip (out , label )]) / len (
602+ out
603+ )
596604 else :
597605 for c in range (output .shape [1 ]):
598606 out .append (output [:, c ].unsqueeze (- 1 ))
@@ -604,13 +612,19 @@ def train_eval(
604612 if batch_id == 0 :
605613 tot_loss .append ([curr_loss .item ()])
606614 if comp_metrics :
607- all_out .append ([[o .detach ().cpu () for o in out ]] if task_type == 'segmentation'
608- else [out .detach ().cpu ()])
615+ all_out .append (
616+ [[o .detach ().cpu () for o in out ]]
617+ if task_type == "segmentation"
618+ else [out .detach ().cpu ()]
619+ )
609620 else :
610621 tot_loss [i ].append (curr_loss .item ())
611622 if comp_metrics :
612- all_out [i ].append ([o .detach ().cpu () for o in out ] if task_type == 'segmentation'
613- else out .detach ().cpu ())
623+ all_out [i ].append (
624+ [o .detach ().cpu () for o in out ]
625+ if task_type == "segmentation"
626+ else out .detach ().cpu ()
627+ )
614628 # Logging
615629 if comp_metrics :
616630 if task_type == "segmentation" :
@@ -641,20 +655,42 @@ def train_eval(
641655 if task_type == "segmentation" :
642656 metrics = []
643657 for i in range (len (all_out )):
644- all_out [i ] = [F .softmax (item , dim = 1 ) for batch in all_out [i ] for item in batch ]
645- all_metrics = [compute_metrics (o , None , l , True ) for o , l in zip (all_out [i ], all_label ) if len (l ) > 0 ]
658+ all_out [i ] = [
659+ F .softmax (item , dim = 1 ) for batch in all_out [i ] for item in batch
660+ ]
661+ all_metrics = [
662+ compute_metrics (o , None , l , True , compute_ci = False )
663+ for o , l in zip (all_out [i ], all_label )
664+ if len (l ) > 0
665+ ]
646666 weights = [len (l ) for l in all_label if len (l ) > 0 ]
647- metrics .append ({key : np .average ([d [key ] for d in all_metrics ], weights = weights ) for key in all_metrics [0 ]} |
648- {f'{ key } _per_sample' : [d [key ] for d in all_metrics ] for key in all_metrics [0 ]})
667+
668+ # Averagin per-image performance and computing confidence intervals
669+ all_metrics_out = {}
670+ for key in all_metrics [0 ]:
671+ metric_vals = [d [key ]["metric_score" ] for d in all_metrics ]
672+ all_metrics_out [key ] = compute_metric (
673+ weights ,
674+ metric_vals ,
675+ lambda weights , metric_vals : np .average (
676+ metric_vals , weights = weights
677+ ),
678+ )
679+ all_metrics_out [f"per_sample_{ key } " ] = metric_vals
680+ metrics .append (all_metrics_out )
649681 else :
650682 # Computing metrics
651683 all_label = torch .cat (all_label )
652684 metrics = []
653685 for i in range (len (all_out )):
654686 all_out [i ] = torch .cat (all_out [i ])
655687 all_out [i ] = F .softmax (all_out [i ], dim = 1 )
656- classification_metrics = compute_metrics (all_out [i ], None , all_label )
657- conformal_metrics = compute_calibration_metrics (all_out [i ], all_label )
688+ classification_metrics = compute_metrics (
689+ all_out [i ], None , all_label , compute_ci = (not hyperparam_search )
690+ )
691+ conformal_metrics = compute_calibration_metrics (
692+ all_out [i ], all_label , compute_ci = (not hyperparam_search )
693+ )
658694 curr_metrics = (
659695 classification_metrics | conformal_metrics
660696 ) # merging dictionaries
0 commit comments