Skip to content

Commit ec1f50f

Browse files
feat: add domain incremental ocl datasets (#344)
1 parent 008ceb8 commit ec1f50f

17 files changed

Lines changed: 881 additions & 319 deletions

File tree

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
("py:class", r"sklearn\..*"),
5656
("py:class", r"torch\..*"),
5757
("py:class", r"tqdm\..*"),
58+
("py:class", r"torchvision\..*"),
59+
("py:class", r"Tensor"),
5860
]
5961

6062
# These warnings are usually false positives.

docs/images/DomainCIFAR100.jpg

352 KB
Loading
180 KB
Loading

docs/images/RotatedMNIST.jpg

154 KB
Loading

src/capymoa/datasets/_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,12 @@ def __init__(
168168
data: torch.Tensor,
169169
targets: torch.Tensor,
170170
transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
171+
target_transform: Optional[Callable[[object], object]] = None,
171172
):
172173
self.data = data
173174
self.targets = targets
174175
self.transform = transform
176+
self.target_transform = target_transform
175177

176178
def __len__(self):
177179
return len(self.data)
@@ -182,5 +184,7 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
182184

183185
if self.transform:
184186
x = self.transform(x)
187+
if self.target_transform:
188+
y = self.target_transform(y)
185189

186190
return x, y

src/capymoa/ocl/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,19 @@
3434
3535
>>> with np.printoptions(precision=2):
3636
... print(metrics.accuracy_matrix)
37-
[[0.88 0.17 0. 0.05 0.08]
38-
[0.85 0.85 0. 0.05 0.08]
39-
[0.75 0.8 0.6 0.08 0.03]
40-
[0.73 0.77 0.52 0.38 0.1 ]
41-
[0.75 0.75 0.5 0.38 0.57]]
37+
[[0.88 0.17 0.08 0.05 0.08]
38+
[0.85 0.85 0.03 0.05 0.08]
39+
[0.75 0.77 0.6 0.08 0.03]
40+
[0.75 0.77 0.52 0.38 0.1 ]
41+
[0.75 0.75 0.52 0.38 0.57]]
4242
4343
Notice that the accuracies in the upper triangle are close to zero because the
4444
learner has not trained on those tasks yet. The diagonal contains the accuracy
4545
on each task after training on that task. The lower triangle contains the
4646
accuracy on each task after training on all tasks.
4747
4848
>>> print(f"Forward Transfer: {metrics.forward_transfer:0.2f}")
49-
Forward Transfer: 0.06
49+
Forward Transfer: 0.07
5050
5151
>>> print(f"Backward Transfer: {metrics.backward_transfer:0.2f}")
5252
Backward Transfer: -0.08
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Use built-in datasets for online continual learning.
2+
3+
In OCL datastreams are irreversible sequences of examples following a
4+
non-stationary data distribution. Learners in OCL can only learn from a single
5+
pass through the datastream but are expected to perform well on any portion of
6+
the datastream.
7+
8+
Portions of the datastream where the data distribution is relatively stationary
9+
are called *tasks*.
10+
11+
A common way to construct an OCL dataset for experimentation is to group the
12+
classes of a classification dataset into tasks. Known as the *class-incremental*
13+
scenario, the learner is presented with a sequence of tasks where each task
14+
contains a new subset of the classes.
15+
16+
For example :class:`SplitMNIST` splits the MNIST dataset into five tasks where
17+
each task contains two classes:
18+
19+
>>> from capymoa.ocl.datasets import SplitMNIST
20+
>>> scenario = SplitMNIST(preload_test=False)
21+
>>> scenario.task_schedule
22+
[{1, 4}, {5, 7}, {9, 3}, {0, 8}, {2, 6}]
23+
24+
25+
To get the usual CapyMOA stream object for training:
26+
27+
>>> instance = scenario.stream.next_instance()
28+
>>> instance
29+
LabeledInstance(
30+
Schema(SplitMNIST10/5),
31+
x=[0. 0. 0. ... 0. 0. 0.],
32+
y_index=4,
33+
y_label='4'
34+
)
35+
36+
CapyMOA streams flatten the data into a feature vector:
37+
38+
>>> instance.x.shape
39+
(784,)
40+
41+
You can access the PyTorch datasets for each task:
42+
43+
>>> x, y = scenario.test_tasks[0][0]
44+
>>> x.shape
45+
torch.Size([1, 28, 28])
46+
>>> y
47+
1
48+
"""
49+
50+
from ._base import _BuiltInCIScenario
51+
from ._tiny import RotatedTinyMNIST, TinySplitMNIST
52+
from ._vision import (
53+
DomainCIFAR100,
54+
RotatedFashionMNIST,
55+
RotatedMNIST,
56+
SplitCIFAR10,
57+
SplitCIFAR100,
58+
SplitFashionMNIST,
59+
SplitMNIST,
60+
)
61+
from ._vit import DomainCIFAR100ViT, SplitCIFAR10ViT, SplitCIFAR100ViT
62+
63+
__all__ = [
64+
"_BuiltInCIScenario",
65+
"SplitMNIST",
66+
"RotatedMNIST",
67+
"TinySplitMNIST",
68+
"RotatedTinyMNIST",
69+
"SplitCIFAR100ViT",
70+
"SplitCIFAR10ViT",
71+
"DomainCIFAR100ViT",
72+
"SplitFashionMNIST",
73+
"RotatedFashionMNIST",
74+
"SplitCIFAR10",
75+
"SplitCIFAR100",
76+
"DomainCIFAR100",
77+
]

0 commit comments

Comments
 (0)