Skip to content

Commit e995d65

Browse files
committed
Make materialize public to the transform package
- materialize is in a protected file _materialize.py, which means that only members of the local transform package can import from it. There is thus no point in make the materialize function itself protected. In fact, this would mean that the materialize function should not be used outside _materialize.py, while in fact, it has to be used in grad.py, jac.py, and stack.py.
1 parent aeaddf5 commit e995d65

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

src/torchjd/autojac/_transform/_materialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55

66

7-
def _materialize(
7+
def materialize(
88
optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor]
99
) -> tuple[Tensor, ...]:
1010
"""

src/torchjd/autojac/_transform/grad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55

66
from ._differentiate import _Differentiate
7-
from ._materialize import _materialize
7+
from ._materialize import materialize
88
from .tensor_dict import Gradients
99

1010

@@ -47,5 +47,5 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
4747
create_graph=self.create_graph,
4848
allow_unused=True,
4949
)
50-
grads = _materialize(optional_grads, inputs)
50+
grads = materialize(optional_grads, inputs)
5151
return grads

src/torchjd/autojac/_transform/jac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch import Size, Tensor
88

99
from ._differentiate import _Differentiate
10-
from ._materialize import _materialize
10+
from ._materialize import materialize
1111
from .tensor_dict import Jacobians
1212

1313

@@ -60,7 +60,7 @@ def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor:
6060
create_graph=self.create_graph,
6161
allow_unused=True,
6262
)
63-
grads = _materialize(optional_grads, inputs=inputs)
63+
grads = materialize(optional_grads, inputs=inputs)
6464
return torch.concatenate([grad.reshape([-1]) for grad in grads])
6565

6666
# By the Jacobians constraint, this value should be the same for all jac_outputs.

src/torchjd/autojac/_transform/stack.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55

6-
from ._materialize import _materialize
6+
from ._materialize import materialize
77
from .base import Transform
88
from .tensor_dict import _A, Gradients, Jacobians
99

@@ -46,6 +46,6 @@ def _stack_one_key(gradient_dicts: list[Gradients], input: Tensor) -> Tensor:
4646
"""
4747

4848
optional_gradients = [gradients.get(input, None) for gradients in gradient_dicts]
49-
gradients = _materialize(optional_gradients, [input] * len(optional_gradients))
49+
gradients = materialize(optional_gradients, [input] * len(optional_gradients))
5050
jacobian = torch.stack(gradients, dim=0)
5151
return jacobian

0 commit comments

Comments
 (0)