Commit e42d677
authored
fix(aggregation): Fix type mismatch bug in
* Move cast of alpha to torch.Tensor out of the condition on the step value
* Add changelog entry
Notes:
* self.prvs_alpha is always a numpy array, so in both cases (if (self.step % self.update_weights_every) == 0: and else), we have to cast alpha to a Tensor.
* In the original implementation of https://github.com/AvivNavon/nash-mtl/blob/main/methods/weight_methods.py#L238, there was already a mismatch of type, with alpha being a tensor when entering the condition (if (self.step % self.update_weights_every) == 0), and being a numpy array otherwise, but the following line (weighted_loss = sum([losses[i] * alpha[i] for i in range(len(alpha))])) made it work regardless.NashMTL (#317)1 parent 19a4ed4 commit e42d677
2 files changed
+4
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
| 38 | + | |
| 39 | + | |
38 | 40 | | |
39 | 41 | | |
40 | 42 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
197 | 197 | | |
198 | 198 | | |
199 | 199 | | |
200 | | - | |
201 | 200 | | |
202 | 201 | | |
203 | 202 | | |
204 | 203 | | |
| 204 | + | |
| 205 | + | |
205 | 206 | | |
206 | 207 | | |
207 | 208 | | |
| |||
0 commit comments