-
Notifications
You must be signed in to change notification settings - Fork 486
Fix MetricCollection state comparison for nested sequence states #3337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
b0a5e37
31a497b
a7620d8
f1e2149
1f374cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,23 @@ def _remove_suffix(string: str, suffix: str) -> str: | |
| return string[: -len(suffix)] if string.endswith(suffix) else string | ||
|
|
||
|
|
||
| def _equal_state_value(state1: Any, state2: Any) -> bool: | ||
| """Recursively compare metric state values while preserving structure checks.""" | ||
| if type(state1) is not type(state2): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| return False | ||
|
|
||
| if isinstance(state1, Tensor): | ||
| return state1.shape == state2.shape and allclose(state1, state2) | ||
|
|
||
| if isinstance(state1, Mapping): | ||
| return state1.keys() == state2.keys() and all(_equal_state_value(state1[k], state2[k]) for k in state1) | ||
|
|
||
| if isinstance(state1, Sequence) and not isinstance(state1, str): | ||
| return len(state1) == len(state2) and all(_equal_state_value(s1, s2) for s1, s2 in zip(state1, state2)) | ||
|
Comment on lines
+67
to
+71
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Mapping and Sequence branches use generator expressions with all(). For very deep nesting (e.g., list of list of list of ...) this could hit Python recursion limits. In practice metric states are rarely >3 levels deep, so this is theoretical -- but note the limit if it matters for your use case. |
||
|
|
||
| return state1 == state2 | ||
|
|
||
|
|
||
| class MetricCollection(ModuleDict): | ||
| """MetricCollection class can be used to chain metrics that have the same call pattern into one single class. | ||
|
|
||
|
|
@@ -316,21 +333,7 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: | |
| state1 = getattr(metric1, key) | ||
| state2 = getattr(metric2, key) | ||
|
|
||
| if type(state1) != type(state2): # noqa: E721 | ||
| return False | ||
|
|
||
| if ( | ||
| isinstance(state1, Tensor) | ||
| and isinstance(state2, Tensor) | ||
| and not (state1.shape == state2.shape and allclose(state1, state2)) | ||
| ): | ||
| return False | ||
|
|
||
| if ( | ||
| isinstance(state1, list) | ||
| and isinstance(state2, list) | ||
| and not (all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2))) | ||
| ): | ||
| if not _equal_state_value(state1, state2): | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _equal_state_value docstring is one line. Consider listing supported types for callers: e.g., "Recursively compare metric state values. Supports: Tensor (shape+value), Mapping (key+recursive value), Sequence/str-excluded (length+recursive element), and primitives (direct ==)."