Skip to content

Commit 850b370

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1a2ae67 commit 850b370

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/torchjd/aggregation/_stch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ def forward(self, gramian: PSDMatrix) -> Tensor:
190190
if self.nadir_accumulator is None:
191191
self.nadir_accumulator = grad_norms.detach().clone()
192192
else:
193-
self.nadir_accumulator = self.nadir_accumulator.to(
194-
device=device, dtype=dtype
195-
) + grad_norms.detach()
193+
self.nadir_accumulator = (
194+
self.nadir_accumulator.to(device=device, dtype=dtype) + grad_norms.detach()
195+
)
196196

197197
self.step += 1
198198

0 commit comments

Comments
 (0)