Skip to content

Commit 798e21f

Browse files
committed
bug #295 solved
Signed-off-by: rly09 <yogiroshan2005@gmail.com>
1 parent 3a0d50a commit 798e21f

4 files changed

Lines changed: 90 additions & 14 deletions

File tree

tutorials/TensorflowToOnnx-1.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@
5757
"source": [
5858
"import os\n",
5959
"import shutil\n",
60-
"import tensorflow as tf\n",
60+
"import tensorflow.compat.v1 as tf\n",
61+
"tf.disable_v2_behavior()\n",
6162
"from assets.tensorflow_to_onnx_example import create_and_train_mnist\n",
6263
"def save_model_to_saved_model(sess, input_tensor, output_tensor):\n",
63-
" from tensorflow.saved_model import simple_save\n",
6464
" save_path = r\"./output/saved_model\"\n",
6565
" if os.path.exists(save_path):\n",
6666
" shutil.rmtree(save_path)\n",
67-
" simple_save(sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor})\n",
67+
" tf.saved_model.simple_save(sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor})\n",
6868
"\n",
6969
"print(\"please wait for a while, because the script will train MNIST from scratch\")\n",
7070
"tf.reset_default_graph()\n",

tutorials/TensorflowToOnnx-2.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@
4848
}
4949
],
5050
"source": [
51-
"import tensorflow as tf\n",
51+
"import tensorflow.compat.v1 as tf\n",
52+
"tf.disable_v2_behavior()\n",
5253
"from assets.tensorflow_to_onnx_example import create_and_train_mnist\n",
5354
"\n",
5455
"def save_model_to_frozen_proto(sess):\n",
@@ -166,8 +167,8 @@
166167
],
167168
"source": [
168169
"from tf2onnx.tfonnx import process_tf_graph, tf_optimize\n",
169-
"import tensorflow as tf\n",
170-
"from tensorflow.graph_util import convert_variables_to_constants as freeze_graph\n",
170+
"import tensorflow.compat.v1 as tf\n",
171+
"from tensorflow.compat.v1.graph_util import convert_variables_to_constants as freeze_graph\n",
171172
"\n",
172173
"print(\"generating mnist.onnx in python script\")\n",
173174
"graph_def = freeze_graph(sess_tf, sess_tf.graph_def, [output_tensor.name[:-2]])\n",

tutorials/assets/tensorflow_to_onnx_example.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,46 @@
2727

2828
import os
2929
import shutil
30-
import tensorflow as tf
31-
from tensorflow.examples.tutorials.mnist import input_data
30+
31+
import numpy as np
32+
import tensorflow.compat.v1 as tf
33+
34+
tf.disable_v2_behavior()
35+
36+
37+
class _MnistSplit:
38+
def __init__(self, images, labels):
39+
self.images = images.reshape([-1, 784]).astype(np.float32) / 255.0
40+
self.labels = labels.astype(np.int64)
41+
self._index = 0
42+
43+
def next_batch(self, batch_size):
44+
start = self._index
45+
self._index += batch_size
46+
if self._index > len(self.images):
47+
indices = np.arange(len(self.images))
48+
np.random.shuffle(indices)
49+
self.images = self.images[indices]
50+
self.labels = self.labels[indices]
51+
start = 0
52+
self._index = batch_size
53+
end = self._index
54+
return self.images[start:end], self.labels[start:end]
55+
56+
57+
class _MnistData:
58+
def __init__(self, train, test):
59+
self.train = train
60+
self.test = test
61+
62+
63+
def read_mnist_data_sets(data_dir):
64+
dataset_path = os.path.join(data_dir, "mnist.npz")
65+
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data(
66+
path=dataset_path)
67+
return _MnistData(
68+
_MnistSplit(train_images, train_labels),
69+
_MnistSplit(test_images, test_labels))
3270

3371

3472
def add(x, y):
@@ -121,7 +159,7 @@ def create_and_train_mnist():
121159
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
122160
# Import data
123161
data_dir = r"/tmp/tensorflow/mnist/input_data"
124-
mnist = input_data.read_data_sets(data_dir)
162+
mnist = read_mnist_data_sets(data_dir)
125163
# Create the model
126164
tf.reset_default_graph()
127165
input_tensor = tf.placeholder(tf.float32, [None, 784], name="input")
@@ -178,8 +216,8 @@ def save_model_to_frozen_proto(sess):
178216

179217
def save_model_to_saved_model(sess, input_tensor, output_tensor):
180218
print('save model to saved_model')
181-
from tensorflow.saved_model import simple_save
182219
save_path = r"./output/saved_model"
183220
if os.path.exists(save_path):
184221
shutil.rmtree(save_path)
185-
simple_save(sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor})
222+
tf.saved_model.simple_save(
223+
sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor})

tutorials/assets/tf-train-mnist.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,53 @@
2929
from __future__ import print_function
3030

3131
import argparse
32+
import os
3233
import sys
3334
import tempfile
3435

35-
from tensorflow.examples.tutorials.mnist import input_data
36+
import numpy as np
37+
import tensorflow.compat.v1 as tf
3638

37-
import tensorflow as tf
39+
tf.disable_v2_behavior()
3840

3941
FLAGS = None
4042

4143

44+
class _MnistSplit:
45+
def __init__(self, images, labels):
46+
self.images = images.reshape([-1, 784]).astype(np.float32) / 255.0
47+
self.labels = labels.astype(np.int64)
48+
self._index = 0
49+
50+
def next_batch(self, batch_size):
51+
start = self._index
52+
self._index += batch_size
53+
if self._index > len(self.images):
54+
indices = np.arange(len(self.images))
55+
np.random.shuffle(indices)
56+
self.images = self.images[indices]
57+
self.labels = self.labels[indices]
58+
start = 0
59+
self._index = batch_size
60+
end = self._index
61+
return self.images[start:end], self.labels[start:end]
62+
63+
64+
class _MnistData:
65+
def __init__(self, train, test):
66+
self.train = train
67+
self.test = test
68+
69+
70+
def read_mnist_data_sets(data_dir):
71+
dataset_path = os.path.join(data_dir, "mnist.npz")
72+
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data(
73+
path=dataset_path)
74+
return _MnistData(
75+
_MnistSplit(train_images, train_labels),
76+
_MnistSplit(test_images, test_labels))
77+
78+
4279
def add(x, y):
4380
return tf.nn.bias_add(x, y, data_format="NCHW")
4481

@@ -137,7 +174,7 @@ def bias_variable(shape):
137174

138175
def main(_):
139176
# Import data
140-
mnist = input_data.read_data_sets(FLAGS.data_dir)
177+
mnist = read_mnist_data_sets(FLAGS.data_dir)
141178

142179
# Create the model
143180
x = tf.placeholder(tf.float32, [None, 784])

0 commit comments

Comments
 (0)