This repository was archived by the owner on Dec 8, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtf2oda_model_main_training.py
More file actions
132 lines (118 loc) · 5.53 KB
/
Copy pathtf2oda_model_main_training.py
File metadata and controls
132 lines (118 loc) · 5.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Creates and runs TF2 object detection models.
For local training/evaluation run:
PIPELINE_CONFIG_PATH=path/to/pipeline.config
MODEL_DIR=/tmp/model_outputs
NUM_TRAIN_STEPS=10000
SAMPLE_1_OF_N_EVAL_EXAMPLES=1
python model_main_tf2.py -- \
--model_dir=$MODEL_DIR --num_train_steps=$NUM_TRAIN_STEPS \
--sample_1_of_n_eval_examples=$SAMPLE_1_OF_N_EVAL_EXAMPLES \
--pipeline_config_path=$PIPELINE_CONFIG_PATH \
--alsologtostderr
"""
from absl import flags
import tensorflow.compat.v2 as tf
from object_detection import model_lib_v2
import time
import os
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
'file.')
flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.')
flags.DEFINE_bool('eval_on_train_data', False, 'Enable evaluating on train '
'data (only supported in distributed training).')
flags.DEFINE_integer('sample_1_of_n_eval_examples', None, 'Will sample one of '
'every n eval input examples, where n is provided.')
flags.DEFINE_integer('sample_1_of_n_eval_on_train_examples', 5, 'Will sample '
'one of every n train input examples for evaluation, '
'where n is provided. This is only used if '
'`eval_training_data` is True.')
flags.DEFINE_string(
'model_dir', None, 'Path to output model directory '
'where event and checkpoint files will be written.')
flags.DEFINE_string(
'checkpoint_dir', None, 'Path to directory holding a checkpoint. If '
'`checkpoint_dir` is provided, this binary operates in eval-only mode, '
'writing resulting metrics to `model_dir`.')
flags.DEFINE_integer('eval_timeout', 3600, 'Number of seconds to wait for an'
'evaluation checkpoint before exiting.')
flags.DEFINE_bool('use_tpu', False, 'Whether the job is executing on a TPU.')
flags.DEFINE_string(
'tpu_name',
default=None,
help='Name of the Cloud TPU for Cluster Resolvers.')
flags.DEFINE_integer(
'num_workers', 1, 'When num_workers > 1, training uses '
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses '
'MirroredStrategy.')
flags.DEFINE_integer(
'checkpoint_every_n', 1000, 'Integer defining how often we checkpoint.')
flags.DEFINE_boolean('record_summaries', True,
('Whether or not to record summaries during'
' training.'))
flags.DEFINE_string('time_measurement_path', default=None, help='Save the training duration in s in a file. '
'Provide the path. '
'If None, training duration will not be saved.')
FLAGS = flags.FLAGS
def main(unused_argv):
start_time = time.time()
flags.mark_flag_as_required('model_dir')
flags.mark_flag_as_required('pipeline_config_path')
tf.config.set_soft_device_placement(True)
if FLAGS.checkpoint_dir:
model_lib_v2.eval_continuously(
pipeline_config_path=FLAGS.pipeline_config_path,
model_dir=FLAGS.model_dir,
train_steps=FLAGS.num_train_steps,
sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,
sample_1_of_n_eval_on_train_examples=(
FLAGS.sample_1_of_n_eval_on_train_examples),
checkpoint_dir=FLAGS.checkpoint_dir,
wait_interval=300, timeout=FLAGS.eval_timeout)
else:
if FLAGS.use_tpu:
# TPU is automatically inferred if tpu_name is None and
# we are running under cloud ai-platform.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
elif FLAGS.num_workers > 1:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
else:
strategy = tf.compat.v2.distribute.MirroredStrategy()
with strategy.scope():
model_lib_v2.train_loop(
pipeline_config_path=FLAGS.pipeline_config_path,
model_dir=FLAGS.model_dir,
train_steps=FLAGS.num_train_steps,
use_tpu=FLAGS.use_tpu,
checkpoint_every_n=FLAGS.checkpoint_every_n,
record_summaries=FLAGS.record_summaries)
end_time = time.time()
elapsed_time = end_time - start_time # in seconds
print('Finished. Elapsed time: {:.0f}s'.format(elapsed_time))
if FLAGS.time_measurement_path:
#print("Save duration to file: ", FLAGS.time_measurement_path)
os.makedirs(os.path.dirname(FLAGS.time_measurement_path), exist_ok=True)
# Save index to a file
file1 = open(FLAGS.time_measurement_path, 'w')
file1.write(str(elapsed_time))
print("Saved elapsed time {}s to {}".format(elapsed_time, FLAGS.time_measurement_path))
if __name__ == '__main__':
tf.compat.v1.app.run()