@@ -779,6 +779,11 @@ def _forward_common(
779779 assert fparam is not None , "fparam should not be None"
780780 assert self .fparam_avg is not None
781781 assert self .fparam_inv_std is not None
782+ if fparam .numel () != nf * self .numb_fparam :
783+ raise ValueError (
784+ f"input fparam: cannot reshape { list (fparam .shape )} "
785+ f"into ({ nf } , { self .numb_fparam } )."
786+ )
782787 fparam = fparam .view ([nf , self .numb_fparam ])
783788 nb , _ = fparam .shape
784789 t_fparam_avg = self ._extend_f_avg_std (self .fparam_avg , nb )
@@ -799,6 +804,11 @@ def _forward_common(
799804 assert aparam is not None , "aparam should not be None"
800805 assert self .aparam_avg is not None
801806 assert self .aparam_inv_std is not None
807+ if aparam .numel () % (nf * self .numb_aparam ) != 0 :
808+ raise ValueError (
809+ f"input aparam: cannot reshape { list (aparam .shape )} "
810+ f"into ({ nf } , nloc, { self .numb_aparam } )."
811+ )
802812 aparam = aparam .view ([nf , - 1 , self .numb_aparam ])
803813 nb , nloc , _ = aparam .shape
804814 t_aparam_avg = self ._extend_a_avg_std (self .aparam_avg , nb , nloc )
0 commit comments