1- import copy
21from typing import Optional , Sequence , Tuple
32
43import torch
54
65from capymoa .stream import Stream , Schema
7- from capymoa .instance import (
8- LabeledInstance ,
9- )
6+ from capymoa .instance import LabeledInstance , RegressionInstance
107from torch .utils .data import Dataset
118
129
13- class TorchClassifyStream (Stream [LabeledInstance ]):
14- """TorchClassifyStream turns a PyTorch dataset into a classification stream.
10+ def _shuffle_dataset (dataset : Dataset , seed : Optional [int ] = None ) -> Dataset :
11+ rng = torch .Generator ()
12+ if seed is not None :
13+ rng .manual_seed (seed )
14+ indicies = torch .randperm (len (dataset ), generator = rng )
15+ return torch .utils .data .Subset (dataset , indicies )
16+
17+
18+ class TorchStream (Stream ):
19+ """A stream adapter for PyTorch datasets.
20+
21+ This class converts PyTorch datasets into CapyMOA streams for both classification
22+ and regression tasks.
23+
24+ Creating a classification stream from a PyTorch dataset:
1525
16- >>> from capymoa.evaluation import ClassificationEvaluator
17- ...
1826 >>> from capymoa.datasets import get_download_dir
19- >>> from capymoa.stream import TorchClassifyStream
20- >>> from torchvision import datasets
21- >>> from torchvision.transforms import ToTensor
22- >>> print("Using PyTorch Dataset"); pytorchDataset = datasets.FashionMNIST( #doctest:+ELLIPSIS
27+ >>> from capymoa.stream import TorchStream
28+ >>> from torchvision import datasets, transforms
29+ >>>
30+ >>> dataset = datasets.FashionMNIST(
2331 ... root=get_download_dir(),
2432 ... train=True,
2533 ... download=True,
26- ... transform=ToTensor()
27- ... )
28- Using PyTorch Dataset...
29- >>> pytorch_stream = TorchClassifyStream(pytorchDataset, 10, class_names=pytorchDataset.classes)
30- >>> pytorch_stream.get_schema()
31- @relation PytorchDataset
32- <BLANKLINE>
33- @attribute 0 numeric
34- @attribute 1 numeric
35- ...
36- @attribute 783 numeric
37- @attribute class {T-shirt/top,Trouser,Pullover,Dress,Coat,Sandal,Shirt,Sneaker,Bag,'Ankle boot'}
38- <BLANKLINE>
39- @data
40- >>> pytorch_stream.next_instance()
41- LabeledInstance(
42- Schema(PytorchDataset),
43- x=[0. 0. 0. ... 0. 0. 0.],
44- y_index=9,
45- y_label='Ankle boot'
46- )
47-
48- You can construct :class:`TorchClassifyStream` using a random sampler by passing a sampler
49- to the constructor:
34+ ... transform=transforms.ToTensor()
35+ ... ) # doctest: +SKIP
36+ >>> stream = TorchStream.from_classification(
37+ ... dataset, num_classes=10, class_names=dataset.classes
38+ ... ) # doctest: +SKIP
39+ >>> stream.next_instance() # doctest: +SKIP
40+ LabeledInstance(...)
41+
42+ Creating a shuffled classification stream:
5043
5144 >>> import torch
52- >>> from torch.utils.data import RandomSampler, TensorDataset
45+ >>> from torch.utils.data import TensorDataset
46+ >>>
5347 >>> dataset = TensorDataset(
54- ... torch.tensor([[1], [2], [3]]), torch.tensor([0, 1, 2])
48+ ... torch.tensor([[1.0], [2.0], [3.0]]),
49+ ... torch.tensor([0, 1, 2])
5550 ... )
56- >>> pytorch_stream = TorchClassifyStream(dataset=dataset, num_classes=3, shuffle=True)
57- >>> for instance in pytorch_stream:
58- ... print(instance.x)
59- [3.]
60- [1.]
61- [2.]
62-
63- Importantly you can restart the stream to iterate over the dataset in
64- the same order again:
65-
66- >>> pytorch_stream.restart()
67- >>> for instance in pytorch_stream:
68- ... print(instance.x)
69- [3.]
70- [1.]
71- [2.]
51+ >>> stream = TorchStream.from_classification(
52+ ... dataset, num_classes=3, shuffle=True, shuffle_seed=0
53+ ... )
54+ >>> [float(inst.x[0]) for inst in stream]
55+ [3.0, 1.0, 2.0]
56+
57+ Streams can be restarted to iterate again:
58+
59+ >>> stream.restart()
60+ >>> [float(inst.x[0]) for inst in stream]
61+ [3.0, 1.0, 2.0]
62+
63+ Creating a regression stream:
64+
65+ >>> dataset = TensorDataset(
66+ ... torch.tensor([[1.0], [2.0], [3.0]]),
67+ ... torch.tensor([0.5, 1.5, 2.5])
68+ ... )
69+ >>> stream = TorchStream.from_regression(
70+ ... dataset, shuffle=True, shuffle_seed=0
71+ ... )
72+ >>> [(float(inst.x[0]), float(inst.y_value)) for inst in stream]
73+ [(3.0, 2.5), (1.0, 0.5), (2.0, 1.5)]
7274 """
7375
74- def __init__ (
75- self ,
76- dataset : Dataset [Tuple [torch .Tensor , torch .LongTensor ]],
77- num_classes : int ,
76+ @ staticmethod
77+ def from_regression (
78+ dataset : Dataset [Tuple [torch .Tensor , torch .Tensor | float ]],
79+ dataset_name : str = "TorchStream" ,
7880 shuffle : bool = False ,
79- shuffle_seed : int = 0 ,
80- class_names : Optional [Sequence [str ]] = None ,
81- dataset_name : str = "PytorchDataset" ,
82- shape : Optional [Sequence [int ]] = None ,
83- ):
84- """Create a stream from a PyTorch dataset.
85-
86- :param dataset: A PyTorch dataset
87- :param num_classes: The number of classes in the dataset
88- :param shuffle: Randomly sample with replacement, defaults to False
89- :param shuffle_seed: Seed for shuffling, defaults to 0
90- :param class_names: The names of the classes, defaults to None
91- :param dataset_name: The name of the dataset, defaults to "PytorchDataset"
81+ shuffle_seed : Optional [int ] = None ,
82+ ) -> "TorchStream" :
83+ """Construct a stream for regression from a PyTorch Dataset.
84+
85+ :param dataset: A PyTorch Dataset that yields tuples of (features, target) for
86+ regression tasks.
87+ :param dataset_name: An optional name for the stream.
88+ :param shape: An optional shape for the features. If not provided, features will
89+ be treated as flat vectors.
90+ :param shuffle: Whether to shuffle the dataset.
91+ :param shuffle_seed: An optional seed for shuffling the dataset.
92+ :return: A TorchStream instance for regression.
9293 """
93- if not (class_names is None or len (class_names ) == num_classes ):
94- raise ValueError ("Number of class labels must match the number of classes" )
9594
96- self .__init_args_kwargs__ = copy .copy (
97- locals ()
98- ) # save init args for recreation. not a deep copy to avoid unnecessary use of memory
95+ # Construct the schema based on the dataset and provided parameters
96+ X , _ = dataset [0 ]
97+ n_features = X .numel ()
98+ features = [str (f ) for f in range (n_features )] + ["target" ]
99+ schema = Schema .from_custom (
100+ features = features ,
101+ target = "target" ,
102+ name = dataset_name ,
103+ )
104+
105+ dataset = _shuffle_dataset (dataset , seed = shuffle_seed ) if shuffle else dataset
106+ return TorchStream (dataset , schema )
99107
100- self ._dataset = dataset
101- self ._index = 0
102- self ._permutation = torch .arange (len (dataset ))
108+ @staticmethod
109+ def from_classification (
110+ dataset : Dataset [Tuple [torch .Tensor , torch .Tensor | int ]],
111+ num_classes : int ,
112+ class_names : Optional [Sequence [str ]] = None ,
113+ dataset_name : str = "TorchStream" ,
114+ shape : Optional [Sequence [int ]] = None ,
115+ shuffle : bool = False ,
116+ shuffle_seed : Optional [int ] = None ,
117+ ) -> "TorchStream" :
118+ """Construct a stream for classification from a PyTorch Dataset.
119+
120+ :param dataset: A PyTorch Dataset that yields tuples of (features, target).
121+ :param num_classes: The number of classes in the classification task.
122+ :param class_names: An optional sequence of class names corresponding to the class indices.
123+ :param dataset_name: An optional name for the stream.
124+ :param shape: An optional shape for the features. If not provided, features will
125+ be treated as flat vectors.
126+ :param shuffle: Whether to shuffle the dataset.
127+ :param shuffle_seed: An optional seed for shuffling the dataset.
128+ :return: A TorchStream instance.
129+ """
103130
104- if shuffle :
105- self ._permutation = torch .randperm (
106- len (dataset ),
107- generator = torch .Generator ().manual_seed (shuffle_seed ),
108- )
131+ if class_names is None :
132+ class_names = [str (k ) for k in range (num_classes )]
133+ if len (class_names ) != num_classes :
134+ raise ValueError ("Length of class_names must match num_classes." )
109135
110- # Use the first instance to infer the number of attributes
111- X , _ = self . _dataset [0 ]
136+ # Construct the schema based on the dataset and provided parameters
137+ X , _ = dataset [0 ]
112138 n_features = X .numel ()
113-
114- # Create a header describing the dataset for MOA
115- self ._schema = Schema .from_custom (
116- features = [f"{ f } " for f in range (n_features )] + ["class" ],
139+ features = [str (f ) for f in range (n_features )] + ["class" ]
140+ schema = Schema .from_custom (
141+ features = features ,
117142 target = "class" ,
118- categories = {"class" : class_names or [ str ( i ) for i in range ( num_classes )] },
143+ categories = {"class" : class_names },
119144 name = dataset_name ,
120145 )
121- if shape is not None :
122- self ._schema .shape = shape
146+ schema ._shape = shape if shape is not None else (n_features ,)
147+
148+ dataset = _shuffle_dataset (dataset , seed = shuffle_seed ) if shuffle else dataset
149+ return TorchStream (dataset , schema )
150+
151+ def __init__ (
152+ self ,
153+ dataset : Dataset ,
154+ schema : Schema ,
155+ ):
156+ """Construct a TorchStream from a PyTorch Dataset and a Schema.
157+
158+ Usually you want :meth:`from_classification` or :meth:`from_regression`.
159+
160+ :param dataset: A PyTorch Dataset that yields tuples of (features, target).
161+ :param schema: A Schema object that describes the structure of the data,
162+ including feature names and target information.
163+ """
164+ self ._dataset = dataset
165+ self .schema = schema
166+ self ._index = 0
123167
124168 def has_more_instances (self ):
125169 return len (self ._dataset ) > self ._index
@@ -128,19 +172,26 @@ def next_instance(self):
128172 if not self .has_more_instances ():
129173 raise StopIteration ()
130174
131- X , y = self ._dataset [self ._permutation [ self . _index ] ]
175+ X , y = self ._dataset [self ._index ]
132176 self ._index += 1 # increment counter for next call
133177
134- # Tensors on the CPU and NumPy arrays share their underlying memory locations
135- # We should prefer numpy over tensors in instances to improve compatibility
136- # See: https://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html#bridge-to-np-label
137- X = X .view (- 1 ).numpy ()
138- if isinstance (y , torch .Tensor ) and torch .isnan (y ):
139- y = - 1
140- return LabeledInstance .from_array (self ._schema , X , int (y ))
178+ if self .schema .is_classification ():
179+ # Tensors on the CPU and NumPy arrays share their underlying memory locations
180+ # We should prefer numpy over tensors in instances to improve compatibility
181+ # See: https://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html#bridge-to-np-label
182+ X = X .view (- 1 ).numpy ()
183+ if isinstance (y , torch .Tensor ) and torch .isnan (y ):
184+ y = - 1
185+ return LabeledInstance .from_array (self .schema , X , int (y ))
186+ elif self .schema .is_regression ():
187+ X = X .view (- 1 ).numpy ()
188+ y = y .item () # Convert single-value tensor to a Python scalar
189+ return RegressionInstance .from_array (self .schema , X , y )
190+ else :
191+ raise ValueError ("Schema must be either classification or regression." )
141192
142193 def get_schema (self ):
143- return self ._schema
194+ return self .schema
144195
145196 def get_moa_stream (self ):
146197 return None
0 commit comments