Skip to content

Commit 7db95ff

Browse files
authored
Fix backward and mtl_backward bug with some tensor shapes (#227)
* Add failing parametrizations to test_value_is_correct * Fix bug in _get_jac_matrix_chunk * Add [YANKED] tag next to [0.4.0] header * Add changelog entry
1 parent 1a8454e commit 7db95ff

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@ changes that do not affect the user.
88

99
## [Unreleased]
1010

11-
## [0.4.0] - 2025-01-02
11+
### Fixed
12+
13+
- Fixed a bug introduced in v0.4.0 that could cause `backward` and `mtl_backward` to fail with some
14+
tensor shapes.
15+
16+
## [0.4.0] - 2025-01-02 [YANKED]
1217

1318
### Changed
1419

src/torchjd/autojac/_transform/jac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _get_jac_matrix_chunk(
114114

115115
chunk_size = jac_outputs_chunk[0].shape[0]
116116
if chunk_size == 1:
117-
grad_outputs = [tensor.squeeze() for tensor in jac_outputs_chunk]
117+
grad_outputs = [tensor.squeeze(0) for tensor in jac_outputs_chunk]
118118
gradient_vector = get_vjp(grad_outputs)
119119
return gradient_vector.unsqueeze(0)
120120
else:

tests/unit/autojac/test_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_various_aggregators(aggregator: Aggregator):
2424

2525

2626
@mark.parametrize("aggregator", [Mean(), UPGrad()])
27-
@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (20, 55)])
27+
@mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)])
2828
@mark.parametrize("manually_specify_inputs", [True, False])
2929
@mark.parametrize("chunk_size", [1, 2, None])
3030
def test_value_is_correct(

tests/unit/autojac/test_mtl_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_various_aggregators(aggregator: Aggregator):
2727

2828

2929
@mark.parametrize("aggregator", [Mean(), UPGrad()])
30-
@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (20, 55)])
30+
@mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)])
3131
@mark.parametrize("manually_specify_shared_params", [True, False])
3232
@mark.parametrize("manually_specify_tasks_params", [True, False])
3333
@mark.parametrize("chunk_size", [1, 2, None])

0 commit comments

Comments
 (0)