|
| 1 | +from typing import Iterable, Iterator, Optional, Sequence, Tuple, Callable |
| 2 | +from capymoa.stream._stream import Schema |
| 3 | +from torch import Tensor, nn |
| 4 | +import torch |
| 5 | +from capymoa.base import BatchClassifier |
| 6 | +from capymoa.ocl.base import TrainTaskAware, TestTaskAware |
| 7 | +from capymoa.ocl.util._buffer_list import BufferList |
| 8 | +from capymoa.ocl.util._replay import SlidingWindow |
| 9 | +from torch.utils.data import DataLoader |
| 10 | + |
| 11 | + |
| 12 | +def weighted_l2_reg( |
| 13 | + params: Iterable[Tensor], |
| 14 | + anchor_params: Iterable[Tensor], |
| 15 | + fisher_diagonals: Iterable[Tensor], |
| 16 | + device: torch.device, |
| 17 | +) -> Tensor: |
| 18 | + """Compute an EWC-style weighted L2 regularisation term. |
| 19 | +
|
| 20 | + :param params: Current model parameters. |
| 21 | + :param anchor_params: Reference parameters from a previous task. |
| 22 | + :param fisher_diagonals: Diagonal Fisher information weights. |
| 23 | + :param device: Device used for the accumulator tensor. |
| 24 | + :return: Weighted L2 penalty scaled by ``1/2``. |
| 25 | + """ |
| 26 | + l2 = torch.tensor(0.0, device=device) |
| 27 | + for param, anchor_param, fisher_diag in zip( |
| 28 | + params, anchor_params, fisher_diagonals, strict=True |
| 29 | + ): |
| 30 | + assert param.shape == anchor_param.shape |
| 31 | + l2 += (fisher_diag * (param - anchor_param) ** 2).sum() |
| 32 | + return l2 / 2.0 |
| 33 | + |
| 34 | + |
| 35 | +def fd_init(model: torch.nn.Module) -> Sequence[Tensor]: |
| 36 | + """Initialise zero-valued Fisher diagonal tensors for a model. |
| 37 | +
|
| 38 | + :param model: Model whose parameters define the Fisher diagonal shapes. |
| 39 | + :return: Zero tensors matching all model parameters. |
| 40 | + """ |
| 41 | + return [torch.zeros_like(param) for param in model.parameters()] |
| 42 | + |
| 43 | + |
| 44 | +def fd_accumulate( |
| 45 | + fisher_diagonals: Sequence[Tensor], |
| 46 | + parameters: Iterator[Tensor], |
| 47 | + alpha: Optional[float] = None, |
| 48 | +) -> Sequence[Tensor]: |
| 49 | + """Accumulates the squared gradients into the Fisher diagonal estimates. |
| 50 | +
|
| 51 | + :param fisher_diagonals: A sequence of tensors representing the current estimates of |
| 52 | + the Fisher diagonals. |
| 53 | + :param parameters: A sequence of model parameters whose gradients have been |
| 54 | + computed. |
| 55 | + :param alpha: Decay factor for the accumulated Fisher diagonals. A value of 1.0 |
| 56 | + corresponds to standard EWC accumulation, while values less than 1.0 implement |
| 57 | + a decay as in Online EWC. |
| 58 | + :return: Updated sequence of tensors representing the accumulated Fisher diagonals. |
| 59 | + """ |
| 60 | + for fisher_diag, param in zip(fisher_diagonals, parameters, strict=True): |
| 61 | + if param.grad is None: |
| 62 | + raise ValueError( |
| 63 | + "Parameter gradients must be computed before updating Fisher diagonals." |
| 64 | + ) |
| 65 | + if alpha is not None: |
| 66 | + fisher_diag.mul_(alpha).add_(param.grad.data.pow(2), alpha=(1 - alpha)) |
| 67 | + else: |
| 68 | + fisher_diag.add_(param.grad.data.pow(2)) |
| 69 | + return fisher_diagonals |
| 70 | + |
| 71 | + |
| 72 | +def fd_compute( |
| 73 | + model: torch.nn.Module, |
| 74 | + forward_fn: Callable[[Tensor], Tensor], |
| 75 | + dataloader: DataLoader[Tuple[Tensor, Tensor]], |
| 76 | + device: torch.device, |
| 77 | + criterion: torch.nn.Module, |
| 78 | +) -> Sequence[Tensor]: |
| 79 | + """Compute module fisher diagonals. |
| 80 | +
|
| 81 | + :param model: A PyTorch classifier model. |
| 82 | + :param dataloader: A PyTorch dataloader for a classification task, yielding batches |
| 83 | + of (inputs, labels). |
| 84 | + :param device: Compute device. |
| 85 | + :param criterion: The loss function to use. |
| 86 | + :return: A sequence of tensors representing the computed Fisher diagonals. |
| 87 | + """ |
| 88 | + model = model.eval().to(device) |
| 89 | + criterion = criterion.eval().to(device) |
| 90 | + |
| 91 | + fisher_diagonals = fd_init(model) |
| 92 | + for inputs, labels in dataloader: |
| 93 | + model.zero_grad() |
| 94 | + inputs, labels = inputs.to(device), labels.to(device) |
| 95 | + outputs = forward_fn(inputs) |
| 96 | + loss = criterion(outputs, labels) |
| 97 | + loss.backward() |
| 98 | + fisher_diagonals = fd_accumulate(fisher_diagonals, model.parameters()) |
| 99 | + # Average the accumulated squared gradients over the number of samples |
| 100 | + fisher_diagonals = [ |
| 101 | + fisher_diag / len(dataloader) for fisher_diag in fisher_diagonals |
| 102 | + ] |
| 103 | + return fisher_diagonals |
| 104 | + |
| 105 | + |
| 106 | +class EWC(BatchClassifier, nn.Module, TrainTaskAware, TestTaskAware): |
| 107 | + """Elastic Weight Consolidation learner. |
| 108 | +
|
| 109 | + Elastic Weight Consolidation (EWC) is a regularisation-based continual learning |
| 110 | + strategy that mitigates catastrophic forgetting by penalising changes to important |
| 111 | + parameters for previous tasks [#f1]_. We incorporate Online EWC-style [#f2]_ updates |
| 112 | + to the Fisher diagonals, which decay the importance of previous tasks' parameters |
| 113 | + over time based on the ``gamma`` hyperparameter. |
| 114 | +
|
| 115 | + Usually the EWC strategy has access to the entire active task's data when estimating |
| 116 | + the Fisher diagonals, but instead we use a replay buffer to approximate the active |
| 117 | + task distribution. |
| 118 | +
|
| 119 | + .. [#f1] Kirkpatrick, J., Pascanu, R., Rabinowitz, N., Veness, J., Desjardins, G., |
| 120 | + Rusu, A. A., Milan, K., Quan, J., Ramalho, T., Grabska-Barwinska, A., Hassabis, |
| 121 | + D., Clopath, C., Kumaran, D., & Hadsell, R. (2017). Overcoming catastrophic |
| 122 | + forgetting in neural networks. Proceedings of the National Academy of Sciences, |
| 123 | + 114(13), 3521–3526. https://doi.org/10.1073/pnas.1611835114 |
| 124 | +
|
| 125 | + .. [#f2] Schwarz, J., Czarnecki, W., Luketina, J., Grabska-Barwinska, A., Teh, Y. |
| 126 | + W., Pascanu, R., & Hadsell, R. (2018). Progress & Compress: A scalable framework |
| 127 | + for continual learning. In J. G. Dy & A. Krause (Eds.), Proceedings of the 35th |
| 128 | + International Conference on Machine Learning, ICML 2018, Stockholmsmässan, |
| 129 | + Stockholm, Sweden, July 10-15, 2018 (Vol. 80, pp. 4535–4544). PMLR. |
| 130 | + http://proceedings.mlr.press/v80/schwarz18a.html |
| 131 | + """ |
| 132 | + |
| 133 | + def __init__( |
| 134 | + self, |
| 135 | + schema: Schema, |
| 136 | + model: torch.nn.Module, |
| 137 | + optimiser: torch.optim.Optimizer, |
| 138 | + lambda_: float, |
| 139 | + fim_buffer: int = 256, |
| 140 | + fim_batch_size: int = 32, |
| 141 | + device: torch.device = torch.device("cpu"), |
| 142 | + mask_test: bool = False, |
| 143 | + mask_train: bool = False, |
| 144 | + gamma: float = 1.0, |
| 145 | + task_mask: Optional[Tensor] = None, |
| 146 | + ) -> None: |
| 147 | + """Construct an EWC learner. |
| 148 | +
|
| 149 | + :param schema: Stream schema used by the classifier interface. |
| 150 | + :param model: Torch model that outputs class logits. |
| 151 | + :param optimiser: Optimiser used to update ``model`` parameters. |
| 152 | + :param lambda_: Weight of the EWC regularisation term. |
| 153 | + :param fim_buffer: Replay window size for Fisher estimation. |
| 154 | + :param fim_batch_size: Mini-batch size used when estimating Fisher diagonals. |
| 155 | + :param device: Compute device. |
| 156 | + :param mask_test: Whether to apply per-task masking during testing. This is a |
| 157 | + task incremental scenario. |
| 158 | + :param mask_train: Whether to apply per-task masking during training. This is |
| 159 | + also known as the labels trick. |
| 160 | + :param task_mask: Optional per-task mask applied to output logits. |
| 161 | + :raises ValueError: If task-specific masking is requested without ``task_mask``. |
| 162 | + """ |
| 163 | + super().__init__(schema, 0) |
| 164 | + nn.Module.__init__(self) |
| 165 | + if (mask_train or mask_test) and task_mask is None: |
| 166 | + raise ValueError( |
| 167 | + "Task schedule must be provided for task incremental or labels trick scenarios." |
| 168 | + ) |
| 169 | + self.device = device |
| 170 | + |
| 171 | + # Hyperparameters |
| 172 | + self._lambda = lambda_ |
| 173 | + self._gamma = gamma |
| 174 | + self._fd_batch_size = fim_batch_size |
| 175 | + self._mask_train = mask_train |
| 176 | + self._mask_test = mask_test |
| 177 | + |
| 178 | + # Modules |
| 179 | + self._optimiser = optimiser |
| 180 | + self._model = model |
| 181 | + self._criterion = torch.nn.CrossEntropyLoss() |
| 182 | + self._buffer = SlidingWindow(fim_buffer, schema.get_num_attributes()) |
| 183 | + |
| 184 | + # Buffers for anchoring the model |
| 185 | + self._anchor_params = BufferList( |
| 186 | + [param.clone().detach() for param in model.parameters()] |
| 187 | + ) |
| 188 | + self._fisher_diags = BufferList( |
| 189 | + [torch.zeros_like(param) for param in model.parameters()] |
| 190 | + ) |
| 191 | + |
| 192 | + # Task tracking |
| 193 | + self._train_task = 0 |
| 194 | + self._test_task = 0 |
| 195 | + if task_mask is None: |
| 196 | + self._task_mask = None |
| 197 | + else: |
| 198 | + self._task_mask = nn.Buffer(task_mask) |
| 199 | + |
| 200 | + # Move all model parameters and buffers to the specified device |
| 201 | + self.to(device) |
| 202 | + |
| 203 | + def batch_train(self, x: Tensor, y: Tensor) -> None: |
| 204 | + self._buffer.update(x, y) |
| 205 | + self._model.train() |
| 206 | + self._optimiser.zero_grad() |
| 207 | + y_hat = self._train_forward(x) |
| 208 | + loss = self._criterion(y_hat, y) |
| 209 | + total_loss = loss + self._lambda * self._regularisation_loss() |
| 210 | + total_loss.backward() |
| 211 | + self._optimiser.step() |
| 212 | + |
| 213 | + @torch.no_grad() |
| 214 | + def batch_predict_proba(self, x: Tensor) -> Tensor: |
| 215 | + self._model.eval() |
| 216 | + y_hat = self._test_forward(x) |
| 217 | + return torch.softmax(y_hat, dim=1) |
| 218 | + |
| 219 | + def on_train_task(self, task_id: int) -> None: |
| 220 | + if task_id > 0: |
| 221 | + self._update_fisher_diags() |
| 222 | + self._update_anchor_params() |
| 223 | + self._train_task = task_id |
| 224 | + |
| 225 | + def on_test_task(self, task_id: int) -> None: |
| 226 | + self._test_task = task_id |
| 227 | + |
| 228 | + def _update_fisher_diags(self) -> None: |
| 229 | + """Estimate and accumulate Fisher diagonals from the replay buffer.""" |
| 230 | + dataset = self._buffer.dataset_view() |
| 231 | + dataloader = DataLoader(dataset, batch_size=self._fd_batch_size, shuffle=False) |
| 232 | + task_fisher_diags = fd_compute( |
| 233 | + self._model, |
| 234 | + self._train_forward, |
| 235 | + dataloader, # type: ignore |
| 236 | + self.device, |
| 237 | + self._criterion, |
| 238 | + ) |
| 239 | + # Update the fisher diagonals buffer with the computed values |
| 240 | + for i in range(len(self._fisher_diags)): |
| 241 | + self._fisher_diags[i].mul_(self._gamma).add_(task_fisher_diags[i]) |
| 242 | + |
| 243 | + def _update_anchor_params(self) -> None: |
| 244 | + """Update anchored parameters to the current model weights.""" |
| 245 | + for param, anchor_param in zip( |
| 246 | + self._model.parameters(), self._anchor_params, strict=True |
| 247 | + ): |
| 248 | + anchor_param.copy_(param.detach()) |
| 249 | + |
| 250 | + def _test_forward(self, x: Tensor) -> Tensor: |
| 251 | + """Compute logits for inference, optionally applying a test-task mask.""" |
| 252 | + y_hat = self._model(x) |
| 253 | + if self._task_mask is not None and self._mask_test: |
| 254 | + y_hat = self._task_mask[self._test_task] * y_hat |
| 255 | + return y_hat |
| 256 | + |
| 257 | + def _train_forward(self, x: Tensor) -> Tensor: |
| 258 | + """Compute logits for training, optionally applying a train-task mask.""" |
| 259 | + y_hat = self._model(x) |
| 260 | + if self._task_mask is not None and self._mask_train: |
| 261 | + y_hat = self._task_mask[self._train_task] * y_hat |
| 262 | + return y_hat |
| 263 | + |
| 264 | + def _regularisation_loss(self) -> torch.Tensor: |
| 265 | + """Return the EWC regularisation loss for the current task.""" |
| 266 | + if self._train_task < 1: |
| 267 | + return torch.tensor(0.0, device=self.device) |
| 268 | + return weighted_l2_reg( |
| 269 | + self._model.parameters(), |
| 270 | + self._anchor_params, |
| 271 | + self._fisher_diags, |
| 272 | + device=self.device, |
| 273 | + ) |
0 commit comments