🐛 Bug
MinMaxMetric stores min_val and max_val as plain tensor attributes rather than registered buffers (or add_state with persistent=True). This means they are not included in state_dict() and are therefore lost when saving and loading a checkpoint. After loading, min_val resets to inf and max_val resets to -inf, so the tracked minimum and maximum no longer reflect the full history of the experiment.
To Reproduce
Code sample
import torch
from torchmetrics.wrappers import MinMaxMetric
from torchmetrics.classification import BinaryAccuracy
# Create metric and simulate two epochs
metric = MinMaxMetric(BinaryAccuracy())
# Epoch 1: perfect predictions
metric.update(torch.tensor([1, 0, 1, 0]), torch.tensor([1, 0, 1, 0]))
result = metric.compute()
print(f"Epoch 1: {result}") # raw=1.0, min=1.0, max=1.0
metric.reset()
# Epoch 2: worse predictions
metric.update(torch.tensor([1, 1, 0, 0]), torch.tensor([1, 0, 1, 0]))
result = metric.compute()
print(f"Epoch 2: {result}") # raw=0.5, min=0.5, max=1.0
# Save and reload state dict (simulates checkpoint resume)
state_dict = metric.state_dict()
print(f"\nstate_dict keys: {list(state_dict.keys())}") # [] — empty!
print(f"min_val in state_dict: {'min_val' in state_dict}") # False
print(f"max_val in state_dict: {'max_val' in state_dict}") # False
# Load into a fresh metric
metric2 = MinMaxMetric(BinaryAccuracy())
metric2.load_state_dict(state_dict)
# Epoch 3 after resume: min/max are lost
metric2.reset()
metric2.update(torch.tensor([1, 0, 0, 0]), torch.tensor([1, 0, 1, 0]))
result = metric2.compute()
print(f"\nEpoch 3 (after resume): {result}")
# raw=0.75, min=0.75, max=0.75
# Expected: min=0.5 (from epoch 2), max=1.0 (from epoch 1)
Expected behavior
After loading a checkpoint, min_val and max_val should be restored to their saved values so that the min/max tracking reflects the entire experiment, not just the portion after the resume.
Actual behavior
state_dict() returns an empty dict. After loading, min_val and max_val are re-initialized to inf / -inf, so the historical min/max are lost.
Root cause
In MinMaxMetric.__init__, min_val and max_val are assigned as plain tensor attributes:
self.min_val = torch.tensor(float("inf"))
self.max_val = torch.tensor(float("-inf"))
They are not registered via register_buffer(), so nn.Module.state_dict() does not include them. They also cannot use add_state() because Metric.reset() would then reset them to their defaults on every epoch, which would defeat their purpose.
Suggested fix
Register them as persistent buffers:
self.register_buffer("min_val", torch.tensor(float("inf")))
self.register_buffer("max_val", torch.tensor(float("-inf")))
This would:
- Include them in
state_dict() / load_state_dict() for correct checkpoint behavior
- Not affect
Metric.reset(), which only resets states in _defaults (populated by add_state)
- Automatically move them with the module on
.to(device) / .cuda(), making the manual .to(val.device) calls in compute() redundant (but still correct as no-ops)
Environment
- TorchMetrics version: 1.8.2
- Python: 3.11.14
- PyTorch: 2.9.0+cu128
- OS: Linux
Additional context
Happy to open a PR if you agree with the fix or if you have another idea :)
🐛 Bug
MinMaxMetricstoresmin_valandmax_valas plain tensor attributes rather than registered buffers (oradd_statewithpersistent=True). This means they are not included instate_dict()and are therefore lost when saving and loading a checkpoint. After loading,min_valresets toinfandmax_valresets to-inf, so the tracked minimum and maximum no longer reflect the full history of the experiment.To Reproduce
Code sample
Expected behavior
After loading a checkpoint,
min_valandmax_valshould be restored to their saved values so that the min/max tracking reflects the entire experiment, not just the portion after the resume.Actual behavior
state_dict()returns an empty dict. After loading,min_valandmax_valare re-initialized toinf/-inf, so the historical min/max are lost.Root cause
In
MinMaxMetric.__init__,min_valandmax_valare assigned as plain tensor attributes:They are not registered via
register_buffer(), sonn.Module.state_dict()does not include them. They also cannot useadd_state()becauseMetric.reset()would then reset them to their defaults on every epoch, which would defeat their purpose.Suggested fix
Register them as persistent buffers:
This would:
state_dict()/load_state_dict()for correct checkpoint behaviorMetric.reset(), which only resets states in_defaults(populated byadd_state).to(device)/.cuda(), making the manual.to(val.device)calls incompute()redundant (but still correct as no-ops)Environment
Additional context
Happy to open a PR if you agree with the fix or if you have another idea :)