-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Expand file tree
/
Copy pathdigit_identification.py
More file actions
43 lines (33 loc) · 1.39 KB
/
digit_identification.py
File metadata and controls
43 lines (33 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license.
import os
import numpy as np
import tensorflow as tf
from PIL import Image
from azureml.core import Model
# Disable eager execution
tf.compat.v1.disable_eager_execution()
def init():
global g_tf_sess
# pull down model from workspace
model_path = Model.get_model_path("mnist-prs")
# contruct graph to execute
tf.compat.v1.reset_default_graph()
saver = tf.compat.v1.train.import_meta_graph(os.path.join(model_path, 'mnist-tf.model.meta'))
g_tf_sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(device_count={'GPU': 0}))
saver.restore(g_tf_sess, os.path.join(model_path, 'mnist-tf.model'))
def run(mini_batch):
print(f'run method start: {__file__}, run({mini_batch})')
resultList = []
in_tensor = g_tf_sess.graph.get_tensor_by_name("network/X:0")
output = g_tf_sess.graph.get_tensor_by_name("network/output/MatMul:0")
for image in mini_batch:
# prepare each image
data = Image.open(image)
np_im = np.array(data).reshape((1, 784))
# perform inference
inference_result = g_tf_sess.run(output, feed_dict={in_tensor: np_im})
# find best probability, and add to result list
best_result = np.argmax(inference_result)
resultList.append("{}: {}".format(os.path.basename(image), best_result))
return resultList