Skip to content

Commit 8537bd8

Browse files
refactor(autojac): Make utilitary functions public (#308)
* Since they are located in a protected file, they're only public to their package. They don't need to be protected to within their file.
1 parent 069cda6 commit 8537bd8

File tree

4 files changed

+27
-27
lines changed

4 files changed

+27
-27
lines changed

src/torchjd/autojac/_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,23 @@
77
from ._transform.ordered_set import OrderedSet
88

99

10-
def _check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None:
10+
def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None:
1111
if not (parallel_chunk_size is None or parallel_chunk_size > 0):
1212
raise ValueError(
1313
"`parallel_chunk_size` should be `None` or greater than `0`. (got "
1414
f"{parallel_chunk_size})"
1515
)
1616

1717

18-
def _as_tensor_list(tensors: Sequence[Tensor] | Tensor) -> list[Tensor]:
18+
def as_tensor_list(tensors: Sequence[Tensor] | Tensor) -> list[Tensor]:
1919
if isinstance(tensors, Tensor):
2020
output = [tensors]
2121
else:
2222
output = list(tensors)
2323
return output
2424

2525

26-
def _get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]:
26+
def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]:
2727
"""
2828
Gets the leaves of the autograd graph of all specified ``tensors``.
2929

src/torchjd/autojac/backward.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac, Transform
88
from ._transform.ordered_set import OrderedSet
9-
from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors
9+
from ._utils import as_tensor_list, check_optional_positive_chunk_size, get_leaf_tensors
1010

1111

1212
def backward(
@@ -67,15 +67,15 @@ def backward(
6767
experience issues with ``backward`` try to use ``parallel_chunk_size=1`` to avoid relying on
6868
``torch.vmap``.
6969
"""
70-
_check_optional_positive_chunk_size(parallel_chunk_size)
70+
check_optional_positive_chunk_size(parallel_chunk_size)
7171

72-
tensors = _as_tensor_list(tensors)
72+
tensors = as_tensor_list(tensors)
7373

7474
if len(tensors) == 0:
7575
raise ValueError("`tensors` cannot be empty")
7676

7777
if inputs is None:
78-
inputs = _get_leaf_tensors(tensors=tensors, excluded=set())
78+
inputs = get_leaf_tensors(tensors=tensors, excluded=set())
7979
else:
8080
inputs = OrderedSet(inputs)
8181

src/torchjd/autojac/mtl_backward.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Transform,
1818
)
1919
from ._transform.ordered_set import OrderedSet
20-
from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors
20+
from ._utils import as_tensor_list, check_optional_positive_chunk_size, get_leaf_tensors
2121

2222

