Skip to content

Commit 4f4b50b

Browse files
committed
Refactor out shared LSTM/GGNN training loop.
github.com//issues/69
1 parent 9c4c442 commit 4f4b50b

4 files changed

Lines changed: 78 additions & 89 deletions

File tree

programl/task/dataflow/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ py_library(
5555
deps = [
5656
":graph_loader",
5757
"//programl/models:async_batch_builder",
58+
"//programl/models:base_batch_builder",
59+
"//programl/models:model",
5860
"//programl/models/ggnn",
5961
"//programl/proto:checkpoint_py",
6062
"//programl/proto:epoch_py",

programl/task/dataflow/dataflow.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import warnings
2222
from typing import Tuple
2323

24-
from labm8.py import app, pbutil
24+
from labm8.py import app, humanize, pbutil
2525
from sklearn.exceptions import UndefinedMetricWarning
2626

27+
from programl.models.base_batch_builder import BaseBatchBuilder
28+
from programl.models.model import Model
2729
from programl.proto import checkpoint_pb2, epoch_pb2
2830

2931
app.DEFINE_string(
@@ -208,3 +210,70 @@ def CreateLoggingDirectories(
208210
(log_dir / "checkpoints").mkdir()
209211
(log_dir / "graph_loader").mkdir()
210212
return log_dir
213+
214+
215+
def run_training_loop(
216+
log_dir: pathlib.Path,
217+
epochs,
218+
val_batches: BaseBatchBuilder,
219+
start_epoch_step: int,
220+
model: Model,
221+
val_graph_count: int,
222+
) -> pathlib.Path:
223+
"""
224+
225+
Args:
226+
log_dir: The logging directory.
227+
epochs: An epoch batch builder.
228+
val_batches: A batch builder for validation.
229+
start_epoch_step: The initial step count.
230+
model: The model to train.
231+
val_graph_count: The number of validation graphs.
232+
233+
Returns:
234+
The log_dir first argument.
235+
"""
236+
for (
237+
epoch_step,
238+
(train_graph_count, train_graph_cumsum, train_batches),
239+
) in enumerate(epochs, start=start_epoch_step):
240+
start_time = time.time()
241+
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"
242+
243+
train_results = model.RunBatches(
244+
epoch_pb2.TRAIN,
245+
train_batches,
246+
log_prefix=f"Train to {hr_graph_cumsum}",
247+
total_graph_count=train_graph_count,
248+
)
249+
val_results = model.RunBatches(
250+
epoch_pb2.VAL,
251+
val_batches.batches,
252+
log_prefix=f"Val at {hr_graph_cumsum}",
253+
total_graph_count=val_graph_count,
254+
)
255+
256+
# Write the epoch to file as an epoch list. This may seem redundant since
257+
# epoch list contains a single item, but it means that we can easily
258+
# concatenate a sequence of these epoch protos to produce a valid epoch
259+
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
260+
epoch = epoch_pb2.EpochList(
261+
epoch=[
262+
epoch_pb2.Epoch(
263+
walltime_seconds=time.time() - start_time,
264+
epoch_num=epoch_step,
265+
train_results=train_results,
266+
val_results=val_results,
267+
)
268+
]
269+
)
270+
print(epoch, end="")
271+
272+
epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
273+
pbutil.ToFile(epoch, epoch_path)
274+
app.Log(1, "Wrote %s", epoch_path)
275+
276+
checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
277+
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)
278+
279+
return log_dir

programl/task/dataflow/ggnn.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -173,50 +173,9 @@ def TrainDataflowGGNN(
173173
)
174174
)
175175

176-
for (
177-
epoch_step,
178-
(train_graph_count, train_graph_cumsum, train_batches),
179-
) in enumerate(epochs, start=start_epoch_step):
180-
start_time = time.time()
181-
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"
182-
183-
train_results = model.RunBatches(
184-
epoch_pb2.TRAIN,
185-
train_batches,
186-
log_prefix=f"Train to {hr_graph_cumsum}",
187-
total_graph_count=train_graph_count,
188-
)
189-
val_results = model.RunBatches(
190-
epoch_pb2.VAL,
191-
val_batches.batches,
192-
log_prefix=f"Val at {hr_graph_cumsum}",
193-
total_graph_count=val_graph_count,
194-
)
195-
196-
# Write the epoch to file as an epoch list. This may seem redundant since
197-
# epoch list contains a single item, but it means that we can easily
198-
# concatenate a sequence of these epoch protos to produce a valid epoch
199-
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
200-
epoch = epoch_pb2.EpochList(
201-
epoch=[
202-
epoch_pb2.Epoch(
203-
walltime_seconds=time.time() - start_time,
204-
epoch_num=epoch_step,
205-
train_results=train_results,
206-
val_results=val_results,
207-
)
208-
]
209-
)
210-
print(epoch, end="")
211-
212-
epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
213-
pbutil.ToFile(epoch, epoch_path)
214-
app.Log(1, "Wrote %s", epoch_path)
215-
216-
checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
217-
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)
218-
219-
return log_dir
176+
return dataflow.run_training_loop(
177+
log_dir, epochs, val_batches, start_epoch_step, model, val_graph_count
178+
)
220179

221180

222181
def TestDataflowGGNN(

programl/task/dataflow/train_lstm.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -160,50 +160,9 @@ def TrainDataflowLSTM(
160160
)
161161
)
162162

163-
for (
164-
epoch_step,
165-
(train_graph_count, train_graph_cumsum, train_batches),
166-
) in enumerate(epochs, start=start_epoch_step):
167-
start_time = time.time()
168-
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"
169-
170-
train_results = model.RunBatches(
171-
epoch_pb2.TRAIN,
172-
train_batches,
173-
log_prefix=f"Train to {hr_graph_cumsum}",
174-
total_graph_count=train_graph_count,
175-
)
176-
val_results = model.RunBatches(
177-
epoch_pb2.VAL,
178-
val_batches.batches,
179-
log_prefix=f"Val at {hr_graph_cumsum}",
180-
total_graph_count=FLAGS.val_graph_count,
181-
)
182-
183-
# Write the epoch to file as an epoch list. This may seem redundant since
184-
# epoch list contains a single item, but it means that we can easily
185-
# concatenate a sequence of these epoch protos to produce a valid epoch
186-
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
187-
epoch = epoch_pb2.EpochList(
188-
epoch=[
189-
epoch_pb2.Epoch(
190-
walltime_seconds=time.time() - start_time,
191-
epoch_num=epoch_step,
192-
train_results=train_results,
193-
val_results=val_results,
194-
)
195-
]
196-
)
197-
print(epoch, end="")
198-
199-
epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
200-
pbutil.ToFile(epoch, epoch_path)
201-
app.Log(1, "Wrote %s", epoch_path)
202-
203-
checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
204-
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)
205-
206-
return log_dir
163+
return dataflow.run_training_loop(
164+
log_dir, epochs, val_batches, start_epoch_step, model, FLAGS.val_graph_count
165+
)
207166

208167

209168
def TestDataflowLSTM(

0 commit comments

Comments
 (0)