|
22 | 22 | import numpy as np |
23 | 23 | import torch |
24 | 24 | import typing_extensions as tpe |
25 | | -from flatten_dict import flatten, unflatten |
26 | 25 | from pydantic import BeforeValidator, PlainSerializer |
27 | 26 | from pytorch_lightning import Trainer |
28 | 27 |
|
29 | 28 | from rectools import ExternalIds |
30 | 29 | from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap |
31 | 30 | from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig |
32 | 31 | from rectools.types import InternalIdsArray |
33 | | -from rectools.utils.misc import get_class_or_function_full_path, import_object |
| 32 | +from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict |
34 | 33 |
|
35 | 34 | from ..item_net import ( |
36 | 35 | CatFeaturesItemNet, |
@@ -605,31 +604,33 @@ def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None: |
605 | 604 | def load_from_checkpoint( |
606 | 605 | cls, |
607 | 606 | checkpoint_path: tp.Union[str, Path], |
608 | | - map_location: tp.Union[str, torch.device, None] = None, |
609 | | - config_update: tp.Dict[str, tp.Any] = {}, |
| 607 | + map_location: tp.Optional[tp.Union[str, torch.device]] = None, |
| 608 | + model_params_update: tp.Optional[tp.Dict[str, tp.Any]] = None, |
610 | 609 | ) -> tpe.Self: |
611 | 610 | """Load model from Lightning checkpoint path. |
612 | 611 |
|
613 | 612 | Parameters |
614 | 613 | ---------- |
615 | 614 | checkpoint_path: Union[str, Path] |
616 | 615 | Path to checkpoint location. |
617 | | - map_location: Union[str, torch.device, None], default None |
| 616 | + map_location: Union[str, torch.device], optional |
618 | 617 | Target device to load the checkpoint (e.g., 'cpu', 'cuda:0'). |
619 | 618 | If None, will use the device the checkpoint was saved on. |
620 | | - config_update: tp.Dict[str, tp.Any], default '{}' |
621 | | - Contains custom values for checkpoint['hyper_parameters']. |
622 | | - Config_update has to be flattened with 'dot' reducer, before passed. |
623 | | -
|
| 619 | + model_params_update: Dict[str, tp.Any], optional |
| 620 | + Contains custom values for checkpoint['hyper_parameters']['model_config']. |
| 621 | + Has to be flattened with 'dot' reducer, before passed. |
| 622 | + You can use this argument to remove training-specific parameters that are not needed anymore. |
| 623 | + e.g. 'get_trainer_func' |
624 | 624 | Returns |
625 | 625 | ------- |
626 | 626 | Model instance. |
627 | 627 | """ |
628 | 628 | checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False) |
629 | | - prev_config = checkpoint["hyper_parameters"] |
630 | | - prev_config_flatten = flatten(prev_config, reducer="dot") |
631 | | - prev_config_flatten.update(config_update) |
632 | | - checkpoint["hyper_parameters"] = unflatten(prev_config_flatten, splitter="dot") |
| 629 | + prev_model_config = checkpoint["hyper_parameters"]["model_config"] |
| 630 | + if model_params_update: |
| 631 | + prev_config_flatten = make_dict_flat(prev_model_config) |
| 632 | + prev_config_flatten.update(model_params_update) |
| 633 | + checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten) |
633 | 634 | loaded = cls._model_from_checkpoint(checkpoint) |
634 | 635 | return loaded |
635 | 636 |
|
|
0 commit comments