@@ -725,64 +725,9 @@ def step(_step_id, task_key="Default") -> None:
725725 pref_lr = _lr .start_lr
726726 else :
727727 pref_lr = cur_lr
728-
729- # save
730- # torch.save(self.wrapper.state_dict(), "wrapper_dict.pt")
731- # import paddle
732- # psd = {}
733- # for k, v in self.wrapper.state_dict().items():
734- # if isinstance(v, torch.Tensor):
735- # psd[k] = paddle.from_dlpack(v.detach())
736- # else:
737- # psd[k] = v
738- # paddle.save(psd, "wrapper_dict.pd")
739- # inp = {}
740- # for k, v in input_dict.items():
741- # if isinstance(v, torch.Tensor):
742- # inp[k] = v.detach().cpu().numpy()
743- # else:
744- # inp[k] = v
745- # np.savez("./input_dict.npz", **inp)
746- # lab = {}
747- # for k, v in label_dict.items():
748- # if isinstance(v, torch.Tensor):
749- # lab[k] = v.detach().cpu().numpy()
750- # else:
751- # lab[k] = v
752- # np.savez("./label_dict.npz", **lab)
753-
754- # load
755- self .wrapper .load_state_dict (torch .load ("./wrapper_dict.pt" ))
756- print ("model loaded" )
757- inp = np .load ("./input_dict.npz" , allow_pickle = True )
758- for k , v in inp .items ():
759- if isinstance (v , np .ndarray ):
760- # print(k, type(v), v.shape, v.dtype)
761- try :
762- input_dict [k ] = torch .tensor (v )
763- except TypeError :
764- pass
765- if isinstance (input_dict [k ], torch .Tensor ):
766- input_dict [k ] = input_dict [k ].cuda ()
767- print ("input_dict loaded" )
768- lab = np .load ("./label_dict.npz" , allow_pickle = True )
769- for k , v in lab .items ():
770- if isinstance (v , np .ndarray ):
771- # print(k, type(v), v.shape, v.dtype)
772- try :
773- label_dict [k ] = torch .tensor (v )
774- except TypeError :
775- pass
776- if isinstance (label_dict [k ], torch .Tensor ):
777- label_dict [k ] = label_dict [k ].cuda ()
778- print ("label_dict loaded" )
779-
780728 model_pred , loss , more_loss = self .wrapper (
781729 ** input_dict , cur_lr = pref_lr , label = label_dict , task_key = task_key
782730 )
783- print ({k : float (v ) for k , v in more_loss .items ()})
784- print (f"{ loss .item ():.10f} " )
785- exit ()
786731 loss .backward ()
787732 if self .gradient_max_norm > 0.0 :
788733 torch .nn .utils .clip_grad_norm_ (
0 commit comments