Skip to content

Commit bc217c5

Browse files
committed
Use vgp_from_module_2
1 parent 8d4fad0 commit bc217c5

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/arena/interfaces.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ def forward_backward(model: Module, input: Tensor, aggregator: GramianWeightedAg
111111

112112
class ForwardBackwardAutogramInterface(Interface):
113113
def __call__(self, _: str):
114-
from torchjd._autogram._vgp import get_gramian, vgp_from_module_1
114+
from torchjd._autogram._vgp import get_gramian, vgp_from_module_2
115115
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator
116116

117117
def forward_backward(model: Module, input: Tensor, aggregator: GramianWeightedAggregator) -> None:
118-
output, vgp_fn = vgp_from_module_1(model, input)
118+
output, vgp_fn = vgp_from_module_2(model, input)
119119
gramian = get_gramian(vgp_fn, output)
120120
weights = aggregator.weighting.weighting(gramian)
121121
output.backward(weights)

0 commit comments

Comments
 (0)