@@ -118,6 +118,7 @@ def __init__(
118118 max_checkpoints_to_keep : int = 5 ,
119119 checkpoint_save_interval_epochs : int = 1 ,
120120 rng_seed : int = core .DEFAULT_RNG_SEED ,
121+ legacy_format : bool = True ,
121122 ):
122123 """Initializes the instance."""
123124
@@ -148,12 +149,14 @@ def __init__(
148149 model_dir , core .TRAINING_COMPLETE_MARKER_FILE
149150 )
150151 self ._checkpoint_dir = os .path .join (model_dir , core .CHECKPOINT_DIR )
152+ self ._legacy_format = legacy_format
151153
152154 if keras .backend .backend () == "jax" :
153155 self ._checkpoint_manager = keras_utils .KerasOrbaxCheckpointManager (
154156 checkpoint_dir = self ._checkpoint_dir ,
155157 max_to_keep = max_checkpoints_to_keep ,
156158 save_interval_epochs = checkpoint_save_interval_epochs ,
159+ legacy_format = self ._legacy_format ,
157160 )
158161 self ._train_callbacks = [
159162 keras_utils .EpochSummaryCallback (
@@ -314,13 +317,18 @@ def __init__(
314317 self ,
315318 checkpoint_dir : str ,
316319 epoch : int ,
320+ legacy_format : bool ,
317321 ):
318322 self ._checkpoint_dir = checkpoint_dir
319323 self ._epoch = epoch
324+ self ._legacy_format = legacy_format
320325
321326 def on_test_begin (self , logs : Mapping [str , Any ] | None = None ):
322327 keras_utils .restore_keras_model (
323- model , self ._checkpoint_dir , step = self ._epoch
328+ model ,
329+ self ._checkpoint_dir ,
330+ step = self ._epoch ,
331+ legacy_format = self ._legacy_format ,
324332 )
325333
326334 history = None
@@ -329,7 +337,9 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
329337 timeout = self ._continuous_eval_timeout ,
330338 timeout_fn = timeout_fn ,
331339 ):
332- restore_callback = _RestoreCallback (self ._checkpoint_dir , epoch )
340+ restore_callback = _RestoreCallback (
341+ self ._checkpoint_dir , epoch , self ._legacy_format
342+ )
333343 [tb_cbk ] = [
334344 cbk
335345 for cbk in self ._eval_callbacks
0 commit comments