2626 NEMO_WHITE_SPACE ,
2727 GraphFst ,
2828 delete_space ,
29- insert_space ,
3029 filter_cardinal_punctuation ,
30+ insert_space ,
3131)
3232from nemo_text_processing .text_normalization .pt .utils import get_abs_path , load_labels
3333
@@ -48,33 +48,29 @@ def __init__(self, deterministic: bool = True):
4848 super ().__init__ (name = "cardinal" , kind = "classify" , deterministic = deterministic )
4949
5050 specials = {
51- row [0 ]: row [1 ]
52- for row in load_labels (get_abs_path ("data/numbers/cardinal_specials.tsv" ))
53- if len (row ) >= 2
51+ row [0 ]: row [1 ] for row in load_labels (get_abs_path ("data/numbers/cardinal_specials.tsv" )) if len (row ) >= 2
5452 }
5553 connector_e = insert_space + pynutil .insert (specials ["connector" ]) + insert_space
5654 thousand = specials ["thousand" ]
5755 hundred_100 = specials ["hundred_100" ]
5856 hundred_1 = specials ["hundred_1" ]
5957
6058 scale_rows = load_labels (get_abs_path ("data/numbers/scales.tsv" ))
61- scales = [
62- (row [0 ], row [1 ], int (row [2 ]))
63- for row in scale_rows
64- if len (row ) >= 3 and row [2 ].strip ().isdigit ()
65- ]
59+ scales = [(row [0 ], row [1 ], int (row [2 ])) for row in scale_rows if len (row ) >= 3 and row [2 ].strip ().isdigit ()]
6660
6761 _num = lambda p : pynini .string_file (get_abs_path (f"data/numbers/{ p } " ))
6862 zero , digit , teens , tens , hundreds = (
69- _num ("zero.tsv" ), _num ("digit.tsv" ), _num ("teens.tsv" ), _num ("tens.tsv" ), _num ("hundreds.tsv" )
63+ _num ("zero.tsv" ),
64+ _num ("digit.tsv" ),
65+ _num ("teens.tsv" ),
66+ _num ("tens.tsv" ),
67+ _num ("hundreds.tsv" ),
7068 )
7169 digits_no_one = (NEMO_DIGIT - "1" ) @ digit
7270
7371 graph_tens = teens | (tens + (pynutil .delete ("0" ) | (connector_e + digit )))
7472 self .tens = graph_tens .optimize ()
75- self .two_digit_non_zero = pynini .union (
76- digit , graph_tens , (pynini .cross ("0" , NEMO_SPACE ) + digit )
77- ).optimize ()
73+ self .two_digit_non_zero = pynini .union (digit , graph_tens , (pynini .cross ("0" , NEMO_SPACE ) + digit )).optimize ()
7874
7975 graph_hundreds = hundreds + pynini .union (
8076 pynutil .delete ("00" ),
@@ -109,7 +105,8 @@ def __init__(self, deterministic: bool = True):
109105 (connector_e + graph_tens ),
110106 (connector_e + pynutil .delete ("0" ) + digit ),
111107 ),
112- hundreds + pynini .union (
108+ hundreds
109+ + pynini .union (
113110 (connector_e + graph_tens ),
114111 (connector_e + digit ),
115112 ),
@@ -129,7 +126,9 @@ def __init__(self, deterministic: bool = True):
129126 )
130127 t_comp_no_one = pynini .union (
131128 pynutil .delete ("000" ) + h_comp_no_one ,
132- h_comp_no_one + insert_space + pynutil .insert (thousand )
129+ h_comp_no_one
130+ + insert_space
131+ + pynutil .insert (thousand )
133132 + ((insert_space + h_comp ) | pynutil .delete ("000" )),
134133 pynini .cross ("001" , thousand ) + ((insert_space + h_comp ) | pynutil .delete ("000" )),
135134 )
@@ -154,8 +153,10 @@ def __init__(self, deterministic: bool = True):
154153 # Units 6 (u6): pure get "e" after scale; compound no "e"
155154 u6_one = pynini .cross ("000001" , "1" ) @ digit
156155 u6_pure = pynini .union (
157- u6_one , pynini .cross ("001000" , thousand ),
158- pynini .cross ("000010" , "10" ) @ graph_tens , pynini .cross ("000100" , hundred_100 ),
156+ u6_one ,
157+ pynini .cross ("001000" , thousand ),
158+ pynini .cross ("000010" , "10" ) @ graph_tens ,
159+ pynini .cross ("000100" , hundred_100 ),
159160 (pynini .cross ("010000" , "10" ) @ graph_tens ) + insert_space + pynutil .insert (thousand ),
160161 pynini .cross ("100000" , hundred_100 ) + insert_space + pynutil .insert (thousand ),
161162 )
@@ -164,36 +165,40 @@ def __init__(self, deterministic: bool = True):
164165 z18 = pynini .accep ("0" * 18 ) # 18 zeros: branch no "e"
165166 smaller_e = (connector_e + u6_pure ) | u6_compound | pynutil .delete ("0" * 6 )
166167 smaller = u6 | pynutil .delete ("0" * 6 )
167- graph_24 = (
168- (( NEMO_DIGIT ** 18 - z18 ) + NEMO_DIGIT ** 6 ) @ (graph_large_scales + smaller_e )
169- ) | (( z18 + NEMO_DIGIT ** 6 ) @ ( pynutil . delete ( z18 ) + smaller ))
168+ graph_24 = ((( NEMO_DIGIT ** 18 - z18 ) + NEMO_DIGIT ** 6 ) @ ( graph_large_scales + smaller_e )) | (
169+ (z18 + NEMO_DIGIT ** 6 ) @ (pynutil . delete ( z18 ) + smaller )
170+ )
170171
171172 trail_by_z = {9 : trail_9 , 12 : trail_12 }
172173 magnitude_patterns = [
173174 self ._build_magnitude_pattern (
174- one_label , plural_suffix , magnitude_zeros , trail_by_z .get (magnitude_zeros ),
175- connector_e , insert_space , digit , graph_tens , graph_hundreds ,
175+ one_label ,
176+ plural_suffix ,
177+ magnitude_zeros ,
178+ trail_by_z .get (magnitude_zeros ),
179+ connector_e ,
180+ insert_space ,
181+ digit ,
182+ graph_tens ,
183+ graph_hundreds ,
176184 )
177185 for one_label , plural_suffix , magnitude_zeros in scales
178186 if magnitude_zeros > 0
179187 ]
180188
181189 pad = (NEMO_DIGIT - "0" ) + pynini .closure (NEMO_DIGIT , 0 )
182190 pad = pad @ pynini .cdrewrite (pynini .closure (pynutil .insert ("0" )), "[BOS]" , "" , NEMO_SIGMA ) @ NEMO_DIGIT ** 24
183- norm = pynini .cdrewrite (delete_space , "[BOS]" , "" , NEMO_SIGMA ) @ pynini .cdrewrite (delete_space , "" , "[EOS]" , NEMO_SIGMA )
184- norm = norm @ pynini .cdrewrite (pynini .cross (pynini .closure (NEMO_WHITE_SPACE , 2 ), NEMO_SPACE ), NEMO_ALPHA , NEMO_ALPHA , NEMO_SIGMA )
191+ norm = pynini .cdrewrite (delete_space , "[BOS]" , "" , NEMO_SIGMA ) @ pynini .cdrewrite (
192+ delete_space , "" , "[EOS]" , NEMO_SIGMA
193+ )
194+ norm = norm @ pynini .cdrewrite (
195+ pynini .cross (pynini .closure (NEMO_WHITE_SPACE , 2 ), NEMO_SPACE ), NEMO_ALPHA , NEMO_ALPHA , NEMO_SIGMA
196+ )
185197 self .graph = reduce (lambda a , b : a | b , magnitude_patterns , pad @ graph_24 @ norm ) | zero
186198 self .graph = filter_cardinal_punctuation (self .graph ).optimize ()
187199
188- optional_minus_graph = pynini .closure (
189- pynutil .insert ("negative: " ) + pynini .cross ("-" , "\" true\" " ), 0 , 1
190- )
191- final_graph = (
192- optional_minus_graph
193- + pynutil .insert ("integer: \" " )
194- + self .graph
195- + pynutil .insert ("\" " )
196- )
200+ optional_minus_graph = pynini .closure (pynutil .insert ("negative: " ) + pynini .cross ("-" , "\" true\" " ), 0 , 1 )
201+ final_graph = optional_minus_graph + pynutil .insert ("integer: \" " ) + self .graph + pynutil .insert ("\" " )
197202 final_graph = self .add_tokens (final_graph )
198203 self .fst = final_graph .optimize ()
199204
@@ -221,13 +226,19 @@ def _build_scale_trailing_graph(self, scale_3, sub_graph, trailing_len, total_le
221226 @staticmethod
222227 def _pure_inputs (num_digits ):
223228 """Inputs 1, 10, 100, ... as num_digits-digit strings."""
224- return pynini .union (
225- * [pynini .accep (str (10 ** k ).zfill (num_digits )) for k in range (0 , num_digits )]
226- )
229+ return pynini .union (* [pynini .accep (str (10 ** k ).zfill (num_digits )) for k in range (0 , num_digits )])
227230
228231 def _magnitude_graph (
229- self , one_word , plural_suffix , zero_count , graph_digit , graph_tens , graph_hundreds ,
230- connector_e , insert_space , trailing_pair = None ,
232+ self ,
233+ one_word ,
234+ plural_suffix ,
235+ zero_count ,
236+ graph_digit ,
237+ graph_tens ,
238+ graph_hundreds ,
239+ connector_e ,
240+ insert_space ,
241+ trailing_pair = None ,
231242 ):
232243 """Round (1–3 digit + scale + zeros); optional trailing (e + pure | space + compound)."""
233244 zeros = "0" * zero_count
@@ -236,54 +247,51 @@ def _magnitude_graph(
236247 for L in (1 , 2 , 3 ):
237248 total = zero_count + L
238249 if L == 1 :
239- lead = pynini .cross ("1" , one_word ) | (
240- (NEMO_DIGIT - "1" ) @ graph_digit + pynutil .insert (plural_suffix )
241- )
250+ lead = pynini .cross ("1" , one_word ) | ((NEMO_DIGIT - "1" ) @ graph_digit + pynutil .insert (plural_suffix ))
242251 else :
243- lead = (
244- pynini .closure (NEMO_DIGIT , L , L )
245- @ (graph_tens if L == 2 else graph_hundreds )
246- + pynutil .insert (plural_suffix )
252+ lead = pynini .closure (NEMO_DIGIT , L , L ) @ (graph_tens if L == 2 else graph_hundreds ) + pynutil .insert (
253+ plural_suffix
247254 )
248255 lead_fst = NEMO_DIGIT ** L @ lead
249- round_pats .append (
250- pynini .closure (NEMO_DIGIT , total , total ) @ (lead_fst + pynutil .delete (zeros ))
251- )
256+ round_pats .append (pynini .closure (NEMO_DIGIT , total , total ) @ (lead_fst + pynutil .delete (zeros )))
252257 if trailing_pair :
253258 pure , compound = trailing_pair
254- trail_part = (
255- NEMO_DIGIT ** zero_count @ (connector_e + pure )
256- | NEMO_DIGIT ** zero_count @ (insert_space + compound )
257- )
258- trail_pats .append (
259- pynini .closure (NEMO_DIGIT , total , total ) @ (lead_fst + trail_part )
259+ trail_part = NEMO_DIGIT ** zero_count @ (connector_e + pure ) | NEMO_DIGIT ** zero_count @ (
260+ insert_space + compound
260261 )
262+ trail_pats .append (pynini .closure (NEMO_DIGIT , total , total ) @ (lead_fst + trail_part ))
261263 graph_round = pynini .union (* round_pats )
262264 graph_trail = pynini .union (* trail_pats ) if trail_pats else None
263265 return graph_round , graph_trail
264266
265267 def _build_magnitude_pattern (
266268 self ,
267- one_label , plural_suffix , magnitude_zeros ,
269+ one_label ,
270+ plural_suffix ,
271+ magnitude_zeros ,
268272 trailing_pair ,
269- connector_e , insert_space ,
270- graph_digit , graph_tens , graph_hundreds ,
273+ connector_e ,
274+ insert_space ,
275+ graph_digit ,
276+ graph_tens ,
277+ graph_hundreds ,
271278 ):
272279 """Restrict length; round + optional non-zero trailing."""
273- restrict = (NEMO_DIGIT - "0" ) + pynini .closure (
274- NEMO_DIGIT , magnitude_zeros , magnitude_zeros + 2
275- )
280+ restrict = (NEMO_DIGIT - "0" ) + pynini .closure (NEMO_DIGIT , magnitude_zeros , magnitude_zeros + 2 )
276281 graph_round , graph_trail = self ._magnitude_graph (
277- one_label , plural_suffix , magnitude_zeros ,
278- graph_digit , graph_tens , graph_hundreds ,
279- connector_e , insert_space , trailing_pair ,
282+ one_label ,
283+ plural_suffix ,
284+ magnitude_zeros ,
285+ graph_digit ,
286+ graph_tens ,
287+ graph_hundreds ,
288+ connector_e ,
289+ insert_space ,
290+ trailing_pair ,
280291 )
281292 if graph_trail is None :
282293 return pynutil .add_weight (restrict @ graph_round , - 1.0 )
283294 non_zero_trail = pynini .union (
284- * [
285- NEMO_DIGIT ** n + (NEMO_DIGIT ** magnitude_zeros - pynini .accep ("0" * magnitude_zeros ))
286- for n in (1 , 2 , 3 )
287- ]
295+ * [NEMO_DIGIT ** n + (NEMO_DIGIT ** magnitude_zeros - pynini .accep ("0" * magnitude_zeros )) for n in (1 , 2 , 3 )]
288296 )
289297 return pynutil .add_weight (restrict @ (graph_round | (non_zero_trail @ graph_trail )), - 1.0 )
0 commit comments