Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tutorials/TensorflowToOnnx-1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@
"source": [
"import os\n",
"import shutil\n",
"import tensorflow as tf\n",
"import tensorflow.compat.v1 as tf\n",
"tf.disable_v2_behavior()\n",
"from assets.tensorflow_to_onnx_example import create_and_train_mnist\n",
"def save_model_to_saved_model(sess, input_tensor, output_tensor):\n",
" from tensorflow.saved_model import simple_save\n",
" save_path = r\"./output/saved_model\"\n",
" if os.path.exists(save_path):\n",
" shutil.rmtree(save_path)\n",
" simple_save(sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor})\n",
" tf.saved_model.simple_save(sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor})\n",
"\n",
"print(\"please wait for a while, because the script will train MNIST from scratch\")\n",
"tf.reset_default_graph()\n",
Expand Down
7 changes: 4 additions & 3 deletions tutorials/TensorflowToOnnx-2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
}
],
"source": [
"import tensorflow as tf\n",
"import tensorflow.compat.v1 as tf\n",
"tf.disable_v2_behavior()\n",
"from assets.tensorflow_to_onnx_example import create_and_train_mnist\n",
"\n",
"def save_model_to_frozen_proto(sess):\n",
Expand Down Expand Up @@ -166,8 +167,8 @@
],
"source": [
"from tf2onnx.tfonnx import process_tf_graph, tf_optimize\n",
"import tensorflow as tf\n",
"from tensorflow.graph_util import convert_variables_to_constants as freeze_graph\n",
"import tensorflow.compat.v1 as tf\n",
"from tensorflow.compat.v1.graph_util import convert_variables_to_constants as freeze_graph\n",
"\n",
"print(\"generating mnist.onnx in python script\")\n",
"graph_def = freeze_graph(sess_tf, sess_tf.graph_def, [output_tensor.name[:-2]])\n",
Expand Down
48 changes: 43 additions & 5 deletions tutorials/assets/tensorflow_to_onnx_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,46 @@

import os
import shutil
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

import numpy as np
import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()


class _MnistSplit:
def __init__(self, images, labels):
self.images = images.reshape([-1, 784]).astype(np.float32) / 255.0
self.labels = labels.astype(np.int64)
self._index = 0

def next_batch(self, batch_size):
start = self._index
self._index += batch_size
if self._index > len(self.images):
indices = np.arange(len(self.images))
np.random.shuffle(indices)
self.images = self.images[indices]
self.labels = self.labels[indices]
start = 0
self._index = batch_size
end = self._index
return self.images[start:end], self.labels[start:end]


class _MnistData:
def __init__(self, train, test):
self.train = train
self.test = test


def read_mnist_data_sets(data_dir):
dataset_path = os.path.join(data_dir, "mnist.npz")
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data(
path=dataset_path)
return _MnistData(
_MnistSplit(train_images, train_labels),
_MnistSplit(test_images, test_labels))


def add(x, y):
Expand Down Expand Up @@ -121,7 +159,7 @@ def create_and_train_mnist():
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# Import data
data_dir = r"/tmp/tensorflow/mnist/input_data"
mnist = input_data.read_data_sets(data_dir)
mnist = read_mnist_data_sets(data_dir)
# Create the model
tf.reset_default_graph()
input_tensor = tf.placeholder(tf.float32, [None, 784], name="input")
Expand Down Expand Up @@ -178,8 +216,8 @@ def save_model_to_frozen_proto(sess):

def save_model_to_saved_model(sess, input_tensor, output_tensor):
print('save model to saved_model')
from tensorflow.saved_model import simple_save
save_path = r"./output/saved_model"
if os.path.exists(save_path):
shutil.rmtree(save_path)
simple_save(sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor})
tf.saved_model.simple_save(
sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor})
43 changes: 40 additions & 3 deletions tutorials/assets/tf-train-mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,53 @@
from __future__ import print_function

import argparse
import os
import sys
import tempfile

from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import tensorflow.compat.v1 as tf

import tensorflow as tf
tf.disable_v2_behavior()

FLAGS = None


class _MnistSplit:
def __init__(self, images, labels):
self.images = images.reshape([-1, 784]).astype(np.float32) / 255.0
self.labels = labels.astype(np.int64)
self._index = 0

def next_batch(self, batch_size):
start = self._index
self._index += batch_size
if self._index > len(self.images):
indices = np.arange(len(self.images))
np.random.shuffle(indices)
self.images = self.images[indices]
self.labels = self.labels[indices]
start = 0
self._index = batch_size
end = self._index
return self.images[start:end], self.labels[start:end]


class _MnistData:
def __init__(self, train, test):
self.train = train
self.test = test


def read_mnist_data_sets(data_dir):
dataset_path = os.path.join(data_dir, "mnist.npz")
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data(
path=dataset_path)
return _MnistData(
_MnistSplit(train_images, train_labels),
_MnistSplit(test_images, test_labels))


def add(x, y):
return tf.nn.bias_add(x, y, data_format="NCHW")

Expand Down Expand Up @@ -137,7 +174,7 @@ def bias_variable(shape):

def main(_):
# Import data
mnist = input_data.read_data_sets(FLAGS.data_dir)
mnist = read_mnist_data_sets(FLAGS.data_dir)

# Create the model
x = tf.placeholder(tf.float32, [None, 784])
Expand Down