diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index 9183bbac..173b66da 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -7,7 +7,7 @@ from ._transform.ordered_set import OrderedSet -def _check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None: +def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None: if not (parallel_chunk_size is None or parallel_chunk_size > 0): raise ValueError( "`parallel_chunk_size` should be `None` or greater than `0`. (got " @@ -15,7 +15,7 @@ def _check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None ) -def _as_tensor_list(tensors: Sequence[Tensor] | Tensor) -> list[Tensor]: +def as_tensor_list(tensors: Sequence[Tensor] | Tensor) -> list[Tensor]: if isinstance(tensors, Tensor): output = [tensors] else: @@ -23,7 +23,7 @@ def _as_tensor_list(tensors: Sequence[Tensor] | Tensor) -> list[Tensor]: return output -def _get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]: +def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]: """ Gets the leaves of the autograd graph of all specified ``tensors``. diff --git a/src/torchjd/autojac/backward.py b/src/torchjd/autojac/backward.py index 6478bd57..cb7226fe 100644 --- a/src/torchjd/autojac/backward.py +++ b/src/torchjd/autojac/backward.py @@ -6,7 +6,7 @@ from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac, Transform from ._transform.ordered_set import OrderedSet -from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors +from ._utils import as_tensor_list, check_optional_positive_chunk_size, get_leaf_tensors def backward( @@ -67,15 +67,15 @@ def backward( experience issues with ``backward`` try to use ``parallel_chunk_size=1`` to avoid relying on ``torch.vmap``. """ - _check_optional_positive_chunk_size(parallel_chunk_size) + check_optional_positive_chunk_size(parallel_chunk_size) - tensors = _as_tensor_list(tensors) + tensors = as_tensor_list(tensors) if len(tensors) == 0: raise ValueError("`tensors` cannot be empty") if inputs is None: - inputs = _get_leaf_tensors(tensors=tensors, excluded=set()) + inputs = get_leaf_tensors(tensors=tensors, excluded=set()) else: inputs = OrderedSet(inputs) diff --git a/src/torchjd/autojac/mtl_backward.py b/src/torchjd/autojac/mtl_backward.py index bfd56030..e11b25cc 100644 --- a/src/torchjd/autojac/mtl_backward.py +++ b/src/torchjd/autojac/mtl_backward.py @@ -17,7 +17,7 @@ Transform, ) from ._transform.ordered_set import OrderedSet -from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors +from ._utils import as_tensor_list, check_optional_positive_chunk_size, get_leaf_tensors def mtl_backward( @@ -79,16 +79,16 @@ def mtl_backward( ``torch.vmap``. """ - _check_optional_positive_chunk_size(parallel_chunk_size) + check_optional_positive_chunk_size(parallel_chunk_size) - features = _as_tensor_list(features) + features = as_tensor_list(features) if shared_params is None: - shared_params = _get_leaf_tensors(tensors=features, excluded=[]) + shared_params = get_leaf_tensors(tensors=features, excluded=[]) else: shared_params = OrderedSet(shared_params) if tasks_params is None: - tasks_params = [_get_leaf_tensors(tensors=[loss], excluded=features) for loss in losses] + tasks_params = [get_leaf_tensors(tensors=[loss], excluded=features) for loss in losses] else: tasks_params = [OrderedSet(task_params) for task_params in tasks_params] diff --git a/tests/unit/autojac/test_utils.py b/tests/unit/autojac/test_utils.py index 91fa4521..24443d60 100644 --- a/tests/unit/autojac/test_utils.py +++ b/tests/unit/autojac/test_utils.py @@ -2,7 +2,7 @@ from pytest import mark, raises from torch.nn import Linear, MSELoss, ReLU, Sequential -from torchjd.autojac._utils import _get_leaf_tensors +from torchjd.autojac._utils import get_leaf_tensors def test_simple_get_leaf_tensors(): @@ -14,7 +14,7 @@ def test_simple_get_leaf_tensors(): y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - leaves = _get_leaf_tensors(tensors=[y1, y2], excluded=set()) + leaves = get_leaf_tensors(tensors=[y1, y2], excluded=set()) assert set(leaves) == {a1, a2} @@ -35,7 +35,7 @@ def test_get_leaf_tensors_excluded_1(): y1 = torch.tensor([-1.0, 1.0]) @ a1 + b2 y2 = b1 - leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2}) + leaves = get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2}) assert set(leaves) == {a1} @@ -56,7 +56,7 @@ def test_get_leaf_tensors_excluded_2(): y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum() y2 = b1 - leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2}) + leaves = get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2}) assert set(leaves) == {a1, a2} @@ -71,7 +71,7 @@ def test_get_leaf_tensors_leaf_not_requiring_grad(): y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - leaves = _get_leaf_tensors(tensors=[y1, y2], excluded=set()) + leaves = get_leaf_tensors(tensors=[y1, y2], excluded=set()) assert set(leaves) == {a1} @@ -90,7 +90,7 @@ def test_get_leaf_tensors_model(): y_hat = model(x) losses = loss_fn(y_hat, y) - leaves = _get_leaf_tensors(tensors=[losses], excluded=set()) + leaves = get_leaf_tensors(tensors=[losses], excluded=set()) assert set(leaves) == set(model.parameters()) @@ -111,7 +111,7 @@ def test_get_leaf_tensors_model_excluded_2(): z_hat = model2(y) losses = loss_fn(z_hat, z) - leaves = _get_leaf_tensors(tensors=[losses], excluded={y}) + leaves = get_leaf_tensors(tensors=[losses], excluded={y}) assert set(leaves) == set(model2.parameters()) @@ -121,14 +121,14 @@ def test_get_leaf_tensors_single_root(): p = torch.tensor([1.0, 2.0], requires_grad=True) y = p * 2 - leaves = _get_leaf_tensors(tensors=[y], excluded=set()) + leaves = get_leaf_tensors(tensors=[y], excluded=set()) assert set(leaves) == {p} def test_get_leaf_tensors_empty_roots(): """Tests that _get_leaf_tensors returns no leaves when roots is the empty set.""" - leaves = _get_leaf_tensors(tensors=[], excluded=set()) + leaves = get_leaf_tensors(tensors=[], excluded=set()) assert set(leaves) == set() @@ -141,7 +141,7 @@ def test_get_leaf_tensors_excluded_root(): y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() - leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={y1}) + leaves = get_leaf_tensors(tensors=[y1, y2], excluded={y1}) assert set(leaves) == {a1} @@ -154,7 +154,7 @@ def test_get_leaf_tensors_deep(depth: int): for i in range(depth): sum_ = sum_ + one - leaves = _get_leaf_tensors(tensors=[sum_], excluded=set()) + leaves = get_leaf_tensors(tensors=[sum_], excluded=set()) assert set(leaves) == {one} @@ -163,7 +163,7 @@ def test_get_leaf_tensors_leaf(): a = torch.tensor(1.0, requires_grad=True) with raises(ValueError): - _ = _get_leaf_tensors(tensors=[a], excluded=set()) + _ = get_leaf_tensors(tensors=[a], excluded=set()) def test_get_leaf_tensors_tensor_not_requiring_grad(): @@ -173,7 +173,7 @@ def test_get_leaf_tensors_tensor_not_requiring_grad(): a = torch.tensor(1.0, requires_grad=False) * 2 with raises(ValueError): - _ = _get_leaf_tensors(tensors=[a], excluded=set()) + _ = get_leaf_tensors(tensors=[a], excluded=set()) def test_get_leaf_tensors_excluded_leaf(): @@ -182,7 +182,7 @@ def test_get_leaf_tensors_excluded_leaf(): a = torch.tensor(1.0, requires_grad=True) * 2 b = torch.tensor(2.0, requires_grad=True) with raises(ValueError): - _ = _get_leaf_tensors(tensors=[a], excluded={b}) + _ = get_leaf_tensors(tensors=[a], excluded={b}) def test_get_leaf_tensors_excluded_not_requiring_grad(): @@ -193,4 +193,4 @@ def test_get_leaf_tensors_excluded_not_requiring_grad(): a = torch.tensor(1.0, requires_grad=True) * 2 b = torch.tensor(2.0, requires_grad=False) * 2 with raises(ValueError): - _ = _get_leaf_tensors(tensors=[a], excluded={b}) + _ = get_leaf_tensors(tensors=[a], excluded={b})