Skip to content

Commit f1c3cc9

Browse files
feat(ocl): add lwf
1 parent 2ef7fe7 commit f1c3cc9

4 files changed

Lines changed: 148 additions & 2 deletions

File tree

src/capymoa/ocl/strategy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55
from ._rar import RAR
66
from . import l2p
77
from ._ewc import EWC
8+
from ._lwf import LWF
89

9-
__all__ = ["ExperienceReplay", "SLDA", "NCM", "GDumb", "RAR", "l2p", "EWC"]
10+
__all__ = ["ExperienceReplay", "SLDA", "NCM", "GDumb", "RAR", "l2p", "EWC", "LWF"]

src/capymoa/ocl/strategy/_lwf.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from copy import deepcopy
2+
from typing import Optional
3+
4+
import torch
5+
from torch import Tensor, nn
6+
7+
from capymoa.base import BatchClassifier
8+
from capymoa.ocl.base import TrainTaskAware
9+
from capymoa.ocl.util.functional import hinton_distillation_loss
10+
from capymoa.stream._stream import Schema
11+
12+
13+
class LWF(BatchClassifier, nn.Module, TrainTaskAware):
14+
"""Learning Without Forgetting (LwF) [#f1]_ .
15+
16+
LwF is a regularisation-based continual learning strategy that distils predictions
17+
from a frozen teacher snapshot of the previous task while learning the current task.
18+
19+
.. [#f1] Li, Z., & Hoiem, D. (2016). Learning without forgetting. CoRR,
20+
abs/1606.09282. http://arxiv.org/abs/1606.09282
21+
"""
22+
23+
def __init__(
24+
self,
25+
schema: Schema,
26+
model: torch.nn.Module,
27+
optimiser: torch.optim.Optimizer,
28+
alpha: float = 1.0,
29+
temperature: float = 2.0,
30+
device: torch.device = torch.device("cpu"),
31+
) -> None:
32+
"""Construct an LWF learner.
33+
34+
:param schema: Stream schema used by the classifier interface.
35+
:param model: Torch model that outputs class logits.
36+
:param optimiser: Optimiser used to update ``model`` parameters.
37+
:param alpha: Weight of the distillation loss term.
38+
:param temperature: Distillation temperature.
39+
:param device: Compute device.
40+
"""
41+
super().__init__(schema, 0)
42+
nn.Module.__init__(self)
43+
if alpha < 0:
44+
raise ValueError("alpha must be non-negative.")
45+
if temperature <= 0:
46+
raise ValueError("temperature must be greater than zero.")
47+
48+
self.device = device
49+
50+
self._alpha = alpha
51+
self._temperature = temperature
52+
53+
self._optimiser = optimiser
54+
self._model = model
55+
self._criterion = torch.nn.CrossEntropyLoss()
56+
57+
self._teacher: Optional[torch.nn.Module] = None
58+
self._train_task = 0
59+
60+
def batch_train(self, x: Tensor, y: Tensor) -> None:
61+
self._model.train()
62+
self._optimiser.zero_grad()
63+
64+
student_logits = self._model(x)
65+
task_loss = self._criterion(student_logits, y)
66+
total_loss = task_loss + self._alpha * self._distillation_loss(
67+
x, student_logits
68+
)
69+
70+
total_loss.backward()
71+
self._optimiser.step()
72+
73+
@torch.no_grad()
74+
def batch_predict_proba(self, x: Tensor) -> Tensor:
75+
self._model.eval()
76+
y_hat = self._model(x)
77+
return torch.softmax(y_hat, dim=1)
78+
79+
def on_train_task(self, task_id: int) -> None:
80+
if task_id > 0:
81+
self._teacher = (
82+
deepcopy(self._model).to(self.device).eval().requires_grad_(False)
83+
)
84+
self._train_task = task_id
85+
86+
@torch.no_grad()
87+
def _teacher_forward(self, x: Tensor) -> Tensor:
88+
if self._teacher is None:
89+
raise RuntimeError("Teacher model is not available before task 1.")
90+
return self._teacher(x)
91+
92+
def _distillation_loss(self, x: Tensor, student_logits: Tensor) -> Tensor:
93+
if self._teacher is None:
94+
return torch.tensor(0.0, device=self.device)
95+
96+
teacher_logits = self._teacher_forward(x)
97+
98+
return hinton_distillation_loss(
99+
teacher_logits=teacher_logits,
100+
student_logits=student_logits,
101+
temperature=self._temperature,
102+
)
103+
104+
def __str__(self) -> str:
105+
return f"LWF(alpha={self._alpha}, temperature={self._temperature})"

src/capymoa/ocl/util/functional.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""A collection of functional utilities for OCL."""
2+
3+
from torch import Tensor
4+
from torch.nn.functional import kl_div, log_softmax
5+
6+
7+
def hinton_distillation_loss(
8+
teacher_logits: Tensor, student_logits: Tensor, temperature: float = 1.0
9+
) -> Tensor:
10+
"""Hinton's distillation loss [#f1].
11+
12+
.. math::
13+
L_{KD} = T^2 KL(softmax(z_s / T), softmax(z_t / T))
14+
15+
where :math:`T` is the temperature, :math:`z_s` are the student logits, and
16+
:math:`z_t` are the teacher logits.
17+
18+
.. [#f1] Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a
19+
Neural Network. arXiv:1503.02531 [Cs, Stat]. http://arxiv.org/abs/1503.02531
20+
21+
:param teacher_logits: Teacher logits of shape ``(batch_size, num_classes)``.
22+
:param student_logits: Student logits of shape ``(batch_size, num_classes)``.
23+
:param temperature: Temperature for distillation. Higher values produce softer
24+
probability distributions.
25+
:return: The distillation loss as a scalar tensor.
26+
"""
27+
return (
28+
kl_div(
29+
log_softmax(student_logits / temperature, dim=1), # Soft predictions
30+
log_softmax(teacher_logits / temperature, dim=1), # Soft targets
31+
log_target=True,
32+
reduction="batchmean", # Mathematically correct unlike the default
33+
)
34+
* temperature**2
35+
)

tests/ocl/test_strategy.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from capymoa.classifier import Finetune, HoeffdingTree
1212
from capymoa.ocl.datasets import TinySplitMNIST
1313
from capymoa.ocl.evaluation import ocl_train_eval_loop
14-
from capymoa.ocl.strategy import ExperienceReplay, SLDA, NCM, GDumb, RAR, EWC
14+
from capymoa.ocl.strategy import ExperienceReplay, SLDA, NCM, GDumb, RAR, EWC, LWF
1515
from capymoa.stream import Schema
1616

1717
import torch
@@ -127,6 +127,11 @@ def _new_rar(schema):
127127
Result(71.99, 47.20, 25.3),
128128
task_mask=True,
129129
),
130+
Case(
131+
"LWF",
132+
new_constructor(LWF, lr=0.10, alpha=4.66, temperature=1.67),
133+
Result(36.49, 25.09, 17.59),
134+
),
130135
]
131136

132137

0 commit comments

Comments
 (0)