Skip to content

Commit b3bef65

Browse files
kmxyvbwangkuiyi
authored andcommitted
Allow multiple convolution layers instead of a fixed two-layer setup in subsampler
GitOrigin-RevId: 0115da7aeac05b5d9661a334ee55133be968b4ef
1 parent 98e654d commit b3bef65

5 files changed

Lines changed: 256 additions & 33 deletions

File tree

axlearn/audio/encoder_asr.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from axlearn.common.base_layer import BaseLayer
2222
from axlearn.common.config import REQUIRED, Required, config_class
2323
from axlearn.common.conformer import RepeatedConformerLayer
24+
from axlearn.common.convolution import Conv1DWithPadding
2425
from axlearn.common.ein_ops import rearrange
2526
from axlearn.common.layers import Dropout, Linear
2627
from axlearn.common.module import Module, nowrap
@@ -139,7 +140,9 @@ class Config(BaseLayer.Config):
139140
# Dropout applied after projection.
140141
dropout: Dropout.Config = Dropout.default_config()
141142
# Positional embeddings.
142-
pos_emb: BaseLayer.Config = SinusoidalPositionalEmbedding.default_config()
143+
pos_emb: Optional[BaseLayer.Config] = SinusoidalPositionalEmbedding.default_config()
144+
# Post Convolution downsample
145+
post_downsample: Optional[Conv1DWithPadding.Config] = None
143146
# Context layers, e.g. a conformer stack.
144147
context: BaseLayer.Config = RepeatedConformerLayer.default_config()
145148

@@ -150,8 +153,14 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
150153
"input_linear", cfg.input_linear.set(input_dim=cfg.input_dim, output_dim=cfg.output_dim)
151154
)
152155
self._add_child("dropout", cfg.dropout)
153-
self._add_child("pos_emb", cfg.pos_emb.set(dim=cfg.output_dim))
156+
if cfg.pos_emb is not None:
157+
self._add_child("pos_emb", cfg.pos_emb.set(dim=cfg.output_dim))
154158
self._add_child("context", cfg.context.set(input_dim=cfg.output_dim))
159+
if cfg.post_downsample is not None:
160+
self._add_child(
161+
"post_downsample",
162+
cfg.post_downsample.set(input_dim=cfg.output_dim, output_dim=cfg.output_dim),
163+
)
155164

