|
16 | 16 | ) |
17 | 17 | from axlearn.audio.test_utils import fake_audio |
18 | 18 | from axlearn.common.attention import RepeatedTransformerLayer |
| 19 | +from axlearn.common.convolution import Conv1DWithPadding |
19 | 20 | from axlearn.common.kv_cache.sliding_window_kv_cache import enable_sliding_window_attention |
20 | 21 | from axlearn.common.module import functional as F |
21 | 22 | from axlearn.common.test_utils import TestCase |
@@ -263,6 +264,121 @@ def test_transformer(self, is_training: bool) -> None: |
263 | 264 | output_collections.summaries["activations/speech_context_norm"].weight, weights |
264 | 265 | ) |
265 | 266 |
|
| 267 | + @parameterized.parameters([True, False]) |
| 268 | + @pytest.mark.fp64 |
| 269 | + def test_conformer_without_pos_emb(self, is_training: bool): |
| 270 | + """Tests SpeechContextNetwork with RepeatedConformerLayer when pos_emb is None.""" |
| 271 | + input_dim, output_dim, dropout_rate, num_layers = 32, 16, 0.2, 2 |
| 272 | + |
| 273 | + cfg = SpeechContextNetwork.default_config().set( |
| 274 | + input_dim=input_dim, output_dim=output_dim, dtype=jnp.float64 |
| 275 | + ) |
| 276 | + cfg.dropout.rate = dropout_rate |
| 277 | + cfg.context.num_layers = num_layers |
| 278 | + cfg.context.layer.self_attention.attention.num_heads = 4 |
| 279 | + cfg.context.layer.lconv.dropout.rate = dropout_rate |
| 280 | + cfg.pos_emb = None |
| 281 | + |
| 282 | + prng_key = jax.random.PRNGKey(123) |
| 283 | + prng_key, init_key, input_key, length_key = jax.random.split(prng_key, num=4) |
| 284 | + layer = cfg.set(name="test").instantiate(parent=None) |
| 285 | + layer_params = layer.initialize_parameters_recursively(init_key) |
| 286 | + |
| 287 | + # pos_emb should be absent from parameters when disabled. |
| 288 | + self.assertNotIn("pos_emb", layer.children) |
| 289 | + self.assertNotIn("pos_emb", layer_params) |
| 290 | + |
| 291 | + # Generate inputs. |
| 292 | + batch_size, seq_len = 4, 10 |
| 293 | + inputs = jnp.tile( |
| 294 | + jax.random.normal(input_key, [batch_size // 2, seq_len, input_dim]), [2, 1, 1] |
| 295 | + ) |
| 296 | + lengths = jnp.tile( |
| 297 | + jax.random.randint(length_key, shape=[batch_size // 2, 1], minval=0, maxval=seq_len), |
| 298 | + [2, 1], |
| 299 | + ) |
| 300 | + segment_ids = (jnp.arange(seq_len)[None, :] < lengths).astype(jnp.int32) |
| 301 | + padding_data = jax.random.normal(jax.random.PRNGKey(135), inputs.shape) |
| 302 | + inputs = jnp.where(segment_ids[..., None] == 0, padding_data, inputs) |
| 303 | + |
| 304 | + output_batch, _ = F( |
| 305 | + layer, |
| 306 | + inputs=dict(inputs=inputs, segment_ids=segment_ids), |
| 307 | + is_training=is_training, |
| 308 | + prng_key=prng_key, |
| 309 | + state=layer_params, |
| 310 | + ) |
| 311 | + outputs, output_segment_ids = output_batch["outputs"], output_batch["segment_ids"] |
| 312 | + self.assertSequenceEqual(outputs.shape, (batch_size, seq_len, output_dim)) |
| 313 | + self.assertTrue(jnp.all(output_segment_ids == segment_ids)) |
| 314 | + |
| 315 | + # If is_training, outputs differ due to dropout; otherwise identical despite padding noise. |
| 316 | + self.assertEqual(not is_training, bool(jnp.allclose(outputs[:2], outputs[2:]))) |
| 317 | + |
| 318 | + @parameterized.parameters([2, 3, 4, 5]) |
| 319 | + def test_post_downsample(self, strides) -> None: |
| 320 | + """Test the code branch with RepeatedTransformerLayer as context layer. |
| 321 | +
|
| 322 | + Args: |
| 323 | + is_training: Whether the is_training code path is tested. |
| 324 | + """ |
| 325 | + is_training = True |
| 326 | + input_dim, output_dim, dropout_rate, num_layers = 32, 16, 0.2, 2 |
| 327 | + num_heads = 8 |
| 328 | + hidden_dim = 4 * input_dim |
| 329 | + |
| 330 | + cfg = SpeechContextNetwork.default_config().set( |
| 331 | + input_dim=input_dim, output_dim=output_dim, dtype=jnp.float64 |
| 332 | + ) |
| 333 | + cfg.dropout.rate = dropout_rate |
| 334 | + cfg.context = RepeatedTransformerLayer.default_config().set(num_layers=num_layers) |
| 335 | + attention = cfg.context.layer.self_attention.attention |
| 336 | + attention.num_heads = num_heads |
| 337 | + attention = enable_sliding_window_attention(attention, sliding_window_size=3) |
| 338 | + cfg.context.layer.self_attention.attention = attention |
| 339 | + # Dropout in transformer |
| 340 | + cfg.context.layer.self_attention.dropout.rate = dropout_rate |
| 341 | + cfg.context.layer.feed_forward.set( |
| 342 | + hidden_dim=hidden_dim, |
| 343 | + ) |
| 344 | + |
| 345 | + # Initialize layer parameters. |
| 346 | + prng_key = jax.random.PRNGKey(123) |
| 347 | + prng_key, init_key, input_key, length_key = jax.random.split(prng_key, num=4) |
| 348 | + |
| 349 | + # Generate inputs. |
| 350 | + batch_size, seq_len = 4, 10 |
| 351 | + inputs = jnp.tile( |
| 352 | + jax.random.normal(input_key, [batch_size // 2, seq_len, input_dim]), [2, 1, 1] |
| 353 | + ) |
| 354 | + lengths = jnp.tile( |
| 355 | + jax.random.randint(length_key, shape=[batch_size // 2, 1], minval=0, maxval=seq_len), |
| 356 | + [2, 1], |
| 357 | + ) |
| 358 | + segment_ids = (jnp.arange(seq_len)[None, :] < lengths).astype(jnp.int32) |
| 359 | + padding_data = jax.random.normal(jax.random.PRNGKey(135), inputs.shape) |
| 360 | + inputs = jnp.where(segment_ids[..., None] == 0, padding_data, inputs) |
| 361 | + |
| 362 | + cfg.post_downsample = Conv1DWithPadding.default_config().set( |
| 363 | + window=strides, strides=strides, padding=((strides - 1, 0),) |
| 364 | + ) |
| 365 | + layer = cfg.set(name="test").instantiate(parent=None) |
| 366 | + layer_params = layer.initialize_parameters_recursively(init_key) |
| 367 | + output_batch, _ = F( |
| 368 | + layer, |
| 369 | + inputs=dict(inputs=inputs, segment_ids=segment_ids), |
| 370 | + is_training=is_training, |
| 371 | + prng_key=prng_key, |
| 372 | + state=layer_params, |
| 373 | + ) |
| 374 | + outputs, output_segment_ids = output_batch["outputs"], output_batch["segment_ids"] |
| 375 | + self.assertSequenceEqual( |
| 376 | + outputs.shape, (batch_size, (seq_len + strides - 1) // strides, output_dim) |
| 377 | + ) |
| 378 | + self.assertSequenceEqual( |
| 379 | + output_segment_ids.shape, (batch_size, (seq_len + strides - 1) // strides) |
| 380 | + ) |
| 381 | + |
266 | 382 |
|
267 | 383 | class ASREncoderTest(TestCase): |
268 | 384 | """Tests ASREncoder.""" |
|
0 commit comments