@@ -725,9 +725,64 @@ def step(_step_id, task_key="Default") -> None:
725725 pref_lr = _lr .start_lr
726726 else :
727727 pref_lr = cur_lr
728+
729+ # save
730+ # torch.save(self.wrapper.state_dict(), "wrapper_dict.pt")
731+ # import paddle
732+ # psd = {}
733+ # for k, v in self.wrapper.state_dict().items():
734+ # if isinstance(v, torch.Tensor):
735+ # psd[k] = paddle.from_dlpack(v.detach())
736+ # else:
737+ # psd[k] = v
738+ # paddle.save(psd, "wrapper_dict.pd")
739+ # inp = {}
740+ # for k, v in input_dict.items():
741+ # if isinstance(v, torch.Tensor):
742+ # inp[k] = v.detach().cpu().numpy()
743+ # else:
744+ # inp[k] = v
745+ # np.savez("./input_dict.npz", **inp)
746+ # lab = {}
747+ # for k, v in label_dict.items():
748+ # if isinstance(v, torch.Tensor):
749+ # lab[k] = v.detach().cpu().numpy()
750+ # else:
751+ # lab[k] = v
752+ # np.savez("./label_dict.npz", **lab)
753+
754+ # load
755+ self .wrapper .load_state_dict (torch .load ("./wrapper_dict.pt" ))
756+ print ("model loaded" )
757+ inp = np .load ("./input_dict.npz" , allow_pickle = True )
758+ for k , v in inp .items ():
759+ if isinstance (v , np .ndarray ):
760+ # print(k, type(v), v.shape, v.dtype)
761+ try :
762+ input_dict [k ] = torch .tensor (v )
763+ except TypeError :
764+ pass
765+ if isinstance (input_dict [k ], torch .Tensor ):
766+ input_dict [k ] = input_dict [k ].cuda ()
767+ print ("input_dict loaded" )
768+ lab = np .load ("./label_dict.npz" , allow_pickle = True )
769+ for k , v in lab .items ():
770+ if isinstance (v , np .ndarray ):
771+ # print(k, type(v), v.shape, v.dtype)
772+ try :
773+ label_dict [k ] = torch .tensor (v )
774+ except TypeError :
775+ pass
776+ if isinstance (label_dict [k ], torch .Tensor ):
777+ label_dict [k ] = label_dict [k ].cuda ()
778+ print ("label_dict loaded" )
779+
728780 model_pred , loss , more_loss = self .wrapper (
729781 ** input_dict , cur_lr = pref_lr , label = label_dict , task_key = task_key
730782 )
783+ print ({k : float (v ) for k , v in more_loss .items ()})
784+ print (f"{ loss .item ():.10f} " )
785+ exit ()
731786 loss .backward ()
732787 if self .gradient_max_norm > 0.0 :
733788 torch .nn .utils .clip_grad_norm_ (
0 commit comments