156165
def forward(self, inputs: Tensor, *, segment_ids: Tensor) -> dict[str, Tensor]:
157166
"""Computes context features.
@@ -172,7 +181,8 @@ def forward(self, inputs: Tensor, *, segment_ids: Tensor) -> dict[str, Tensor]:
172181

173182
if isinstance(cfg.context, RepeatedConformerLayer.Config):
174183
positions = _segment_relative_positions(segment_ids)
175-
x = x + self.pos_emb(positions)
184+
if cfg.pos_emb is not None:
185+
x = x + self.pos_emb(positions)
176186
x = self.context(inputs=x, segment_ids=segment_ids)
177187
elif isinstance(cfg.context, RepeatedTransformerLayer.Config):
178188
# We don't need to do add pos_emb for transformer block
@@ -188,6 +198,9 @@ def forward(self, inputs: Tensor, *, segment_ids: Tensor) -> dict[str, Tensor]:
188198
activations=x,
189199
activation_paddings=segment_ids == 0,
190200
)
201+
if cfg.post_downsample is not None:
202+
x, _ = self.post_downsample(x=x, paddings=segment_ids == 0)
203+
segment_ids = self.post_downsample.conv_paddings(segment_ids)
191204
return dict(outputs=x * (segment_ids != 0)[..., None], segment_ids=segment_ids)
192205

193206

axlearn/audio/encoder_asr_test.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from axlearn.audio.test_utils import fake_audio
1818
from axlearn.common.attention import RepeatedTransformerLayer
19+
from axlearn.common.convolution import Conv1DWithPadding
1920
from axlearn.common.kv_cache.sliding_window_kv_cache import enable_sliding_window_attention
2021
from axlearn.common.module import functional as F
2122
from axlearn.common.test_utils import TestCase
@@ -263,6 +264,121 @@ def test_transformer(self, is_training: bool) -> None:
263264
output_collections.summaries["activations/speech_context_norm"].weight, weights
264265
)
265266

267+
@parameterized.parameters([True, False])
268+
@pytest.mark.fp64
269+
def test_conformer_without_pos_emb(self, is_training: bool):
270+
"""Tests SpeechContextNetwork with RepeatedConformerLayer when pos_emb is None."""
271+
input_dim, output_dim, dropout_rate, num_layers = 32, 16, 0.2, 2
272+
273+
cfg = SpeechContextNetwork.default_config().set(
274+
input_dim=input_dim, output_dim=output_dim, dtype=jnp.float64
275+
)
276+
cfg.dropout.rate = dropout_rate
277+
cfg.context.num_layers = num_layers
278+
cfg.context.layer.self_attention.attention.num_heads = 4
279+
cfg.context.layer.lconv.dropout.rate = dropout_rate
280+
cfg.pos_emb = None
281+
282+
prng_key = jax.random.PRNGKey(123)
283+
prng_key, init_key, input_key, length_key = jax.random.split(prng_key, num=4)
284+
layer = cfg.set(name="test").instantiate(parent=None)
285+
layer_params = layer.initialize_parameters_recursively(init_key)
286+
287+
# pos_emb should be absent from parameters when disabled.
288+
self.assertNotIn("pos_emb", layer.children)
289+
self.assertNotIn("pos_emb", layer_params)
290+
291+
# Generate inputs.
292+
batch_size, seq_len = 4, 10
293+
inputs = jnp.tile(
294+
jax.random.normal(input_key, [batch_size // 2, seq_len, input_dim]), [2, 1, 1]
295+
)
296+
lengths = jnp.tile(
297+
jax.random.randint(length_key, shape=[batch_size // 2, 1], minval=0, maxval=seq_len),
298+
[2, 1],
299+
)
300+
segment_ids = (jnp.arange(seq_len)[None, :] < lengths).astype(jnp.int32)
301+
padding_data = jax.random.normal(jax.random.PRNGKey(135), inputs.shape)
302+
inputs = jnp.where(segment_ids[..., None] == 0, padding_data, inputs)
303+
304+
output_batch, _ = F(
305+
layer,
306+
inputs=dict(inputs=inputs, segment_ids=segment_ids),
307+
is_training=is_training,
308+
prng_key=prng_key,
309+
state=layer_params,
310+
)
311+
outputs, output_segment_ids = output_batch["outputs"], output_batch["segment_ids"]
312+
self.assertSequenceEqual(outputs.shape, (batch_size, seq_len, output_dim))
313+
self.assertTrue(jnp.all(output_segment_ids == segment_ids))
314+
315+
# If is_training, outputs differ due to dropout; otherwise identical despite padding noise.
316+
self.assertEqual(not is_training, bool(jnp.allclose(outputs[:2], outputs[2:])))
317+
318+
@parameterized.parameters([2, 3, 4, 5])
319+
def test_post_downsample(self, strides) -> None:
320+
"""Test the code branch with RepeatedTransformerLayer as context layer.
321+
322+
Args:
323+
is_training: Whether the is_training code path is tested.
324+
"""
325+
is_training = True
326+
input_dim, output_dim, dropout_rate, num_layers = 32, 16, 0.2, 2
327+
num_heads = 8
328+
hidden_dim = 4 * input_dim
329+
330+
cfg = SpeechContextNetwork.default_config().set(
331+
input_dim=input_dim, output_dim=output_dim, dtype=jnp.float64
332+
)
333+
cfg.dropout.rate = dropout_rate
334+
cfg.context = RepeatedTransformerLayer.default_config().set(num_layers=num_layers)
335+
attention = cfg.context.layer.self_attention.attention
336+
attention.num_heads = num_heads
337+
attention = enable_sliding_window_attention(attention, sliding_window_size=3)
338+
cfg.context.layer.self_attention.attention = attention
339+
# Dropout in transformer
340+
cfg.context.layer.self_attention.dropout.rate = dropout_rate
341+
cfg.context.layer.feed_forward.set(
342+
hidden_dim=hidden_dim,
343+
)
344+
345+
# Initialize layer parameters.
346+
prng_key = jax.random.PRNGKey(123)
347+
prng_key, init_key, input_key, length_key = jax.random.split(prng_key, num=4)
348+
349+
# Generate inputs.
350+
batch_size, seq_len = 4, 10
351+
inputs = jnp.tile(
352+
jax.random.normal(input_key, [batch_size // 2, seq_len, input_dim]), [2, 1, 1]
353+
)
354+
lengths = jnp.tile(
355+
jax.random.randint(length_key, shape=[batch_size // 2, 1], minval=0, maxval=seq_len),
356+
[2, 1],
357+
)
358+
segment_ids = (jnp.arange(seq_len)[None, :] < lengths).astype(jnp.int32)
359+
padding_data = jax.random.normal(jax.random.PRNGKey(135), inputs.shape)
360+
inputs = jnp.where(segment_ids[..., None] == 0, padding_data, inputs)
361+
362+
cfg.post_downsample = Conv1DWithPadding.default_config().set(
363+
window=strides, strides=strides, padding=((strides - 1, 0),)
364+
)
365+
layer = cfg.set(name="test").instantiate(parent=None)
366+
layer_params = layer.initialize_parameters_recursively(init_key)
367+
output_batch, _ = F(
368+
layer,
369+
inputs=dict(inputs=inputs, segment_ids=segment_ids),
370+
is_training=is_training,
371+
prng_key=prng_key,
372+
state=layer_params,
373+
)
374+
outputs, output_segment_ids = output_batch["outputs"], output_batch["segment_ids"]
375+
self.assertSequenceEqual(
376+
outputs.shape, (batch_size, (seq_len + strides - 1) // strides, output_dim)
377+
)
378+
self.assertSequenceEqual(
379+
output_segment_ids.shape, (batch_size, (seq_len + strides - 1) // strides)
380+
)
381+
266382

267383
class ASREncoderTest(TestCase):
268384
"""Tests ASREncoder."""

axlearn/audio/subsamplers.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class Config(BaseLayer.Config):
3333
# Output channel dim.
3434
output_dim: Required[int] = REQUIRED
3535
# Hidden dim of the conv layers. If None, defaults to output_dim.
36-
hidden_dim: Optional[int] = None
36+
hidden_dim: int | list[int] | None = None
3737
# Configures both of the convolutions.
3838
conv: Conv2DWith1DPadding.Config = Conv2DWith1DPadding.default_config().set(
3939
window=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))
@@ -46,27 +46,41 @@ class Config(BaseLayer.Config):
4646
# activation to only one convolution).
4747
activation: Optional[Union[Optional[str], tuple[Optional[str], Optional[str]]]] = None
4848

49+
@classmethod
50+
def get_hidden_dim_list(cls, cfg: Config) -> list[int]:
51+
if isinstance(cfg.hidden_dim, int):
52+
hidden_dim = [cfg.hidden_dim]
53+
elif cfg.hidden_dim is None:
54+
hidden_dim = [cfg.output_dim]
55+
else:
56+
hidden_dim = list(cfg.hidden_dim)
57+
return hidden_dim
58+
4959
def __init__(self, cfg: Config, *, parent: Optional[Module]):
5060
super().__init__(cfg, parent=parent)
5161
cfg = self.config
5262

5363
activation = cfg.activation
64+
hidden_dim = [cfg.input_dim] + self.get_hidden_dim_list(cfg) + [cfg.output_dim]
65+
self.num_layers = len(hidden_dim) - 1
5466
if not isinstance(activation, (list, tuple)):
55-
activation = (activation, activation)
56-
if len(activation) != 2 or not all(x is None or isinstance(x, str) for x in activation):
67+
activation = [activation] * self.num_layers
68+
if len(activation) != self.num_layers or not all(
69+
x is None or isinstance(x, str) for x in activation
70+
):
5771
raise ValueError(
58-
"Expected cfg.activation to be None, a string, or pair of string | None, "
72+
"Expected cfg.activation to be None, a string, or list/tuple of string | None, "
5973
f"got: {cfg.activation}"
6074
)
6175
self._activation = [None if act is None else get_activation_fn(act) for act in activation]
6276

63-
hidden_dim = cfg.hidden_dim or cfg.output_dim
64-
self._add_child("conv1", cfg.conv.set(input_dim=cfg.input_dim, output_dim=hidden_dim))
65-
self._add_child("conv2", cfg.conv.set(input_dim=hidden_dim, output_dim=cfg.output_dim))
66-
77+
for i in range(1, len(hidden_dim)):
78+
self._add_child(
79+
f"conv{i}", cfg.conv.set(input_dim=hidden_dim[i - 1], output_dim=hidden_dim[i])
80+
)
6781
if cfg.norm:
68-
self._add_child("norm1", cfg.norm.set(input_dim=hidden_dim))
69-
self._add_child("norm2", cfg.norm.set(input_dim=cfg.output_dim))
82+
for i in range(1, len(hidden_dim)):
83+
self._add_child(f"norm{i}", cfg.norm.set(input_dim=hidden_dim[i]))
7084

7185
def output_shape(self, *, input_shape: Sequence[Optional[int]]):
7286
"""Computes the output shape after subsampling.
@@ -90,9 +104,9 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]):
90104
f"input_shape[-1] = {input_shape[-1]} does not match "
91105
f"cfg.input_dim = {cfg.input_dim}."
92106
)
93-
conv1_shape = self.conv1.output_shape(input_shape=input_shape)
94-
conv2_shape = self.conv2.output_shape(input_shape=conv1_shape)
95-
return conv2_shape
107+
for i in range(1, self.num_layers + 1):
108+
input_shape = self._children[f"conv{i}"].output_shape(input_shape=input_shape)
109+
return input_shape
96110

