diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 839d97619bd..99bf4fd4c5e 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -56,6 +56,23 @@ def _remove_suffix(string: str, suffix: str) -> str: return string[: -len(suffix)] if string.endswith(suffix) else string +def _equal_state_value(state1: Any, state2: Any) -> bool: + """Recursively compare metric state values while preserving structure checks.""" + if type(state1) is not type(state2): + return False + + if isinstance(state1, Tensor): + return state1.shape == state2.shape and allclose(state1, state2) + + if isinstance(state1, Mapping): + return state1.keys() == state2.keys() and all(_equal_state_value(state1[k], state2[k]) for k in state1) + + if isinstance(state1, Sequence) and not isinstance(state1, str): + return len(state1) == len(state2) and all(_equal_state_value(s1, s2) for s1, s2 in zip(state1, state2)) + + return state1 == state2 + + class MetricCollection(ModuleDict): """MetricCollection class can be used to chain metrics that have the same call pattern into one single class. @@ -316,21 +333,7 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: state1 = getattr(metric1, key) state2 = getattr(metric2, key) - if type(state1) != type(state2): # noqa: E721 - return False - - if ( - isinstance(state1, Tensor) - and isinstance(state2, Tensor) - and not (state1.shape == state2.shape and allclose(state1, state2)) - ): - return False - - if ( - isinstance(state1, list) - and isinstance(state2, list) - and not (all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2))) - ): + if not _equal_state_value(state1, state2): return False return True diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 46de5556286..7acba11ce2d 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -855,3 +855,28 @@ def test_collection_state_being_re_established_after_copy(): assert not m12._state_is_copy assert m12.m1.mean_x.data_ptr() == m12.m2.mean_x.data_ptr(), "States should point to the same location" assert m12._equal_metric_states(m12.m1, m12.m2) + + +def test_collection_compute_groups_with_nested_sequence_states(): + """Check that compute group merging handles nested sequence state values.""" + + class DummyNestedListMetric(Metric): + full_state_update = True + + def __init__(self) -> None: + super().__init__() + self.add_state("x", [], dist_reduce_fx=None) + + def update(self, x: torch.Tensor) -> None: + self.x.append((x, x + 1)) + + def compute(self) -> list[tuple[torch.Tensor, torch.Tensor]]: + return self.x + + m1, m2 = DummyNestedListMetric(), DummyNestedListMetric() + metrics = MetricCollection({"m1": m1, "m2": m2}, compute_groups=True) + + metrics.update(torch.tensor([1.0])) + + assert metrics.compute_groups == {0: ["m1", "m2"]} + assert metrics._equal_metric_states(metrics.m1, metrics.m2)