-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcustom_dataset_loader.py
More file actions
46 lines (37 loc) · 1.43 KB
/
custom_dataset_loader.py
File metadata and controls
46 lines (37 loc) · 1.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms, utils
from torch.utils import data
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils
from os.path import join
from os import listdir
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
class DatasetFromFolder(data.Dataset):
def __init__(self, image_dir, img_size):
super(DatasetFromFolder, self).__init__()
# self.src_img_path = join(image_dir, "src-font-images")
self.b_path = join(image_dir, "train_imgs")
self.image_filenames = [x for x in listdir(self.b_path) if is_image_file(x)]
self.img_size = img_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def __getitem__(self, index):
# Get Images
image = self.transform(Image.open(join(self.b_path, self.image_filenames[index])))
return image
def __len__(self):
return len(self.image_filenames)
def plot_data(loader):
# Plot some training images
image = next(iter(loader))
print("image shape", image.shape)
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("training images")
plt.imshow(np.transpose(vutils.make_grid(image[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()