Skip to content

Commit c8f7366

Browse files
committed
use direct checkpoint path in load_from_checkpoint
1 parent cacd3c2 commit c8f7366

1 file changed

Lines changed: 33 additions & 7 deletions

File tree

  • rectools/models/nn/transformers

rectools/models/nn/transformers/base.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,25 @@ def _get_config(self) -> TransformerModelConfig_T:
584584
return self.config_class(**params)
585585

586586
@classmethod
587-
def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
588-
"""Create model from loaded Lightning checkpoint."""
587+
def _model_from_checkpoint(
588+
cls, checkpoint: tp.Dict[str, tp.Any], ckpt_path: tp.Optional[tp.Union[str, Path]] = None
589+
) -> tpe.Self:
590+
"""
591+
Create model from loaded Lightning checkpoint.
592+
593+
Parameters
594+
----------
595+
checkpoint: Dict[str, tp.Any]
596+
Checkpoint object (pl/torch like)
597+
ckpt_path: Union[str, Path], optional
598+
Path to checkpoint location.
599+
If specified should be a path to `checkpoint` arg file.
600+
`checkpoint` is saved to temp file if not specified.
601+
602+
Returns
603+
-------
604+
Model instance.
605+
"""
589606
model_config = checkpoint["hyper_parameters"]["model_config"]
590607
loaded = cls.from_config(model_config)
591608
loaded.is_fitted = True
@@ -607,17 +624,26 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
607624
model_config=model_config,
608625
)
609626

610-
# save checkpoint to temp file to be able to use it in trainer
611-
with NamedTemporaryFile() as f:
612-
torch.save(checkpoint, f.name)
627+
try:
628+
temp_file = None
629+
actual_ckpt_path = ckpt_path
630+
if actual_ckpt_path is None:
631+
temp_file = NamedTemporaryFile()
632+
actual_ckpt_path = temp_file.name
633+
torch.save(checkpoint, actual_ckpt_path)
634+
613635
loaded.fit_trainer = deepcopy(loaded._trainer)
614636
# use stub dataset to load trainer state
615637
loaded.fit_trainer.fit(
616638
loaded.lightning_model,
617-
ckpt_path=f.name,
639+
ckpt_path=actual_ckpt_path,
618640
train_dataloaders=DataLoader(TensorDataset(torch.Tensor())),
619641
)
620642

643+
finally:
644+
if temp_file is not None:
645+
temp_file.close()
646+
621647
loaded.lightning_model.is_fitted = True
622648

623649
return loaded
@@ -675,7 +701,7 @@ def load_from_checkpoint(
675701
prev_config_flatten = make_dict_flat(prev_model_config)
676702
prev_config_flatten.update(model_params_update)
677703
checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten)
678-
loaded = cls._model_from_checkpoint(checkpoint)
704+
loaded = cls._model_from_checkpoint(checkpoint, ckpt_path=checkpoint_path)
679705
return loaded
680706

681707
def load_weights_from_checkpoint(self, checkpoint_path: tp.Union[str, Path]) -> None:

0 commit comments

Comments
 (0)