Skip to content

Commit cacd3c2

Browse files
nsundalovblondered
andauthored
Apply suggestions from code review
accept suggestions Co-authored-by: Daria <93913290+blondered@users.noreply.github.com>
1 parent e3b0839 commit cacd3c2

2 files changed

Lines changed: 2 additions & 3 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## Unreleased
99
### Added
1010
- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287))
11-
- allow resaving transformer model multiple times. Load train state on model loading ([#289](https://github.com/MobileTeleSystems/RecTools/pull/289))
11+
- Support for resaving transformer models multiple times and loading trainer state ([#289](https://github.com/MobileTeleSystems/RecTools/pull/289))
1212

1313
## [0.14.0] - 16.05.2025
1414

rectools/models/nn/transformers/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,8 +610,7 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
610610
# save checkpoint to temp file to be able to use it in trainer
611611
with NamedTemporaryFile() as f:
612612
torch.save(checkpoint, f.name)
613-
fit_trainer = deepcopy(loaded._trainer)
614-
loaded.fit_trainer = fit_trainer
613+
loaded.fit_trainer = deepcopy(loaded._trainer)
615614
# use stub dataset to load trainer state
616615
loaded.fit_trainer.fit(
617616
loaded.lightning_model,

0 commit comments

Comments
 (0)