-
Notifications
You must be signed in to change notification settings - Fork 124
Expand file tree
/
Copy pathdatautils.py
More file actions
80 lines (64 loc) · 3.01 KB
/
datautils.py
File metadata and controls
80 lines (64 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import numpy as np
import sys
from urllib import request
from torch.utils.data import Dataset
sys.path.append("../semi-supervised")
n_labels = 10
cuda = torch.cuda.is_available()
class SpriteDataset(Dataset):
"""
A PyTorch wrapper for the dSprites dataset by
Matthey et al. 2017. The dataset provides a 2D scene
with a sprite under different transformations:
* color
* shape
* scale
* orientation
* x-position
* y-position
"""
def __init__(self, transform=None):
self.transform = transform
url = "https://github.com/deepmind/dsprites-dataset/raw/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"
try:
self.dset = np.load("./dsprites.npz", encoding="bytes")["imgs"]
except FileNotFoundError:
request.urlretrieve(url, "./dsprites.npz")
self.dset = np.load("./dsprites.npz", encoding="bytes")["imgs"]
def __len__(self):
return len(self.dset)
def __getitem__(self, idx):
sample = self.dset[idx]
if self.transform:
sample = self.transform(sample)
return sample
def get_mnist(location="./", batch_size=64, labels_per_class=100):
from functools import reduce
from operator import __or__
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from utils import onehot
flatten_bernoulli = lambda x: transforms.ToTensor()(x).view(-1).bernoulli()
mnist_train = MNIST(location, train=True, download=True,
transform=flatten_bernoulli, target_transform=onehot(n_labels))
mnist_valid = MNIST(location, train=False, download=True,
transform=flatten_bernoulli, target_transform=onehot(n_labels))
def get_sampler(labels, n=None):
# Only choose digits in n_labels
(indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(n_labels)]))
# Ensure uniform distribution of labels
np.random.shuffle(indices)
indices = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(n_labels)])
indices = torch.from_numpy(indices)
sampler = SubsetRandomSampler(indices)
return sampler
# Dataloaders for MNIST
labelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, num_workers=2, pin_memory=cuda,
sampler=get_sampler(mnist_train.train_labels.numpy(), labels_per_class))
unlabelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, num_workers=2, pin_memory=cuda,
sampler=get_sampler(mnist_train.train_labels.numpy()))
validation = torch.utils.data.DataLoader(mnist_valid, batch_size=batch_size, num_workers=2, pin_memory=cuda,
sampler=get_sampler(mnist_valid.test_labels.numpy()))
return labelled, unlabelled, validation