Skip to content

Commit b90928e

Browse files
Merge branch 'main' into aggregastion-explicit-gramians
2 parents 8c45168 + cec84ea commit b90928e

File tree

14 files changed

+162
-72
lines changed

14 files changed

+162
-72
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
name: Build and Deploy Documentation
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
tags:
7+
- 'v[0-9]*.[0-9]*.[0-9]*'
8+
pull_request:
9+
types: [opened, synchronize, reopened]
10+
11+
jobs:
12+
Build_and_deploy_doc:
13+
runs-on: ubuntu-latest
14+
permissions:
15+
contents: write
16+
steps:
17+
- name: Checkout repository
18+
uses: actions/checkout@v4
19+
20+
- name: Setup PDM
21+
uses: pdm-project/setup-pdm@v4
22+
with:
23+
python-version: '3.13'
24+
25+
- name: Install dependencies (default & doc)
26+
run: pdm install --group doc --frozen-lockfile
27+
28+
- name: Build Documentation
29+
working-directory: docs
30+
run: pdm run make dirhtml
31+
32+
- name: Determine deployment folder
33+
id: deploy_folder
34+
run: |
35+
echo "Determining deployment folder..."
36+
if [ "${{ github.event_name }}" = "pull_request" ]; then
37+
echo "Deploying to target pr/${{ github.event.number }}"
38+
echo "DEPLOY_DIR=pr/${{ github.event.number }}" >> $GITHUB_OUTPUT
39+
elif [[ "${{ github.ref }}" == refs/tags/* ]]; then
40+
echo "Deploying to target ${{ github.ref_name }}"
41+
echo "DEPLOY_DIR=${{ github.ref_name }}" >> $GITHUB_OUTPUT
42+
else
43+
echo "Deploying to target main"
44+
echo "DEPLOY_DIR=main" >> $GITHUB_OUTPUT
45+
fi
46+
47+
- name: Deploy to DEPLOY_DIR of TorchJD/documentation
48+
uses: peaceiris/actions-gh-pages@v4
49+
with:
50+
deploy_key: ${{ secrets.DOCUMENTATION_DEPLOY_KEY }}
51+
publish_dir: docs/build/dirhtml
52+
destination_dir: ${{ steps.deploy_folder.outputs.DEPLOY_DIR }}
53+
external_repository: TorchJD/documentation
54+
publish_branch: main
55+
56+
- name: Deploy to stable of TorchJD/documentation
57+
if: startsWith(github.ref, 'refs/tags/')
58+
uses: peaceiris/actions-gh-pages@v4
59+
with:
60+
deploy_key: ${{ secrets.DOCUMENTATION_DEPLOY_KEY }}
61+
publish_dir: docs/build/dirhtml
62+
destination_dir: stable
63+
external_repository: TorchJD/documentation
64+
publish_branch: main
65+
66+
- name: Add documentation link to summary
67+
run: |
68+
echo "### 📄 [View Deployed Documentation](https://torchjd.github.io/documentation/${{ steps.deploy_folder.outputs.DEPLOY_DIR }})" >> $GITHUB_STEP_SUMMARY
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
name: Cleanup PR Documentation
2+
3+
on:
4+
pull_request:
5+
types: [closed]
6+
7+
jobs:
8+
Cleanup_documentation:
9+
runs-on: ubuntu-latest
10+
permissions:
11+
contents: write
12+
steps:
13+
- name: Checkout gh-pages branch
14+
uses: actions/checkout@v4
15+
with:
16+
repository: TorchJD/documentation
17+
ref: main
18+
ssh-key: ${{ secrets.DOCUMENTATION_DEPLOY_KEY }}
19+
20+
- name: Remove PR documentation for closed PR
21+
run: |
22+
PR_NUMBER="${{ github.event.number }}"
23+
echo "Removing documentation for PR #${PR_NUMBER}"
24+
rm -rf pr/${PR_NUMBER}
25+
26+
- name: Commit and push cleanup
27+
run: |
28+
git config user.name "github-actions"
29+
git config user.email "github-actions@github.com"
30+
git add .
31+
git commit -m "Cleanup documentation for closed PR #${{ github.event.number }}" || echo "No changes to commit"
32+
git push origin HEAD:main

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
from torch import Tensor
55

6-
from .base import _A, RequirementError, Transform
6+
from .base import RequirementError, Transform
77
from .ordered_set import OrderedSet
8+
from .tensor_dict import _A
89

910

10-
class _Differentiate(Transform[_A, _A], ABC):
11+
class Differentiate(Transform[_A, _A], ABC):
1112
def __init__(
1213
self,
1314
outputs: Iterable[Tensor],
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Sequence
2+
3+
import torch
4+
from torch import Tensor
5+
6+
7+
def materialize(
8+
optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor]
9+
) -> tuple[Tensor, ...]:
10+
"""
11+
Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same
12+
shape as the corresponding input. Returns the obtained sequence as a tuple.
13+
14+
Note that the name "materialize" comes from the flag `materialize_grads` from
15+
`torch.autograd.grad`, which will be available in future torch releases.
16+
"""
17+
18+
tensors = []
19+
for optional_tensor, input in zip(optional_tensors, inputs):
20+
if optional_tensor is None:
21+
tensors.append(torch.zeros_like(input))
22+
else:
23+
tensors.append(optional_tensor)
24+
return tuple(tensors)

src/torchjd/autojac/_transform/_utils.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

src/torchjd/autojac/_transform/aggregate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
3232

3333
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
3434
def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
35-
self.key_order = OrderedSet(key_order)
35+
self.key_order = key_order
3636
self.aggregator = aggregator
3737

3838
def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:

src/torchjd/autojac/_transform/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from torch import Tensor
77

8-
from ._utils import _A, _B, _C, _union
8+
from .tensor_dict import _A, _B, _C, EmptyTensorDict, _least_common_ancestor
99

1010

1111
class RequirementError(ValueError):
@@ -99,8 +99,13 @@ def __str__(self) -> str:
9999
return "(" + " | ".join(strings) + ")"
100100

101101
def __call__(self, tensor_dict: _A) -> _B:
102-
output = _union([transform(tensor_dict) for transform in self.transforms])
103-
return output
102+
tensor_dicts = [transform(tensor_dict) for transform in self.transforms]
103+
output_type: type[_A] = EmptyTensorDict
104+
output: _A = EmptyTensorDict()
105+
for tensor_dict in tensor_dicts:
106+
output_type = _least_common_ancestor(output_type, type(tensor_dict))
107+
output |= tensor_dict
108+
return output_type(output)
104109

105110
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
106111
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]

src/torchjd/autojac/_transform/grad.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import torch
44
from torch import Tensor
55

6-
from ._differentiate import _Differentiate
7-
from ._utils import _materialize
6+
from ._differentiate import Differentiate
7+
from ._materialize import materialize
88
from .tensor_dict import Gradients
99

1010

11-
class Grad(_Differentiate[Gradients]):
11+
class Grad(Differentiate[Gradients]):
1212
def __init__(
1313
self,
1414
outputs: Iterable[Tensor],
@@ -47,5 +47,5 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
4747
create_graph=self.create_graph,
4848
allow_unused=True,
4949
)
50-
grads = _materialize(optional_grads, inputs)
50+
grads = materialize(optional_grads, inputs)
5151
return grads

src/torchjd/autojac/_transform/jac.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import torch
77
from torch import Size, Tensor
88

9-
from ._differentiate import _Differentiate
10-
from ._utils import _materialize
9+
from ._differentiate import Differentiate
10+
from ._materialize import materialize
1111
from .tensor_dict import Jacobians
1212

1313

14-
class Jac(_Differentiate[Jacobians]):
14+
class Jac(Differentiate[Jacobians]):
1515
def __init__(
1616
self,
1717
outputs: Iterable[Tensor],
@@ -60,7 +60,7 @@ def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor:
6060
create_graph=self.create_graph,
6161
allow_unused=True,
6262
)
63-
grads = _materialize(optional_grads, inputs=inputs)
63+
grads = materialize(optional_grads, inputs=inputs)
6464
return torch.concatenate([grad.reshape([-1]) for grad in grads])
6565

6666
# By the Jacobians constraint, this value should be the same for all jac_outputs.

src/torchjd/autojac/_transform/ordered_set.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections import OrderedDict
2-
from typing import Iterable
2+
from typing import Hashable, Iterable, TypeVar
33

4-
from torchjd.autojac._transform._utils import _KeyType
4+
_KeyType = TypeVar("_KeyType", bound=Hashable)
55

66

77
class OrderedSet(OrderedDict[_KeyType, None]):

0 commit comments

Comments
 (0)