forked from Gpialla/DataAugForTSC
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelper.py
More file actions
157 lines (134 loc) · 4.57 KB
/
Copy pathhelper.py
File metadata and controls
157 lines (134 loc) · 4.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Imports
from data.ucr_archive import load_dataset as load_ucr
from data.ucr_archive import UCR_ARCHIVE_2015_DATASETS, UCR_ARCHIVE_2018_DATASETS, UCR_VERSIONS
from data.uea_archive import load_dataset as load_uea
from data.uea_archive import UEA_ARCHIVE_2018_DATASETS
from data.digitsRTD import load_dataset as load_digits_dataset
from data.adv_p_dataset import load_dataset as load_adv_p
PREPROCESSINGS_NAMES = ["z_norm", "feature_scaling"]
def load_ds_from__archive(archive_name, ds_name, ds_version):
if archive_name == "UCR":
return load_ucr_dataset(ds_name, int(ds_version))
elif archive_name == "adv_p":
return load_adv_p_dataset(ds_name, ds_version)
def get_preprocessing_by_name(name):
"""
Args:
name (str): The name of the method.
Returns:
function: Returns the corresponding preprocessing method.
"""
if name=="z_norm":
from data.data_preprocessing import z_norm
return z_norm
elif name=="feature_scaling":
from data.data_preprocessing import feature_scaling
return feature_scaling
def load_adv_p_dataset(ds_name, DS_version):
"""
Returns:
tuple: The dataset.
x_train, y_train, x_test, y_test = load_ucr_dataset(ds_name, UCR_version)
"""
from data.adv_p_dataset import load_dataset as load_adv_p
return load_adv_p(ds_name, DS_version)
def load_ucr_dataset(ds_name, UCR_version):
"""
Args:
UCR_version (int): The UCR version
Returns:
tuple: The dataset.
x_train, y_train, x_test, y_test = load_ucr_dataset(ds_name, UCR_version)
"""
return load_ucr(ds_name, UCR_version)
def load_digits_dataset():
"""
Returns:
tuple: The dataset.
x_train, y_train, x_test, y_test = load_ucr_dataset(ds_name, UCR_version)
"""
return load_digits_dataset()
def load_adv_p_dataset(ds_name, UCR_version):
"""
Returns:
tuple: The dataset.
x_train, y_train, x_test, y_test = load_ucr_dataset(ds_name, UCR_version)
"""
return load_adv_p(ds_name)
def load_uea_dataset(ds_name):
"""
Args:
ds_name (str): The name of the dataset
Returns:
tuple: The dataset.
x_train, y_train, x_test, y_test = load_uea(ds_name)
"""
return load_uea(ds_name)
def get_ucr_list_datasets(list_ds_name, UCR_version):
"""
Args:
list_ds_name (list): A list containing the dataset names
UCR_version (int): The UCR version
Returns:
dict: A dict containing all datasets.
x_train, y_train, x_test, y_test = dict[ds_name]
"""
all_ds = {}
for ds_name in list_ds_name:
all_ds[ds_name] = load_ucr(ds_name, UCR_version)
return all_ds
def get_uea_list_datasets(list_ds_name):
"""
Args:
list_ds_name (list): A list containing the dataset names
Returns:
dict: A dict containing all datasets.
x_train, y_train, x_test, y_test = dict[ds_name]
"""
all_ds = {}
for ds_name in list_ds_name:
all_ds[ds_name] = load_uea(ds_name)
return all_ds
def get_ucr_all(UCR_version):
"""
Args:
UCR_version (int): The UCR version
Returns:
dict: A dict containing all datasets.
x_train, y_train, x_test, y_test = dict[ds_name]
"""
if UCR_version==2015:
return get_ucr_list_datasets(UCR_ARCHIVE_2015_DATASETS, UCR_version)
if UCR_version==2018:
return get_ucr_list_datasets(UCR_ARCHIVE_2018_DATASETS, UCR_version)
def get_uea_all():
"""
Returns:
dict: A dict containing all datasets.
x_train, y_train, x_test, y_test = dict[ds_name]
"""
return get_uea_list_datasets(UEA_ARCHIVE_2018_DATASETS)
def get_ucr_first_10(UCR_version):
"""
Args:
UCR_version (int): The UCR version
Returns:
dict: The first 10 datasets of the UCR archive.
x_train, y_train, x_test, y_test = dict[ds_name]
"""
if UCR_version==2015:
return get_ucr_list_datasets(sorted(UCR_ARCHIVE_2015_DATASETS[:10]), UCR_version)
if UCR_version==2018:
return get_ucr_list_datasets(sorted(UCR_ARCHIVE_2018_DATASETS[:10]), UCR_version)
def get_ucr_last_10(UCR_version):
"""
Args:
UCR_version (int): The UCR version
Returns:
dict: The last 10 datasets of the UCR archive.
x_train, y_train, x_test, y_test = dict[ds_name]
"""
if UCR_version==2015:
return get_ucr_list_datasets(sorted(UCR_ARCHIVE_2015_DATASETS[-10:]), UCR_version)
if UCR_version==2018:
return get_ucr_list_datasets(sorted(UCR_ARCHIVE_2018_DATASETS[-10:]), UCR_version)