Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ def _remove_suffix(string: str, suffix: str) -> str:
return string[: -len(suffix)] if string.endswith(suffix) else string


def _equal_state_value(state1: Any, state2: Any) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _equal_state_value docstring is one line. Consider listing supported types for callers: e.g., "Recursively compare metric state values. Supports: Tensor (shape+value), Mapping (key+recursive value), Sequence/str-excluded (length+recursive element), and primitives (direct ==)."

"""Recursively compare metric state values while preserving structure checks."""
if type(state1) is not type(state2):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type(state1) is not type(state2) is intentionally strict (exact type match), which is the right call here -- but worth a comment since most Python code uses != or isinstance. Consider # noqa: E721 or a short note.

return False

if isinstance(state1, Tensor):
return state1.shape == state2.shape and allclose(state1, state2)

if isinstance(state1, Mapping):
return state1.keys() == state2.keys() and all(_equal_state_value(state1[k], state2[k]) for k in state1)

if isinstance(state1, Sequence) and not isinstance(state1, str):
return len(state1) == len(state2) and all(_equal_state_value(s1, s2) for s1, s2 in zip(state1, state2))
Comment on lines +67 to +71
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Mapping and Sequence branches use generator expressions with all(). For very deep nesting (e.g., list of list of list of ...) this could hit Python recursion limits. In practice metric states are rarely >3 levels deep, so this is theoretical -- but note the limit if it matters for your use case.


return state1 == state2


class MetricCollection(ModuleDict):
"""MetricCollection class can be used to chain metrics that have the same call pattern into one single class.

Expand Down Expand Up @@ -316,21 +333,7 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:
state1 = getattr(metric1, key)
state2 = getattr(metric2, key)

if type(state1) != type(state2): # noqa: E721
return False

if (
isinstance(state1, Tensor)
and isinstance(state2, Tensor)
and not (state1.shape == state2.shape and allclose(state1, state2))
):
return False

if (
isinstance(state1, list)
and isinstance(state2, list)
and not (all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2)))
):
if not _equal_state_value(state1, state2):
return False

return True
Expand Down
25 changes: 25 additions & 0 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,3 +855,28 @@ def test_collection_state_being_re_established_after_copy():
assert not m12._state_is_copy
assert m12.m1.mean_x.data_ptr() == m12.m2.mean_x.data_ptr(), "States should point to the same location"
assert m12._equal_metric_states(m12.m1, m12.m2)


def test_collection_compute_groups_with_nested_sequence_states():
"""Check that compute group merging handles nested sequence state values."""

class DummyNestedListMetric(Metric):
full_state_update = True

def __init__(self) -> None:
super().__init__()
self.add_state("x", [], dist_reduce_fx=None)

def update(self, x: torch.Tensor) -> None:
self.x.append((x, x + 1))

def compute(self) -> list[tuple[torch.Tensor, torch.Tensor]]:
return self.x

m1, m2 = DummyNestedListMetric(), DummyNestedListMetric()
metrics = MetricCollection({"m1": m1, "m2": m2}, compute_groups=True)

metrics.update(torch.tensor([1.0]))

assert metrics.compute_groups == {0: ["m1", "m2"]}
assert metrics._equal_metric_states(metrics.m1, metrics.m2)
Loading