Skip to content

Commit 231c2ff

Browse files
feat(ocl): add EWC (#354)
1 parent 7bf795a commit 231c2ff

11 files changed

Lines changed: 502 additions & 29 deletions

File tree

docs/_templates/autosummary/class.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,4 @@
77
:show-inheritance:
88
:special-members: __init__, __call__, __iter__, __next__
99
:member-order: groupwise
10-
{%- if module not in inherited_members_module_denylist %}
1110
:inherited-members:
12-
{% endif %}

docs/conf.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import os
99
import sys
1010
from pathlib import Path
11+
from typing import Optional
12+
import re
1113
from capymoa.__about__ import __version__
1214
from docs.util.github_link import make_linkcode_resolve
1315

@@ -57,6 +59,7 @@
5759
("py:class", r"tqdm\..*"),
5860
("py:class", r"torchvision\..*"),
5961
("py:class", r"Tensor"),
62+
("py:class", r"nn\.Module"),
6063
]
6164

6265
# These warnings are usually false positives.
@@ -65,12 +68,6 @@
6568
toc_object_entries_show_parents = "hide"
6669
autosummary_ignore_module_all = False
6770
autosummary_generate = True
68-
autosummary_context = {
69-
# List of modules that we do not include inherited members in. This is
70-
# usually because they import from torch.nn.Module or similar large
71-
# classes.
72-
"inherited_members_module_denylist": ["capymoa.ann"]
73-
}
7471

7572
autodoc_member_order = "groupwise"
7673
autodoc_class_signature = "separated"
@@ -128,6 +125,7 @@
128125
intersphinx_mapping = {
129126
"sklearn": ("https://scikit-learn.org/stable/", None),
130127
"torch": ("https://pytorch.org/docs/stable/", None),
128+
"python": ("https://docs.python.org/3", None),
131129
}
132130

