Skip to content

Commit 0f85811

Browse files
committed
Improve test_can_skip_jacobian_combination
1 parent 63c9dde commit 0f85811

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

tests/unit/autojac/test_jac_to_grad.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
TrimmedMean,
2424
UPGrad,
2525
)
26-
from torchjd.aggregation._aggregator_bases import WeightedAggregator
26+
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator
2727
from torchjd.autojac._jac_to_grad import (
2828
_can_skip_jacobian_combination,
2929
_has_forward_hook,
@@ -226,11 +226,22 @@ def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool) -
226226
handle = aggregator.register_forward_hook(lambda _module, _input, output: output)
227227
assert not _can_skip_jacobian_combination(aggregator)
228228
handle.remove()
229+
assert _can_skip_jacobian_combination(aggregator) == expected
229230
handle = aggregator.register_forward_pre_hook(lambda _module, input: input)
230231
assert not _can_skip_jacobian_combination(aggregator)
231232
handle.remove()
232233
assert _can_skip_jacobian_combination(aggregator) == expected
233234

235+
if isinstance(aggregator, GramianWeightedAggregator):
236+
handle = aggregator.weighting.register_forward_hook(lambda _module, _input, output: output)
237+
assert not _can_skip_jacobian_combination(aggregator)
238+
handle.remove()
239+
assert _can_skip_jacobian_combination(aggregator) == expected
240+
handle = aggregator.weighting.register_forward_pre_hook(lambda _module, input: input)
241+
assert not _can_skip_jacobian_combination(aggregator)
242+
handle.remove()
243+
assert _can_skip_jacobian_combination(aggregator) == expected
244+
234245

235246
def test_noncontiguous_jac() -> None:
236247
"""Tests that jac_to_grad works when the .jac field is non-contiguous."""

0 commit comments

Comments
 (0)