-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdataset.py
More file actions
59 lines (49 loc) · 1.89 KB
/
dataset.py
File metadata and controls
59 lines (49 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.utils.data as torch_data
import pandas as pd
import numpy as np
import math
class RefitDataset(torch_data.Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return len(self.data[0])
def init_transformation(self, transform):
if not self.transform:
self.transform = transform
else:
print('Transformations are already predefined and you cannot initialize other transformations')
def __getitem__(self, index):
aggregate, iam = self.data[0][index], self.data[1][index]
if self.transform:
sample = {}
sample['Aggregate'] = aggregate
sample['Individual'] = iam
aggregate, iam = self.transform(sample)
aggregate, iam = torch.from_numpy(aggregate).double(), torch.from_numpy(iam).double()
return aggregate, iam
class REDDDataset(torch_data.Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def get_mean_and_std(self):
array = self.data[0]
#array = np.reshape(array, (1, -1))
return array.mean(), array.std()
def __len__(self):
return len(self.data[0])
def init_transformation(self, transform):
if not self.transform:
self.transform = transform
else:
print('Transformations are already predefined and you cannot initialize other transformations')
def __getitem__(self, index):
aggregate, iam = self.data[0][index], self.data[1][index]
if self.transform:
sample = {}
sample['Aggregate'] = aggregate
sample['Individual'] = iam
aggregate, iam = self.transform(sample)
aggregate, iam = torch.from_numpy(aggregate), torch.from_numpy(iam)
return aggregate, iam