@@ -584,8 +584,25 @@ def _get_config(self) -> TransformerModelConfig_T:
584584 return self .config_class (** params )
585585
586586 @classmethod
587- def _model_from_checkpoint (cls , checkpoint : tp .Dict [str , tp .Any ]) -> tpe .Self :
588- """Create model from loaded Lightning checkpoint."""
587+ def _model_from_checkpoint (
588+ cls , checkpoint : tp .Dict [str , tp .Any ], ckpt_path : tp .Optional [tp .Union [str , Path ]] = None
589+ ) -> tpe .Self :
590+ """
591+ Create model from loaded Lightning checkpoint.
592+
593+ Parameters
594+ ----------
595+ checkpoint: Dict[str, tp.Any]
596+ Checkpoint object (pl/torch like)
597+ ckpt_path: Union[str, Path], optional
598+ Path to checkpoint location.
599+ If specified should be a path to `checkpoint` arg file.
600+ `checkpoint` is saved to temp file if not specified.
601+
602+ Returns
603+ -------
604+ Model instance.
605+ """
589606 model_config = checkpoint ["hyper_parameters" ]["model_config" ]
590607 loaded = cls .from_config (model_config )
591608 loaded .is_fitted = True
@@ -607,17 +624,26 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
607624 model_config = model_config ,
608625 )
609626
610- # save checkpoint to temp file to be able to use it in trainer
611- with NamedTemporaryFile () as f :
612- torch .save (checkpoint , f .name )
627+ try :
628+ temp_file = None
629+ actual_ckpt_path = ckpt_path
630+ if actual_ckpt_path is None :
631+ temp_file = NamedTemporaryFile ()
632+ actual_ckpt_path = temp_file .name
633+ torch .save (checkpoint , actual_ckpt_path )
634+
613635 loaded .fit_trainer = deepcopy (loaded ._trainer )
614636 # use stub dataset to load trainer state
615637 loaded .fit_trainer .fit (
616638 loaded .lightning_model ,
617- ckpt_path = f . name ,
639+ ckpt_path = actual_ckpt_path ,
618640 train_dataloaders = DataLoader (TensorDataset (torch .Tensor ())),
619641 )
620642
643+ finally :
644+ if temp_file is not None :
645+ temp_file .close ()
646+
621647 loaded .lightning_model .is_fitted = True
622648
623649 return loaded
@@ -675,7 +701,7 @@ def load_from_checkpoint(
675701 prev_config_flatten = make_dict_flat (prev_model_config )
676702 prev_config_flatten .update (model_params_update )
677703 checkpoint ["hyper_parameters" ]["model_config" ] = unflatten_dict (prev_config_flatten )
678- loaded = cls ._model_from_checkpoint (checkpoint )
704+ loaded = cls ._model_from_checkpoint (checkpoint , ckpt_path = checkpoint_path )
679705 return loaded
680706
681707 def load_weights_from_checkpoint (self , checkpoint_path : tp .Union [str , Path ]) -> None :
0 commit comments