@@ -33,7 +33,7 @@ def test_forward_basic(self):
3333
3434 outputs , _ = F (
3535 layer ,
36- inputs = inputs ,
36+ inputs = ( inputs ,) ,
3737 is_training = True ,
3838 prng_key = prng_key ,
3939 state = layer_params ,
@@ -61,7 +61,7 @@ def test_forward_with_layer_norm(self):
6161
6262 outputs , _ = F (
6363 layer ,
64- inputs = inputs ,
64+ inputs = ( inputs ,) ,
6565 is_training = True ,
6666 prng_key = prng_key ,
6767 state = layer_params ,
@@ -88,7 +88,7 @@ def test_forward_without_residual(self):
8888
8989 outputs , _ = F (
9090 layer ,
91- inputs = inputs ,
91+ inputs = ( inputs ,) ,
9292 is_training = True ,
9393 prng_key = prng_key ,
9494 state = layer_params ,
@@ -119,7 +119,7 @@ def test_forward_with_scaling(self):
119119
120120 outputs , _ = F (
121121 layer ,
122- inputs = inputs ,
122+ inputs = ( inputs ,) ,
123123 is_training = True ,
124124 prng_key = prng_key ,
125125 state = layer_params ,
@@ -147,7 +147,7 @@ def test_forward_with_activation(self, activation: str):
147147
148148 outputs , _ = F (
149149 layer ,
150- inputs = inputs ,
150+ inputs = ( inputs ,) ,
151151 is_training = True ,
152152 prng_key = prng_key ,
153153 state = layer_params ,
@@ -175,20 +175,23 @@ def test_parameter_counts(self):
175175 up_proj_weight = layer_params ["up_proj" ]["weight" ]
176176 up_proj_bias = layer_params ["up_proj" ]["bias" ]
177177 layer_norm_scale = layer_params ["layer_norm" ]["scale" ]
178+ layer_norm_bias = layer_params ["layer_norm" ]["bias" ]
178179
179180 self .assertEqual (down_proj_weight .shape , (input_dim , bottleneck_dim ))
180181 self .assertEqual (down_proj_bias .shape , (bottleneck_dim ,))
181182 self .assertEqual (up_proj_weight .shape , (bottleneck_dim , input_dim ))
182183 self .assertEqual (up_proj_bias .shape , (input_dim ,))
183184 self .assertEqual (layer_norm_scale .shape , (input_dim ,))
185+ self .assertEqual (layer_norm_bias .shape , (input_dim ,))
184186
185187 total_params = np .prod (down_proj_weight .shape )
186188 total_params += np .prod (down_proj_bias .shape )
187189 total_params += np .prod (up_proj_weight .shape )
188190 total_params += np .prod (up_proj_bias .shape )
189191 total_params += np .prod (layer_norm_scale .shape )
192+ total_params += np .prod (layer_norm_bias .shape )
190193
191- self .assertEqual (total_params , 82368 )
194+ self .assertEqual (total_params , 33664 )
192195
193196 @parameterized .parameters ([True , False ])
194197 def test_training_vs_eval_mode (self , is_training : bool ):
@@ -209,7 +212,7 @@ def test_training_vs_eval_mode(self, is_training: bool):
209212
210213 outputs , _ = F (
211214 layer ,
212- inputs = inputs ,
215+ inputs = ( inputs ,) ,
213216 is_training = is_training ,
214217 prng_key = prng_key ,
215218 state = layer_params ,
0 commit comments