Skip to content

Commit 008ceb8

Browse files
feat(stream): rename 'TorchClassifyStream' to 'TorchStream' and support regression (#334)
1 parent 5ad82eb commit 008ceb8

6 files changed

Lines changed: 172 additions & 115 deletions

File tree

notebooks/03_pytorch.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@
478478
},
479479
{
480480
"cell_type": "code",
481-
"execution_count": 6,
481+
"execution_count": null,
482482
"id": "8a6fdd873625b07b",
483483
"metadata": {
484484
"ExecuteTime": {
@@ -858,7 +858,7 @@
858858
],
859859
"source": [
860860
"from capymoa.classifier import OnlineBagging\n",
861-
"from capymoa.stream import TorchClassifyStream\n",
861+
"from capymoa.stream import TorchStream\n",
862862
"from capymoa.evaluation import prequential_evaluation\n",
863863
"from capymoa.evaluation.visualization import plot_windowed_results\n",
864864
"\n",
@@ -868,7 +868,9 @@
868868
"pytorch_dataset = datasets.FashionMNIST(\n",
869869
" root=\"data\", train=True, download=True, transform=ToTensor()\n",
870870
")\n",
871-
"pytorch_stream = TorchClassifyStream(dataset=pytorch_dataset, num_classes=10)\n",
871+
"pytorch_stream = TorchStream.from_classification(\n",
872+
" dataset=pytorch_dataset, num_classes=10\n",
873+
")\n",
872874
"\n",
873875
"# Creating a learner\n",
874876
"ob_learner = OnlineBagging(schema=pytorch_stream.get_schema(), ensemble_size=5)\n",

src/capymoa/ocl/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
partition_by_schedule,
6666
class_schedule_to_task_mask,
6767
)
68-
from capymoa.stream import Stream, TorchClassifyStream
68+
from capymoa.stream import Stream, TorchStream
6969
from capymoa.stream._stream import Schema
7070

7171

@@ -227,7 +227,7 @@ def __init__(
227227
self.test_tasks = self._preload_datasets(self.test_tasks)
228228

229229
# Create streams for training and testing
230-
self.stream = TorchClassifyStream(
230+
self.stream = TorchStream.from_classification(
231231
ConcatDataset(self.train_tasks),
232232
num_classes=self.num_classes,
233233
shuffle=False,

src/capymoa/stream/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
)
88
from ._csv_stream import CSVStream
99
from ._stream_from_file import stream_from_file
10-
from .torch import TorchClassifyStream
10+
from .torch import TorchStream
1111
from . import drift, generator, preprocessing
1212

1313
__all__ = [
1414
"Stream",
1515
"Schema",
1616
"ARFFStream",
17-
"TorchClassifyStream",
17+
"TorchStream",
1818
"CSVStream",
1919
"drift",
2020
"generator",

src/capymoa/stream/_stream_from_file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def stream_from_file(
3434
* :class:`~capymoa.stream.CSVStream`
3535
* :class:`~capymoa.stream.ARFFStream`
3636
* :class:`~capymoa.stream.NumpyStream`
37-
* :class:`~capymoa.stream.TorchClassifyStream`
37+
* :class:`~capymoa.stream.TorchStream`
3838
* :class:`~capymoa.stream.Stream`
3939
4040
**CSV File Considerations:**

src/capymoa/stream/torch.py

Lines changed: 153 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,125 +1,169 @@
1-
import copy
21
from typing import Optional, Sequence, Tuple
32

43
import torch
54

65
from capymoa.stream import Stream, Schema
7-
from capymoa.instance import (
8-
LabeledInstance,
9-
)
6+
from capymoa.instance import LabeledInstance, RegressionInstance
107
from torch.utils.data import Dataset
118

129

13-
class TorchClassifyStream(Stream[LabeledInstance]):
14-
"""TorchClassifyStream turns a PyTorch dataset into a classification stream.
10+
def _shuffle_dataset(dataset: Dataset, seed: Optional[int] = None) -> Dataset:
11+
rng = torch.Generator()
12+
if seed is not None:
13+
rng.manual_seed(seed)
14+
indicies = torch.randperm(len(dataset), generator=rng)
15+
return torch.utils.data.Subset(dataset, indicies)
16+
17+
18+
class TorchStream(Stream):
19+
"""A stream adapter for PyTorch datasets.
20+
21+
This class converts PyTorch datasets into CapyMOA streams for both classification
22+
and regression tasks.
23+
24+
Creating a classification stream from a PyTorch dataset:
1525
16-
>>> from capymoa.evaluation import ClassificationEvaluator
17-
...
1826
>>> from capymoa.datasets import get_download_dir
19-
>>> from capymoa.stream import TorchClassifyStream
20-
>>> from torchvision import datasets
21-
>>> from torchvision.transforms import ToTensor
22-
>>> print("Using PyTorch Dataset"); pytorchDataset = datasets.FashionMNIST( #doctest:+ELLIPSIS
27+
>>> from capymoa.stream import TorchStream
28+
>>> from torchvision import datasets, transforms
29+
>>>
30+
>>> dataset = datasets.FashionMNIST(
2331
... root=get_download_dir(),
2432
... train=True,
2533
... download=True,
26-
... transform=ToTensor()
27-
... )
28-
Using PyTorch Dataset...
29-
>>> pytorch_stream = TorchClassifyStream(pytorchDataset, 10, class_names=pytorchDataset.classes)
30-
>>> pytorch_stream.get_schema()
31-
@relation PytorchDataset
32-
<BLANKLINE>
33-
@attribute 0 numeric
34-
@attribute 1 numeric
35-
...
36-
@attribute 783 numeric
37-
@attribute class {T-shirt/top,Trouser,Pullover,Dress,Coat,Sandal,Shirt,Sneaker,Bag,'Ankle boot'}
38-
<BLANKLINE>
39-
@data
40-
>>> pytorch_stream.next_instance()
41-
LabeledInstance(
42-
Schema(PytorchDataset),
43-
x=[0. 0. 0. ... 0. 0. 0.],
44-
y_index=9,
45-
y_label='Ankle boot'
46-
)
47-
48-
You can construct :class:`TorchClassifyStream` using a random sampler by passing a sampler
49-
to the constructor:
34+
... transform=transforms.ToTensor()
35+
... ) # doctest: +SKIP
36+
>>> stream = TorchStream.from_classification(
37+
... dataset, num_classes=10, class_names=dataset.classes
38+
... ) # doctest: +SKIP
39+
>>> stream.next_instance() # doctest: +SKIP
40+
LabeledInstance(...)
41+
42+
Creating a shuffled classification stream:
5043
5144
>>> import torch
52-
>>> from torch.utils.data import RandomSampler, TensorDataset
45+
>>> from torch.utils.data import TensorDataset
46+
>>>
5347
>>> dataset = TensorDataset(
54-
... torch.tensor([[1], [2], [3]]), torch.tensor([0, 1, 2])
48+
... torch.tensor([[1.0], [2.0], [3.0]]),
49+
... torch.tensor([0, 1, 2])
5550
... )
56-
>>> pytorch_stream = TorchClassifyStream(dataset=dataset, num_classes=3, shuffle=True)
57-
>>> for instance in pytorch_stream:
58-
... print(instance.x)
59-
[3.]
60-
[1.]
61-
[2.]
62-
63-
Importantly you can restart the stream to iterate over the dataset in
64-
the same order again:
65-
66-
>>> pytorch_stream.restart()
67-
>>> for instance in pytorch_stream:
68-
... print(instance.x)
69-
[3.]
70-
[1.]
71-
[2.]
51+
>>> stream = TorchStream.from_classification(
52+
... dataset, num_classes=3, shuffle=True, shuffle_seed=0
53+
... )
54+
>>> [float(inst.x[0]) for inst in stream]
55+
[3.0, 1.0, 2.0]
56+
57+
Streams can be restarted to iterate again:
58+
59+
>>> stream.restart()
60+
>>> [float(inst.x[0]) for inst in stream]
61+
[3.0, 1.0, 2.0]
62+
63+
Creating a regression stream:
64+
65+
>>> dataset = TensorDataset(
66+
... torch.tensor([[1.0], [2.0], [3.0]]),
67+
... torch.tensor([0.5, 1.5, 2.5])
68+
... )
69+
>>> stream = TorchStream.from_regression(
70+
... dataset, shuffle=True, shuffle_seed=0
71+
... )
72+
>>> [(float(inst.x[0]), float(inst.y_value)) for inst in stream]
73+
[(3.0, 2.5), (1.0, 0.5), (2.0, 1.5)]
7274
"""
7375

74-
def __init__(
75-
self,
76-
dataset: Dataset[Tuple[torch.Tensor, torch.LongTensor]],
77-
num_classes: int,
76+
@staticmethod
77+
def from_regression(
78+
dataset: Dataset[Tuple[torch.Tensor, torch.Tensor | float]],
79+
dataset_name: str = "TorchStream",
7880
shuffle: bool = False,
79-
shuffle_seed: int = 0,
80-
class_names: Optional[Sequence[str]] = None,
81-
dataset_name: str = "PytorchDataset",
82-
shape: Optional[Sequence[int]] = None,
83-
):
84-
"""Create a stream from a PyTorch dataset.
85-
86-
:param dataset: A PyTorch dataset
87-
:param num_classes: The number of classes in the dataset
88-
:param shuffle: Randomly sample with replacement, defaults to False
89-
:param shuffle_seed: Seed for shuffling, defaults to 0
90-
:param class_names: The names of the classes, defaults to None
91-
:param dataset_name: The name of the dataset, defaults to "PytorchDataset"
81+
shuffle_seed: Optional[int] = None,
82+
) -> "TorchStream":
83+
"""Construct a stream for regression from a PyTorch Dataset.
84+
85+
:param dataset: A PyTorch Dataset that yields tuples of (features, target) for
86+
regression tasks.
87+
:param dataset_name: An optional name for the stream.
88+
:param shape: An optional shape for the features. If not provided, features will
89+
be treated as flat vectors.
90+
:param shuffle: Whether to shuffle the dataset.
91+
:param shuffle_seed: An optional seed for shuffling the dataset.
92+
:return: A TorchStream instance for regression.
9293
"""
93-
if not (class_names is None or len(class_names) == num_classes):
94-
raise ValueError("Number of class labels must match the number of classes")
9594

96-
self.__init_args_kwargs__ = copy.copy(
97-
locals()
98-
) # save init args for recreation. not a deep copy to avoid unnecessary use of memory
95+
# Construct the schema based on the dataset and provided parameters
96+
X, _ = dataset[0]
97+
n_features = X.numel()
98+
features = [str(f) for f in range(n_features)] + ["target"]
99+
schema = Schema.from_custom(
100+
features=features,
101+
target="target",
102+
name=dataset_name,
103+
)
104+
105+
dataset = _shuffle_dataset(dataset, seed=shuffle_seed) if shuffle else dataset
106+
return TorchStream(dataset, schema)
99107

100-
self._dataset = dataset
101-
self._index = 0
102-
self._permutation = torch.arange(len(dataset))
108+
@staticmethod
109+
def from_classification(
110+
dataset: Dataset[Tuple[torch.Tensor, torch.Tensor | int]],
111+
num_classes: int,
112+
class_names: Optional[Sequence[str]] = None,
113+
dataset_name: str = "TorchStream",
114+
shape: Optional[Sequence[int]] = None,
115+
shuffle: bool = False,
116+
shuffle_seed: Optional[int] = None,
117+
) -> "TorchStream":
118+
"""Construct a stream for classification from a PyTorch Dataset.
119+
120+
:param dataset: A PyTorch Dataset that yields tuples of (features, target).
121+
:param num_classes: The number of classes in the classification task.
122+
:param class_names: An optional sequence of class names corresponding to the class indices.
123+
:param dataset_name: An optional name for the stream.
124+
:param shape: An optional shape for the features. If not provided, features will
125+
be treated as flat vectors.
126+
:param shuffle: Whether to shuffle the dataset.
127+
:param shuffle_seed: An optional seed for shuffling the dataset.
128+
:return: A TorchStream instance.
129+
"""
103130

104-
if shuffle:
105-
self._permutation = torch.randperm(
106-
len(dataset),
107-
generator=torch.Generator().manual_seed(shuffle_seed),
108-
)
131+
if class_names is None:
132+
class_names = [str(k) for k in range(num_classes)]
133+
if len(class_names) != num_classes:
134+
raise ValueError("Length of class_names must match num_classes.")
109135

110-
# Use the first instance to infer the number of attributes
111-
X, _ = self._dataset[0]
136+
# Construct the schema based on the dataset and provided parameters
137+
X, _ = dataset[0]
112138
n_features = X.numel()
113-
114-
# Create a header describing the dataset for MOA
115-
self._schema = Schema.from_custom(
116-
features=[f"{f}" for f in range(n_features)] + ["class"],
139+
features = [str(f) for f in range(n_features)] + ["class"]
140+
schema = Schema.from_custom(
141+
features=features,
117142
target="class",
118-
categories={"class": class_names or [str(i) for i in range(num_classes)]},
143+
categories={"class": class_names},
119144
name=dataset_name,
120145
)
121-
if shape is not None:
122-
self._schema.shape = shape
146+
schema._shape = shape if shape is not None else (n_features,)
147+
148+
dataset = _shuffle_dataset(dataset, seed=shuffle_seed) if shuffle else dataset
149+
return TorchStream(dataset, schema)
150+
151+
def __init__(
152+
self,
153+
dataset: Dataset,
154+
schema: Schema,
155+
):
156+
"""Construct a TorchStream from a PyTorch Dataset and a Schema.
157+
158+
Usually you want :meth:`from_classification` or :meth:`from_regression`.
159+
160+
:param dataset: A PyTorch Dataset that yields tuples of (features, target).
161+
:param schema: A Schema object that describes the structure of the data,
162+
including feature names and target information.
163+
"""
164+
self._dataset = dataset
165+
self.schema = schema
166+
self._index = 0
123167

124168
def has_more_instances(self):
125169
return len(self._dataset) > self._index
@@ -128,19 +172,26 @@ def next_instance(self):
128172
if not self.has_more_instances():
129173
raise StopIteration()
130174

131-
X, y = self._dataset[self._permutation[self._index]]
175+
X, y = self._dataset[self._index]
132176
self._index += 1 # increment counter for next call
133177

134-
# Tensors on the CPU and NumPy arrays share their underlying memory locations
135-
# We should prefer numpy over tensors in instances to improve compatibility
136-
# See: https://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html#bridge-to-np-label
137-
X = X.view(-1).numpy()
138-
if isinstance(y, torch.Tensor) and torch.isnan(y):
139-
y = -1
140-
return LabeledInstance.from_array(self._schema, X, int(y))
178+
if self.schema.is_classification():
179+
# Tensors on the CPU and NumPy arrays share their underlying memory locations
180+
# We should prefer numpy over tensors in instances to improve compatibility
181+
# See: https://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html#bridge-to-np-label
182+
X = X.view(-1).numpy()
183+
if isinstance(y, torch.Tensor) and torch.isnan(y):
184+
y = -1
185+
return LabeledInstance.from_array(self.schema, X, int(y))
186+
elif self.schema.is_regression():
187+
X = X.view(-1).numpy()
188+
y = y.item() # Convert single-value tensor to a Python scalar
189+
return RegressionInstance.from_array(self.schema, X, y)
190+
else:
191+
raise ValueError("Schema must be either classification or regression.")
141192

142193
def get_schema(self):
143-
return self._schema
194+
return self.schema
144195

145196
def get_moa_stream(self):
146197
return None

0 commit comments

Comments
 (0)