Skip to content

Commit 9ddd2b3

Browse files
audio-based TN fix for empty pred_text/text (#92)
* fix for empty pred_text Signed-off-by: Evelina <ebakhturina@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add unittests Signed-off-by: Evelina <ebakhturina@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix path Signed-off-by: Evelina <ebakhturina@nvidia.com> * fix path Signed-off-by: Evelina <ebakhturina@nvidia.com> * fix pytest Signed-off-by: Evelina <ebakhturina@nvidia.com> --------- Signed-off-by: Evelina <ebakhturina@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1fdfff5 commit 9ddd2b3

5 files changed

Lines changed: 103 additions & 39 deletions

File tree

Jenkinsfile

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ pipeline {
5757
}
5858

5959

60-
6160
stage('L0: Create EN TN/ITN Grammars') {
6261
when {
6362
anyOf {
@@ -67,7 +66,11 @@ pipeline {
6766
}
6867
failFast true
6968
parallel {
70-
69+
stage('L0: Test utils') {
70+
steps {
71+
sh 'CUDA_VISIBLE_DEVICES="" pytest tests/nemo_text_processing/audio_based_utils/ --cpu'
72+
}
73+
}
7174
stage('L0: En TN grammars') {
7275
steps {
7376
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/text_normalization/normalize.py --text="1" --cache_dir ${EN_TN_CACHE}'

nemo_text_processing/text_normalization/normalize_with_audio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def normalize(
141141
Returns:
142142
normalized text options (usually there are multiple ways of normalizing a given semiotic class)
143143
"""
144-
if pred_text is None or self.tagger is None:
144+
if pred_text is None or pred_text == "" or self.tagger is None:
145145
return self.normalize_non_deterministic(
146146
text=text, n_tagged=n_tagged, punct_post_process=punct_post_process, verbose=verbose
147147
)
@@ -156,6 +156,7 @@ def normalize(
156156
semiotic_spans, pred_text_spans, norm_spans, text_with_span_tags_list, masked_idx_list = get_alignment(
157157
text, det_norm, pred_text, verbose=False
158158
)
159+
159160
sem_tag_idx = 0
160161
for cur_semiotic_span, cur_pred_text, cur_deter_norm in zip(semiotic_spans, pred_text_spans, norm_spans):
161162
if len(cur_semiotic_span) == 0:

nemo_text_processing/text_normalization/utils_audio_based.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424

2525
def _get_alignment(a: str, b: str) -> Dict:
2626
"""
27-
28-
Construscts alignment between a and b
27+
Constructs alignment between a and b
2928
3029
Returns:
3130
a dictionary, where keys are a's word index and values is a Tuple that contains span from b, and whether it
@@ -62,7 +61,7 @@ def _get_alignment(a: str, b: str) -> Dict:
6261

6362
def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, norm: str, pred_text: str, verbose=False):
6463
"""
65-
Adjust alignement boundaries by taking norm--raw texts and norm--pred_text alignements, and creating raw-pred_text
64+
Adjust alignment boundaries by taking norm--raw texts and norm--pred_text alignments, and creating raw-pred_text alignment
6665
alignment.
6766
6867
norm_raw_diffs: output of _get_alignment(norm, raw)
@@ -92,10 +91,12 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
9291
raw_text_mask_idx: [1, 4]
9392
"""
9493

95-
adjusted = []
94+
raw_pred_spans = []
9695
word_id = 0
9796
while word_id < len(norm.split()):
9897
norm_raw, norm_pred = norm_raw_diffs[word_id], norm_pred_diffs[word_id]
98+
# if there is a mismatch in norm_raw and norm_pred, expand the boundaries of the shortest mismatch to align with the longest one
99+
# e.g., norm_raw = (1, 2, 'match') norm_pred = (1, 5, 'non-match') => expand norm_raw until the next matching sequence or the end of string to align with norm_pred
99100
if (norm_raw[2] == MATCH and norm_pred[2] == NONMATCH) or (norm_raw[2] == NONMATCH and norm_pred[2] == MATCH):
100101
mismatched_id = word_id
101102
non_match_raw_start = norm_raw[0]
@@ -114,20 +115,21 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
114115
if not done:
115116
non_match_raw_end = len(raw.split())
116117
non_match_pred_end = len(pred_text.split())
117-
adjusted.append(
118+
raw_pred_spans.append(
118119
(
119120
mismatched_id,
120121
(non_match_raw_start, non_match_raw_end, NONMATCH),
121122
(non_match_pred_start, non_match_pred_end, NONMATCH),
122123
)
123124
)
124125
else:
125-
adjusted.append((word_id, norm_raw, norm_pred))
126+
raw_pred_spans.append((word_id, norm_raw, norm_pred))
126127
word_id += 1
127128

128-
adjusted2 = []
129+
# aggregate neighboring spans with the same status
130+
spans_merged_neighbors = []
129131
last_status = None
130-
for idx, item in enumerate(adjusted):
132+
for idx, item in enumerate(raw_pred_spans):
131133
if last_status is None:
132134
last_status = item[1][2]
133135
raw_start = item[1][0]
@@ -139,7 +141,7 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
139141
raw_end = item[1][1]
140142
pred_text_end = item[2][1]
141143
else:
142-
adjusted2.append(
144+
spans_merged_neighbors.append(
143145
[[norm_span_start, item[0]], [raw_start, raw_end], [pred_text_start, pred_text_end], last_status]
144146
)
145147
last_status = item[1][2]
@@ -152,13 +154,13 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
152154
if last_status == item[1][2]:
153155
raw_end = item[1][1]
154156
pred_text_end = item[2][1]
155-
adjusted2.append(
157+
spans_merged_neighbors.append(
156158
[[norm_span_start, item[0]], [raw_start, raw_end], [pred_text_start, pred_text_end], last_status]
157159
)
158160
else:
159-
adjusted2.append(
161+
spans_merged_neighbors.append(
160162
[
161-
[adjusted[idx - 1][0], len(norm.split())],
163+
[raw_pred_spans[idx - 1][0], len(norm.split())],
162164
[item[1][0], len(raw.split())],
163165
[item[2][0], len(pred_text.split())],
164166
item[1][2],
@@ -171,10 +173,10 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
171173

172174
# increase boundaries between raw and pred_text if some spans contain empty pred_text
173175
extended_spans = []
174-
adjusted3 = []
176+
raw_norm_spans_corrected_for_pred_text = []
175177
idx = 0
176-
while idx < len(adjusted2):
177-
item = adjusted2[idx]
178+
while idx < len(spans_merged_neighbors):
179+
item = spans_merged_neighbors[idx]
178180

179181
cur_semiotic = " ".join(raw_list[item[1][0] : item[1][1]])
180182
cur_pred_text = " ".join(pred_text_list[item[2][0] : item[2][1]])
@@ -186,8 +188,8 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
186188
# if cur_pred_text is an empty string
187189
if item[2][0] == item[2][1]:
188190
# for the last item
189-
if idx == len(adjusted2) - 1 and len(adjusted3) > 0:
190-
last_item = adjusted3[-1]
191+
if idx == len(spans_merged_neighbors) - 1 and len(raw_norm_spans_corrected_for_pred_text) > 0:
192+
last_item = raw_norm_spans_corrected_for_pred_text[-1]
191193
last_item[0][1] = item[0][1]
192194
last_item[1][1] = item[1][1]
193195
last_item[2][1] = item[2][1]
@@ -196,29 +198,31 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
196198
raw_start, raw_end = item[0]
197199
norm_start, norm_end = item[1]
198200
pred_start, pred_end = item[2]
199-
while idx < len(adjusted2) - 1 and not ((pred_end - pred_start) > 2 and adjusted2[idx][-1] == MATCH):
201+
while idx < len(spans_merged_neighbors) - 1 and not (
202+
(pred_end - pred_start) > 2 and spans_merged_neighbors[idx][-1] == MATCH
203+
):
200204
idx += 1
201-
raw_end = adjusted2[idx][0][1]
202-
norm_end = adjusted2[idx][1][1]
203-
pred_end = adjusted2[idx][2][1]
205+
raw_end = spans_merged_neighbors[idx][0][1]
206+
norm_end = spans_merged_neighbors[idx][1][1]
207+
pred_end = spans_merged_neighbors[idx][2][1]
204208
cur_item = [[raw_start, raw_end], [norm_start, norm_end], [pred_start, pred_end], NONMATCH]
205-
adjusted3.append(cur_item)
206-
extended_spans.append(len(adjusted3) - 1)
209+
raw_norm_spans_corrected_for_pred_text.append(cur_item)
210+
extended_spans.append(len(raw_norm_spans_corrected_for_pred_text) - 1)
207211
idx += 1
208212
else:
209-
adjusted3.append(item)
213+
raw_norm_spans_corrected_for_pred_text.append(item)
210214
idx += 1
211215

212216
semiotic_spans = []
213217
norm_spans = []
214218
pred_texts = []
215219
raw_text_masked = ""
216-
for idx, item in enumerate(adjusted3):
220+
for idx, item in enumerate(raw_norm_spans_corrected_for_pred_text):
217221
cur_semiotic = " ".join(raw_list[item[1][0] : item[1][1]])
218222
cur_pred_text = " ".join(pred_text_list[item[2][0] : item[2][1]])
219223
cur_norm_span = " ".join(norm_list[item[0][0] : item[0][1]])
220224

221-
if idx == len(adjusted3) - 1:
225+
if idx == len(raw_norm_spans_corrected_for_pred_text) - 1:
222226
cur_norm_span = " ".join(norm_list[item[0][0] : len(norm_list)])
223227
if (item[-1] == NONMATCH and cur_semiotic != cur_norm_span) or (idx in extended_spans):
224228
raw_text_masked += " " + SEMIOTIC_TAG
@@ -233,24 +237,31 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
233237

234238
if verbose:
235239
print("+" * 50)
236-
print("adjusted:")
237-
for item in adjusted2:
240+
print("raw_pred_spans:")
241+
for item in spans_merged_neighbors:
238242
print(f"{raw.split()[item[1][0]: item[1][1]]} -- {pred_text.split()[item[2][0]: item[2][1]]}")
239243

240244
print("+" * 50)
241-
print("adjusted2:")
242-
for item in adjusted2:
245+
print("spans_merged_neighbors:")
246+
for item in spans_merged_neighbors:
243247
print(f"{raw.split()[item[1][0]: item[1][1]]} -- {pred_text.split()[item[2][0]: item[2][1]]}")
244248
print("+" * 50)
245-
print("adjusted3:")
246-
for item in adjusted3:
249+
print("raw_norm_spans_corrected_for_pred_text:")
250+
for item in raw_norm_spans_corrected_for_pred_text:
247251
print(f"{raw.split()[item[1][0]: item[1][1]]} -- {pred_text.split()[item[2][0]: item[2][1]]}")
248252
print("+" * 50)
249253

250254
return semiotic_spans, pred_texts, norm_spans, raw_text_masked_list, raw_text_mask_idx
251255

252256

253-
def get_alignment(raw, norm, pred_text, verbose: bool = False):
257+
def get_alignment(raw: str, norm: str, pred_text: str, verbose: bool = False):
258+
"""
259+
Aligns raw text with deterministically normalized text and ASR output, finds semiotic spans
260+
"""
261+
for value in [raw, norm, pred_text]:
262+
if value is None or value == "":
263+
return [], [], [], [], []
264+
254265
norm_pred_diffs = _get_alignment(norm, pred_text)
255266
norm_raw_diffs = _get_alignment(norm, raw)
256267

@@ -271,8 +282,9 @@ def get_alignment(raw, norm, pred_text, verbose: bool = False):
271282

272283

273284
if __name__ == "__main__":
274-
raw = 'This is #4 ranking on G.S.K.T.'
275-
pred_text = 'this iss for ranking on g k p'
285+
raw = 'This is a #4 ranking on G.S.K.T.'
286+
pred_text = 'this iss p k for ranking on g k p'
276287
norm = 'This is nubmer four ranking on GSKT'
277288

278-
get_alignment(raw, norm, pred_text, True)
289+
output = get_alignment(raw, norm, pred_text, True)
290+
print(output)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from nemo_text_processing.text_normalization.utils_audio_based import get_alignment
17+
18+
19+
class TestAudioBasedTNUtils:
20+
@pytest.mark.run_only_on('CPU')
21+
@pytest.mark.unit
22+
def test_default(self):
23+
raw = 'This is #4 ranking on G.S.K.T.'
24+
pred_text = 'this iss for ranking on g k p'
25+
norm = 'This is nubmer four ranking on GSKT'
26+
27+
output = get_alignment(raw, norm, pred_text, True)
28+
reference = (
29+
['is #4', 'G.S.K.T.'],
30+
['iss for', 'g k p'],
31+
['is nubmer four', 'GSKT'],
32+
['This', '[SEMIOTIC_SPAN]', 'ranking', 'on', '[SEMIOTIC_SPAN]'],
33+
[1, 4],
34+
)
35+
assert output == reference

0 commit comments

Comments
 (0)