We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b714253 commit 2bb8ab1Copy full SHA for 2bb8ab1
1 file changed
src/torchjd/autojac/_jac_to_grad.py
@@ -113,12 +113,15 @@ def jac_to_grad(
113
114
115
def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]:
116
- return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator)
+ return (
117
+ isinstance(aggregator, GramianWeightedAggregator)
118
+ and not _has_forward_hook(aggregator)
119
+ and not _has_forward_hook(aggregator.weighting)
120
+ )
121
122
123
def _has_forward_hook(module: nn.Module) -> bool:
124
"""Return whether the module has any forward hook registered."""
- # TODO: also check hooks on the outer weighting
125
return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0
126
127
0 commit comments