Skip to content

Commit a1cf5bf

Browse files
committed
added test& edited pyproject
1 parent 659d016 commit a1cf5bf

3 files changed

Lines changed: 28 additions & 8 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ ipywidgets = {version = ">=7.7,<8.2", optional = true}
8585
plotly = {version="^5.22.0", optional = true}
8686
nbformat = {version = ">=4.2.0", optional = true}
8787
cupy-cuda12x = {version = "^13.3.0", python = "<3.13", optional = true}
88+
flatten-dict = "^0.4.2"
8889

8990

9091
[tool.poetry.extras]

rectools/models/nn/transformers/base.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
from copy import deepcopy
1919
from pathlib import Path
2020
from tempfile import NamedTemporaryFile
21-
from flatten_dict import flatten, unflatten
2221

2322
import numpy as np
2423
import torch
2524
import typing_extensions as tpe
25+
from flatten_dict import flatten, unflatten
2626
from pydantic import BeforeValidator, PlainSerializer
2727
from pytorch_lightning import Trainer
2828

@@ -608,8 +608,7 @@ def load_from_checkpoint(
608608
map_location: tp.Union[str, torch.device, None] = None,
609609
config_update: tp.Dict[str, tp.Any] = {},
610610
) -> tpe.Self:
611-
"""
612-
Load model from Lightning checkpoint path.
611+
"""Load model from Lightning checkpoint path.
613612
614613
Parameters
615614
----------
@@ -619,16 +618,18 @@ def load_from_checkpoint(
619618
Target device to load the checkpoint (e.g., 'cpu', 'cuda:0').
620619
If None, will use the device the checkpoint was saved on.
621620
config_update: tp.Dict[str, tp.Any], default '{}'
622-
Сontains custom values for checkpoint['hyper_parameters']
621+
Contains custom values for checkpoint['hyper_parameters'].
622+
Config_update has to be flattened with 'dot' reducer, before passed.
623+
623624
Returns
624625
-------
625626
Model instance.
626627
"""
627628
checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
628629
prev_config = checkpoint["hyper_parameters"]
629-
prev_config_flatten = flatten(prev_config, reducer='dot')
630-
prev_config_flatten.update(flatten(config_update, reducer='dot'))
631-
checkpoint["hyper_parameters"] = unflatten(prev_config_flatten, splitter='dot')
630+
prev_config_flatten = flatten(prev_config, reducer="dot")
631+
prev_config_flatten.update(config_update)
632+
checkpoint["hyper_parameters"] = unflatten(prev_config_flatten, splitter="dot")
632633
loaded = cls._model_from_checkpoint(checkpoint)
633634
return loaded
634635

tests/models/nn/transformers/test_base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pandas as pd
2020
import pytest
2121
import torch
22+
from flatten_dict import flatten
2223
from pytest import FixtureRequest
2324
from pytorch_lightning import Trainer, seed_everything
2425
from pytorch_lightning.loggers import CSVLogger
@@ -152,10 +153,24 @@ def test_save_load_for_fitted_model(
152153

153154
@pytest.mark.parametrize("test_dataset", ("dataset", "dataset_item_features"))
154155
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
156+
@pytest.mark.parametrize("map_location", ("cpu", torch.device("cuda:0"), None))
157+
@pytest.mark.parametrize(
158+
"config_update",
159+
(
160+
{
161+
"model_config": {
162+
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
163+
"get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer",
164+
}
165+
},
166+
),
167+
)
155168
def test_load_from_checkpoint(
156169
self,
157170
model_cls: tp.Type[TransformerModelBase],
158171
test_dataset: str,
172+
map_location: tp.Union[str, torch.device, None],
173+
config_update: tp.Dict[str, tp.Any],
159174
request: FixtureRequest,
160175
) -> None:
161176

@@ -173,7 +188,10 @@ def test_load_from_checkpoint(
173188
raise ValueError("No log dir")
174189
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
175190
assert os.path.isfile(ckpt_path)
176-
recovered_model = model_cls.load_from_checkpoint(ckpt_path)
191+
config_update_flatten = flatten(config_update, reducer="dot")
192+
recovered_model = model_cls.load_from_checkpoint(
193+
ckpt_path, map_location=map_location, config_update=config_update_flatten
194+
)
177195
assert isinstance(recovered_model, model_cls)
178196

179197
self._assert_same_reco(model, recovered_model, dataset)

0 commit comments

Comments
 (0)