-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathdataUtils.jl
More file actions
49 lines (41 loc) · 1.76 KB
/
dataUtils.jl
File metadata and controls
49 lines (41 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
using PyCall
torchvision = pyimport("torchvision")
transforms = pyimport("torchvision.transforms")
datasets = pyimport("torchvision.datasets")
function getDataLoaders(trainData, testData, batchsize)
trainLoader = torch.utils.data.DataLoader(dataset=trainData,
batch_size=batchsize,
shuffle=true)
testLoader = torch.utils.data.DataLoader(dataset=testData,
batch_size=batchsize,
shuffle=false)
trainLoader, testLoader
end
function getmnistDataLoaders(batchsize)
trainData = datasets.MNIST(root="../data",
train=true,
transform=transforms.ToTensor(),
download=true)
testData = datasets.MNIST(root="../data",
train=false,
transform=transforms.ToTensor())
getDataLoaders(trainData, testData, batchsize)
end
function getcifar10DataLoaders(batchsize)
transform = transforms.Compose([
transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor()])
trainData = datasets.CIFAR10(root="../data",
train=true,
transform=transform,
download=true)
testData = datasets.CIFAR10(root="../data",
train=false,
transform=transforms.ToTensor())
getDataLoaders(trainData, testData, batchsize)
end
#trainLoader, testLoader = getcifar10DataLoaders(128)
#x = trainLoader[:dataset][1] |> first
#x[:shape]