Skip to content

Commit 7d8e554

Browse files
committed
End on correct epoch
1 parent 0a16842 commit 7d8e554

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

tensorflow_benchmark/tf_word_language_model/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def file_stream():
135135

136136
def iterate_forever(self, batch_size, num_steps):
137137
def file_stream():
138-
epoch_num = 0
138+
epoch_num = 1
139139
while True:
140140
file_patterns = glob.glob(self._file_pattern)
141141
if not self._deterministic:

tensorflow_benchmark/tf_word_language_model/run_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def run_train(dataset, hps, logdir, ps_device, eval_dataset, task=0, master=""):
5555
cur_global_step = 0
5656
prev_time = time.time()
5757
data_iterator = dataset.iterate_forever(hps.batch_size * hps.num_gpus, hps.num_steps)
58-
while not sv.should_stop() and cur_epoch < hps.epochs:
58+
while not sv.should_stop():
5959
fetches = [model.global_step, model.loss, model.train_op]
6060
# Chief worker computes summaries every 100 steps.
6161
should_compute_summary = (task == 0 and local_step % 100 == 0)
@@ -81,8 +81,12 @@ def run_train(dataset, hps, logdir, ps_device, eval_dataset, task=0, master=""):
8181
log_perplexity = loss_nom / loss_den
8282
print("Results after epoch %d: log_perplexity = %.3f perplexity = %.3f" % (
8383
cur_epoch, log_perplexity, np.exp(log_perplexity)))
84+
8485
x, y = next(data_iterator)
8586

87+
if cur_epoch >= hps.epochs:
88+
break
89+
8690
should_run_profiler = (hps.run_profiler and task == 0 and local_step % 1000 == 13)
8791
if should_run_profiler:
8892
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)

0 commit comments

Comments
 (0)