|
5 | 5 | See: http://pytorch.org/docs/0.3.1/data.html |
6 | 6 | ''' |
7 | 7 |
|
8 | | -import os |
9 | | - |
10 | 8 | import torch |
11 | 9 | from PIL import Image |
12 | 10 | import torch.utils.data |
@@ -75,39 +73,7 @@ def __getitem__(self, i): |
75 | 73 | assert len(set(tiles)) == 1, 'all images are for the same tile' |
76 | 74 | assert tiles[0] == mask_tile, 'image tile is the same as mask tile' |
77 | 75 |
|
78 | | - return (torch.cat(images, dim=0), mask, tiles) |
79 | | - |
80 | | - |
81 | | -class ImageDirectory(torch.utils.data.Dataset): |
82 | | - '''Dataset to read images from a single directory. |
83 | | - ''' |
84 | | - |
85 | | - def __init__(self, root, transform=None): |
86 | | - '''Creates an `ImageDirectory` instance. |
87 | | -
|
88 | | - Args: |
89 | | - root: the base directory where images reside. |
90 | | - transform: the transformation to run on each image. |
91 | | - ''' |
92 | | - |
93 | | - super().__init__() |
94 | | - |
95 | | - self.root = root |
96 | | - self.transform = transform |
97 | | - self.file_names = os.listdir(root) |
98 | | - |
99 | | - def __len__(self): |
100 | | - return len(self.file_names) |
101 | | - |
102 | | - def __getitem__(self, i): |
103 | | - name = self.file_names[i] |
104 | | - path = os.path.join(self.root, name) |
105 | | - image = Image.open(path).convert('RGB') |
106 | | - |
107 | | - if self.transform is not None: |
108 | | - image = self.transform(image) |
109 | | - |
110 | | - return image, name |
| 76 | + return torch.cat(images, dim=0), mask, tiles |
111 | 77 |
|
112 | 78 |
|
113 | 79 | # Todo: once we have the SlippyMapDataset this dataset should wrap |
|
0 commit comments