Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/torchjd/autojac/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@
from ._transform.ordered_set import OrderedSet


def _check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None:
def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None:
if not (parallel_chunk_size is None or parallel_chunk_size > 0):
raise ValueError(
"`parallel_chunk_size` should be `None` or greater than `0`. (got "
f"{parallel_chunk_size})"
)


def _as_tensor_list(tensors: Sequence[Tensor] | Tensor) -> list[Tensor]:
def as_tensor_list(tensors: Sequence[Tensor] | Tensor) -> list[Tensor]:
if isinstance(tensors, Tensor):
output = [tensors]
else:
output = list(tensors)
return output


def _get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]:
def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]:
"""
Gets the leaves of the autograd graph of all specified ``tensors``.

Expand Down
8 changes: 4 additions & 4 deletions src/torchjd/autojac/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac, Transform
from ._transform.ordered_set import OrderedSet
from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors
from ._utils import as_tensor_list, check_optional_positive_chunk_size, get_leaf_tensors


def backward(
Expand Down Expand Up @@ -67,15 +67,15 @@ def backward(
experience issues with ``backward`` try to use ``parallel_chunk_size=1`` to avoid relying on
``torch.vmap``.
"""
_check_optional_positive_chunk_size(parallel_chunk_size)
check_optional_positive_chunk_size(parallel_chunk_size)

tensors = _as_tensor_list(tensors)
tensors = as_tensor_list(tensors)

if len(tensors) == 0:
raise ValueError("`tensors` cannot be empty")

if inputs is None:
inputs = _get_leaf_tensors(tensors=tensors, excluded=set())
inputs = get_leaf_tensors(tensors=tensors, excluded=set())
else:
inputs = OrderedSet(inputs)

Expand Down
10 changes: 5 additions & 5 deletions src/torchjd/autojac/mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Transform,
)
from ._transform.ordered_set import OrderedSet
from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors
from ._utils import as_tensor_list, check_optional_positive_chunk_size, get_leaf_tensors