97111
def forward(self, inputs: Tensor, *, segment_ids: Tensor) -> dict[str, Tensor]:
98112
"""Subsamples the speech.
@@ -112,20 +126,14 @@ def forward(self, inputs: Tensor, *, segment_ids: Tensor) -> dict[str, Tensor]:
112126
self._add_activation_summary(
113127
name="subsampler_inputs", activations=inputs, activation_paddings=paddings
114128
)
115-
x, paddings = self.conv1(inputs, paddings=paddings)
116-
segment_ids = self.conv1.conv_paddings(segment_ids)
117-
if cfg.norm:
118-
x = self.norm1(x, segment_ids=segment_ids)
119-
if self._activation[0]:
120-
x = self._activation[0](x)
121-
122-
x, paddings = self.conv2(x, paddings=paddings)
123-
segment_ids = self.conv2.conv_paddings(segment_ids)
124-
if cfg.norm:
125-
x = self.norm2(x, segment_ids=segment_ids)
126-
if self._activation[1]:
127-
x = self._activation[1](x)
128-
129+
x = inputs
130+
for i in range(1, self.num_layers + 1):
131+
x, paddings = self._children[f"conv{i}"](x, paddings=paddings)
132+
segment_ids = self._children[f"conv{i}"].conv_paddings(segment_ids)
133+
if cfg.norm:
134+
x = self._children[f"norm{i}"](x, segment_ids=segment_ids)
135+
if self._activation[i - 1]:
136+
x = self._activation[i - 1](x)
129137
self._add_activation_summary(
130138
name="subsampler_outputs", activations=x, activation_paddings=paddings
131139
)

axlearn/audio/subsamplers_test.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ class ConvSubSamplerTest(TestCase):
2121
"""Tests ConvSubSampler."""
2222

2323
@parameterized.parameters(
24-
dict(activation=("nn.tanh", "nn.relu", "nn.silu"), expected=ValueError("pair of string")),
25-
dict(activation=("nn.tanh",), expected=ValueError("pair of string")),
24+
dict(
25+
activation=("nn.tanh", "nn.relu", "nn.silu"),
26+
expected=ValueError("list/tuple of string"),
27+
),
28+
dict(activation=("nn.tanh",), expected=ValueError("list/tuple of string")),
2629
dict(activation="nn.tanh"), # Single value is broadcasted.
2730
dict(activation=("nn.tanh", None)), # Some of the values can be None.
2831
dict(activation=(None, None)), # Some of the values can be None.
@@ -189,6 +192,86 @@ def test_segment_ids(
189192
self.assertEqual(tuple(subsampled_shape), outputs["outputs"].shape)
190193
self.assertEqual(tuple(subsampled_shape)[:2], outputs["segment_ids"].shape)
191194

195+
@parameterized.parameters(
196+
dict(
197+
window=5,
198+
stride=2,
199+
conv_padding=(1, 1),
200+
hidden_dim=[12, 16],
201+
output_dim=8,
202+
activation=("nn.tanh", None, None),
203+
),
204+
dict(
205+
window=3,
206+
stride=2,
207+
conv_padding=(1, 0),
208+
hidden_dim=[12, 16, 32],
209+
output_dim=8,
210+
activation=("nn.tanh", None, None, None),
211+
),
212+
)
213+
def test_multi_layers(
214+
self,
215+
window: int,
216+
stride: int,
217+
conv_padding: tuple[int, int],
218+
output_dim: int,
219+
hidden_dim: Optional[int] = None,
220+
activation: Optional[Union[Optional[str], tuple[Optional[str], Optional[str]]]] = None,
221+
):
222+
"""Tests that padding inputs do not affect outputs."""
223+
batch_size, num_frames, num_filters, input_dim = 4, 10, 80, 1
224+
cfg = ConvSubSampler.default_config().set(
225+
input_dim=input_dim,
226+
output_dim=output_dim,
227+
hidden_dim=hidden_dim,
228+
activation=activation,
229+
)
230+
cfg.conv.window = (window, window)
231+
cfg.conv.strides = (stride, stride)
232+
cfg.conv.padding = (conv_padding, conv_padding)
233+
cfg.norm = BatchNorm.default_config()
234+
235+
# Initialize layer parameters.
236+
layer = cfg.set(name="test").instantiate(parent=None)
237+
prng_key = jax.random.PRNGKey(123)
238+
prng_key, init_key, data_key = jax.random.split(prng_key, num=3)
239+
layer_params = layer.initialize_parameters_recursively(init_key)
240+
241+
hidden_dim = [cfg.input_dim] + hidden_dim + [cfg.output_dim]
242+
self.assertEqual(
243+
{
244+
f"conv{i + 1}": dict(
245+
weight=(window, window, hidden_dim[i], hidden_dim[i + 1]),
246+
bias=(hidden_dim[i + 1],),
247+
)
248+
for i in range(len(hidden_dim) - 1)
249+
}
250+
| {
251+
f"norm{i + 1}": dict(
252+
bias=(hidden_dim[i + 1],),
253+
moving_mean=(hidden_dim[i + 1],),
254+
moving_variance=(hidden_dim[i + 1],),
255+
scale=(hidden_dim[i + 1],),
256+
)
257+
for i in range(len(hidden_dim) - 1)
258+
},
259+
utils.shapes(layer_params),
260+
)
261+
262+
inputs_shape = [batch_size, num_frames, num_filters, input_dim]
263+
inputs = jax.random.normal(key=data_key, shape=inputs_shape) * 10.0
264+
segment_ids = jnp.ones([batch_size, num_frames])
265+
outputs, _ = F(
266+
layer,
267+
inputs=dict(inputs=inputs, segment_ids=segment_ids),
268+
is_training=True,
269+
prng_key=prng_key,
270+
state=layer_params,
271+
)
272+
expected_shape = layer.output_shape(input_shape=tuple(inputs.shape))
273+
self.assertEqual(tuple(expected_shape), outputs["outputs"].shape)
274+
192275
@parameterized.parameters(jnp.float32, jnp.bfloat16)
193276
def test_activation_summaries(self, dtype):
194277
"""Tests that activation summaries behave as expected."""

0 commit comments

Comments
 (0)