@@ -111,7 +111,11 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin):
111111 att_context_probs (List[float]): a list of probabilities of each one of the att_context_size
112112 when a list of them is passed. If not specified, uniform distribution is being used.
113113 Defaults to None
114- att_context_style (str): 'regular' or 'chunked_limited'.
114+ att_chunk_context_size (List[List[int]]): specifies the context sizes for unified (offline/streaming) ASR training.
115+ It defines the range of Left, Middle, and Right context sizes for the attention mechanism.
116+ At each streaming step, the context size is sampled from the range of Left, Middle, and Right context sizes.
117+ Example: att_chunk_context_size=[[70],[1,2,7,13],[0,1,3,7,13]] -> sampling -> [70, 2, 3] -> attention mask generation
118+ att_context_style (str): 'regular', 'chunked_limited', or 'chunked_limited_with_rc'.
115119 Defaults to 'regular'
116120 xscaling (bool): enables scaling the inputs to the multi-headed attention layers by `sqrt(d_model)`.
117121 Defaults to True.
@@ -126,6 +130,9 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin):
126130 `None` means `[(conv_kernel_size-1)//2`, `(conv_kernel_size-1)//2]`, and 'causal' means
127131 `[(conv_kernel_size-1), 0]`.
128132 Defaults to None.
133+ conv_context_style (str): 'regular' or 'dcc'
134+ DCC - Dynamic Chunked Convolution that is used for unified ASR training.
135+ Defaults to 'regular'.
129136 conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used.
130137 When enables, the left half of the convolution kernel would get masked in streaming cases.
131138 Defaults to False.
@@ -305,13 +312,15 @@ def __init__(
305312 n_heads = 4 ,
306313 att_context_size = None ,
307314 att_context_probs = None ,
315+ att_chunk_context_size = None ,
308316 att_context_style = 'regular' ,
309317 xscaling = True ,
310318 untie_biases = True ,
311319 pos_emb_max_len = 5000 ,
312320 conv_kernel_size = 31 ,
313321 conv_norm_type = 'batch_norm' ,
314322 conv_context_size = None ,
323+ conv_context_style = 'regular' ,
315324 use_bias = True ,
316325 dropout = 0.1 ,
317326 dropout_pre_encoder = 0.1 ,
@@ -346,6 +355,22 @@ def __init__(
346355 self .use_pytorch_sdpa_backends = use_pytorch_sdpa_backends
347356 self .sync_max_audio_length = sync_max_audio_length
348357
358+ assert conv_context_style in ["regular" , "dcc" ], f"Invalid conv_context_style: { conv_context_style } !"
359+ self .conv_context_style = conv_context_style
360+ self .conv_kernel_size = conv_kernel_size
361+
362+ # Setting up the att_chunk_context_size
363+ if att_chunk_context_size is not None :
364+ assert (
365+ att_context_style == "chunked_limited_with_rc"
366+ ), "att_chunk_context_size is only supported for chunked_limited_with_rc attention style!"
367+ assert (
368+ len (att_chunk_context_size ) == 3
369+ ), "att_chunk_context_size must have 3 elements: [left_context, chunk_size, right_context]"
370+ self .att_chunk_context_size = att_chunk_context_size
371+ else :
372+ self .att_chunk_context_size = None
373+
349374 # Setting up the att_context_size
350375 (
351376 self .att_context_size_all ,
@@ -716,7 +741,6 @@ def forward_internal(
716741 offset = offset ,
717742 device = audio_signal .device ,
718743 )
719-
720744 # saving tensors if required for interctc loss
721745 if self .is_access_enabled (getattr (self , "model_guid" , None )):
722746 if self .interctc_capture_at_layers is None :
@@ -816,6 +840,33 @@ def _create_masks(self, att_context_size, padding_length, max_audio_length, offs
816840 torch .le (diff_chunks , left_chunks_num ), torch .ge (diff_chunks , 0 )
817841 )
818842 att_mask = torch .logical_and (att_mask , chunked_limited_mask .unsqueeze (0 ))
843+ elif self .att_context_style == "chunked_limited_with_rc" and sum (att_context_size ) != - 3 :
844+ assert (
845+ len (att_context_size ) == 3
846+ ), "att_context_size must have 3 elements: [left_context, chunk_size, right_context]"
847+
848+ left_context_frames = att_context_size [0 ]
849+ chunk_size_frames = att_context_size [1 ]
850+ right_context_frames = att_context_size [2 ]
851+ assert chunk_size_frames >= 1 , "chunk_size_frames must be greater than 0!"
852+ # Calculate chunk index for each frame (which processing group it belongs to)
853+ frame_idx = torch .arange (0 , max_audio_length , dtype = torch .int , device = att_mask .device )
854+ chunk_idx = torch .div (frame_idx , chunk_size_frames , rounding_mode = "trunc" )
855+
856+ window_start = chunk_idx * chunk_size_frames - left_context_frames
857+ window_start = torch .maximum (window_start , torch .zeros_like (window_start ))
858+ window_end = chunk_idx * chunk_size_frames + chunk_size_frames - 1 + right_context_frames
859+
860+ window_end = torch .minimum (window_end , torch .full_like (window_end , max_audio_length - 1 ))
861+ # Create the mask: frame i can see frame j if window_start[i] <= j <= window_end[i]
862+ j_indices = frame_idx .unsqueeze (0 ) # [1, T]
863+ window_start_expanded = window_start .unsqueeze (1 ) # [T, 1]
864+ window_end_expanded = window_end .unsqueeze (1 ) # [T, 1]
865+
866+ chunked_limited_mask = torch .logical_and (
867+ j_indices >= window_start_expanded , j_indices <= window_end_expanded
868+ )
869+ att_mask = torch .logical_and (att_mask , chunked_limited_mask .unsqueeze (0 ))
819870 else :
820871 att_mask = None
821872
@@ -876,6 +927,9 @@ def _calc_context_sizes(
876927 else :
877928 att_context_size_all = [[- 1 , - 1 ]]
878929
930+ if att_context_style == "chunked_limited_with_rc" :
931+ att_context_size_all = [[- 1 , - 1 , - 1 ]]
932+
879933 if att_context_probs :
880934 if len (att_context_probs ) != len (att_context_size_all ):
881935 raise ValueError ("The size of the att_context_probs should be the same as att_context_size." )
@@ -955,6 +1009,9 @@ def setup_streaming_params(
9551009 elif self .att_context_style == "chunked_limited" :
9561010 lookahead_steps = att_context_size [1 ]
9571011 streaming_cfg .cache_drop_size = 0
1012+ elif self .att_context_style == "chunked_limited_with_rc" :
1013+ lookahead_steps = att_context_size [2 ] * self .n_layers + self .conv_context_size [1 ] * self .n_layers
1014+ streaming_cfg .cache_drop_size = 0
9581015 elif self .att_context_style == "regular" :
9591016 lookahead_steps = att_context_size [1 ] * self .n_layers + self .conv_context_size [1 ] * self .n_layers
9601017 streaming_cfg .cache_drop_size = lookahead_steps
0 commit comments