22
33import os
44import sys
5+ from collections import defaultdict
56from datetime import datetime
67from typing import Any , Callable , Iterable
78
@@ -444,7 +445,7 @@ def results_multitask( # noqa: C901
444445 save_results (bool, optional): Whether to save results dict. Defaults to True.
445446
446447 Returns:
447- dict[str, dict[str, list | np.ndarray ]]: Dictionary of predicted results for each
448+ dict[str, dict[str, list | np.array ]]: Dictionary of predicted results for each
448449 task.
449450 """
450451 if not (print_results or save_results ):
@@ -465,27 +466,22 @@ def results_multitask( # noqa: C901
465466 test_loader = DataLoader (test_set , ** data_params )
466467 print (f"Testing on { len (test_set ):,} samples" )
467468
468- results_dict : dict [str , dict [str , list | np .ndarray ]] = {n : {} for n in task_dict }
469- for name , task in task_dict .items ():
470- if task == "regression" :
471- results_dict [name ]["pred" ] = np .zeros ((ensemble_folds , len (test_set )))
472- if robust :
473- results_dict [name ]["ale" ] = np .zeros ((ensemble_folds , len (test_set )))
474-
475- elif task == "classification" :
476- results_dict [name ]["logits" ] = []
477- results_dict [name ]["pre-logits" ] = []
478- if robust :
479- results_dict [name ]["pre-logits_ale" ] = []
469+ results_dict : dict [str , dict [str , list | np .ndarray ]] = {}
470+ for target_name , task_type in task_dict .items ():
471+ results_dict [target_name ] = defaultdict (
472+ list
473+ if task_type == "classification"
474+ else lambda : np .zeros ((ensemble_folds , len (test_set ))) # type: ignore
475+ )
480476
481- for j in range (ensemble_folds ):
477+ for ens_idx in range (ensemble_folds ):
482478
483479 if ensemble_folds == 1 :
484480 resume = f"{ ROOT } /models/{ model_name } /{ eval_type } -r{ run_id } .pth.tar"
485481 print ("Evaluating Model" )
486482 else :
487- resume = f"{ ROOT } /models/{ model_name } /{ eval_type } -r{ j } .pth.tar"
488- print (f"Evaluating Model { j + 1 } /{ ensemble_folds } " )
483+ resume = f"{ ROOT } /models/{ model_name } /{ eval_type } -r{ ens_idx } .pth.tar"
484+ print (f"Evaluating Model { ens_idx + 1 } /{ ensemble_folds } " )
489485
490486 if not os .path .isfile (resume ):
491487 raise FileNotFoundError (f"no checkpoint found at '{ resume } '" )
@@ -506,45 +502,47 @@ def results_multitask( # noqa: C901
506502 model .load_state_dict (checkpoint ["state_dict" ])
507503
508504 normalizer_dict : dict [str , Normalizer | None ] = {}
509- for task , state_dict in checkpoint ["normalizer_dict" ].items ():
505+ for task_type , state_dict in checkpoint ["normalizer_dict" ].items ():
510506 if state_dict is not None :
511- normalizer_dict [task ] = Normalizer .from_state_dict (state_dict )
507+ normalizer_dict [task_type ] = Normalizer .from_state_dict (state_dict )
512508 else :
513- normalizer_dict [task ] = None
509+ normalizer_dict [task_type ] = None
514510
515- y_test , output , * ids = model .predict (test_loader )
511+ y_test , outputs , * ids = model .predict (test_loader )
516512
517- for pred , target , (name , task ) in zip (output , y_test , model .task_dict .items ()):
518- if task == "regression" :
519- normalizer = normalizer_dict [name ]
513+ for preds , targets , (target_name , task_type ), res_dict in zip (
514+ outputs , y_test , model .task_dict .items (), results_dict .values ()
515+ ):
516+ if task_type == "regression" :
517+ normalizer = normalizer_dict [target_name ]
520518 assert isinstance (normalizer , Normalizer )
521519 if model .robust :
522- mean , log_std = pred .unbind (dim = 1 )
523- pred = normalizer .denorm (mean .data .cpu ())
520+ mean , log_std = preds .unbind (dim = 1 )
521+ preds = normalizer .denorm (mean .data .cpu ())
524522 ale_std = torch .exp (log_std ).data .cpu () * normalizer .std
525- results_dict [ name ][ "ale" ][j , :] = ale_std .view (- 1 ).numpy () # type: ignore
523+ res_dict [ "ale" ][ens_idx , :] = ale_std .view (- 1 ).numpy () # type: ignore
526524 else :
527- pred = normalizer .denorm (pred .data .cpu ())
525+ preds = normalizer .denorm (preds .data .cpu ())
528526
529- results_dict [ name ][ "pred " ][j , :] = pred .view (- 1 ).numpy () # type: ignore
527+ res_dict [ "preds " ][ens_idx , :] = preds .view (- 1 ).numpy () # type: ignore
530528
531- elif task == "classification" :
529+ elif task_type == "classification" :
532530 if model .robust :
533- mean , log_std = pred .chunk (2 , dim = 1 )
531+ mean , log_std = preds .chunk (2 , dim = 1 )
534532 logits = (
535533 sampled_softmax (mean , log_std , samples = 10 ).data .cpu ().numpy ()
536534 )
537535 pre_logits = mean .data .cpu ().numpy ()
538536 pre_logits_std = torch .exp (log_std ).data .cpu ().numpy ()
539- results_dict [ name ] ["pre-logits_ale" ].append (pre_logits_std ) # type: ignore
537+ res_dict ["pre-logits_ale" ].append (pre_logits_std ) # type: ignore
540538 else :
541- pre_logits = pred .data .cpu ().numpy ()
539+ pre_logits = preds .data .cpu ().numpy ()
542540 logits = pre_logits .softmax (1 )
543541
544- results_dict [ name ] ["pre-logits" ].append (pre_logits ) # type: ignore
545- results_dict [ name ] ["logits" ].append (logits ) # type: ignore
542+ res_dict ["pre-logits" ].append (pre_logits ) # type: ignore
543+ res_dict ["logits" ].append (logits ) # type: ignore
546544
547- results_dict [ name ][ "target " ] = target
545+ res_dict [ "targets " ] = targets
548546
549547 # TODO cleaner way to get identifier names
550548 if save_results :
@@ -555,23 +553,23 @@ def results_multitask( # noqa: C901
555553 )
556554
557555 if print_results :
558- for name , task in task_dict .items ():
559- print (f"\n Task: '{ name } ' on test set" )
560- if task == "regression" :
561- print_metrics_regression (** results_dict [name ]) # type: ignore
562- elif task == "classification" :
563- print_metrics_classification (** results_dict [name ]) # type: ignore
556+ for target_name , task_type in task_dict .items ():
557+ print (f"\n Task: '{ target_name } ' on test set" )
558+ if task_type == "regression" :
559+ print_metrics_regression (** results_dict [target_name ]) # type: ignore
560+ elif task_type == "classification" :
561+ print_metrics_classification (** results_dict [target_name ]) # type: ignore
564562
565563 return results_dict
566564
567565
568- def print_metrics_regression (targets : Tensor , preds : Tensor , ** kwargs ) -> None :
569- """Print out metrics for a regression task.
566+ def print_metrics_regression (targets : np . ndarray , preds : np . ndarray , ** kwargs ) -> None :
567+ """Print out single model and/or ensemble metrics for a regression task.
570568
571569 Args:
572- targets (ndarray(n_test)): targets for regression task
573- preds (ndarray (n_ensemble, n_test)): model predictions
574- kwargs: unused entries from the results dictionary
570+ targets (np.array): Targets for regression task. Shape (n_test,).
571+ preds (np.array): Model predictions. Shape (n_ensemble, n_test).
572+ kwargs: unused entries from the results dictionary.
575573 """
576574 ensemble_folds = preds .shape [0 ]
577575 res = preds - targets
@@ -620,7 +618,7 @@ def print_metrics_regression(targets: Tensor, preds: Tensor, **kwargs) -> None:
620618
621619
622620def print_metrics_classification (
623- target : LongTensor ,
621+ targets : LongTensor ,
624622 logits : Tensor ,
625623 average : Literal ["micro" , "macro" , "samples" , "weighted" ] = "micro" ,
626624 ** kwargs ,
@@ -632,8 +630,8 @@ def print_metrics_classification(
632630 to multi-task automatically?
633631
634632 Args:
635- target (ndarray(n_test)): categorical encoding of the tasks
636- logits (list[n_ens * ndarray (n_targets, n_test)]): logits predicted by the model
633+ targets (np.array): Categorical encoding of the tasks. Shape (n_test,).
634+ logits (list[n_ens * np.array (n_targets, n_test)]): logits predicted by the model.
637635 average ("micro" | "macro" | "samples" | "weighted"): Determines the type of
638636 data averaging. Defaults to 'micro' which calculates metrics globally by
639637 considering each element of the label indicator matrix as a label.
@@ -652,16 +650,16 @@ def print_metrics_classification(
652650 fscore = np .zeros (len (logits ))
653651
654652 target_ohe = np .zeros_like (logits [0 ])
655- target_ohe [np .arange (target .size ), target ] = 1
653+ target_ohe [np .arange (targets .size ), targets ] = 1
656654
657655 for j , y_logit in enumerate (logits ):
658656
659657 y_pred = np .argmax (y_logit , axis = 1 )
660658
661- acc [j ] = accuracy_score (target , y_pred )
659+ acc [j ] = accuracy_score (targets , y_pred )
662660 roc_auc [j ] = roc_auc_score (target_ohe , y_logit , average = average )
663661 precision [j ], recall [j ], fscore [j ], _ = precision_recall_fscore_support (
664- target , y_pred , average = average
662+ targets , y_pred , average = average
665663 )
666664
667665 if len (logits ) == 1 :
@@ -699,10 +697,10 @@ def print_metrics_classification(
699697
700698 y_pred = np .argmax (ens_logits , axis = 1 )
701699
702- ens_acc = accuracy_score (target , y_pred )
700+ ens_acc = accuracy_score (targets , y_pred )
703701 ens_roc_auc = roc_auc_score (target_ohe , ens_logits , average = average )
704702 ens_prec , ens_recall , ens_fscore , _ = precision_recall_fscore_support (
705- target , y_pred , average = average
703+ targets , y_pred , average = average
706704 )
707705
708706 print ("\n Ensemble Performance Metrics:" )
@@ -779,8 +777,8 @@ def get_metrics(
779777 """Get performance metrics for model predictions.
780778
781779 Args:
782- targets (np.ndarray ): Ground truth values.
783- preds (np.ndarray ): Model predictions. Should be class probabilities for classification
780+ targets (np.array ): Ground truth values.
781+ preds (np.array ): Model predictions. Should be class probabilities for classification
784782 (i.e. output model after applying softmax/sigmoid). Same shape as targets for
785783 regression, and [len(targets), n_classes] for classification.
786784 type ('regression' | 'classification'): Task type.
0 commit comments