@@ -779,7 +779,13 @@ def _forward_common(
779779 assert fparam is not None , "fparam should not be None"
780780 assert self .fparam_avg is not None
781781 assert self .fparam_inv_std is not None
782- fparam = fparam .view ([nf , self .numb_fparam ])
782+ try :
783+ fparam = fparam .view ([nf , self .numb_fparam ])
784+ except RuntimeError as e :
785+ raise ValueError (
786+ f"input fparam: cannot reshape { list (fparam .shape )} "
787+ f"into ({ nf } , { self .numb_fparam } )."
788+ ) from e
783789 nb , _ = fparam .shape
784790 t_fparam_avg = self ._extend_f_avg_std (self .fparam_avg , nb )
785791 t_fparam_inv_std = self ._extend_f_avg_std (self .fparam_inv_std , nb )
@@ -799,7 +805,13 @@ def _forward_common(
799805 assert aparam is not None , "aparam should not be None"
800806 assert self .aparam_avg is not None
801807 assert self .aparam_inv_std is not None
802- aparam = aparam .view ([nf , - 1 , self .numb_aparam ])
808+ try :
809+ aparam = aparam .view ([nf , - 1 , self .numb_aparam ])
810+ except RuntimeError as e :
811+ raise ValueError (
812+ f"input aparam: cannot reshape { list (aparam .shape )} "
813+ f"into ({ nf } , nloc, { self .numb_aparam } )."
814+ ) from e
803815 nb , nloc , _ = aparam .shape
804816 t_aparam_avg = self ._extend_a_avg_std (self .aparam_avg , nb , nloc )
805817 t_aparam_inv_std = self ._extend_a_avg_std (self .aparam_inv_std , nb , nloc )
0 commit comments