Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/_templates/autosummary/class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,4 @@
:show-inheritance:
:special-members: __init__, __call__, __iter__, __next__
:member-order: groupwise
{%- if module not in inherited_members_module_denylist %}
:inherited-members:
{% endif %}
33 changes: 27 additions & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import os
import sys
from pathlib import Path
from typing import Optional
import re
from capymoa.__about__ import __version__
from docs.util.github_link import make_linkcode_resolve

Expand Down Expand Up @@ -57,6 +59,7 @@
("py:class", r"tqdm\..*"),
("py:class", r"torchvision\..*"),
("py:class", r"Tensor"),
("py:class", r"nn\.Module"),
]

# These warnings are usually false positives.
Expand All @@ -65,12 +68,6 @@
toc_object_entries_show_parents = "hide"
autosummary_ignore_module_all = False
autosummary_generate = True
autosummary_context = {
# List of modules that we do not include inherited members in. This is
# usually because they import from torch.nn.Module or similar large
# classes.
"inherited_members_module_denylist": ["capymoa.ann"]
}

autodoc_member_order = "groupwise"
autodoc_class_signature = "separated"
Expand Down Expand Up @@ -128,6 +125,7 @@
intersphinx_mapping = {
"sklearn": ("https://scikit-learn.org/stable/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"python": ("https://docs.python.org/3", None),
}

""" Options for linkcode extension ------------------------------------------
Expand Down Expand Up @@ -180,3 +178,26 @@
},
],
}

autodoc_skip_member_patterns = [
# Inheriting from torch.nn.Module creates issues so we skip them.
r"torch\.nn\.modules\..*",
]


def autodoc_skip_member(app, obj_type, name, obj, skip, options) -> Optional[bool]:
if skip:
return None
if not hasattr(obj, "__module__") or not hasattr(obj, "__qualname__"):
return None
fqn = f"{obj.__module__}.{obj.__qualname__}"

for pattern in autodoc_skip_member_patterns:
if re.match(pattern, fqn):
return True

return None


def setup(app):
app.connect("autodoc-skip-member", autodoc_skip_member)
3 changes: 2 additions & 1 deletion src/capymoa/ocl/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from ._gdumb import GDumb
from ._rar import RAR
from . import l2p
from ._ewc import EWC

__all__ = ["ExperienceReplay", "SLDA", "NCM", "GDumb", "RAR", "l2p"]
__all__ = ["ExperienceReplay", "SLDA", "NCM", "GDumb", "RAR", "l2p", "EWC"]
273 changes: 273 additions & 0 deletions src/capymoa/ocl/strategy/_ewc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
from typing import Iterable, Iterator, Optional, Sequence, Tuple, Callable
from capymoa.stream._stream import Schema
from torch import Tensor, nn
import torch
from capymoa.base import BatchClassifier
from capymoa.ocl.base import TrainTaskAware, TestTaskAware
from capymoa.ocl.util._buffer_list import BufferList
from capymoa.ocl.util._replay import SlidingWindow
from torch.utils.data import DataLoader


def weighted_l2_reg(
params: Iterable[Tensor],
anchor_params: Iterable[Tensor],
fisher_diagonals: Iterable[Tensor],
device: torch.device,
) -> Tensor:
"""Compute an EWC-style weighted L2 regularisation term.

:param params: Current model parameters.
:param anchor_params: Reference parameters from a previous task.
:param fisher_diagonals: Diagonal Fisher information weights.
:param device: Device used for the accumulator tensor.
:return: Weighted L2 penalty scaled by ``1/2``.
"""
l2 = torch.tensor(0.0, device=device)
for param, anchor_param, fisher_diag in zip(
params, anchor_params, fisher_diagonals, strict=True
):
assert param.shape == anchor_param.shape
l2 += (fisher_diag * (param - anchor_param) ** 2).sum()
return l2 / 2.0


def fd_init(model: torch.nn.Module) -> Sequence[Tensor]:
"""Initialise zero-valued Fisher diagonal tensors for a model.

:param model: Model whose parameters define the Fisher diagonal shapes.
:return: Zero tensors matching all model parameters.
"""
return [torch.zeros_like(param) for param in model.parameters()]


def fd_accumulate(
fisher_diagonals: Sequence[Tensor],
parameters: Iterator[Tensor],
alpha: Optional[float] = None,
) -> Sequence[Tensor]:
"""Accumulates the squared gradients into the Fisher diagonal estimates.

:param fisher_diagonals: A sequence of tensors representing the current estimates of
the Fisher diagonals.
:param parameters: A sequence of model parameters whose gradients have been
computed.
:param alpha: Decay factor for the accumulated Fisher diagonals. A value of 1.0
corresponds to standard EWC accumulation, while values less than 1.0 implement
a decay as in Online EWC.
:return: Updated sequence of tensors representing the accumulated Fisher diagonals.
"""
for fisher_diag, param in zip(fisher_diagonals, parameters, strict=True):
if param.grad is None:
raise ValueError(
"Parameter gradients must be computed before updating Fisher diagonals."
)
if alpha is not None:
fisher_diag.mul_(alpha).add_(param.grad.data.pow(2), alpha=(1 - alpha))
else:
fisher_diag.add_(param.grad.data.pow(2))
return fisher_diagonals


def fd_compute(
model: torch.nn.Module,
forward_fn: Callable[[Tensor], Tensor],
dataloader: DataLoader[Tuple[Tensor, Tensor]],
device: torch.device,
criterion: torch.nn.Module,
) -> Sequence[Tensor]:
"""Compute module fisher diagonals.

:param model: A PyTorch classifier model.
:param dataloader: A PyTorch dataloader for a classification task, yielding batches
of (inputs, labels).
:param device: Compute device.
:param criterion: The loss function to use.
:return: A sequence of tensors representing the computed Fisher diagonals.
"""
model = model.eval().to(device)
criterion = criterion.eval().to(device)

fisher_diagonals = fd_init(model)
for inputs, labels in dataloader:
model.zero_grad()
inputs, labels = inputs.to(device), labels.to(device)
outputs = forward_fn(inputs)
loss = criterion(outputs, labels)
loss.backward()
fisher_diagonals = fd_accumulate(fisher_diagonals, model.parameters())
# Average the accumulated squared gradients over the number of samples
fisher_diagonals = [
fisher_diag / len(dataloader) for fisher_diag in fisher_diagonals
]
return fisher_diagonals


class EWC(BatchClassifier, nn.Module, TrainTaskAware, TestTaskAware):
"""Elastic Weight Consolidation learner.

Elastic Weight Consolidation (EWC) is a regularisation-based continual learning
strategy that mitigates catastrophic forgetting by penalising changes to important
parameters for previous tasks [#f1]_. We incorporate Online EWC-style [#f2]_ updates
to the Fisher diagonals, which decay the importance of previous tasks' parameters
over time based on the ``gamma`` hyperparameter.

Usually the EWC strategy has access to the entire active task's data when estimating
the Fisher diagonals, but instead we use a replay buffer to approximate the active
task distribution.

.. [#f1] Kirkpatrick, J., Pascanu, R., Rabinowitz, N., Veness, J., Desjardins, G.,
Rusu, A. A., Milan, K., Quan, J., Ramalho, T., Grabska-Barwinska, A., Hassabis,
D., Clopath, C., Kumaran, D., & Hadsell, R. (2017). Overcoming catastrophic
forgetting in neural networks. Proceedings of the National Academy of Sciences,
114(13), 3521–3526. https://doi.org/10.1073/pnas.1611835114

.. [#f2] Schwarz, J., Czarnecki, W., Luketina, J., Grabska-Barwinska, A., Teh, Y.
W., Pascanu, R., & Hadsell, R. (2018). Progress & Compress: A scalable framework
for continual learning. In J. G. Dy & A. Krause (Eds.), Proceedings of the 35th
International Conference on Machine Learning, ICML 2018, Stockholmsmässan,
Stockholm, Sweden, July 10-15, 2018 (Vol. 80, pp. 4535–4544). PMLR.
http://proceedings.mlr.press/v80/schwarz18a.html
"""

def __init__(
self,
schema: Schema,
model: torch.nn.Module,
optimiser: torch.optim.Optimizer,
lambda_: float,
fim_buffer: int = 256,
fim_batch_size: int = 32,
device: torch.device = torch.device("cpu"),
mask_test: bool = False,
mask_train: bool = False,
gamma: float = 1.0,
task_mask: Optional[Tensor] = None,
) -> None:
"""Construct an EWC learner.

:param schema: Stream schema used by the classifier interface.
:param model: Torch model that outputs class logits.
:param optimiser: Optimiser used to update ``model`` parameters.
:param lambda_: Weight of the EWC regularisation term.
:param fim_buffer: Replay window size for Fisher estimation.
:param fim_batch_size: Mini-batch size used when estimating Fisher diagonals.
:param device: Compute device.
:param mask_test: Whether to apply per-task masking during testing. This is a
task incremental scenario.
:param mask_train: Whether to apply per-task masking during training. This is
also known as the labels trick.
:param task_mask: Optional per-task mask applied to output logits.
:raises ValueError: If task-specific masking is requested without ``task_mask``.
"""
super().__init__(schema, 0)
nn.Module.__init__(self)
if (mask_train or mask_test) and task_mask is None:
raise ValueError(
"Task schedule must be provided for task incremental or labels trick scenarios."
)
self.device = device

# Hyperparameters
self._lambda = lambda_
self._gamma = gamma
self._fd_batch_size = fim_batch_size
self._mask_train = mask_train
self._mask_test = mask_test

# Modules
self._optimiser = optimiser
self._model = model
self._criterion = torch.nn.CrossEntropyLoss()
self._buffer = SlidingWindow(fim_buffer, schema.get_num_attributes())

Comment thread
tachyonicClock marked this conversation as resolved.
# Buffers for anchoring the model
self._anchor_params = BufferList(
[param.clone().detach() for param in model.parameters()]
)
self._fisher_diags = BufferList(
[torch.zeros_like(param) for param in model.parameters()]
)

# Task tracking
self._train_task = 0
self._test_task = 0
if task_mask is None:
self._task_mask = None
else:
self._task_mask = nn.Buffer(task_mask)

# Move all model parameters and buffers to the specified device
self.to(device)

def batch_train(self, x: Tensor, y: Tensor) -> None:
self._buffer.update(x, y)
self._model.train()
self._optimiser.zero_grad()
y_hat = self._train_forward(x)
loss = self._criterion(y_hat, y)
total_loss = loss + self._lambda * self._regularisation_loss()
total_loss.backward()
self._optimiser.step()

@torch.no_grad()
def batch_predict_proba(self, x: Tensor) -> Tensor:
self._model.eval()
y_hat = self._test_forward(x)
return torch.softmax(y_hat, dim=1)

def on_train_task(self, task_id: int) -> None:
if task_id > 0:
self._update_fisher_diags()
self._update_anchor_params()
self._train_task = task_id

def on_test_task(self, task_id: int) -> None:
self._test_task = task_id

def _update_fisher_diags(self) -> None:
"""Estimate and accumulate Fisher diagonals from the replay buffer."""
dataset = self._buffer.dataset_view()
dataloader = DataLoader(dataset, batch_size=self._fd_batch_size, shuffle=False)
task_fisher_diags = fd_compute(
self._model,
self._train_forward,
dataloader, # type: ignore
self.device,
self._criterion,
)
# Update the fisher diagonals buffer with the computed values
for i in range(len(self._fisher_diags)):
self._fisher_diags[i].mul_(self._gamma).add_(task_fisher_diags[i])

def _update_anchor_params(self) -> None:
"""Update anchored parameters to the current model weights."""
for param, anchor_param in zip(
self._model.parameters(), self._anchor_params, strict=True
):
anchor_param.copy_(param.detach())

def _test_forward(self, x: Tensor) -> Tensor:
"""Compute logits for inference, optionally applying a test-task mask."""
y_hat = self._model(x)
if self._task_mask is not None and self._mask_test:
y_hat = self._task_mask[self._test_task] * y_hat
return y_hat

def _train_forward(self, x: Tensor) -> Tensor:
"""Compute logits for training, optionally applying a train-task mask."""
y_hat = self._model(x)
if self._task_mask is not None and self._mask_train:
y_hat = self._task_mask[self._train_task] * y_hat
return y_hat

def _regularisation_loss(self) -> torch.Tensor:
"""Return the EWC regularisation loss for the current task."""
if self._train_task < 1:
return torch.tensor(0.0, device=self.device)
return weighted_l2_reg(
self._model.parameters(),
self._anchor_params,
self._fisher_diags,
device=self.device,
)
2 changes: 1 addition & 1 deletion src/capymoa/ocl/strategy/_experience_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from capymoa.base import BatchClassifier
from capymoa.ocl.base import TrainTaskAware, TestTaskAware
from capymoa.ocl.util._coreset import ReservoirSampler
from capymoa.ocl.util._replay import ReservoirSampler


class ExperienceReplay(BatchClassifier, TrainTaskAware, TestTaskAware):
Expand Down
2 changes: 1 addition & 1 deletion src/capymoa/ocl/strategy/_gdumb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from capymoa.ocl.util._coreset import GreedySampler
from capymoa.ocl.util._replay import GreedySampler
import torch
from capymoa.base import BatchClassifier
from capymoa.ocl.base import TestTaskAware
Expand Down
2 changes: 1 addition & 1 deletion src/capymoa/ocl/strategy/_rar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import Tensor

from capymoa.base import BatchClassifier
from capymoa.ocl.util._coreset import ReservoirSampler
from capymoa.ocl.util._replay import ReservoirSampler
from capymoa.ocl.base import TrainTaskAware, TestTaskAware

from typing import Callable
Expand Down
Loading
Loading