2424
2525def _get_alignment (a : str , b : str ) -> Dict :
2626 """
27-
28- Construscts alignment between a and b
27+ Constructs alignment between a and b
2928
3029 Returns:
3130 a dictionary, where keys are a's word index and values is a Tuple that contains span from b, and whether it
@@ -62,7 +61,7 @@ def _get_alignment(a: str, b: str) -> Dict:
6261
6362def adjust_boundaries (norm_raw_diffs : Dict , norm_pred_diffs : Dict , raw : str , norm : str , pred_text : str , verbose = False ):
6463 """
65- Adjust alignement boundaries by taking norm--raw texts and norm--pred_text alignements , and creating raw-pred_text
64+ Adjust alignment boundaries by taking norm--raw texts and norm--pred_text alignments , and creating raw-pred_text alignment
6665 alignment.
6766
6867 norm_raw_diffs: output of _get_alignment(norm, raw)
@@ -92,10 +91,12 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
9291 raw_text_mask_idx: [1, 4]
9392 """
9493
95- adjusted = []
94+ raw_pred_spans = []
9695 word_id = 0
9796 while word_id < len (norm .split ()):
9897 norm_raw , norm_pred = norm_raw_diffs [word_id ], norm_pred_diffs [word_id ]
98+ # if there is a mismatch in norm_raw and norm_pred, expand the boundaries of the shortest mismatch to align with the longest one
99+ # e.g., norm_raw = (1, 2, 'match') norm_pred = (1, 5, 'non-match') => expand norm_raw until the next matching sequence or the end of string to align with norm_pred
99100 if (norm_raw [2 ] == MATCH and norm_pred [2 ] == NONMATCH ) or (norm_raw [2 ] == NONMATCH and norm_pred [2 ] == MATCH ):
100101 mismatched_id = word_id
101102 non_match_raw_start = norm_raw [0 ]
@@ -114,20 +115,21 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
114115 if not done :
115116 non_match_raw_end = len (raw .split ())
116117 non_match_pred_end = len (pred_text .split ())
117- adjusted .append (
118+ raw_pred_spans .append (
118119 (
119120 mismatched_id ,
120121 (non_match_raw_start , non_match_raw_end , NONMATCH ),
121122 (non_match_pred_start , non_match_pred_end , NONMATCH ),
122123 )
123124 )
124125 else :
125- adjusted .append ((word_id , norm_raw , norm_pred ))
126+ raw_pred_spans .append ((word_id , norm_raw , norm_pred ))
126127 word_id += 1
127128
128- adjusted2 = []
129+ # aggregate neighboring spans with the same status
130+ spans_merged_neighbors = []
129131 last_status = None
130- for idx , item in enumerate (adjusted ):
132+ for idx , item in enumerate (raw_pred_spans ):
131133 if last_status is None :
132134 last_status = item [1 ][2 ]
133135 raw_start = item [1 ][0 ]
@@ -139,7 +141,7 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
139141 raw_end = item [1 ][1 ]
140142 pred_text_end = item [2 ][1 ]
141143 else :
142- adjusted2 .append (
144+ spans_merged_neighbors .append (
143145 [[norm_span_start , item [0 ]], [raw_start , raw_end ], [pred_text_start , pred_text_end ], last_status ]
144146 )
145147 last_status = item [1 ][2 ]
@@ -152,13 +154,13 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
152154 if last_status == item [1 ][2 ]:
153155 raw_end = item [1 ][1 ]
154156 pred_text_end = item [2 ][1 ]
155- adjusted2 .append (
157+ spans_merged_neighbors .append (
156158 [[norm_span_start , item [0 ]], [raw_start , raw_end ], [pred_text_start , pred_text_end ], last_status ]
157159 )
158160 else :
159- adjusted2 .append (
161+ spans_merged_neighbors .append (
160162 [
161- [adjusted [idx - 1 ][0 ], len (norm .split ())],
163+ [raw_pred_spans [idx - 1 ][0 ], len (norm .split ())],
162164 [item [1 ][0 ], len (raw .split ())],
163165 [item [2 ][0 ], len (pred_text .split ())],
164166 item [1 ][2 ],
@@ -171,10 +173,10 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
171173
172174 # increase boundaries between raw and pred_text if some spans contain empty pred_text
173175 extended_spans = []
174- adjusted3 = []
176+ raw_norm_spans_corrected_for_pred_text = []
175177 idx = 0
176- while idx < len (adjusted2 ):
177- item = adjusted2 [idx ]
178+ while idx < len (spans_merged_neighbors ):
179+ item = spans_merged_neighbors [idx ]
178180
179181 cur_semiotic = " " .join (raw_list [item [1 ][0 ] : item [1 ][1 ]])
180182 cur_pred_text = " " .join (pred_text_list [item [2 ][0 ] : item [2 ][1 ]])
@@ -186,8 +188,8 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
186188 # if cur_pred_text is an empty string
187189 if item [2 ][0 ] == item [2 ][1 ]:
188190 # for the last item
189- if idx == len (adjusted2 ) - 1 and len (adjusted3 ) > 0 :
190- last_item = adjusted3 [- 1 ]
191+ if idx == len (spans_merged_neighbors ) - 1 and len (raw_norm_spans_corrected_for_pred_text ) > 0 :
192+ last_item = raw_norm_spans_corrected_for_pred_text [- 1 ]
191193 last_item [0 ][1 ] = item [0 ][1 ]
192194 last_item [1 ][1 ] = item [1 ][1 ]
193195 last_item [2 ][1 ] = item [2 ][1 ]
@@ -196,29 +198,31 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
196198 raw_start , raw_end = item [0 ]
197199 norm_start , norm_end = item [1 ]
198200 pred_start , pred_end = item [2 ]
199- while idx < len (adjusted2 ) - 1 and not ((pred_end - pred_start ) > 2 and adjusted2 [idx ][- 1 ] == MATCH ):
201+ while idx < len (spans_merged_neighbors ) - 1 and not (
202+ (pred_end - pred_start ) > 2 and spans_merged_neighbors [idx ][- 1 ] == MATCH
203+ ):
200204 idx += 1
201- raw_end = adjusted2 [idx ][0 ][1 ]
202- norm_end = adjusted2 [idx ][1 ][1 ]
203- pred_end = adjusted2 [idx ][2 ][1 ]
205+ raw_end = spans_merged_neighbors [idx ][0 ][1 ]
206+ norm_end = spans_merged_neighbors [idx ][1 ][1 ]
207+ pred_end = spans_merged_neighbors [idx ][2 ][1 ]
204208 cur_item = [[raw_start , raw_end ], [norm_start , norm_end ], [pred_start , pred_end ], NONMATCH ]
205- adjusted3 .append (cur_item )
206- extended_spans .append (len (adjusted3 ) - 1 )
209+ raw_norm_spans_corrected_for_pred_text .append (cur_item )
210+ extended_spans .append (len (raw_norm_spans_corrected_for_pred_text ) - 1 )
207211 idx += 1
208212 else :
209- adjusted3 .append (item )
213+ raw_norm_spans_corrected_for_pred_text .append (item )
210214 idx += 1
211215
212216 semiotic_spans = []
213217 norm_spans = []
214218 pred_texts = []
215219 raw_text_masked = ""
216- for idx , item in enumerate (adjusted3 ):
220+ for idx , item in enumerate (raw_norm_spans_corrected_for_pred_text ):
217221 cur_semiotic = " " .join (raw_list [item [1 ][0 ] : item [1 ][1 ]])
218222 cur_pred_text = " " .join (pred_text_list [item [2 ][0 ] : item [2 ][1 ]])
219223 cur_norm_span = " " .join (norm_list [item [0 ][0 ] : item [0 ][1 ]])
220224
221- if idx == len (adjusted3 ) - 1 :
225+ if idx == len (raw_norm_spans_corrected_for_pred_text ) - 1 :
222226 cur_norm_span = " " .join (norm_list [item [0 ][0 ] : len (norm_list )])
223227 if (item [- 1 ] == NONMATCH and cur_semiotic != cur_norm_span ) or (idx in extended_spans ):
224228 raw_text_masked += " " + SEMIOTIC_TAG
@@ -233,24 +237,31 @@ def adjust_boundaries(norm_raw_diffs: Dict, norm_pred_diffs: Dict, raw: str, nor
233237
234238 if verbose :
235239 print ("+" * 50 )
236- print ("adjusted :" )
237- for item in adjusted2 :
240+ print ("raw_pred_spans :" )
241+ for item in spans_merged_neighbors :
238242 print (f"{ raw .split ()[item [1 ][0 ]: item [1 ][1 ]]} -- { pred_text .split ()[item [2 ][0 ]: item [2 ][1 ]]} " )
239243
240244 print ("+" * 50 )
241- print ("adjusted2 :" )
242- for item in adjusted2 :
245+ print ("spans_merged_neighbors :" )
246+ for item in spans_merged_neighbors :
243247 print (f"{ raw .split ()[item [1 ][0 ]: item [1 ][1 ]]} -- { pred_text .split ()[item [2 ][0 ]: item [2 ][1 ]]} " )
244248 print ("+" * 50 )
245- print ("adjusted3 :" )
246- for item in adjusted3 :
249+ print ("raw_norm_spans_corrected_for_pred_text :" )
250+ for item in raw_norm_spans_corrected_for_pred_text :
247251 print (f"{ raw .split ()[item [1 ][0 ]: item [1 ][1 ]]} -- { pred_text .split ()[item [2 ][0 ]: item [2 ][1 ]]} " )
248252 print ("+" * 50 )
249253
250254 return semiotic_spans , pred_texts , norm_spans , raw_text_masked_list , raw_text_mask_idx
251255
252256
253- def get_alignment (raw , norm , pred_text , verbose : bool = False ):
257+ def get_alignment (raw : str , norm : str , pred_text : str , verbose : bool = False ):
258+ """
259+ Aligns raw text with deterministically normalized text and ASR output, finds semiotic spans
260+ """
261+ for value in [raw , norm , pred_text ]:
262+ if value is None or value == "" :
263+ return [], [], [], [], []
264+
254265 norm_pred_diffs = _get_alignment (norm , pred_text )
255266 norm_raw_diffs = _get_alignment (norm , raw )
256267
@@ -271,8 +282,9 @@ def get_alignment(raw, norm, pred_text, verbose: bool = False):
271282
272283
273284if __name__ == "__main__" :
274- raw = 'This is #4 ranking on G.S.K.T.'
275- pred_text = 'this iss for ranking on g k p'
285+ raw = 'This is a #4 ranking on G.S.K.T.'
286+ pred_text = 'this iss p k for ranking on g k p'
276287 norm = 'This is nubmer four ranking on GSKT'
277288
278- get_alignment (raw , norm , pred_text , True )
289+ output = get_alignment (raw , norm , pred_text , True )
290+ print (output )
0 commit comments