Skip to content

Commit c7d720f

Browse files
committed
fix flaws
1 parent 37ebe2c commit c7d720f

3 files changed

Lines changed: 24 additions & 23 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ ipywidgets = {version = ">=7.7,<8.2", optional = true}
8585
plotly = {version="^5.22.0", optional = true}
8686
nbformat = {version = ">=4.2.0", optional = true}
8787
cupy-cuda12x = {version = "^13.3.0", python = "<3.13", optional = true}
88-
flatten-dict = "^0.4.2"
8988

9089

9190
[tool.poetry.extras]

rectools/models/nn/transformers/base.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,14 @@
2222
import numpy as np
2323
import torch
2424
import typing_extensions as tpe
25-
from flatten_dict import flatten, unflatten
2625
from pydantic import BeforeValidator, PlainSerializer
2726
from pytorch_lightning import Trainer
2827

2928
from rectools import ExternalIds
3029
from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap
3130
from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig
3231
from rectools.types import InternalIdsArray
33-
from rectools.utils.misc import get_class_or_function_full_path, import_object
32+
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict
3433

3534
from ..item_net import (
3635
CatFeaturesItemNet,
@@ -605,31 +604,33 @@ def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None:
605604
def load_from_checkpoint(
606605
cls,
607606
checkpoint_path: tp.Union[str, Path],
608-
map_location: tp.Union[str, torch.device, None] = None,
609-
config_update: tp.Dict[str, tp.Any] = {},
607+
map_location: tp.Optional[tp.Union[str, torch.device]] = None,
608+
model_params_update: tp.Optional[tp.Dict[str, tp.Any]] = None,
610609
) -> tpe.Self:
611610
"""Load model from Lightning checkpoint path.
612611
613612
Parameters
614613
----------
615614
checkpoint_path: Union[str, Path]
616615
Path to checkpoint location.
617-
map_location: Union[str, torch.device, None], default None
616+
map_location: Union[str, torch.device], optional
618617
Target device to load the checkpoint (e.g., 'cpu', 'cuda:0').
619618
If None, will use the device the checkpoint was saved on.
620-
config_update: tp.Dict[str, tp.Any], default '{}'
621-
Contains custom values for checkpoint['hyper_parameters'].
622-
Config_update has to be flattened with 'dot' reducer, before passed.
623-
619+
model_params_update: Dict[str, tp.Any], optional
620+
Contains custom values for checkpoint['hyper_parameters']['model_config'].
621+
Has to be flattened with 'dot' reducer, before passed.
622+
You can use this argument to remove training-specific parameters that are not needed anymore.
623+
e.g. 'get_trainer_func'
624624
Returns
625625
-------
626626
Model instance.
627627
"""
628628
checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
629-
prev_config = checkpoint["hyper_parameters"]
630-
prev_config_flatten = flatten(prev_config, reducer="dot")
631-
prev_config_flatten.update(config_update)
632-
checkpoint["hyper_parameters"] = unflatten(prev_config_flatten, splitter="dot")
629+
prev_model_config = checkpoint["hyper_parameters"]["model_config"]
630+
if model_params_update:
631+
prev_config_flatten = make_dict_flat(prev_model_config)
632+
prev_config_flatten.update(model_params_update)
633+
checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten)
633634
loaded = cls._model_from_checkpoint(checkpoint)
634635
return loaded
635636

tests/models/nn/transformers/test_base.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import pandas as pd
2020
import pytest
2121
import torch
22-
from flatten_dict import flatten
2322
from pytest import FixtureRequest
2423
from pytorch_lightning import Trainer, seed_everything
2524
from pytorch_lightning.loggers import CSVLogger
@@ -155,22 +154,25 @@ def test_save_load_for_fitted_model(
155154
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
156155
@pytest.mark.parametrize("map_location", ("cpu", torch.device("cuda:0"), None))
157156
@pytest.mark.parametrize(
158-
"config_update",
157+
"model_params_update",
159158
(
160159
{
161-
"model_config": {
162-
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
163-
"get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer",
164-
}
160+
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
161+
"get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer",
165162
},
163+
{
164+
"get_val_mask_func": None,
165+
"get_trainer_func": None,
166+
},
167+
None
166168
),
167169
)
168170
def test_load_from_checkpoint(
169171
self,
170172
model_cls: tp.Type[TransformerModelBase],
171173
test_dataset: str,
172174
map_location: tp.Union[str, torch.device, None],
173-
config_update: tp.Dict[str, tp.Any],
175+
model_params_update: tp.Dict[str, tp.Any],
174176
request: FixtureRequest,
175177
) -> None:
176178

@@ -188,9 +190,8 @@ def test_load_from_checkpoint(
188190
raise ValueError("No log dir")
189191
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
190192
assert os.path.isfile(ckpt_path)
191-
config_update_flatten = flatten(config_update, reducer="dot")
192193
recovered_model = model_cls.load_from_checkpoint(
193-
ckpt_path, map_location=map_location, config_update=config_update_flatten
194+
ckpt_path, map_location=map_location, model_params_update=model_params_update
194195
)
195196
assert isinstance(recovered_model, model_cls)
196197

0 commit comments

Comments
 (0)