-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathdataset.py
More file actions
111 lines (93 loc) · 3.58 KB
/
dataset.py
File metadata and controls
111 lines (93 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import json
import glob
from collections import defaultdict
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
tr_normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]
)
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
tr_normalize,
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
tr_normalize,
])
class VOCDataset(Dataset):
def __init__(self, voc_root, is_train=True):
mode = 'trainval' if is_train else 'test'
# find the image sets
image_set_dir = os.path.join(voc_root, 'ImageSets', 'Main')
image_sets = glob.glob(os.path.join(image_set_dir, '*_' + mode + '.txt'))
assert len(image_sets) == 20
# read the labels
self.n_labels = len(image_sets)
images = defaultdict(lambda:-np.ones(self.n_labels, dtype=np.uint8))
for k, s in enumerate(sorted(image_sets)):
for l in open(s, 'r'):
name, lbl = l.strip().split()
lbl = int(lbl)
if lbl < 0:
lbl = 0
elif lbl == 0:
lbl = 255
images[os.path.join(voc_root, 'JPEGImages', name + '.jpg')][k] = lbl
self.images = [(k, images[k]) for k in images.keys()]
np.random.shuffle(self.images)
self.transform = train_transform if is_train else val_transform
def __len__(self):
return len(self.images)
def __getitem__(self, i):
image = self.transform(Image.open(self.images[i][0]).convert('RGB'))
label = torch.tensor(self.images[i][1]).float()
return image, label
def _parse_voc_xml(self, node):
label = torch.zeros(len(self.cat_map))
for child in node:
if child.tag == 'object':
label[self.cat_map[list(child)[0].text]] = 1.
return label
class COCODataset(Dataset):
def __init__(self, data_dir, is_train=True):
mode = 'train' if is_train else 'val'
# load annotation file
annotation_file = os.path.join(data_dir, 'annotations', f'instances_{mode}2014.json')
with open(annotation_file, 'r') as f:
annotation = json.load(f)
# construct image paths
id_map = {}
self.image_paths = []
image_dir = os.path.join(data_dir, f'{mode}2014')
for idx, image in enumerate(annotation['images']):
id_map[image['id']] = idx
self.image_paths.append(os.path.join(image_dir, image['file_name']))
# create category mapping
cat_map = {}
for idx, cat in enumerate(annotation['categories']):
cat_map[cat['id']] = idx
# construct labels
self.labels = torch.zeros(len(annotation['images']), len(cat_map))
for instance in annotation['annotations']:
img_idx = id_map[instance['image_id']]
cat_idx = cat_map[instance['category_id']]
self.labels[img_idx, cat_idx] = 1.
# set transform
self.transform = train_transform if is_train else val_transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
label = self.labels[idx]
image_path = self.image_paths[idx]
image = self.transform(Image.open(image_path).convert('RGB'))
return image, label