Skip to content

MinMaxMetric state not saved in the checkpoint #3323

@patrontheo

Description

@patrontheo

🐛 Bug

MinMaxMetric stores min_val and max_val as plain tensor attributes rather than registered buffers (or add_state with persistent=True). This means they are not included in state_dict() and are therefore lost when saving and loading a checkpoint. After loading, min_val resets to inf and max_val resets to -inf, so the tracked minimum and maximum no longer reflect the full history of the experiment.

To Reproduce

Code sample
import torch
from torchmetrics.wrappers import MinMaxMetric
from torchmetrics.classification import BinaryAccuracy

# Create metric and simulate two epochs
metric = MinMaxMetric(BinaryAccuracy())

# Epoch 1: perfect predictions
metric.update(torch.tensor([1, 0, 1, 0]), torch.tensor([1, 0, 1, 0]))
result = metric.compute()
print(f"Epoch 1: {result}")  # raw=1.0, min=1.0, max=1.0
metric.reset()

# Epoch 2: worse predictions
metric.update(torch.tensor([1, 1, 0, 0]), torch.tensor([1, 0, 1, 0]))
result = metric.compute()
print(f"Epoch 2: {result}")  # raw=0.5, min=0.5, max=1.0

# Save and reload state dict (simulates checkpoint resume)
state_dict = metric.state_dict()
print(f"\nstate_dict keys: {list(state_dict.keys())}")  # [] — empty!
print(f"min_val in state_dict: {'min_val' in state_dict}")  # False
print(f"max_val in state_dict: {'max_val' in state_dict}")  # False

# Load into a fresh metric
metric2 = MinMaxMetric(BinaryAccuracy())
metric2.load_state_dict(state_dict)

# Epoch 3 after resume: min/max are lost
metric2.reset()
metric2.update(torch.tensor([1, 0, 0, 0]), torch.tensor([1, 0, 1, 0]))
result = metric2.compute()
print(f"\nEpoch 3 (after resume): {result}")
# raw=0.75, min=0.75, max=0.75
# Expected:  min=0.5 (from epoch 2), max=1.0 (from epoch 1)

Expected behavior

After loading a checkpoint, min_val and max_val should be restored to their saved values so that the min/max tracking reflects the entire experiment, not just the portion after the resume.

Actual behavior

state_dict() returns an empty dict. After loading, min_val and max_val are re-initialized to inf / -inf, so the historical min/max are lost.

Root cause

In MinMaxMetric.__init__, min_val and max_val are assigned as plain tensor attributes:

self.min_val = torch.tensor(float("inf"))
self.max_val = torch.tensor(float("-inf"))

They are not registered via register_buffer(), so nn.Module.state_dict() does not include them. They also cannot use add_state() because Metric.reset() would then reset them to their defaults on every epoch, which would defeat their purpose.

Suggested fix

Register them as persistent buffers:

self.register_buffer("min_val", torch.tensor(float("inf")))
self.register_buffer("max_val", torch.tensor(float("-inf")))

This would:

  • Include them in state_dict() / load_state_dict() for correct checkpoint behavior
  • Not affect Metric.reset(), which only resets states in _defaults (populated by add_state)
  • Automatically move them with the module on .to(device) / .cuda(), making the manual .to(val.device) calls in compute() redundant (but still correct as no-ops)
Environment
  • TorchMetrics version: 1.8.2
  • Python: 3.11.14
  • PyTorch: 2.9.0+cu128
  • OS: Linux

Additional context

Happy to open a PR if you agree with the fix or if you have another idea :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions