Skip to content

Commit b16c188

Browse files
committed
Fix audio adapter test failures
- Fixed TypeError by wrapping single tensor inputs in tuples for F() calls in both adapter.py and adapter_test.py - Fixed parameter count assertion by including layer_norm.bias in the count calculation
1 parent ae7c73f commit b16c188

2 files changed

Lines changed: 12 additions & 9 deletions

File tree

axlearn/audio/adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def adapt_encoder_features(self, features, *, is_training=False, prng_key=None,
215215
if state is not None and prng_key is not None:
216216
outputs, _ = F(
217217
self.encoder_adapter,
218-
inputs=features,
218+
inputs=(features,),
219219
is_training=is_training,
220220
prng_key=prng_key,
221221
state=state,
@@ -245,7 +245,7 @@ def adapt_decoder_features(self, features, *, is_training=False, prng_key=None,
245245
if state is not None and prng_key is not None:
246246
outputs, _ = F(
247247
self.decoder_adapter,
248-
inputs=features,
248+
inputs=(features,),
249249
is_training=is_training,
250250
prng_key=prng_key,
251251
state=state,

axlearn/audio/adapter_test.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_forward_basic(self):
3333

3434
outputs, _ = F(
3535
layer,
36-
inputs=inputs,
36+
inputs=(inputs,),
3737
is_training=True,
3838
prng_key=prng_key,
3939
state=layer_params,
@@ -61,7 +61,7 @@ def test_forward_with_layer_norm(self):
6161

6262
outputs, _ = F(
6363
layer,
64-
inputs=inputs,
64+
inputs=(inputs,),
6565
is_training=True,
6666
prng_key=prng_key,
6767
state=layer_params,
@@ -88,7 +88,7 @@ def test_forward_without_residual(self):
8888

8989
outputs, _ = F(
9090
layer,
91-
inputs=inputs,
91+
inputs=(inputs,),
9292
is_training=True,
9393
prng_key=prng_key,
9494
state=layer_params,
@@ -119,7 +119,7 @@ def test_forward_with_scaling(self):
119119

120120
outputs, _ = F(
121121
layer,
122-
inputs=inputs,
122+
inputs=(inputs,),
123123
is_training=True,
124124
prng_key=prng_key,
125125
state=layer_params,
@@ -147,7 +147,7 @@ def test_forward_with_activation(self, activation: str):
147147

148148
outputs, _ = F(
149149
layer,
150-
inputs=inputs,
150+
inputs=(inputs,),
151151
is_training=True,
152152
prng_key=prng_key,
153153
state=layer_params,
@@ -175,20 +175,23 @@ def test_parameter_counts(self):
175175
up_proj_weight = layer_params["up_proj"]["weight"]
176176
up_proj_bias = layer_params["up_proj"]["bias"]
177177
layer_norm_scale = layer_params["layer_norm"]["scale"]
178+
layer_norm_bias = layer_params["layer_norm"]["bias"]
178179

179180
self.assertEqual(down_proj_weight.shape, (input_dim, bottleneck_dim))
180181
self.assertEqual(down_proj_bias.shape, (bottleneck_dim,))
181182
self.assertEqual(up_proj_weight.shape, (bottleneck_dim, input_dim))
182183
self.assertEqual(up_proj_bias.shape, (input_dim,))
183184
self.assertEqual(layer_norm_scale.shape, (input_dim,))
185+
self.assertEqual(layer_norm_bias.shape, (input_dim,))
184186

185187
total_params = np.prod(down_proj_weight.shape)
186188
total_params += np.prod(down_proj_bias.shape)
187189
total_params += np.prod(up_proj_weight.shape)
188190
total_params += np.prod(up_proj_bias.shape)
189191
total_params += np.prod(layer_norm_scale.shape)
192+
total_params += np.prod(layer_norm_bias.shape)
190193

191-
self.assertEqual(total_params, 82368)
194+
self.assertEqual(total_params, 33664)
192195

193196
@parameterized.parameters([True, False])
194197
def test_training_vs_eval_mode(self, is_training: bool):
@@ -209,7 +212,7 @@ def test_training_vs_eval_mode(self, is_training: bool):
209212

210213
outputs, _ = F(
211214
layer,
212-
inputs=inputs,
215+
inputs=(inputs,),
213216
is_training=is_training,
214217
prng_key=prng_key,
215218
state=layer_params,

0 commit comments

Comments
 (0)