-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdatasets.py
More file actions
20 lines (14 loc) · 799 Bytes
/
datasets.py
File metadata and controls
20 lines (14 loc) · 799 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
def get_dataset(path, batch_size, ih, iw):
trf = transforms.Compose([ \
transforms.Resize((ih,iw)),
transforms.ToTensor()])
# train_set = datasets.STL10(path, 'train', transform=trf, download=True)
# test_set = datasets.STL10(path, 'test', transform=trf, download=True)
train_set = datasets.MNIST(path, True, transform=trf, download=True)
test_set = datasets.MNIST(path, False, transform=trf, download=True)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last=True)
return train_loader, test_loader