Skip to content

Commit b1aaee9

Browse files
committed
Move newly added functions
1 parent 4f24e39 commit b1aaee9

16 files changed

Lines changed: 42 additions & 63 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ changelog does not include internal changes that do not affect the user.
1212

1313
- **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the
1414
Jacobian. Now, these functions compute and populate the `.jac` fields of the parameters, and a new
15-
function `torchjd.utils.jac_to_grad` should then be called to aggregate those `.jac` fields into
15+
function `torchjd.autojac.jac_to_grad` should then be called to aggregate those `.jac` fields into
1616
`.grad` fields.
1717
This means that users now have more control on what they do with the Jacobians (they can easily
1818
aggregate them group by group or even param by param if they want), but it now requires an extra

docs/source/examples/amp.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@ case, the losses) should preferably be scaled with a `GradScaler
1212
following example shows the resulting code for a multi-task learning use-case.
1313

1414
.. code-block:: python
15-
:emphasize-lines: 2, 18, 28, 35-36, 38-39
15+
:emphasize-lines: 2, 17, 27, 34-35, 37-38
1616
1717
import torch
1818
from torch.amp import GradScaler
1919
from torch.nn import Linear, MSELoss, ReLU, Sequential
2020
from torch.optim import SGD
2121
2222
from torchjd.aggregation import UPGrad
23-
from torchjd.autojac import mtl_backward
24-
from torchjd.utils import jac_to_grad
23+
from torchjd.autojac import mtl_backward, jac_to_grad
2524
2625
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
2726
task1_module = Linear(3, 1)

docs/source/examples/basic_usage.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ Import several classes from ``torch`` and ``torchjd``:
1919
from torch.optim import SGD
2020
2121
from torchjd import autojac
22-
from torchjd.aggregation import UPGrad
23-
from torchjd.utils import jac_to_grad
22+
from torchjd.aggregation import UPGrad, jac_to_grad
2423
2524
Define the model and the optimizer, as usual:
2625

docs/source/examples/iwrm.rst

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
5050
5151
5252
53-
5453
X = torch.randn(8, 16, 10)
5554
Y = torch.randn(8, 16)
5655
@@ -78,15 +77,14 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
7877
.. tab-item:: autojac
7978

8079
.. code-block:: python
81-
:emphasize-lines: 5-7, 13, 17, 22-24
80+
:emphasize-lines: 5-6, 12, 16, 21-23
8281
8382
import torch
8483
from torch.nn import Linear, MSELoss, ReLU, Sequential
8584
from torch.optim import SGD
8685
8786
from torchjd.aggregation import UPGrad
88-
from torchjd.autojac import backward
89-
from torchjd.utils import jac_to_grad
87+
from torchjd.autojac import backward, jac_to_grad
9088
9189
X = torch.randn(8, 16, 10)
9290
Y = torch.randn(8, 16)
@@ -115,7 +113,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
115113
.. tab-item:: autogram (recommended)
116114

117115
.. code-block:: python
118-
:emphasize-lines: 5-6, 13, 17-18, 22-25
116+
:emphasize-lines: 5-6, 12, 16-17, 21-24
119117
120118
import torch
121119
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -124,7 +122,6 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
124122
from torchjd.aggregation import UPGradWeighting
125123
from torchjd.autogram import Engine
126124
127-
128125
X = torch.randn(8, 16, 10)
129126
Y = torch.randn(8, 16)
130127

docs/source/examples/lightning_integration.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using
1111
<../docs/autojac/mtl_backward>` at each training iteration.
1212

1313
.. code-block:: python
14-
:emphasize-lines: 9-11, 19, 32-33
14+
:emphasize-lines: 9-10, 18, 31-32
1515
1616
import torch
1717
from lightning import LightningModule, Trainer
@@ -22,8 +22,7 @@ The following code example demonstrates a basic multi-task learning setup using
2222
from torch.utils.data import DataLoader, TensorDataset
2323
2424
from torchjd.aggregation import UPGrad
25-
from torchjd.autojac import mtl_backward
26-
from torchjd.utils import jac_to_grad
25+
from torchjd.autojac import mtl_backward, jac_to_grad
2726
2827
class Model(LightningModule):
2928
def __init__(self):

docs/source/examples/monitoring.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,15 @@ Jacobian descent is doing something different than gradient descent. With
1515
they have a negative inner product).
1616

1717
.. code-block:: python
18-
:emphasize-lines: 10-12, 14-19, 34-35
18+
:emphasize-lines: 9-11, 13-18, 33-34
1919
2020
import torch
2121
from torch.nn import Linear, MSELoss, ReLU, Sequential
2222
from torch.nn.functional import cosine_similarity
2323
from torch.optim import SGD
2424
2525
from torchjd.aggregation import UPGrad
26-
from torchjd.autojac import mtl_backward
27-
from torchjd.utils import jac_to_grad
26+
from torchjd.autojac import mtl_backward, jac_to_grad
2827
2928
def print_weights(_, __, weights: torch.Tensor) -> None:
3029
"""Prints the extracted weights."""

docs/source/examples/mtl.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
1919

2020

2121
.. code-block:: python
22-
:emphasize-lines: 5-7, 20, 33-34
22+
:emphasize-lines: 5-6, 19, 32-33
2323
2424
import torch
2525
from torch.nn import Linear, MSELoss, ReLU, Sequential
2626
from torch.optim import SGD
2727
2828
from torchjd.aggregation import UPGrad
29-
from torchjd.autojac import mtl_backward
30-
from torchjd.utils import jac_to_grad
29+
from torchjd.autojac import mtl_backward, jac_to_grad
3130
3231
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
3332
task1_module = Linear(3, 1)

docs/source/examples/rnn.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@ element of the output sequences. If the gradients of these losses are likely to
66
descent can be leveraged to enhance optimization.
77

88
.. code-block:: python
9-
:emphasize-lines: 5-7, 11, 18, 20-21
9+
:emphasize-lines: 5-6, 10, 17, 19-20
1010
1111
import torch
1212
from torch.nn import RNN
1313
from torch.optim import SGD
1414
1515
from torchjd.aggregation import UPGrad
16-
from torchjd.autojac import backward
17-
from torchjd.utils import jac_to_grad
16+
from torchjd.autojac import backward, jac_to_grad
1817
1918
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
2019
optimizer = SGD(rnn.parameters(), lr=0.1)

src/torchjd/autojac/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
from ._backward import backward
9+
from ._jac_to_grad import jac_to_grad
910
from ._mtl_backward import mtl_backward
1011

11-
__all__ = ["backward", "mtl_backward"]
12+
__all__ = ["backward", "jac_to_grad", "mtl_backward"]
Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,18 @@
33

44
from torch import Tensor
55

6-
from torchjd.utils._tensor_with_jac import TensorWithJac
6+
7+
class TensorWithJac(Tensor):
8+
"""
9+
Tensor known to have a populated jac field.
10+
11+
Should not be directly instantiated, but can be used as a type hint and can be casted to.
12+
"""
13+
14+
jac: Tensor
715

816

9-
def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None:
17+
def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None:
1018
for param, jac in zip(params, jacobians, strict=True):
1119
_check_expects_grad(param)
1220
# We that the shape is correct to be consistent with torch, that checks that the grad
@@ -34,7 +42,7 @@ def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> N
3442
param.__setattr__("jac", jac)
3543

3644

37-
def _accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None:
45+
def accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None:
3846
for param, grad in zip(params, gradients, strict=True):
3947
_check_expects_grad(param)
4048
if hasattr(param, "grad") and param.grad is not None:

0 commit comments

Comments
 (0)