Skip to content

Commit 63b30fb

Browse files
Fix transcribe when nbest hypotheses are returned (NVIDIA-NeMo#13540)
* fix transcribe when nbest Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * minor fix Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * fix in process_aed_timestamp_outputs to return list Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <lilithgrigoryan@users.noreply.github.com> * minor fix Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <lilithgrigoryan@users.noreply.github.com> * clean up Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * clean up Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * clean up Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <lilithgrigoryan@users.noreply.github.com> * restore canary not from hf Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> --------- Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> Signed-off-by: lilithgrigoryan <lilithgrigoryan@users.noreply.github.com> Co-authored-by: lilithgrigoryan <lilithgrigoryan@users.noreply.github.com>
1 parent b912a85 commit 63b30fb

4 files changed

Lines changed: 123 additions & 56 deletions

File tree

nemo/collections/asr/parts/mixins/transcription.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
3030
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment, ChannelSelectorType
3131
from nemo.collections.asr.parts.utils import manifest_utils
32+
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
3233
from nemo.collections.common.data.utils import move_data_to_device
3334
from nemo.utils import logging, logging_mode
3435

35-
TranscriptionReturnType = Union[List[str], List['Hypothesis'], Tuple[List[str]], Tuple[List['Hypothesis']]]
36+
TranscriptionReturnType = Union[List[str], List[Hypothesis], Tuple[List[str]], Tuple[List[Hypothesis]]]
3637
GenericTranscriptionType = Union[List[Any], List[List[Any]], Tuple[Any], Tuple[List[Any]], Dict[str, List[Any]]]
3738

3839

@@ -273,18 +274,7 @@ def transcribe(
273274
if results is None:
274275
results = []
275276

276-
# if list of inner list of results, copy structure
277-
if isinstance(processed_outputs[0], list):
278-
for _ in processed_outputs:
279-
results.append([])
280-
281-
# If nested list structure
282-
if isinstance(processed_outputs[0], list):
283-
for i, processed_output in enumerate(processed_outputs):
284-
results[i].extend(processed_output)
285-
else:
286-
# If flat list structure
287-
results.extend(processed_outputs)
277+
results.extend(processed_outputs)
288278

289279
elif isinstance(processed_outputs, dict):
290280
# Create a results of the same type as each element in processed_outputs

nemo/collections/asr/parts/submodules/multitask_beam_decoding.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def forward(
220220
hypotheses = [Hypothesis(score=0.0, y_sequence=[], timestamp=[]) for _ in range(self.beam_size)]
221221
# Pack results into Hypotheses
222222
hypotheses = pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i])
223-
self.format_hypotheses(hypotheses, decoder_input_ids)
224223
packed_result.append(NBestHypotheses(hypotheses))
224+
self.format_hypotheses(packed_result, decoder_input_ids)
225225
else:
226226
beam_scores = [None for _ in range(len(best_hypo))]
227227
best_hypo = best_hypo.detach().cpu()
@@ -234,7 +234,9 @@ def forward(
234234

235235
return (packed_result,)
236236

237-
def format_hypotheses(self, packed_result: List[Hypothesis], decoder_input_ids: Union[torch.Tensor, None]) -> None:
237+
def format_hypotheses(
238+
self, packed_result: List[Hypothesis | NBestHypotheses], decoder_input_ids: Union[torch.Tensor, None]
239+
) -> None:
238240
"""
239241
For each hypothesis in the mini-batch:
240242
* Remove the decoder input ids (prompt) from the predictions
@@ -246,21 +248,28 @@ def format_hypotheses(self, packed_result: List[Hypothesis], decoder_input_ids:
246248
len(packed_result) == decoder_input_ids.shape[0]
247249
), f"Mismatching number of examples {len(packed_result)=} {decoder_input_ids.shape[0]=}"
248250
decoder_input_ids = decoder_input_ids.detach().cpu()
249-
for hyp, prefix in zip(packed_result, decoder_input_ids):
250-
assert (
251-
hyp.y_sequence[: prefix.shape[0]] == prefix
252-
).all(), f"The decoder input IDs were not found at the beginning of prediction: {hyp.y_sequence=} {prefix=})"
253-
hyp.y_sequence = hyp.y_sequence[prefix.shape[0] :]
254-
for hyp in packed_result:
255-
ids = hyp.y_sequence
256-
ids_len = ids.shape[0]
257-
pos = -1
258-
while ids[pos] == self.pad or ids[pos] == self.eos:
259-
pos -= 1
260-
if ids_len + pos == -1:
261-
break # empty sequence
262-
if pos < -1:
263-
hyp.y_sequence = ids[: pos + 1]
251+
252+
for h, prefix in zip(packed_result, decoder_input_ids):
253+
hypotheses = h.n_best_hypotheses if isinstance(h, NBestHypotheses) else [h]
254+
for hyp in hypotheses:
255+
assert (hyp.y_sequence[: prefix.shape[0]] == prefix).all(), (
256+
f"The decoder input IDs were not found at the beginning of prediction: "
257+
f"{hyp.y_sequence=} {prefix=}"
258+
)
259+
hyp.y_sequence = hyp.y_sequence[prefix.shape[0] :]
260+
261+
for h in packed_result:
262+
hyps = h.n_best_hypotheses if isinstance(h, NBestHypotheses) else [h]
263+
for hyp in hyps:
264+
ids = hyp.y_sequence
265+
ids_len = ids.shape[0]
266+
pos = -1
267+
while ids[pos] == self.pad or ids[pos] == self.eos:
268+
pos -= 1
269+
if ids_len + pos == -1:
270+
break # empty sequence
271+
if pos < -1:
272+
hyp.y_sequence = ids[: pos + 1]
264273

265274

266275
@dataclass

nemo/collections/asr/parts/utils/timestamp_utils.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,13 @@ def process_aed_timestamp_outputs(outputs, subsampling_factor: int = 1, window_s
2222
"""
2323
Processes AED timestamp outputs and extracts word-level timestamps.
2424
Args:
25-
outputs (list or Hypothesis): The hypothesis outputs to process. Can be a single Hypothesis object or a list of Hypothesis objects.
25+
outputs (Hypothesis, list of Hypotesis or list of list of Hypotesis): The hypothesis outputs to process. Can be a single Hypothesis object or a list of Hypothesis objects.
2626
subsampling_factor (int, optional): The subsampling factor used in the model. Default is 1.
2727
window_stride (float, optional): The window stride used in the model. Default is 0.01.
2828
Returns:
29-
list or Hypothesis: The processed hypothesis outputs with word-level timestamps added.
29+
list of list of Hypotesis: The processed hypothesis outputs with word-level timestamps added.
3030
"""
3131

32-
if outputs is None:
33-
return outputs
34-
35-
if isinstance(outputs, Hypothesis):
36-
outputs = [outputs]
37-
38-
if not isinstance(outputs[0], Hypothesis):
39-
raise ValueError(f"Expected Hypothesis object, got {type(outputs[0])}")
40-
4132
def extract_words_with_timestamps(text, subsampling_factor: int = 1, window_stride: float = 0.01):
4233
text = text.strip() # remove leading and trailing whitespaces - training data artifact
4334

@@ -77,24 +68,52 @@ def segments_offset_to_time(segments, window_stride, subsampling_factor):
7768
segment['end'] = segment['end_offset'] * window_stride * subsampling_factor
7869
return segments
7970

80-
for idx, hyp in enumerate(outputs):
71+
def process_hypothesis(hyp, subsampling_factor: int, window_stride: float):
72+
"""
73+
Processes a single Hypothesis object to extract timestamps.
74+
"""
8175
timestamp, text = extract_words_with_timestamps(hyp.text, subsampling_factor, window_stride)
76+
hyp.text = text
8277
if timestamp is not None:
83-
if len(outputs[idx].timestamp) == 0:
84-
outputs[idx].timestamp = {}
85-
outputs[idx].timestamp['char'] = [] # not supported for AED
86-
outputs[idx].timestamp['word'] = timestamp
87-
outputs[idx].text = text
78+
if len(hyp.timestamp) == 0:
79+
hyp.timestamp = {}
80+
81+
hyp.timestamp.update(
82+
{
83+
'word': timestamp,
84+
'segment': [],
85+
'char': [], # not supported for AED
86+
}
87+
)
88+
8889
segments = AbstractCTCDecoding._get_segment_offsets(timestamp, segment_delimiter_tokens=['.', '?', '!'])
89-
segments = segments_offset_to_time(segments, window_stride, subsampling_factor)
90-
outputs[idx].timestamp['segment'] = segments
90+
hyp.timestamp['segment'] = segments_offset_to_time(segments, window_stride, subsampling_factor)
9191
else:
92-
outputs[idx].text = text
93-
outputs[idx].timestamp = {}
94-
outputs[idx].timestamp['word'] = []
95-
outputs[idx].timestamp['segment'] = []
96-
outputs[idx].timestamp['char'] = []
97-
return outputs
92+
hyp.timestamp = {
93+
'word': [],
94+
'segment': [],
95+
'char': [],
96+
}
97+
98+
return hyp
99+
100+
if outputs is None:
101+
return outputs
102+
103+
if isinstance(outputs, Hypothesis):
104+
return [process_hypothesis(outputs, subsampling_factor, window_stride)]
105+
elif isinstance(outputs, list) and isinstance(outputs[0], Hypothesis):
106+
# list of Hypothesis
107+
return [process_hypothesis(hyp, subsampling_factor, window_stride) for hyp in outputs]
108+
elif isinstance(outputs, list) and isinstance(outputs[0], list) and isinstance(outputs[0][0], Hypothesis):
109+
# list of list of Hypothesis (for beam decoding)
110+
return [
111+
[process_hypothesis(hyp, subsampling_factor, window_stride) for hyp in hyps_list] for hyps_list in outputs
112+
]
113+
else:
114+
raise ValueError(
115+
f"Expected Hypothesis, list of Hypothesis or list of list of Hypothesis object, got {type(outputs)}"
116+
)
98117

99118

100119
def process_timestamp_outputs(outputs, subsampling_factor: int = 1, window_stride: float = 0.01):

tests/collections/asr/mixins/test_transcription.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
import copy
1616
import json
1717
import os
1818
from dataclasses import dataclass
1919
from typing import Any, Dict, List
2020

2121
import pytest
2222
import torch
23+
from omegaconf import open_dict
2324
from torch.utils.data import DataLoader, Dataset
2425

2526
from nemo.collections.asr.data.audio_to_text import _speech_collate_fn
@@ -366,6 +367,54 @@ def test_transcribe_dataloader(self, audio_files, fast_conformer_ctc_model):
366367
assert isinstance(outputs[0], Hypothesis)
367368
assert isinstance(outputs[1], Hypothesis)
368369

370+
@pytest.mark.unit
371+
def test_transcribe_return_nbest_rnnt(self, audio_files, fast_conformer_transducer_model):
372+
fast_conformer_transducer_model.eval()
373+
audio1, audio2 = audio_files
374+
375+
orig_decoding_config = copy.deepcopy(fast_conformer_transducer_model.cfg.decoding)
376+
377+
decoding_config = copy.deepcopy(fast_conformer_transducer_model.cfg.decoding)
378+
with open_dict(decoding_config):
379+
decoding_config["strategy"] = "malsd_batch"
380+
decoding_config["beam"]["beam_size"] = 4
381+
decoding_config["beam"]["return_best_hypothesis"] = False
382+
decoding_config["beam"]["allow_cuda_graphs"] = False
383+
fast_conformer_transducer_model.change_decoding_strategy(decoding_config)
384+
385+
outputs = fast_conformer_transducer_model.transcribe([audio1, audio2], batch_size=1, timestamps=False)
386+
387+
assert len(outputs) == 2
388+
assert all(len(output) >= 1 for output in outputs)
389+
assert all(isinstance(output, list) for output in outputs)
390+
assert all(isinstance(hyp, Hypothesis) for output in outputs for hyp in output)
391+
392+
# Reset the decoding strategy to original
393+
fast_conformer_transducer_model.change_decoding_strategy(orig_decoding_config)
394+
395+
@pytest.mark.unit
396+
def test_transcribe_return_nbest_canary(self, audio_files, canary_1b_flash):
397+
canary_1b_flash.eval()
398+
audio1, audio2 = audio_files
399+
400+
orig_decoding_config = copy.deepcopy(canary_1b_flash.cfg.decoding)
401+
402+
decoding_config = copy.deepcopy(canary_1b_flash.cfg.decoding)
403+
with open_dict(decoding_config):
404+
decoding_config["beam"]["beam_size"] = 4
405+
decoding_config["beam"]["return_best_hypothesis"] = False
406+
canary_1b_flash.change_decoding_strategy(decoding_config)
407+
408+
outputs = canary_1b_flash.transcribe([audio1, audio2], batch_size=1, timestamps=False)
409+
410+
assert len(outputs) == 2
411+
assert all(len(output) >= 1 for output in outputs)
412+
assert all(isinstance(output, list) for output in outputs)
413+
assert all(isinstance(hyp, Hypothesis) for output in outputs for hyp in output)
414+
415+
# Reset the decoding strategy to original
416+
canary_1b_flash.change_decoding_strategy(orig_decoding_config)
417+
369418
@pytest.mark.with_downloads()
370419
@pytest.mark.unit
371420
def test_timestamps_with_transcribe(self, audio_files, fast_conformer_ctc_model):

0 commit comments

Comments
 (0)