1818from copy import deepcopy
1919from pathlib import Path
2020from tempfile import NamedTemporaryFile
21- from flatten_dict import flatten , unflatten
2221
2322import numpy as np
2423import torch
2524import typing_extensions as tpe
25+ from flatten_dict import flatten , unflatten
2626from pydantic import BeforeValidator , PlainSerializer
2727from pytorch_lightning import Trainer
2828
@@ -608,8 +608,7 @@ def load_from_checkpoint(
608608 map_location : tp .Union [str , torch .device , None ] = None ,
609609 config_update : tp .Dict [str , tp .Any ] = {},
610610 ) -> tpe .Self :
611- """
612- Load model from Lightning checkpoint path.
611+ """Load model from Lightning checkpoint path.
613612
614613 Parameters
615614 ----------
@@ -619,16 +618,18 @@ def load_from_checkpoint(
619618 Target device to load the checkpoint (e.g., 'cpu', 'cuda:0').
620619 If None, will use the device the checkpoint was saved on.
621620 config_update: tp.Dict[str, tp.Any], default '{}'
622- Сontains custom values for checkpoint['hyper_parameters']
621+ Contains custom values for checkpoint['hyper_parameters'].
622+ Config_update has to be flattened with 'dot' reducer, before passed.
623+
623624 Returns
624625 -------
625626 Model instance.
626627 """
627628 checkpoint = torch .load (checkpoint_path , map_location = map_location , weights_only = False )
628629 prev_config = checkpoint ["hyper_parameters" ]
629- prev_config_flatten = flatten (prev_config , reducer = ' dot' )
630- prev_config_flatten .update (flatten ( config_update , reducer = 'dot' ) )
631- checkpoint ["hyper_parameters" ] = unflatten (prev_config_flatten , splitter = ' dot' )
630+ prev_config_flatten = flatten (prev_config , reducer = " dot" )
631+ prev_config_flatten .update (config_update )
632+ checkpoint ["hyper_parameters" ] = unflatten (prev_config_flatten , splitter = " dot" )
632633 loaded = cls ._model_from_checkpoint (checkpoint )
633634 return loaded
634635
0 commit comments