Skip to content

Commit 33a646d

Browse files
authored
Merge pull request #1089 from AndreaCossu/master
AvalancheDatasetType removal, dataset collate_fn used in dataloading during training and eval, sequential benchmark
2 parents 5359cee + aaa8fdd commit 33a646d

30 files changed

Lines changed: 392 additions & 434 deletions

avalanche/benchmarks/classic/stream51.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import math
2222
import os
2323

24-
from avalanche.benchmarks.utils import AvalancheDatasetType
25-
2624
_mu = [0.485, 0.456, 0.406]
2725
_std = [0.229, 0.224, 0.225]
2826
_default_stream51_transform = transforms.Compose(
@@ -286,8 +284,7 @@ def CLStream51(
286284
task_labels=[0 for _ in range(num_tasks)],
287285
complete_test_set_only=scenario == "instance",
288286
train_transform=train_transform,
289-
eval_transform=eval_transform,
290-
dataset_type=AvalancheDatasetType.CLASSIFICATION,
287+
eval_transform=eval_transform
291288
)
292289

293290
return benchmark_obj
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
################################################################################
2+
# Copyright (c) 2022 ContinualAI. #
3+
# Copyrights licensed under the MIT License. #
4+
# See the accompanying LICENSE file for terms. #
5+
# #
6+
# Author(s): Andrea Cossu #
7+
# E-mail: contact@continualai.org #
8+
# Website: www.continualai.org #
9+
################################################################################
10+
11+
""" This module conveniently wraps TorchAudio Datasets for using a clean and
12+
comprehensive Avalanche API."""
13+
14+
try:
15+
import torchaudio
16+
except ImportError:
17+
raise ModuleNotFoundError(
18+
"TorchAudio package is required to load its dataset. "
19+
"You can install it as extra dependency with "
20+
"`pip install avalanche-lib[extra]`")
21+
from torchaudio.datasets import SPEECHCOMMANDS
22+
from avalanche.benchmarks.utils import AvalancheDataset
23+
from avalanche.benchmarks.datasets import default_dataset_location
24+
import torch
25+
26+
27+
def speech_commands_collate(batch):
28+
tensors, targets, t_labels = [], [], []
29+
for waveform, label, rate, sid, uid, t_label in batch:
30+
tensors += [waveform]
31+
targets += [torch.tensor(label)]
32+
t_labels += [torch.tensor(t_label)]
33+
tensors = [item.t() for item in tensors]
34+
tensors = torch.nn.utils.rnn.pad_sequence(tensors,
35+
batch_first=True,
36+
padding_value=0.)
37+
if len(tensors.size()) == 2: # no MFCC, add feature dimension
38+
tensors = tensors.unsqueeze(-1)
39+
targets = torch.stack(targets)
40+
t_labels = torch.stack(t_labels)
41+
return tensors, targets, t_labels
42+
43+
44+
class SpeechCommandsData(SPEECHCOMMANDS):
45+
def __init__(self, root, url, download, subset, mfcc_preprocessing):
46+
super().__init__(root=root, download=download,
47+
subset=subset, url=url)
48+
self.labels_names = ['backward', 'bed', 'bird', 'cat', 'dog', 'down',
49+
'eight', 'five', 'follow', 'forward', 'four',
50+
'go', 'happy', 'house', 'learn', 'left',
51+
'marvin', 'nine', 'no', 'off', 'on', 'one',
52+
'right', 'seven', 'sheila', 'six', 'stop',
53+
'three', 'tree', 'two', 'up', 'visual',
54+
'wow', 'yes', 'zero']
55+
self.mfcc_preprocessing = mfcc_preprocessing
56+
57+
def __getitem__(self, item):
58+
wave, rate, label, speaker_id, ut_number = super().__getitem__(item)
59+
label = self.labels_names.index(label)
60+
wave = wave.squeeze(0) # (T,)
61+
if self.mfcc_preprocessing is not None:
62+
assert rate == self.mfcc_preprocessing.sample_rate
63+
# (T, MFCC)
64+
wave = self.mfcc_preprocessing(wave).permute(1, 0)
65+
return wave, label, rate, speaker_id, ut_number
66+
67+
68+
def SpeechCommands(root=default_dataset_location(''),
69+
url='speech_commands_v0.02',
70+
download=True, subset=None,
71+
mfcc_preprocessing=None):
72+
"""
73+
root: dataset root location
74+
url: version name of the dataset
75+
download: automatically download the dataset, if not present
76+
subset: one of 'training', 'validation', 'testing'
77+
mfcc_preprocessing: an optional torchaudio.transforms.MFCC instance
78+
to preprocess each audio. Warning: this may slow down the execution
79+
since preprocessing is applied on-the-fly each time a sample is
80+
retrieved from the dataset.
81+
"""
82+
dataset = SpeechCommandsData(root=root, download=download,
83+
subset=subset, url=url,
84+
mfcc_preprocessing=mfcc_preprocessing)
85+
labels = [datapoint[1] for datapoint in dataset]
86+
return AvalancheDataset(dataset,
87+
collate_fn=speech_commands_collate,
88+
targets=labels)
89+
90+
91+
__all__ = ['SpeechCommands']

avalanche/benchmarks/generators/benchmark_generators.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from avalanche.benchmarks.utils.avalanche_dataset import (
5050
SupportedDataset,
5151
AvalancheDataset,
52-
AvalancheDatasetType,
5352
AvalancheSubset,
5453
)
5554

@@ -224,14 +223,14 @@ class "34" will be mapped to "1", class "11" to "2" and so on.
224223
train_dataset,
225224
transform_groups=transform_groups,
226225
initial_transform_group="train",
227-
dataset_type=AvalancheDatasetType.CLASSIFICATION,
226+
targets_adapter=int
228227
)
229228

