|
27 | 27 |
|
28 | 28 | import os |
29 | 29 | import shutil |
30 | | -import tensorflow as tf |
31 | | -from tensorflow.examples.tutorials.mnist import input_data |
| 30 | + |
| 31 | +import numpy as np |
| 32 | +import tensorflow.compat.v1 as tf |
| 33 | + |
| 34 | +tf.disable_v2_behavior() |
| 35 | + |
| 36 | + |
| 37 | +class _MnistSplit: |
| 38 | + def __init__(self, images, labels): |
| 39 | + self.images = images.reshape([-1, 784]).astype(np.float32) / 255.0 |
| 40 | + self.labels = labels.astype(np.int64) |
| 41 | + self._index = 0 |
| 42 | + |
| 43 | + def next_batch(self, batch_size): |
| 44 | + start = self._index |
| 45 | + self._index += batch_size |
| 46 | + if self._index > len(self.images): |
| 47 | + indices = np.arange(len(self.images)) |
| 48 | + np.random.shuffle(indices) |
| 49 | + self.images = self.images[indices] |
| 50 | + self.labels = self.labels[indices] |
| 51 | + start = 0 |
| 52 | + self._index = batch_size |
| 53 | + end = self._index |
| 54 | + return self.images[start:end], self.labels[start:end] |
| 55 | + |
| 56 | + |
| 57 | +class _MnistData: |
| 58 | + def __init__(self, train, test): |
| 59 | + self.train = train |
| 60 | + self.test = test |
| 61 | + |
| 62 | + |
| 63 | +def read_mnist_data_sets(data_dir): |
| 64 | + dataset_path = os.path.join(data_dir, "mnist.npz") |
| 65 | + (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data( |
| 66 | + path=dataset_path) |
| 67 | + return _MnistData( |
| 68 | + _MnistSplit(train_images, train_labels), |
| 69 | + _MnistSplit(test_images, test_labels)) |
32 | 70 |
|
33 | 71 |
|
34 | 72 | def add(x, y): |
@@ -121,7 +159,7 @@ def create_and_train_mnist(): |
121 | 159 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
122 | 160 | # Import data |
123 | 161 | data_dir = r"/tmp/tensorflow/mnist/input_data" |
124 | | - mnist = input_data.read_data_sets(data_dir) |
| 162 | + mnist = read_mnist_data_sets(data_dir) |
125 | 163 | # Create the model |
126 | 164 | tf.reset_default_graph() |
127 | 165 | input_tensor = tf.placeholder(tf.float32, [None, 784], name="input") |
@@ -178,8 +216,8 @@ def save_model_to_frozen_proto(sess): |
178 | 216 |
|
179 | 217 | def save_model_to_saved_model(sess, input_tensor, output_tensor): |
180 | 218 | print('save model to saved_model') |
181 | | - from tensorflow.saved_model import simple_save |
182 | 219 | save_path = r"./output/saved_model" |
183 | 220 | if os.path.exists(save_path): |
184 | 221 | shutil.rmtree(save_path) |
185 | | - simple_save(sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor}) |
| 222 | + tf.saved_model.simple_save( |
| 223 | + sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor}) |
0 commit comments