2323
def mtl_backward(
@@ -79,16 +79,16 @@ def mtl_backward(
7979
``torch.vmap``.
8080
"""
8181

82-
_check_optional_positive_chunk_size(parallel_chunk_size)
82+
check_optional_positive_chunk_size(parallel_chunk_size)
8383

84-
features = _as_tensor_list(features)
84+
features = as_tensor_list(features)
8585

8686
if shared_params is None:
87-
shared_params = _get_leaf_tensors(tensors=features, excluded=[])
87+
shared_params = get_leaf_tensors(tensors=features, excluded=[])
8888
else:
8989
shared_params = OrderedSet(shared_params)
9090
if tasks_params is None:
91-
tasks_params = [_get_leaf_tensors(tensors=[loss], excluded=features) for loss in losses]
91+
tasks_params = [get_leaf_tensors(tensors=[loss], excluded=features) for loss in losses]
9292
else:
9393
tasks_params = [OrderedSet(task_params) for task_params in tasks_params]
9494

tests/unit/autojac/test_utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytest import mark, raises
33
from torch.nn import Linear, MSELoss, ReLU, Sequential
44

5-
from torchjd.autojac._utils import _get_leaf_tensors
5+
from torchjd.autojac._utils import get_leaf_tensors
66

77

88
def test_simple_get_leaf_tensors():
@@ -14,7 +14,7 @@ def test_simple_get_leaf_tensors():
1414
y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
1515
y2 = (a1**2).sum() + a2.norm()
1616

17-
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded=set())
17+
leaves = get_leaf_tensors(tensors=[y1, y2], excluded=set())
1818
assert set(leaves) == {a1, a2}
1919

2020

@@ -35,7 +35,7 @@ def test_get_leaf_tensors_excluded_1():
3535
y1 = torch.tensor([-1.0, 1.0]) @ a1 + b2
3636
y2 = b1
3737

38-
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2})
38+
leaves = get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2})
3939
assert set(leaves) == {a1}
4040

4141

@@ -56,7 +56,7 @@ def test_get_leaf_tensors_excluded_2():
5656
y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
5757
y2 = b1
5858

59-
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2})
59+
leaves = get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2})
6060
assert set(leaves) == {a1, a2}
6161

6262

@@ -71,7 +71,7 @@ def test_get_leaf_tensors_leaf_not_requiring_grad():
7171
y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
7272
y2 = (a1**2).sum() + a2.norm()
7373

74-
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded=set())
74+
leaves = get_leaf_tensors(tensors=[y1, y2], excluded=set())
7575
assert set(leaves) == {a1}
7676

7777

@@ -90,7 +90,7 @@ def test_get_leaf_tensors_model():
9090
y_hat = model(x)
9191
losses = loss_fn(y_hat, y)
9292

93-
leaves = _get_leaf_tensors(tensors=[losses], excluded=set())
93+
leaves = get_leaf_tensors(tensors=[losses], excluded=set())
9494
assert set(leaves) == set(model.parameters())
9595

9696

@@ -111,7 +111,7 @@ def test_get_leaf_tensors_model_excluded_2():
111111
z_hat = model2(y)
112112
losses = loss_fn(z_hat, z)
113113

114-
leaves = _get_leaf_tensors(tensors=[losses], excluded={y})
114+
leaves = get_leaf_tensors(tensors=[losses], excluded={y})
115115
assert set(leaves) == set(model2.parameters())
116116

117117

@@ -121,14 +121,14 @@ def test_get_leaf_tensors_single_root():
121121
p = torch.tensor([1.0, 2.0], requires_grad=True)
122122
y = p * 2
123123

124-
leaves = _get_leaf_tensors(tensors=[y], excluded=set())
124+
leaves = get_leaf_tensors(tensors=[y], excluded=set())
125125
assert set(leaves) == {p}
126126

127127

128128
def test_get_leaf_tensors_empty_roots():
129129
"""Tests that _get_leaf_tensors returns no leaves when roots is the empty set."""
130130

131-
leaves = _get_leaf_tensors(tensors=[], excluded=set())
131+
leaves = get_leaf_tensors(tensors=[], excluded=set())
132132
assert set(leaves) == set()
133133

134134

@@ -141,7 +141,7 @@ def test_get_leaf_tensors_excluded_root():
141141
y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
142142
y2 = (a1**2).sum()
143143

144-
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={y1})
144+
leaves = get_leaf_tensors(tensors=[y1, y2], excluded={y1})
145145
assert set(leaves) == {a1}
146146

147147

@@ -154,7 +154,7 @@ def test_get_leaf_tensors_deep(depth: int):
154154
for i in range(depth):
155155
sum_ = sum_ + one
156156

157-
leaves = _get_leaf_tensors(tensors=[sum_], excluded=set())
157+
leaves = get_leaf_tensors(tensors=[sum_], excluded=set())
158158
assert set(leaves) == {one}
159159

160160

@@ -163,7 +163,7 @@ def test_get_leaf_tensors_leaf():
163163

164164
a = torch.tensor(1.0, requires_grad=True)
165165
with raises(ValueError):
166-
_ = _get_leaf_tensors(tensors=[a], excluded=set())
166+
_ = get_leaf_tensors(tensors=[a], excluded=set())
167167

168168

169169
def test_get_leaf_tensors_tensor_not_requiring_grad():
@@ -173,7 +173,7 @@ def test_get_leaf_tensors_tensor_not_requiring_grad():
173173

174174
a = torch.tensor(1.0, requires_grad=False) * 2
175175
with raises(ValueError):
176-
_ = _get_leaf_tensors(tensors=[a], excluded=set())
176+
_ = get_leaf_tensors(tensors=[a], excluded=set())
177177

178178

179179
def test_get_leaf_tensors_excluded_leaf():
@@ -182,7 +182,7 @@ def test_get_leaf_tensors_excluded_leaf():
182182
a = torch.tensor(1.0, requires_grad=True) * 2
183183
b = torch.tensor(2.0, requires_grad=True)
184184
with raises(ValueError):
185-
_ = _get_leaf_tensors(tensors=[a], excluded={b})
185+
_ = get_leaf_tensors(tensors=[a], excluded={b})
186186

187187

188188
def test_get_leaf_tensors_excluded_not_requiring_grad():
@@ -193,4 +193,4 @@ def test_get_leaf_tensors_excluded_not_requiring_grad():
193193
a = torch.tensor(1.0, requires_grad=True) * 2
194194
b = torch.tensor(2.0, requires_grad=False) * 2
195195
with raises(ValueError):
196-
_ = _get_leaf_tensors(tensors=[a], excluded={b})
196+
_ = get_leaf_tensors(tensors=[a], excluded={b})

0 commit comments

Comments
 (0)