Skip to content

Commit 2bb8ab1

Browse files
committed
Fix _can_skip_jacobian_combination
1 parent b714253 commit 2bb8ab1

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,15 @@ def jac_to_grad(
113113

114114

115115
def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]:
116-
return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator)
116+
return (
117+
isinstance(aggregator, GramianWeightedAggregator)
118+
and not _has_forward_hook(aggregator)
119+
and not _has_forward_hook(aggregator.weighting)
120+
)
117121

118122

119123
def _has_forward_hook(module: nn.Module) -> bool:
120124
"""Return whether the module has any forward hook registered."""
121-
# TODO: also check hooks on the outer weighting
122125
return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0
123126

124127

0 commit comments

Comments
 (0)