|
23 | 23 | TrimmedMean, |
24 | 24 | UPGrad, |
25 | 25 | ) |
26 | | -from torchjd.aggregation._aggregator_bases import WeightedAggregator |
| 26 | +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator |
27 | 27 | from torchjd.autojac._jac_to_grad import ( |
28 | 28 | _can_skip_jacobian_combination, |
29 | 29 | _has_forward_hook, |
@@ -226,11 +226,22 @@ def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool) - |
226 | 226 | handle = aggregator.register_forward_hook(lambda _module, _input, output: output) |
227 | 227 | assert not _can_skip_jacobian_combination(aggregator) |
228 | 228 | handle.remove() |
| 229 | + assert _can_skip_jacobian_combination(aggregator) == expected |
229 | 230 | handle = aggregator.register_forward_pre_hook(lambda _module, input: input) |
230 | 231 | assert not _can_skip_jacobian_combination(aggregator) |
231 | 232 | handle.remove() |
232 | 233 | assert _can_skip_jacobian_combination(aggregator) == expected |
233 | 234 |
|
| 235 | + if isinstance(aggregator, GramianWeightedAggregator): |
| 236 | + handle = aggregator.weighting.register_forward_hook(lambda _module, _input, output: output) |
| 237 | + assert not _can_skip_jacobian_combination(aggregator) |
| 238 | + handle.remove() |
| 239 | + assert _can_skip_jacobian_combination(aggregator) == expected |
| 240 | + handle = aggregator.weighting.register_forward_pre_hook(lambda _module, input: input) |
| 241 | + assert not _can_skip_jacobian_combination(aggregator) |
| 242 | + handle.remove() |
| 243 | + assert _can_skip_jacobian_combination(aggregator) == expected |
| 244 | + |
234 | 245 |
|
235 | 246 | def test_noncontiguous_jac() -> None: |
236 | 247 | """Tests that jac_to_grad works when the .jac field is non-contiguous.""" |
|
0 commit comments