Skip to content

Commit e08e729

Browse files
authored
Change TransformState to NamedTuple (#106)
* Change TransformState to NamedTuple * Change class type docstring from Args to Attributes * Update inplace gotcha * Update docstring
1 parent a89667f commit e08e729

22 files changed

Lines changed: 183 additions & 183 deletions

docs/getting_started.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Here:
6464
- `build` is a function that loads `config_args` into the `init` and `update` functions
6565
and stores them within the `transform` instance. The `init` and `update`
6666
functions then conform to a preset signature allowing for easy switching between algorithms.
67-
- `state` is a [`dataclass`](https://docs.python.org/3/library/dataclasses.html)
67+
- `state` is a [`NamedTuple`](https://docs.python.org/3/library/typing.html#typing.NamedTuple)
6868
encoding the state of the algorithm, including `params` and `aux` attributes.
6969
- `init` constructs the iteration-varying `state` based on the model parameters `params`.
7070
- `update` updates the `state` based on a new `batch` of data.

docs/gotchas.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,29 @@ state2 = transform.update(state, batch, inplace=True)
7070
# state is updated and state2 is a pointer to state
7171
```
7272

73+
When adding a new algorithm, in-place support can be achieved by modifying `TensorTree`s
74+
via the [`flexi_tree_map`](https://normal-computing.github.io/posteriors/api/tree_utils/#posteriors.tree_utils.flexi_tree_map) function:
75+
76+
```python
77+
from posteriors.tree_utils import flexi_tree_map
78+
79+
new_state = flexi_tree_map(lambda x: x + 1, state, inplace=True)
80+
```
81+
82+
As `posteriors` transform states are immutable `NamedTuple`s, in-place modification of
83+
`TensorTree` leaves can be achieved by modifying the data of the tensor directly with [`tree_insert_`](https://normal-computing.github.io/posteriors/api/tree_utils/#posteriors.tree_utils.tree_insert_):
84+
85+
```python
86+
from posteriors.tree_utils import tree_insert_
87+
88+
tree_insert_(state.log_posterior, log_post.detach())
89+
```
90+
91+
However, the `aux` component of the `TransformState` is not guaranteed to be a `TensorTree`,
92+
and so in-place modification of `aux` is not supported. Using `state._replace(aux=aux)`
93+
will return a state with all `TensorTree` pointing to the same memory as input `state`,
94+
but with a new `aux` component (`aux` is not modified in the input `state` object).
95+
7396

7497
## `torch.tensor` with autograd
7598

docs/tutorials/lightning_autoencoder.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ from torchvision.datasets import MNIST
1616
from torchvision.transforms import ToTensor
1717
import lightning as L
1818
import torchopt
19-
from dataclasses import asdict
2019

2120
import posteriors
2221

@@ -100,7 +99,7 @@ class LitAutoEncoderUQ(L.LightningModule):
10099
# it is independent of forward
101100
self.state = self.transform.update(self.state, batch, inplace=True)
102101
# Logging to TensorBoard (if installed) by default
103-
for k, v in asdict(self.state).items():
102+
for k, v in self.state._asdict().items():
104103
if isinstance(v, float) or (isinstance(v, torch.Tensor) and v.numel() == 1):
105104
self.log(k, v)
106105

examples/lightning_autoencoder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torchvision.transforms import ToTensor
66
import lightning as L
77
import torchopt
8-
from dataclasses import asdict
98

109
import posteriors
1110

@@ -54,7 +53,7 @@ def training_step(self, batch, batch_idx):
5453
# it is independent of forward
5554
self.state = self.transform.update(self.state, batch, inplace=True)
5655
# Logging to TensorBoard (if installed) by default
57-
for k, v in asdict(self.state).items():
56+
for k, v in self.state._asdict().items():
5857
if isinstance(v, float) or (isinstance(v, torch.Tensor) and v.numel() == 1):
5958
self.log(k, v)
6059

posteriors/ekf/dense_fisher.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from typing import Any
1+
from typing import Any, NamedTuple
22
from functools import partial
33
import torch
44
from torch.func import grad_and_value
5-
from dataclasses import dataclass
65
from optree.integration.torch import tree_ravel
76

8-
from posteriors.tree_utils import tree_size
7+
from posteriors.tree_utils import tree_size, tree_insert_
98

10-
from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
9+
from posteriors.types import TensorTree, Transform, LogProbFn
1110
from posteriors.utils import (
1211
per_samplify,
1312
empirical_fisher,
@@ -67,11 +66,10 @@ def build(
6766
return Transform(init_fn, update_fn)
6867

6968

70-
@dataclass
71-
class EKFDenseState(TransformState):
69+
class EKFDenseState(NamedTuple):
7270
"""State encoding a Normal distribution over parameters.
7371
74-
Args:
72+
Attributes:
7573
params: Mean of the Normal distribution.
7674
cov: Covariance matrix of the
7775
Normal distribution.
@@ -81,7 +79,7 @@ class EKFDenseState(TransformState):
8179

8280
params: TensorTree
8381
cov: torch.Tensor
84-
log_likelihood: float = 0
82+
log_likelihood: torch.Tensor = torch.tensor([])
8583
aux: Any = None
8684

8785

@@ -170,11 +168,11 @@ def log_likelihood_reduced(params, batch):
170168
update_mean = mu_unravel_f(update_mean)
171169

172170
if inplace:
173-
state.params = update_mean
174-
state.cov = update_cov
175-
state.log_likelihood = log_liks.mean().detach()
176-
state.aux = aux
177-
return state
171+
tree_insert_(state.params, update_mean)
172+
tree_insert_(state.cov, update_cov)
173+
tree_insert_(state.log_likelihood, log_liks.mean().detach())
174+
return state._replace(aux=aux)
175+
178176
return EKFDenseState(update_mean, update_cov, log_liks.mean().detach(), aux)
179177

180178

posteriors/ekf/diag_fisher.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from typing import Any
1+
from typing import Any, NamedTuple
22
from functools import partial
33
import torch
44
from torch.func import jacrev
55
from optree import tree_map
6-
from dataclasses import dataclass
76

8-
from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
9-
from posteriors.tree_utils import flexi_tree_map
7+
from posteriors.types import TensorTree, Transform, LogProbFn
8+
from posteriors.tree_utils import flexi_tree_map, tree_insert_
109
from posteriors.utils import (
1110
diag_normal_sample,
1211
per_samplify,
@@ -68,11 +67,10 @@ def build(
6867
return Transform(init_fn, update_fn)
6968

7069

71-
@dataclass
72-
class EKFDiagState(TransformState):
70+
class EKFDiagState(NamedTuple):
7371
"""State encoding a diagonal Normal distribution over parameters.
7472
75-
Args:
73+
Attributes:
7674
params: Mean of the Normal distribution.
7775
sd_diag: Square-root diagonal of the covariance matrix of the
7876
Normal distribution.
@@ -82,7 +80,7 @@ class EKFDiagState(TransformState):
8280

8381
params: TensorTree
8482
sd_diag: TensorTree
85-
log_likelihood: float = 0
83+
log_likelihood: torch.Tensor = torch.tensor([])
8684
aux: Any = None
8785

8886

@@ -176,9 +174,9 @@ def update(
176174
)
177175

178176
if inplace:
179-
state.log_likelihood = log_liks.mean().detach()
180-
state.aux = aux
181-
return state
177+
tree_insert_(state.log_likelihood, log_liks.mean().detach())
178+
return state._replace(aux=aux)
179+
182180
return EKFDiagState(update_mean, update_sd_diag, log_liks.mean().detach(), aux)
183181

184182

posteriors/laplace/dense_fisher.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from typing import Any
2-
from dataclasses import dataclass
1+
from typing import Any, NamedTuple
32
from functools import partial
43
import torch
54
from optree import tree_map
65
from optree.integration.torch import tree_ravel
76

8-
from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
7+
from posteriors.types import TensorTree, Transform, LogProbFn
98
from posteriors.tree_utils import tree_size
109
from posteriors.utils import (
1110
per_samplify,
@@ -55,12 +54,11 @@ def build(
5554
return Transform(init_fn, update_fn)
5655

5756

58-
@dataclass
59-
class DenseLaplaceState(TransformState):
57+
class DenseLaplaceState(NamedTuple):
6058
"""State encoding a Normal distribution over parameters,
6159
with a dense precision matrix
6260
63-
Args:
61+
Attributes:
6462
params: Mean of the Normal distribution.
6563
prec: Precision matrix of the Normal distribution.
6664
aux: Auxiliary information from the log_posterior call.
@@ -130,9 +128,8 @@ def update(
130128
)(state.params)
131129

132130
if inplace:
133-
state.prec += fisher
134-
state.aux = aux
135-
return state
131+
state.prec.data += fisher
132+
return state._replace(aux=aux)
136133
else:
137134
return DenseLaplaceState(state.params, state.prec + fisher, aux)
138135

posteriors/laplace/dense_ggn.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
from functools import partial
2-
from typing import Any
2+
from typing import Any, NamedTuple
33
import torch
44
from optree import tree_map
5-
from dataclasses import dataclass
65
from optree.integration.torch import tree_ravel
76

87
from posteriors.types import (
98
TensorTree,
109
Transform,
1110
ForwardFn,
1211
OuterLogProbFn,
13-
TransformState,
1412
)
1513
from posteriors.utils import (
1614
tree_size,
@@ -67,12 +65,11 @@ def build(
6765
return Transform(init_fn, update_fn)
6866

6967

70-
@dataclass
71-
class DenseLaplaceState(TransformState):
68+
class DenseLaplaceState(NamedTuple):
7269
"""State encoding a Normal distribution over parameters,
7370
with a dense precision matrix
7471
75-
Args:
72+
Attributes:
7673
params: Mean of the Normal distribution.
7774
prec: Precision matrix of the Normal distribution.
7875
aux: Auxiliary information from the log_posterior call.
@@ -145,9 +142,8 @@ def outer_loss(z, batch):
145142
)(state.params)
146143

147144
if inplace:
148-
state.prec += ggn_batch
149-
state.aux = aux
150-
return state
145+
state.prec.data += ggn_batch
146+
return state._replace(aux=aux)
151147
else:
152148
return DenseLaplaceState(state.params, state.prec + ggn_batch, aux)
153149

posteriors/laplace/diag_fisher.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from functools import partial
2-
from typing import Any
2+
from typing import Any, NamedTuple
33
import torch
44
from torch.func import jacrev
55
from optree import tree_map
6-
from dataclasses import dataclass
76

8-
from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
7+
from posteriors.types import TensorTree, Transform, LogProbFn
98
from posteriors.tree_utils import flexi_tree_map
109
from posteriors.utils import (
1110
diag_normal_sample,
@@ -54,11 +53,10 @@ def build(
5453
return Transform(init_fn, update_fn)
5554

5655

57-
@dataclass
58-
class DiagLaplaceState(TransformState):
56+
class DiagLaplaceState(NamedTuple):
5957
"""State encoding a diagonal Normal distribution over parameters.
6058
61-
Args:
59+
Attributes:
6260
params: Mean of the Normal distribution.
6361
prec_diag: Diagonal of the precision matrix of the Normal distribution.
6462
aux: Auxiliary information from the log_posterior call.
@@ -134,8 +132,7 @@ def update_func(x, y):
134132
)
135133

136134
if inplace:
137-
state.aux = aux
138-
return state
135+
return state._replace(aux=aux)
139136
return DiagLaplaceState(state.params, prec_diag, aux)
140137

141138

posteriors/laplace/diag_ggn.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from functools import partial
2-
from typing import Any
2+
from typing import Any, NamedTuple
33
import torch
44
from optree import tree_map
5-
from dataclasses import dataclass
65

76
from posteriors.types import (
87
TensorTree,
98
Transform,
109
ForwardFn,
1110
OuterLogProbFn,
12-
TransformState,
1311
)
1412
from posteriors.tree_utils import flexi_tree_map
1513
from posteriors.utils import (
@@ -66,11 +64,10 @@ def build(
6664
return Transform(init_fn, update_fn)
6765

6866

69-
@dataclass
70-
class DiagLaplaceState(TransformState):
67+
class DiagLaplaceState(NamedTuple):
7168
"""State encoding a diagonal Normal distribution over parameters.
7269
73-
Args:
70+
Attributes:
7471
params: Mean of the Normal distribution.
7572
prec_diag: Diagonal of the precision matrix of the Normal distribution.
7673
aux: Auxiliary information from the log_posterior call.
@@ -149,8 +146,7 @@ def update_func(x, y):
149146
)
150147

151148
if inplace:
152-
state.aux = aux
153-
return state
149+
return state._replace(aux=aux)
154150
return DiagLaplaceState(state.params, prec_diag, aux)
155151

156152

0 commit comments

Comments
 (0)