diff --git a/timm/train.py b/timm/train.py index 1ccf1e4e38..72ce79f0fc 100755 --- a/timm/train.py +++ b/timm/train.py @@ -937,6 +937,10 @@ def train(config: dict[str, t.Any]): if eval_metrics: mlflow.log_metric("val loss", eval_metrics["loss"], step=epoch) mlflow.log_metric("val accuracy", eval_metrics["top1"], step=epoch) + for vr in utils.EVAL_VERIFICATION_RATES: + mlflow.log_metric(f"FA at {int(100 * vr):03d}", eval_metrics[f"fa@{vr}"]) + mlflow.log_metric(f"AFA at {int(100 * vr):03d}", eval_metrics[f"afa@{vr}"]) + if output_dir is not None: lrs = [param_group['lr'] for param_group in optimizer.param_groups] @@ -1152,6 +1156,7 @@ def validate( losses_m = utils.AverageMeter() top1_m = utils.AverageMeter() top5_m = utils.AverageMeter() + correct_with_confidences_m = utils.CorrectnessOfPredictionsWithConfidencesMeter() model.eval() @@ -1193,6 +1198,7 @@ def validate( losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) top5_m.update(acc5.item(), output.size(0)) + correct_with_confidences_m.update(output, target) batch_time_m.update(time.time() - end) end = time.time() @@ -1206,7 +1212,32 @@ def validate( f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})' ) - metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) + metrics = OrderedDict( + [ + ("loss", losses_m.avg), + ("top1", top1_m.avg), + ("top5", top5_m.avg), + *[ + (f"fa@{vr}", fa) + for vr, fa in zip( + utils.EVAL_VERIFICATION_RATES, + correct_with_confidences_m.final_accuracy( + utils.EVAL_VERIFICATION_RATES + ), + ) + ], + *[ + (f"afa@{vr}", afa) + for vr, afa in zip( + utils.EVAL_VERIFICATION_RATES, + correct_with_confidences_m.average_final_accuracy( + utils.EVAL_VERIFICATION_RATES + ), + ) + ], + ] + ) + return metrics diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 9093b75a82..7afbd50029 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -8,7 +8,7 @@ world_info_from_env, is_distributed_env, is_primary from .jit import set_jit_legacy, set_jit_fuser from .log import setup_default_logging, FormatterNoInfo -from .metrics import AverageMeter, accuracy +from .metrics import AverageMeter, accuracy, CorrectnessOfPredictionsWithConfidencesMeter, EVAL_VERIFICATION_RATES from .misc import natural_key, add_bool_arg, ParseKwargs from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3 diff --git a/timm/utils/metrics.py b/timm/utils/metrics.py index 9fdbe13ef1..f313867789 100644 --- a/timm/utils/metrics.py +++ b/timm/utils/metrics.py @@ -2,7 +2,9 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import torch +EVAL_VERIFICATION_RATES = [0.01, 0.02, 0.05, 0.1, 0.2] class AverageMeter: """Computes and stores the average and current value""" @@ -22,6 +24,56 @@ def update(self, val, n=1): self.avg = self.sum / self.count +class CorrectnessOfPredictionsWithConfidencesMeter: + def __init__(self): + self.reset() + + def reset(self): + self.predictions_correct = [] + self.confidences = [] + + def update(self, output, target): + confidences, preds = output.topk(k=1) + preds = preds.t() + correct = preds.eq(target.reshape(1, -1).expand_as(preds)).flatten() + + self.predictions_correct.append(correct.detach().cpu()) + self.confidences.append(confidences.detach().cpu()) + + def final_accuracy(self, vrs): + correct = torch.cat(self.predictions_correct) + confidences = torch.cat(self.confidences) + + correct_sorted = correct[confidences.flatten().argsort()] + N = len(correct_sorted) + + def _fa(vr): + n_verified = round(vr * N) + return (n_verified + correct_sorted[n_verified:].sum()) / N + + return [_fa(vr) for vr in vrs] + + def average_final_accuracy(self, vrs): + correct = torch.cat(self.predictions_correct) + confidences = torch.cat(self.confidences) + + correct_sorted = correct[confidences.flatten().argsort()] + N = len(correct_sorted) + + def _afa(vr): + # see https://drive.google.com/file/d/1Uag8VtD3RwsoS8hs59X6T5u_iwuqspkS/view + # for derivation of this formula + n_verified = round(vr * N) + afa_weights = torch.arange(1, N + 1) / n_verified + return ( + (n_verified - 1) / 2 + + (afa_weights[:n_verified] * correct_sorted[:n_verified]).sum() + + correct_sorted[n_verified:].sum() + ) / N + + return [_afa(vr) for vr in vrs] + + def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" maxk = min(max(topk), output.size()[1])