From 03b2b88436b5b27c9aac4f519dc7f85fa1fdbdb0 Mon Sep 17 00:00:00 2001 From: ShanuDey Date: Mon, 25 Nov 2019 15:56:36 +0530 Subject: [PATCH 1/2] TensorFlow 2.0 above support is added. From TensorFlow 2.0 contrib is deprecated. So the code need necessary changes to run with TF 2.0. Learn More: https://www.tensorflow.org/guide/upgrade Signed-off-by: ShanuDey --- darkflow/net/build.py | 42 ++++++++++++++++----------------- darkflow/net/help.py | 8 +++---- darkflow/net/ops/baseop.py | 8 +++---- darkflow/net/ops/convolution.py | 14 +++++------ darkflow/net/ops/simple.py | 14 +++++------ darkflow/net/vanilla/train.py | 4 ++-- darkflow/net/yolo/train.py | 26 ++++++++++---------- darkflow/net/yolov2/train.py | 26 ++++++++++---------- darkflow/utils/loader.py | 6 ++--- 9 files changed, 74 insertions(+), 74 deletions(-) diff --git a/darkflow/net/build.py b/darkflow/net/build.py index 1359f9f12..8fd8847bb 100644 --- a/darkflow/net/build.py +++ b/darkflow/net/build.py @@ -12,14 +12,14 @@ class TFNet(object): _TRAINER = dict({ - 'rmsprop': tf.train.RMSPropOptimizer, - 'adadelta': tf.train.AdadeltaOptimizer, - 'adagrad': tf.train.AdagradOptimizer, - 'adagradDA': tf.train.AdagradDAOptimizer, - 'momentum': tf.train.MomentumOptimizer, - 'adam': tf.train.AdamOptimizer, - 'ftrl': tf.train.FtrlOptimizer, - 'sgd': tf.train.GradientDescentOptimizer + 'rmsprop': tf.compat.v1.train.RMSPropOptimizer, + 'adadelta': tf.compat.v1.train.AdadeltaOptimizer, + 'adagrad': tf.compat.v1.train.AdagradOptimizer, + 'adagradDA': tf.compat.v1.train.AdagradDAOptimizer, + 'momentum': tf.compat.v1.train.MomentumOptimizer, + 'adam': tf.compat.v1.train.AdamOptimizer, + 'ftrl': tf.compat.v1.train.FtrlOptimizer, + 'sgd': tf.compat.v1.train.GradientDescentOptimizer }) # imported methods @@ -78,8 +78,8 @@ def __init__(self, FLAGS, darknet = None): time.time() - start)) def build_from_pb(self): - with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f: - graph_def = tf.GraphDef() + with tf.compat.v1.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f: + graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def( @@ -91,9 +91,9 @@ def build_from_pb(self): self.framework = create_framework(self.meta, self.FLAGS) # Placeholders - self.inp = tf.get_default_graph().get_tensor_by_name('input:0') + self.inp = tf.compat.v1.get_default_graph().get_tensor_by_name('input:0') self.feed = dict() # other placeholders - self.out = tf.get_default_graph().get_tensor_by_name('output:0') + self.out = tf.compat.v1.get_default_graph().get_tensor_by_name('output:0') self.setup_meta_ops() @@ -102,7 +102,7 @@ def build_forward(self): # Placeholders inp_size = [None] + self.meta['inp_size'] - self.inp = tf.placeholder(tf.float32, inp_size, 'input') + self.inp = tf.compat.v1.placeholder(tf.float32, inp_size, 'input') self.feed = dict() # other placeholders # Build the forward pass @@ -129,7 +129,7 @@ def setup_meta_ops(self): utility = min(self.FLAGS.gpu, 1.) if utility > 0.0: self.say('GPU mode with {} usage'.format(utility)) - cfg['gpu_options'] = tf.GPUOptions( + cfg['gpu_options'] = tf.compat.v1.GPUOptions( per_process_gpu_memory_fraction = utility) cfg['allow_soft_placement'] = True else: @@ -139,14 +139,14 @@ def setup_meta_ops(self): if self.FLAGS.train: self.build_train_op() if self.FLAGS.summary: - self.summary_op = tf.summary.merge_all() - self.writer = tf.summary.FileWriter(self.FLAGS.summary + 'train') + self.summary_op = tf.compat.v1.summary.merge_all() + self.writer = tf.compat.v1.summary.FileWriter(self.FLAGS.summary + 'train') - self.sess = tf.Session(config = tf.ConfigProto(**cfg)) - self.sess.run(tf.global_variables_initializer()) + self.sess = tf.compat.v1.Session(config = tf.compat.v1.ConfigProto(**cfg)) + self.sess.run(tf.compat.v1.global_variables_initializer()) if not self.ntrain: return - self.saver = tf.train.Saver(tf.global_variables(), + self.saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep = self.FLAGS.keep) if self.FLAGS.load != 0: self.load_from_ckpt() @@ -165,7 +165,7 @@ def savepb(self): flags_pb.train = False # rebuild another tfnet. all const. tfnet_pb = TFNet(flags_pb, darknet_pb) - tfnet_pb.sess = tf.Session(graph = tfnet_pb.graph) + tfnet_pb.sess = tf.compat.v1.Session(graph = tfnet_pb.graph) # tfnet_pb.predict() # uncomment for unit testing name = 'built_graph/{}.pb'.format(self.meta['name']) os.makedirs(os.path.dirname(name), exist_ok=True) @@ -174,4 +174,4 @@ def savepb(self): json.dump(self.meta, fp) self.say('Saving const graph def to {}'.format(name)) graph_def = tfnet_pb.sess.graph_def - tf.train.write_graph(graph_def,'./', name, False) \ No newline at end of file + tf.io.write_graph(graph_def,'./', name, False) \ No newline at end of file diff --git a/darkflow/net/help.py b/darkflow/net/help.py index 616e85bf6..13a5fa727 100644 --- a/darkflow/net/help.py +++ b/darkflow/net/help.py @@ -45,15 +45,15 @@ def load_old_graph(self, ckpt): ckpt_loader = create_loader(ckpt) self.say(old_graph_msg.format(ckpt)) - for var in tf.global_variables(): + for var in tf.compat.v1.global_variables(): name = var.name.split(':')[0] args = [name, var.get_shape()] val = ckpt_loader(args) assert val is not None, \ 'Cannot find and load {}'.format(var.name) shp = val.shape - plh = tf.placeholder(tf.float32, shp) - op = tf.assign(var, plh) + plh = tf.compat.v1.placeholder(tf.float32, shp) + op = tf.compat.v1.assign(var, plh) self.sess.run(op, {plh: val}) def _get_fps(self, frame): @@ -156,7 +156,7 @@ def to_darknet(self): darknet_ckpt = self.darknet with self.graph.as_default() as g: - for var in tf.global_variables(): + for var in tf.compat.v1.global_variables(): name = var.name.split(':')[0] var_name = name.split('-') l_idx = int(var_name[0]) diff --git a/darkflow/net/ops/baseop.py b/darkflow/net/ops/baseop.py index 8992aca1a..ba75f2729 100644 --- a/darkflow/net/ops/baseop.py +++ b/darkflow/net/ops/baseop.py @@ -65,10 +65,10 @@ def wrap_variable(self, var): if not self.var: return val = self.lay.w[var] - self.lay.w[var] = tf.constant_initializer(val) + self.lay.w[var] = tf.compat.v1.constant_initializer(val) if var in self._SLIM: return - with tf.variable_scope(self.scope): - self.lay.w[var] = tf.get_variable(var, + with tf.compat.v1.variable_scope(self.scope): + self.lay.w[var] = tf.compat.v1.get_variable(var, shape = self.lay.wshape[var], dtype = tf.float32, initializer = self.lay.w[var]) @@ -81,7 +81,7 @@ def wrap_pholder(self, ph, feed): sig = '{}/{}'.format(self.scope, ph) val = self.lay.h[ph] - self.lay.h[ph] = tf.placeholder_with_default( + self.lay.h[ph] = tf.compat.v1.placeholder_with_default( val['dfault'], val['shape'], name = sig) feed[self.lay.h[ph]] = val['feed'] diff --git a/darkflow/net/ops/convolution.py b/darkflow/net/ops/convolution.py index 167b0fd78..930ad02cc 100644 --- a/darkflow/net/ops/convolution.py +++ b/darkflow/net/ops/convolution.py @@ -1,4 +1,4 @@ -import tensorflow.contrib.slim as slim +import tf_slim as slim from .baseop import BaseOp import tensorflow as tf import numpy as np @@ -24,7 +24,7 @@ def _forward(self): def forward(self): inp = self.inp.out s = self.lay.stride - self.out = tf.extract_image_patches( + self.out = tf.image.extract_patches( inp, [1,s,s,1], [1,s,s,1], [1,1,1,1], 'VALID') def speak(self): @@ -36,7 +36,7 @@ def speak(self): class local(BaseOp): def forward(self): pad = [[self.lay.pad, self.lay.pad]] * 2; - temp = tf.pad(self.inp.out, [[0, 0]] + pad + [[0, 0]]) + temp = tf.pad(tensor=self.inp.out, paddings=[[0, 0]] + pad + [[0, 0]]) k = self.lay.w['kernels'] ksz = self.lay.ksize @@ -49,7 +49,7 @@ def forward(self): i_, j_ = i + 1 - half, j + 1 - half tij = temp[:, i_ : i_ + ksz, j_ : j_ + ksz,:] row_i.append( - tf.nn.conv2d(tij, kij, + tf.nn.conv2d(input=tij, filters=kij, padding = 'VALID', strides = [1] * 4)) out += [tf.concat(row_i, 2)] @@ -66,8 +66,8 @@ def speak(self): class convolutional(BaseOp): def forward(self): pad = [[self.lay.pad, self.lay.pad]] * 2; - temp = tf.pad(self.inp.out, [[0, 0]] + pad + [[0, 0]]) - temp = tf.nn.conv2d(temp, self.lay.w['kernel'], padding = 'VALID', + temp = tf.pad(tensor=self.inp.out, paddings=[[0, 0]] + pad + [[0, 0]]) + temp = tf.nn.conv2d(input=temp, filters=self.lay.w['kernel'], padding = 'VALID', name = self.scope, strides = [1] + [self.lay.stride] * 2 + [1]) if self.lay.batch_norm: temp = self.batchnorm(self.lay, temp) @@ -113,4 +113,4 @@ def speak(self): args += [l.batch_norm * '+bnorm'] args += [l.activation] msg = 'extr {}x{}p{}_{} {} {}'.format(*args) - return msg \ No newline at end of file + return msg diff --git a/darkflow/net/ops/simple.py b/darkflow/net/ops/simple.py index 01e28ced6..4c3eccc65 100644 --- a/darkflow/net/ops/simple.py +++ b/darkflow/net/ops/simple.py @@ -1,4 +1,4 @@ -import tensorflow.contrib.slim as slim +import tf_slim as slim from .baseop import BaseOp import tensorflow as tf from distutils.version import StrictVersion @@ -22,7 +22,7 @@ def speak(self): class connected(BaseOp): def forward(self): - self.out = tf.nn.xw_plus_b( + self.out = tf.compat.v1.nn.xw_plus_b( self.inp.out, self.lay.w['weights'], self.lay.w['biases'], @@ -56,7 +56,7 @@ def speak(self): class flatten(BaseOp): def forward(self): temp = tf.transpose( - self.inp.out, [0,3,1,2]) + a=self.inp.out, perm=[0,3,1,2]) self.out = slim.flatten( temp, scope = self.scope) @@ -73,7 +73,7 @@ def speak(self): return 'softmax()' class avgpool(BaseOp): def forward(self): self.out = tf.reduce_mean( - self.inp.out, [1, 2], + input_tensor=self.inp.out, axis=[1, 2], name = self.scope ) @@ -86,7 +86,7 @@ def forward(self): self.lay.h['pdrop'] = 1.0 self.out = tf.nn.dropout( self.inp.out, - self.lay.h['pdrop'], + 1 - (self.lay.h['pdrop']), name = self.scope ) @@ -103,8 +103,8 @@ def speak(self): class maxpool(BaseOp): def forward(self): - self.out = tf.nn.max_pool( - self.inp.out, padding = 'SAME', + self.out = tf.nn.max_pool2d( + input=self.inp.out, padding = 'SAME', ksize = [1] + [self.lay.ksize]*2 + [1], strides = [1] + [self.lay.stride]*2 + [1], name = self.scope diff --git a/darkflow/net/vanilla/train.py b/darkflow/net/vanilla/train.py index 13de7a454..a951a4693 100644 --- a/darkflow/net/vanilla/train.py +++ b/darkflow/net/vanilla/train.py @@ -34,8 +34,8 @@ def loss(self, net_out): loss = l1_loss(diff) elif loss_type == 'softmax': - loss = tf.nn.softmax_cross_entropy_with_logits(logits, y) - loss = tf.reduce_mean(loss) + loss = tf.nn.softmax_cross_entropy_with_logits(labels=tf.stop_gradient(y)) + loss = tf.reduce_mean(input_tensor=loss) elif loss_type == 'svm': assert 'train_size' in m, \ diff --git a/darkflow/net/yolo/train.py b/darkflow/net/yolo/train.py index 78d8a2f11..0b2164b5e 100644 --- a/darkflow/net/yolo/train.py +++ b/darkflow/net/yolo/train.py @@ -1,4 +1,4 @@ -import tensorflow.contrib.slim as slim +import tf_slim as slim import pickle import tensorflow as tf from .misc import show @@ -30,15 +30,15 @@ def loss(self, net_out): size2 = [None, SS, B] # return the below placeholders - _probs = tf.placeholder(tf.float32, size1) - _confs = tf.placeholder(tf.float32, size2) - _coord = tf.placeholder(tf.float32, size2 + [4]) + _probs = tf.compat.v1.placeholder(tf.float32, size1) + _confs = tf.compat.v1.placeholder(tf.float32, size2) + _coord = tf.compat.v1.placeholder(tf.float32, size2 + [4]) # weights term for L2 loss - _proid = tf.placeholder(tf.float32, size1) + _proid = tf.compat.v1.placeholder(tf.float32, size1) # material calculating IOU - _areas = tf.placeholder(tf.float32, size2) - _upleft = tf.placeholder(tf.float32, size2 + [2]) - _botright = tf.placeholder(tf.float32, size2 + [2]) + _areas = tf.compat.v1.placeholder(tf.float32, size2) + _upleft = tf.compat.v1.placeholder(tf.float32, size2 + [2]) + _botright = tf.compat.v1.placeholder(tf.float32, size2 + [2]) self.placeholders = { 'probs':_probs, 'confs':_confs, 'coord':_coord, 'proid':_proid, @@ -63,8 +63,8 @@ def loss(self, net_out): # calculate the best IOU, set 0.0 confidence for worse boxes iou = tf.truediv(intersect, _areas + area_pred - intersect) - best_box = tf.equal(iou, tf.reduce_max(iou, [2], True)) - best_box = tf.to_float(best_box) + best_box = tf.equal(iou, tf.reduce_max(input_tensor=iou, axis=[2], keepdims=True)) + best_box = tf.cast(best_box, dtype=tf.float32) confs = tf.multiply(best_box, _confs) # take care of the weight terms @@ -87,6 +87,6 @@ def loss(self, net_out): print('Building {} loss'.format(m['model'])) loss = tf.pow(net_out - true, 2) loss = tf.multiply(loss, wght) - loss = tf.reduce_sum(loss, 1) - self.loss = .5 * tf.reduce_mean(loss) - tf.summary.scalar('{} loss'.format(m['model']), self.loss) + loss = tf.reduce_sum(input_tensor=loss, axis=1) + self.loss = .5 * tf.reduce_mean(input_tensor=loss) + tf.compat.v1.summary.scalar('{} loss'.format(m['model']), self.loss) diff --git a/darkflow/net/yolov2/train.py b/darkflow/net/yolov2/train.py index a0bf7154e..c26e36693 100644 --- a/darkflow/net/yolov2/train.py +++ b/darkflow/net/yolov2/train.py @@ -1,4 +1,4 @@ -import tensorflow.contrib.slim as slim +import tf_slim as slim import pickle import tensorflow as tf from ..yolo.misc import show @@ -37,15 +37,15 @@ def loss(self, net_out): size2 = [None, HW, B] # return the below placeholders - _probs = tf.placeholder(tf.float32, size1) - _confs = tf.placeholder(tf.float32, size2) - _coord = tf.placeholder(tf.float32, size2 + [4]) + _probs = tf.compat.v1.placeholder(tf.float32, size1) + _confs = tf.compat.v1.placeholder(tf.float32, size2) + _coord = tf.compat.v1.placeholder(tf.float32, size2 + [4]) # weights term for L2 loss - _proid = tf.placeholder(tf.float32, size1) + _proid = tf.compat.v1.placeholder(tf.float32, size1) # material calculating IOU - _areas = tf.placeholder(tf.float32, size2) - _upleft = tf.placeholder(tf.float32, size2 + [2]) - _botright = tf.placeholder(tf.float32, size2 + [2]) + _areas = tf.compat.v1.placeholder(tf.float32, size2) + _upleft = tf.compat.v1.placeholder(tf.float32, size2 + [2]) + _botright = tf.compat.v1.placeholder(tf.float32, size2 + [2]) self.placeholders = { 'probs':_probs, 'confs':_confs, 'coord':_coord, 'proid':_proid, @@ -83,8 +83,8 @@ def loss(self, net_out): # calculate the best IOU, set 0.0 confidence for worse boxes iou = tf.truediv(intersect, _areas + area_pred - intersect) - best_box = tf.equal(iou, tf.reduce_max(iou, [2], True)) - best_box = tf.to_float(best_box) + best_box = tf.equal(iou, tf.reduce_max(input_tensor=iou, axis=[2], keepdims=True)) + best_box = tf.cast(best_box, dtype=tf.float32) confs = tf.multiply(best_box, _confs) # take care of the weight terms @@ -102,6 +102,6 @@ def loss(self, net_out): loss = tf.pow(adjusted_net_out - true, 2) loss = tf.multiply(loss, wght) loss = tf.reshape(loss, [-1, H*W*B*(4 + 1 + C)]) - loss = tf.reduce_sum(loss, 1) - self.loss = .5 * tf.reduce_mean(loss) - tf.summary.scalar('{} loss'.format(m['model']), self.loss) \ No newline at end of file + loss = tf.reduce_sum(input_tensor=loss, axis=1) + self.loss = .5 * tf.reduce_mean(input_tensor=loss) + tf.compat.v1.summary.scalar('{} loss'.format(m['model']), self.loss) diff --git a/darkflow/utils/loader.py b/darkflow/utils/loader.py index 723560df5..08ae70df8 100644 --- a/darkflow/utils/loader.py +++ b/darkflow/utils/loader.py @@ -85,10 +85,10 @@ class checkpoint_loader(loader): def load(self, ckpt, ignore): meta = ckpt + '.meta' with tf.Graph().as_default() as graph: - with tf.Session().as_default() as sess: - saver = tf.train.import_meta_graph(meta) + with tf.compat.v1.Session().as_default() as sess: + saver = tf.compat.v1.train.import_meta_graph(meta) saver.restore(sess, ckpt) - for var in tf.global_variables(): + for var in tf.compat.v1.global_variables(): name = var.name.split(':')[0] packet = [name, var.get_shape().as_list()] self.src_key += [packet] From 8d9cfed844e78567e624ec0ad4986ffd4f6098a6 Mon Sep 17 00:00:00 2001 From: ShanuDey Date: Mon, 25 Nov 2019 16:58:11 +0530 Subject: [PATCH 2/2] tf-slim is require to use slim from tf2+ contib is deprecated so we need to use tf-slim module for using slim Signed-off-by: ShanuDey --- README.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 9232904df..fd666a792 100644 --- a/README.md +++ b/README.md @@ -17,19 +17,22 @@ Python3, tensorflow 1.0, numpy, opencv 3. ### Getting started -You can choose _one_ of the following three ways to get started with darkflow. +- Install tf_slim module since contrib is deprecated from tf2+ + ```pip install git+https://github.com/ShanuDey/tf-slim.git``` -1. Just build the Cython extensions in place. NOTE: If installing this way you will have to use `./flow` in the cloned darkflow directory instead of `flow` as darkflow is not installed globally. +- You can choose _one_ of the following three ways to get started with darkflow. + + 1. Just build the Cython extensions in place. NOTE: If installing this way you will have to use `./flow` in the cloned darkflow directory instead of `flow` as darkflow is not installed globally. ``` python3 setup.py build_ext --inplace ``` -2. Let pip install darkflow globally in dev mode (still globally accessible, but changes to the code immediately take effect) + 2. Let pip install darkflow globally in dev mode (still globally accessible, but changes to the code immediately take effect) ``` pip install -e . ``` -3. Install with pip globally + 3. Install with pip globally ``` pip install . ```