def mtl_backward(
Expand Down Expand Up @@ -79,16 +79,16 @@ def mtl_backward(
``torch.vmap``.
"""

_check_optional_positive_chunk_size(parallel_chunk_size)
check_optional_positive_chunk_size(parallel_chunk_size)

features = _as_tensor_list(features)
features = as_tensor_list(features)

if shared_params is None:
shared_params = _get_leaf_tensors(tensors=features, excluded=[])
shared_params = get_leaf_tensors(tensors=features, excluded=[])
else:
shared_params = OrderedSet(shared_params)
if tasks_params is None:
tasks_params = [_get_leaf_tensors(tensors=[loss], excluded=features) for loss in losses]
tasks_params = [get_leaf_tensors(tensors=[loss], excluded=features) for loss in losses]
else:
tasks_params = [OrderedSet(task_params) for task_params in tasks_params]

Expand Down
30 changes: 15 additions & 15 deletions tests/unit/autojac/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytest import mark, raises
from torch.nn import Linear, MSELoss, ReLU, Sequential

from torchjd.autojac._utils import _get_leaf_tensors
from torchjd.autojac._utils import get_leaf_tensors


def test_simple_get_leaf_tensors():
Expand All @@ -14,7 +14,7 @@ def test_simple_get_leaf_tensors():
y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
y2 = (a1**2).sum() + a2.norm()

leaves = _get_leaf_tensors(tensors=[y1, y2], excluded=set())
leaves = get_leaf_tensors(tensors=[y1, y2], excluded=set())
assert set(leaves) == {a1, a2}


Expand All @@ -35,7 +35,7 @@ def test_get_leaf_tensors_excluded_1():
y1 = torch.tensor([-1.0, 1.0]) @ a1 + b2
y2 = b1

leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2})
leaves = get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2})
assert set(leaves) == {a1}


Expand All @@ -56,7 +56,7 @@ def test_get_leaf_tensors_excluded_2():
y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
y2 = b1

leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2})
leaves = get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2})
assert set(leaves) == {a1, a2}


Expand All @@ -71,7 +71,7 @@ def test_get_leaf_tensors_leaf_not_requiring_grad():
y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
y2 = (a1**2).sum() + a2.norm()

leaves = _get_leaf_tensors(tensors=[y1, y2], excluded=set())
leaves = get_leaf_tensors(tensors=[y1, y2], excluded=set())
assert set(leaves) == {a1}


Expand All @@ -90,7 +90,7 @@ def test_get_leaf_tensors_model():
y_hat = model(x)
losses = loss_fn(y_hat, y)

leaves = _get_leaf_tensors(tensors=[losses], excluded=set())
leaves = get_leaf_tensors(tensors=[losses], excluded=set())
assert set(leaves) == set(model.parameters())


Expand All @@ -111,7 +111,7 @@ def test_get_leaf_tensors_model_excluded_2():
z_hat = model2(y)
losses = loss_fn(z_hat, z)

leaves = _get_leaf_tensors(tensors=[losses], excluded={y})
leaves = get_leaf_tensors(tensors=[losses], excluded={y})
assert set(leaves) == set(model2.parameters())


Expand All @@ -121,14 +121,14 @@ def test_get_leaf_tensors_single_root():
p = torch.tensor([1.0, 2.0], requires_grad=True)
y = p * 2

leaves = _get_leaf_tensors(tensors=[y], excluded=set())
leaves = get_leaf_tensors(tensors=[y], excluded=set())
assert set(leaves) == {p}


def test_get_leaf_tensors_empty_roots():
"""Tests that _get_leaf_tensors returns no leaves when roots is the empty set."""

leaves = _get_leaf_tensors(tensors=[], excluded=set())
leaves = get_leaf_tensors(tensors=[], excluded=set())
assert set(leaves) == set()


Expand All @@ -141,7 +141,7 @@ def test_get_leaf_tensors_excluded_root():
y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
y2 = (a1**2).sum()

leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={y1})
leaves = get_leaf_tensors(tensors=[y1, y2], excluded={y1})
assert set(leaves) == {a1}


Expand All @@ -154,7 +154,7 @@ def test_get_leaf_tensors_deep(depth: int):
for i in range(depth):
sum_ = sum_ + one

leaves = _get_leaf_tensors(tensors=[sum_], excluded=set())
leaves = get_leaf_tensors(tensors=[sum_], excluded=set())
assert set(leaves) == {one}


Expand All @@ -163,7 +163,7 @@ def test_get_leaf_tensors_leaf():

a = torch.tensor(1.0, requires_grad=True)
with raises(ValueError):
_ = _get_leaf_tensors(tensors=[a], excluded=set())
_ = get_leaf_tensors(tensors=[a], excluded=set())


def test_get_leaf_tensors_tensor_not_requiring_grad():
Expand All @@ -173,7 +173,7 @@ def test_get_leaf_tensors_tensor_not_requiring_grad():

a = torch.tensor(1.0, requires_grad=False) * 2
with raises(ValueError):
_ = _get_leaf_tensors(tensors=[a], excluded=set())
_ = get_leaf_tensors(tensors=[a], excluded=set())


def test_get_leaf_tensors_excluded_leaf():
Expand All @@ -182,7 +182,7 @@ def test_get_leaf_tensors_excluded_leaf():
a = torch.tensor(1.0, requires_grad=True) * 2
b = torch.tensor(2.0, requires_grad=True)
with raises(ValueError):
_ = _get_leaf_tensors(tensors=[a], excluded={b})
_ = get_leaf_tensors(tensors=[a], excluded={b})


def test_get_leaf_tensors_excluded_not_requiring_grad():
Expand All @@ -193,4 +193,4 @@ def test_get_leaf_tensors_excluded_not_requiring_grad():
a = torch.tensor(1.0, requires_grad=True) * 2
b = torch.tensor(2.0, requires_grad=False) * 2
with raises(ValueError):
_ = _get_leaf_tensors(tensors=[a], excluded={b})
_ = get_leaf_tensors(tensors=[a], excluded={b})