Skip to content

Commit 8031222

Browse files
edwardzhou130recml authors
authored andcommitted
Save variables with keys in the checkpoint
PiperOrigin-RevId: 756065918
1 parent cbf9a8f commit 8031222

4 files changed

Lines changed: 430 additions & 35 deletions

File tree

recml/core/training/keras_trainer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(
118118
max_checkpoints_to_keep: int = 5,
119119
checkpoint_save_interval_epochs: int = 1,
120120
rng_seed: int = core.DEFAULT_RNG_SEED,
121+
legacy_format: bool = True,
121122
):
122123
"""Initializes the instance."""
123124

@@ -148,12 +149,14 @@ def __init__(
148149
model_dir, core.TRAINING_COMPLETE_MARKER_FILE
149150
)
150151
self._checkpoint_dir = os.path.join(model_dir, core.CHECKPOINT_DIR)
152+
self._legacy_format = legacy_format
151153

152154
if keras.backend.backend() == "jax":
153155
self._checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager(
154156
checkpoint_dir=self._checkpoint_dir,
155157
max_to_keep=max_checkpoints_to_keep,
156158
save_interval_epochs=checkpoint_save_interval_epochs,
159+
legacy_format=self._legacy_format,
157160
)
158161
self._train_callbacks = [
159162
keras_utils.EpochSummaryCallback(
@@ -314,13 +317,18 @@ def __init__(
314317
self,
315318
checkpoint_dir: str,
316319
epoch: int,
320+
legacy_format: bool,
317321
):
318322
self._checkpoint_dir = checkpoint_dir
319323
self._epoch = epoch
324+
self._legacy_format = legacy_format
320325

321326
def on_test_begin(self, logs: Mapping[str, Any] | None = None):
322327
keras_utils.restore_keras_model(
323-
model, self._checkpoint_dir, step=self._epoch
328+
model,
329+
self._checkpoint_dir,
330+
step=self._epoch,
331+
legacy_format=self._legacy_format,
324332
)
325333

326334
history = None
@@ -329,7 +337,9 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
329337
timeout=self._continuous_eval_timeout,
330338
timeout_fn=timeout_fn,
331339
):
332-
restore_callback = _RestoreCallback(self._checkpoint_dir, epoch)
340+
restore_callback = _RestoreCallback(
341+
self._checkpoint_dir, epoch, self._legacy_format
342+
)
333343
[tb_cbk] = [
334344
cbk
335345
for cbk in self._eval_callbacks

recml/core/training/keras_trainer_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,22 @@ def test_keras_task_and_trainer(self, mode: str):
9494
):
9595
self.assertEqual(history.history["num_params/trainable"][0], 2)
9696

97+
def test_new_checkpoint_format(self):
98+
if keras.backend.backend() != "jax":
99+
self.skipTest("Only supported on the Jax backend.")
100+
trainer = keras_trainer.KerasTrainer(
101+
distribution=keras.distribution.DataParallel(),
102+
train_steps=5,
103+
steps_per_eval=3,
104+
steps_per_loop=2,
105+
model_dir=self.create_tempdir().full_path,
106+
continuous_eval_timeout=5,
107+
legacy_format=False,
108+
)
109+
experiment = core.Experiment(_KerasTask(), trainer)
110+
core.run_experiment(experiment, core.Experiment.Mode.TRAIN)
111+
core.run_experiment(experiment, core.Experiment.Mode.CONTINUOUS_EVAL)
112+
97113

98114
if __name__ == "__main__":
99115
absltest.main()

0 commit comments

Comments
 (0)