133131
""" Options for linkcode extension ------------------------------------------
@@ -180,3 +178,26 @@
180178
},
181179
],
182180
}
181+
182+
autodoc_skip_member_patterns = [
183+
# Inheriting from torch.nn.Module creates issues so we skip them.
184+
r"torch\.nn\.modules\..*",
185+
]
186+
187+
188+
def autodoc_skip_member(app, obj_type, name, obj, skip, options) -> Optional[bool]:
189+
if skip:
190+
return None
191+
if not hasattr(obj, "__module__") or not hasattr(obj, "__qualname__"):
192+
return None
193+
fqn = f"{obj.__module__}.{obj.__qualname__}"
194+
195+
for pattern in autodoc_skip_member_patterns:
196+
if re.match(pattern, fqn):
197+
return True
198+
199+
return None
200+
201+
202+
def setup(app):
203+
app.connect("autodoc-skip-member", autodoc_skip_member)

src/capymoa/ocl/strategy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
from ._gdumb import GDumb
55
from ._rar import RAR
66
from . import l2p
7+
from ._ewc import EWC
78

8-
__all__ = ["ExperienceReplay", "SLDA", "NCM", "GDumb", "RAR", "l2p"]
9+
__all__ = ["ExperienceReplay", "SLDA", "NCM", "GDumb", "RAR", "l2p", "EWC"]

src/capymoa/ocl/strategy/_ewc.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
from typing import Iterable, Iterator, Optional, Sequence, Tuple, Callable
2+
from capymoa.stream._stream import Schema
3+
from torch import Tensor, nn
4+
import torch
5+
from capymoa.base import BatchClassifier
6+
from capymoa.ocl.base import TrainTaskAware, TestTaskAware
7+
from capymoa.ocl.util._buffer_list import BufferList
8+
from capymoa.ocl.util._replay import SlidingWindow
9+
from torch.utils.data import DataLoader
10+
11+
12+
def weighted_l2_reg(
13+
params: Iterable[Tensor],
14+
anchor_params: Iterable[Tensor],
15+
fisher_diagonals: Iterable[Tensor],
16+
device: torch.device,
17+
) -> Tensor:
18+
"""Compute an EWC-style weighted L2 regularisation term.
19+
20+
:param params: Current model parameters.
21+
:param anchor_params: Reference parameters from a previous task.
22+
:param fisher_diagonals: Diagonal Fisher information weights.
23+
:param device: Device used for the accumulator tensor.
24+
:return: Weighted L2 penalty scaled by ``1/2``.
25+
"""
26+
l2 = torch.tensor(0.0, device=device)
27+
for param, anchor_param, fisher_diag in zip(
28+
params, anchor_params, fisher_diagonals, strict=True
29+
):
30+
assert param.shape == anchor_param.shape
31+
l2 += (fisher_diag * (param - anchor_param) ** 2).sum()
32+
return l2 / 2.0
33+
34+
35+
def fd_init(model: torch.nn.Module) -> Sequence[Tensor]:
36+
"""Initialise zero-valued Fisher diagonal tensors for a model.
37+
38+
:param model: Model whose parameters define the Fisher diagonal shapes.
39+
:return: Zero tensors matching all model parameters.
40+
"""
41+
return [torch.zeros_like(param) for param in model.parameters()]
42+
43+
44+
def fd_accumulate(
45+
fisher_diagonals: Sequence[Tensor],
46+
parameters: Iterator[Tensor],
47+
alpha: Optional[float] = None,
48+
) -> Sequence[Tensor]:
49+
"""Accumulates the squared gradients into the Fisher diagonal estimates.
50+
51+
:param fisher_diagonals: A sequence of tensors representing the current estimates of
52+
the Fisher diagonals.
53+
:param parameters: A sequence of model parameters whose gradients have been
54+
computed.
55+
:param alpha: Decay factor for the accumulated Fisher diagonals. A value of 1.0
56+
corresponds to standard EWC accumulation, while values less than 1.0 implement
57+
a decay as in Online EWC.
58+
:return: Updated sequence of tensors representing the accumulated Fisher diagonals.
59+
"""
60+
for fisher_diag, param in zip(fisher_diagonals, parameters, strict=True):
61+
if param.grad is None:
62+
raise ValueError(
63+
"Parameter gradients must be computed before updating Fisher diagonals."
64+
)
65+
if alpha is not None:
66+
fisher_diag.mul_(alpha).add_(param.grad.data.pow(2), alpha=(1 - alpha))
67+
else:
68+
fisher_diag.add_(param.grad.data.pow(2))
69+
return fisher_diagonals
70+
71+
72+
def fd_compute(
73+
model: torch.nn.Module,
74+
forward_fn: Callable[[Tensor], Tensor],
75+
dataloader: DataLoader[Tuple[Tensor, Tensor]],
76+
device: torch.device,
77+
criterion: torch.nn.Module,
78+
) -> Sequence[Tensor]:
79+
"""Compute module fisher diagonals.
80+
81+
:param model: A PyTorch classifier model.
82+
:param dataloader: A PyTorch dataloader for a classification task, yielding batches
83+
of (inputs, labels).
84+
:param device: Compute device.
85+
:param criterion: The loss function to use.
86+
:return: A sequence of tensors representing the computed Fisher diagonals.
87+
"""
88+
model = model.eval().to(device)
89+
criterion = criterion.eval().to(device)
90+
91+
fisher_diagonals = fd_init(model)
92+
for inputs, labels in dataloader:
93+
model.zero_grad()
94+
inputs, labels = inputs.to(device), labels.to(device)
95+
outputs = forward_fn(inputs)
96+
loss = criterion(outputs, labels)
97+
loss.backward()
98+
fisher_diagonals = fd_accumulate(fisher_diagonals, model.parameters())
99+
# Average the accumulated squared gradients over the number of samples
100+
fisher_diagonals = [
101+
fisher_diag / len(dataloader) for fisher_diag in fisher_diagonals
102+
]
103+
return fisher_diagonals
104+
105+
106+
class EWC(BatchClassifier, nn.Module, TrainTaskAware, TestTaskAware):
107+
"""Elastic Weight Consolidation learner.
108+
109+
Elastic Weight Consolidation (EWC) is a regularisation-based continual learning
110+
strategy that mitigates catastrophic forgetting by penalising changes to important
111+
parameters for previous tasks [#f1]_. We incorporate Online EWC-style [#f2]_ updates
112+
to the Fisher diagonals, which decay the importance of previous tasks' parameters
113+
over time based on the ``gamma`` hyperparameter.
114+
115+
Usually the EWC strategy has access to the entire active task's data when estimating
116+
the Fisher diagonals, but instead we use a replay buffer to approximate the active
117+
task distribution.
118+
119+
.. [#f1] Kirkpatrick, J., Pascanu, R., Rabinowitz, N., Veness, J., Desjardins, G.,
120+
Rusu, A. A., Milan, K., Quan, J., Ramalho, T., Grabska-Barwinska, A., Hassabis,
121+
D., Clopath, C., Kumaran, D., & Hadsell, R. (2017). Overcoming catastrophic
122+
forgetting in neural networks. Proceedings of the National Academy of Sciences,
123+
114(13), 3521–3526. https://doi.org/10.1073/pnas.1611835114
124+
125+
.. [#f2] Schwarz, J., Czarnecki, W., Luketina, J., Grabska-Barwinska, A., Teh, Y.
126+
W., Pascanu, R., & Hadsell, R. (2018). Progress & Compress: A scalable framework
127+
for continual learning. In J. G. Dy & A. Krause (Eds.), Proceedings of the 35th
128+
International Conference on Machine Learning, ICML 2018, Stockholmsmässan,
129+
Stockholm, Sweden, July 10-15, 2018 (Vol. 80, pp. 4535–4544). PMLR.
130+
http://proceedings.mlr.press/v80/schwarz18a.html
131+
"""
132+
133+
def __init__(
134+
self,
135+
schema: Schema,
136+
model: torch.nn.Module,
137+
optimiser: torch.optim.Optimizer,
138+
lambda_: float,
139+
fim_buffer: int = 256,
140+
fim_batch_size: int = 32,
141+
device: torch.device = torch.device("cpu"),
142+
mask_test: bool = False,
143+
mask_train: bool = False,
144+
gamma: float = 1.0,
145+
task_mask: Optional[Tensor] = None,
146+
) -> None:
147+
"""Construct an EWC learner.
148+
149+
:param schema: Stream schema used by the classifier interface.
150+
:param model: Torch model that outputs class logits.
151+
:param optimiser: Optimiser used to update ``model`` parameters.
152+
:param lambda_: Weight of the EWC regularisation term.
153+
:param fim_buffer: Replay window size for Fisher estimation.
154+
:param fim_batch_size: Mini-batch size used when estimating Fisher diagonals.
155+
:param device: Compute device.
156+
:param mask_test: Whether to apply per-task masking during testing. This is a
157+
task incremental scenario.
158+
:param mask_train: Whether to apply per-task masking during training. This is
159+
also known as the labels trick.
160+
:param task_mask: Optional per-task mask applied to output logits.
161+
:raises ValueError: If task-specific masking is requested without ``task_mask``.
162+
"""
163+
super().__init__(schema, 0)
164+
nn.Module.__init__(self)
165+
if (mask_train or mask_test) and task_mask is None:
166+
raise ValueError(
167+
"Task schedule must be provided for task incremental or labels trick scenarios."
168+
)
169+
self.device = device
170+
171+
# Hyperparameters
172+
self._lambda = lambda_
173+
self._gamma = gamma
174+
self._fd_batch_size = fim_batch_size
175+
self._mask_train = mask_train
176+
self._mask_test = mask_test
177+
178+
# Modules
179+
self._optimiser = optimiser
180+
self._model = model
181+
self._criterion = torch.nn.CrossEntropyLoss()
182+
self._buffer = SlidingWindow(fim_buffer, schema.get_num_attributes())
183+
184+
# Buffers for anchoring the model
185+
self._anchor_params = BufferList(
186+
[param.clone().detach() for param in model.parameters()]
187+
)
188+
self._fisher_diags = BufferList(
189+
[torch.zeros_like(param) for param in model.parameters()]
190+
)
191+
192+
# Task tracking
193+
self._train_task = 0
194+
self._test_task = 0
195+
if task_mask is None:
196+
self._task_mask = None
197+
else:
198+
self._task_mask = nn.Buffer(task_mask)
199+
200+
# Move all model parameters and buffers to the specified device
201+
self.to(device)
202+
203+
def batch_train(self, x: Tensor, y: Tensor) -> None:
204+
self._buffer.update(x, y)
205+
self._model.train()
206+
self._optimiser.zero_grad()
207+
y_hat = self._train_forward(x)
208+
loss = self._criterion(y_hat, y)
209+
total_loss = loss + self._lambda * self._regularisation_loss()
210+
total_loss.backward()
211+
self._optimiser.step()
212+
213+
@torch.no_grad()
214+
def batch_predict_proba(self, x: Tensor) -> Tensor:
215+
self._model.eval()
216+
y_hat = self._test_forward(x)
217+
return torch.softmax(y_hat, dim=1)
218+
219+
def on_train_task(self, task_id: int) -> None:
220+
if task_id > 0:
221+
self._update_fisher_diags()
222+
self._update_anchor_params()
223+
self._train_task = task_id
224+
225+
def on_test_task(self, task_id: int) -> None:
226+
self._test_task = task_id
227+
228+
def _update_fisher_diags(self) -> None:
229+
"""Estimate and accumulate Fisher diagonals from the replay buffer."""
230+
dataset = self._buffer.dataset_view()
231+
dataloader = DataLoader(dataset, batch_size=self._fd_batch_size, shuffle=False)
232+
task_fisher_diags = fd_compute(
233+
self._model,
234+
self._train_forward,
235+
dataloader, # type: ignore
236+
self.device,
237+
self._criterion,
238+
)
239+
# Update the fisher diagonals buffer with the computed values
240+
for i in range(len(self._fisher_diags)):
241+
self._fisher_diags[i].mul_(self._gamma).add_(task_fisher_diags[i])
242+
243+
def _update_anchor_params(self) -> None:
244+
"""Update anchored parameters to the current model weights."""
245+
for param, anchor_param in zip(
246+
self._model.parameters(), self._anchor_params, strict=True
247+
):
248+
anchor_param.copy_(param.detach())
249+
250+
def _test_forward(self, x: Tensor) -> Tensor:
251+
"""Compute logits for inference, optionally applying a test-task mask."""
252+
y_hat = self._model(x)
253+
if self._task_mask is not None and self._mask_test:
254+
y_hat = self._task_mask[self._test_task] * y_hat
255+
return y_hat
256+
257+
def _train_forward(self, x: Tensor) -> Tensor:
258+
"""Compute logits for training, optionally applying a train-task mask."""
259+
y_hat = self._model(x)
260+
if self._task_mask is not None and self._mask_train:
261+
y_hat = self._task_mask[self._train_task] * y_hat
262+
return y_hat
263+
264+
def _regularisation_loss(self) -> torch.Tensor:
265+
"""Return the EWC regularisation loss for the current task."""
266+
if self._train_task < 1:
267+
return torch.tensor(0.0, device=self.device)
268+
return weighted_l2_reg(
269+
self._model.parameters(),
270+
self._anchor_params,
271+
self._fisher_diags,
272+
device=self.device,
273+
)

src/capymoa/ocl/strategy/_experience_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from capymoa.base import BatchClassifier
55
from capymoa.ocl.base import TrainTaskAware, TestTaskAware
6-
from capymoa.ocl.util._coreset import ReservoirSampler
6+
from capymoa.ocl.util._replay import ReservoirSampler
77

88

99
class ExperienceReplay(BatchClassifier, TrainTaskAware, TestTaskAware):

src/capymoa/ocl/strategy/_gdumb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from capymoa.ocl.util._coreset import GreedySampler
1+
from capymoa.ocl.util._replay import GreedySampler
22
import torch
33
from capymoa.base import BatchClassifier
44
from capymoa.ocl.base import TestTaskAware

src/capymoa/ocl/strategy/_rar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import Tensor
33

44
from capymoa.base import BatchClassifier
5-
from capymoa.ocl.util._coreset import ReservoirSampler
5+
from capymoa.ocl.util._replay import ReservoirSampler
66
from capymoa.ocl.base import TrainTaskAware, TestTaskAware
77

88
from typing import Callable

0 commit comments

Comments
 (0)