230229
test_dataset = AvalancheDataset(
231230
test_dataset,
232231
transform_groups=transform_groups,
233232
initial_transform_group="eval",
234-
dataset_type=AvalancheDatasetType.CLASSIFICATION,
233+
targets_adapter=int
235234
)
236235

237236
return NCScenario(
@@ -348,14 +347,14 @@ def ni_benchmark(
348347
seq_train_dataset,
349348
transform_groups=transform_groups,
350349
initial_transform_group="train",
351-
dataset_type=AvalancheDatasetType.CLASSIFICATION,
350+
targets_adapter=int
352351
)
353352

354353
seq_test_dataset = AvalancheDataset(
355354
seq_test_dataset,
356355
transform_groups=transform_groups,
357356
initial_transform_group="eval",
358-
dataset_type=AvalancheDatasetType.CLASSIFICATION,
357+
targets_adapter=int
359358
)
360359

361360
return NIScenario(

avalanche/benchmarks/generators/scenario_generators.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
from avalanche.benchmarks.utils import concat_datasets_sequentially
3939
from avalanche.benchmarks.utils.avalanche_dataset import (
4040
SupportedDataset,
41-
as_classification_dataset,
42-
AvalancheDatasetType,
41+
as_classification_dataset
4342
)
4443

4544

@@ -325,8 +324,7 @@ def dataset_scenario(
325324
test_dataset_list: Sequence[SupportedDataset],
326325
task_labels: Sequence[int],
327326
*,
328-
complete_test_set_only: bool = False,
329-
dataset_type: AvalancheDatasetType = AvalancheDatasetType.UNDEFINED
327+
complete_test_set_only: bool = False
330328
) -> GenericCLScenario:
331329
"""
332330
This helper function is DEPRECATED in favor of `dataset_benchmark`.
@@ -363,10 +361,6 @@ def dataset_scenario(
363361
parameter must be list with a single element (the complete test set).
364362
Defaults to False, which means that ``train_dataset_list`` and
365363
``test_dataset_list`` must contain the same amount of datasets.
366-
:param dataset_type: The type of the dataset. Defaults to None, which
367-
means that the type will be obtained from the input datasets. If input
368-
datasets are not instances of :class:`AvalancheDataset`, the type
369-
UNDEFINED will be used.
370364
371365
:returns: A properly initialized :class:`GenericCLScenario` instance.
372366
"""
@@ -380,8 +374,7 @@ def dataset_scenario(
380374
train_dataset_list=train_dataset_list,
381375
test_dataset_list=test_dataset_list,
382376
task_labels=task_labels,
383-
complete_test_set_only=complete_test_set_only,
384-
dataset_type=dataset_type,
377+
complete_test_set_only=complete_test_set_only
385378
)
386379

387380

@@ -482,8 +475,7 @@ def paths_scenario(
482475
train_transform=None,
483476
train_target_transform=None,
484477
eval_transform=None,
485-
eval_target_transform=None,
486-
dataset_type: AvalancheDatasetType = AvalancheDatasetType.UNDEFINED
478+
eval_target_transform=None
487479
) -> GenericCLScenario:
488480
"""
489481
This helper function is DEPRECATED in favor of `paths_benchmark`.
@@ -545,7 +537,6 @@ def paths_scenario(
545537
comprehensive list of possible transformations). Defaults to None.
546538
:param eval_target_transform: The transformation to apply to test
547539
patterns targets. Defaults to None.
548-
:param dataset_type: The type of the dataset. Defaults to UNDEFINED.
549540
550541
:returns: A properly initialized :class:`GenericCLScenario` instance.
551542
"""
@@ -563,8 +554,7 @@ def paths_scenario(
563554
train_transform=train_transform,
564555
train_target_transform=train_target_transform,
565556
eval_transform=eval_transform,
566-
eval_target_transform=eval_target_transform,
567-
dataset_type=dataset_type,
557+
eval_target_transform=eval_target_transform
568558
)
569559

570560

@@ -577,8 +567,7 @@ def tensors_scenario(
577567
train_transform=None,
578568
train_target_transform=None,
579569
eval_transform=None,
580-
eval_target_transform=None,
581-
dataset_type: AvalancheDatasetType = AvalancheDatasetType.UNDEFINED
570+
eval_target_transform=None
582571
) -> GenericCLScenario:
583572
"""
584573
This helper function is DEPRECATED in favor of `tensors_benchmark`.
@@ -635,7 +624,6 @@ def tensors_scenario(
635624
comprehensive list of possible transformations). Defaults to None.
636625
:param eval_target_transform: The transformation to apply to test
637626
patterns targets. Defaults to None.
638-
:param dataset_type: The type of the dataset. Defaults to UNDEFINED.
639627
640628
:returns: A properly initialized :class:`GenericCLScenario` instance.
641629
"""
@@ -653,8 +641,7 @@ def tensors_scenario(
653641
train_transform=train_transform,
654642
train_target_transform=train_target_transform,
655643
eval_transform=eval_transform,
656-
eval_target_transform=eval_target_transform,
657-
dataset_type=dataset_type,
644+
eval_target_transform=eval_target_transform
658645
)
659646

660647

@@ -669,8 +656,7 @@ def tensor_scenario(
669656
train_transform=None,
670657
train_target_transform=None,
671658
eval_transform=None,
672-
eval_target_transform=None,
673-
dataset_type: AvalancheDatasetType = AvalancheDatasetType.UNDEFINED
659+
eval_target_transform=None
674660
) -> GenericCLScenario:
675661
"""
676662
This helper function is DEPRECATED in favor of `tensors_benchmark`.
@@ -731,7 +717,6 @@ def tensor_scenario(
731717
comprehensive list of possible transformations). Defaults to None.
732718
:param eval_target_transform: The transformation to apply to test
733719
patterns targets. Defaults to None.
734-
:param dataset_type: The type of the dataset. Defaults to UNDEFINED.
735720
736721
:returns: A properly initialized :class:`GenericCLScenario` instance.
737722
"""
@@ -782,8 +767,7 @@ def tensor_scenario(
782767
train_transform=train_transform,
783768
train_target_transform=train_target_transform,
784769
eval_transform=eval_transform,
785-
eval_target_transform=eval_target_transform,
786-
dataset_type=dataset_type,
770+
eval_target_transform=eval_target_transform
787771
)
788772

789773

0 commit comments

Comments
 (0)