33import typing as tp
44from pathlib import Path
55
6- import torch
76import pytorch_lightning as pl
7+ import torch
88from pytorch_lightning .callbacks import EarlyStopping
99
10+ from .gpu_data import align_embeddings , build_sequences , make_dataloader
11+ from .unisrec_lightning import SUPPORTED_LOSSES , SUPPORTED_OPTIMIZERS , SUPPORTED_SCHEDULERS , UniSRecLightning
1012from .unisrec_net import UniSRec
11- from .unisrec_lightning import UniSRecLightning , SUPPORTED_LOSSES , SUPPORTED_OPTIMIZERS , SUPPORTED_SCHEDULERS
12- from .gpu_data import build_sequences , align_embeddings , make_dataloader
1313
1414
1515class UniSRecModel :
@@ -143,7 +143,12 @@ def _make_trainer(self, max_epochs: int, val_dl: tp.Any = None) -> pl.Trainer:
143143 )
144144
145145 def _make_lightning (
146- self , net : UniSRec , param_groups : tp .List [tp .Dict ], use_id : bool , max_epochs : int , train_dl : tp .Any ,
146+ self ,
147+ net : UniSRec ,
148+ param_groups : tp .List [tp .Dict ],
149+ use_id : bool ,
150+ max_epochs : int ,
151+ train_dl : tp .Any ,
147152 ) -> UniSRecLightning :
148153 total_steps = len (train_dl ) * max_epochs if self .scheduler else None
149154 return UniSRecLightning (
@@ -172,16 +177,22 @@ def _phase2_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]:
172177 {"params" : [net .whitening_bias ], "lr" : self .phase2_lr * 10.0 , "weight_decay" : 0.0 },
173178 ]
174179 if net .head is not None :
175- groups .append ({
176- "params" : list (net .head .parameters ()),
177- "lr" : self .phase2_lr * self .lr_head ,
178- "weight_decay" : self .weight_decay ,
179- })
180+ groups .append (
181+ {
182+ "params" : list (net .head .parameters ()),
183+ "lr" : self .phase2_lr * self .lr_head ,
184+ "weight_decay" : self .weight_decay ,
185+ }
186+ )
180187 else :
181188 groups = [
182189 {"params" : list (net .bn_input .parameters ()), "lr" : self .phase2_lr , "weight_decay" : 0.0 },
183190 {"params" : list (net .bn_score .parameters ()), "lr" : self .phase2_lr , "weight_decay" : 0.0 },
184- {"params" : list (net .head .parameters ()), "lr" : self .phase2_lr * self .lr_head , "weight_decay" : self .weight_decay },
191+ {
192+ "params" : list (net .head .parameters ()),
193+ "lr" : self .phase2_lr * self .lr_head ,
194+ "weight_decay" : self .weight_decay ,
195+ },
185196 ]
186197 return groups
187198
@@ -198,21 +209,27 @@ def _phase3_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]:
198209 ]
199210 head : tp .List [tp .Dict [str , tp .Any ]] = []
200211 if net .head is not None :
201- head = [{"params" : list (net .head .parameters ()), "lr" : self .phase3_lr * self .lr_head , "weight_decay" : self .weight_decay }]
212+ head = [
213+ {
214+ "params" : list (net .head .parameters ()),
215+ "lr" : self .phase3_lr * self .lr_head ,
216+ "weight_decay" : self .weight_decay ,
217+ }
218+ ]
202219 transformer = [
203220 {"params" : list (net .pos_emb .parameters ()), "lr" : self .phase3_lr * self .lr_transformer , "weight_decay" : 0.0 },
204221 {
205222 "params" : (
206- [p for l in net .attention_layers for p in l .parameters ()]
207- + [p for l in net .forward_layers for p in l .parameters ()]
223+ [p for layer in net .attention_layers for p in layer .parameters ()]
224+ + [p for layer in net .forward_layers for p in layer .parameters ()]
208225 ),
209226 "lr" : self .phase3_lr * self .lr_transformer ,
210227 "weight_decay" : self .weight_decay ,
211228 },
212229 {
213230 "params" : (
214- [p for l in net .attention_layernorms for p in l .parameters ()]
215- + [p for l in net .forward_layernorms for p in l .parameters ()]
231+ [p for layer in net .attention_layernorms for p in layer .parameters ()]
232+ + [p for layer in net .forward_layernorms for p in layer .parameters ()]
216233 + list (net .last_layernorm .parameters ())
217234 ),
218235 "lr" : self .phase3_lr ,
@@ -246,7 +263,9 @@ def fit(
246263 self
247264 """
248265 x , y , unique_items , unique_users = build_sequences (
249- user_ids , item_ids , timestamps ,
266+ user_ids ,
267+ item_ids ,
268+ timestamps ,
250269 max_len = self .session_max_len ,
251270 min_interactions = self .train_min_user_interactions ,
252271 )
@@ -303,12 +322,15 @@ def _run_phase(param_groups: tp.List[tp.Dict], use_id: bool, max_epochs: int) ->
303322
304323 def save_checkpoint (self , path : tp .Union [str , Path ]) -> None :
305324 assert self ._net is not None
306- torch .save ({
307- "net" : self ._net .state_dict (),
308- "unique_items" : self ._unique_items ,
309- "unique_users" : self ._unique_users ,
310- "n_items" : len (self ._unique_items ),
311- }, path )
325+ torch .save (
326+ {
327+ "net" : self ._net .state_dict (),
328+ "unique_items" : self ._unique_items ,
329+ "unique_users" : self ._unique_users ,
330+ "n_items" : len (self ._unique_items ),
331+ },
332+ path ,
333+ )
312334
313335 def load_checkpoint (self , path : tp .Union [str , Path ], device : str = "cuda" ) -> None :
314336 ckpt = torch .load (path , map_location = device , weights_only = False )
0 commit comments