-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpredict.py
More file actions
131 lines (108 loc) · 5.11 KB
/
predict.py
File metadata and controls
131 lines (108 loc) · 5.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#Author: Udit Maherwal
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn,optim
import torch.nn.functional as F
import torchvision
import json
import os
from torchvision import datasets,transforms,models
from collections import OrderedDict
import time
import PIL
from PIL import Image
import seaborn as sb
import train
def arguments_parser():
parser = argparse.ArgumentParser("Flowers Prediction - Probabilities and image class")
parser.add_argument('--checkpoint_01',dest='checkpoint', help="stored the snapshot of your trained model", default="checkpoint_01.pth")
parser.add_argument('--image_path',help="image for which we have to find probabilities and image_class",default ="flowers/test/1/image_06752.jpg" )
parser.add_argument('--device',dest='gpu',type=str,help ="Either GPU(Faster) or CPU(Prefer only if you have GPU driver installed)",default='gpu')
parser.add_argument('--k',dest='k',help='top labels taken till',default=5)
return parser.parse_args()
def process_image(image):
'''Scales, crops, and normalizes a PIL image for a PyTorch model,
returns an Numpy array
'''
im=Image.open(image)
width, height = im.size
if height > width :
scaling = 256 , 256 * (height/width)
else:
scaling = 256 * (width/height) , 256
im.thumbnail(scaling,Image.ANTIALIAS)
left = (width/4 - 112)
top = (height/4 - 112)
right = (width/4 + 112)
bottom = (height/4 + 112)
im = im.crop((left, top, right, bottom))
color_channel = (np.array(im))/255
means = [0.485, 0.456, 0.406]
deviation = [0.229, 0.224, 0.225]
nomralized_image = ((color_channel - means)/deviation).transpose((2,0,1))
return nomralized_image
def predict(image_path,model,k=5):
'''Predict the class (or classes) of an image using a trained deep learning model.
'''
model.cuda()
image = torch.from_numpy(np.expand_dims(process_image(image_path),axis=0)).type(torch.FloatTensor).to('cuda')
output = model(image)
probs, labels = output.topk(k)
probs= torch.exp(probs)
probs = np.array(probs.detach())[0]
labels = np.array(labels.detach())[0]
idx_to_class = {value: key for key, value in model.class_to_idx.items()}
tlabels = [idx_to_class[label] for label in labels]
return probs,tlabels
def main():
arguments = arguments_parser()
i=0
gpu = arguments.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'
training_transforms = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomRotation(30),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([.485,.456,.406],
[.229,.224,.225])])
validation_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([.485,.456,.406],
[.229,.224,.225])])
testing_transfroms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([.485,.456,.406],
[.229,.224,.225])])
#Loading the datasets with ImageFolder
image_datasets = [datasets.ImageFolder(train_dir,transform=training_transforms),
datasets.ImageFolder(valid_dir,transform=validation_transforms),
datasets.ImageFolder(test_dir,transform=testing_transfroms)]
#Using the image datasets and the trainforms, defining the dataloaders
dataloaders = [torch.utils.data.DataLoader(image_datasets[0],batch_size=64,shuffle=True),
torch.utils.data.DataLoader(image_datasets[1],batch_size=64,shuffle=True),
torch.utils.data.DataLoader(image_datasets[2],batch_size=64,shuffle=True)]
with open('cat_to_name.json', 'r') as f:
cat_to_name = json.load(f)
load_from = arguments.checkpoint
model= train.load_checkpoint(load_from)
image = arguments.image_path
k = arguments.k
model.class_to_idx = image_datasets[0].class_to_idx
probabilities,image_class = predict(image,model,k)
labels = [cat_to_name[str(index)] for index in image_class]
probs = probabilities
print('Image: ' + image)
i=0
while i < 5:
print("{} with a probability of {}".format(labels[i], probs[i]))
i += 1
if __name__ == "__main__":
main()