Skip to content

Commit e42d677

Browse files
authored
fix(aggregation): Fix type mismatch bug in NashMTL (#317)
* Move cast of alpha to torch.Tensor out of the condition on the step value * Add changelog entry Notes: * self.prvs_alpha is always a numpy array, so in both cases (if (self.step % self.update_weights_every) == 0: and else), we have to cast alpha to a Tensor. * In the original implementation of https://github.com/AvivNavon/nash-mtl/blob/main/methods/weight_methods.py#L238, there was already a mismatch of type, with alpha being a tensor when entering the condition (if (self.step % self.update_weights_every) == 0), and being a numpy array otherwise, but the following line (weighted_loss = sum([losses[i] * alpha[i] for i in range(len(alpha))])) made it work regardless.
1 parent 19a4ed4 commit e42d677

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ changes that do not affect the user.
3535
- Removed arbitrary exception handling in `IMTLG` and `AlignedMTL` when the computation fails. In
3636
practice, this fix should only affect some matrices with extremely large values, which should
3737
not usually happen.
38+
- Fixed a bug in `NashMTL` that made it fail (due to a type mismatch) when `update_weights_every`
39+
was more than 1.
3840

3941
## [0.5.0] - 2025-02-01
4042

src/torchjd/aggregation/nash_mtl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,12 @@ def forward(self, matrix: Tensor) -> Tensor:
197197
self.normalization_factor = torch.norm(GTG).detach().cpu().numpy().reshape((1,))
198198
GTG = GTG / self.normalization_factor.item()
199199
alpha = self._solve_optimization(GTG.cpu().detach().numpy())
200-
alpha = torch.from_numpy(alpha).to(device=matrix.device, dtype=matrix.dtype)
201200
else:
202201
self.step += 1
203202
alpha = self.prvs_alpha
204203

204+
alpha = torch.from_numpy(alpha).to(device=matrix.device, dtype=matrix.dtype)
205+
205206
if self.max_norm > 0:
206207
norm = torch.linalg.norm(alpha @ matrix)
207208
if norm > self.max_norm:

0 commit comments

Comments
 (0)