Skip to content

Commit aa3f531

Browse files
pre-commit-ci[bot]folivoramanh
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci Signed-off-by: Mai Anh <palasek182@gmail.com>
1 parent 36416da commit aa3f531

17 files changed

Lines changed: 152 additions & 288 deletions

File tree

nemo_text_processing/text_normalization/pt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010
# distributed under the License is distributed on an "AS IS" BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
13-
# limitations under the License.
13+
# limitations under the License.

nemo_text_processing/text_normalization/pt/graph_utils.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@
4747

4848
delete_space = pynutil.delete(pynini.closure(NEMO_WHITE_SPACE))
4949
insert_space = pynutil.insert(" ")
50-
delete_extra_space = pynini.cross(
51-
pynini.closure(NEMO_WHITE_SPACE, 1), " "
52-
).optimize()
50+
delete_extra_space = pynini.cross(pynini.closure(NEMO_WHITE_SPACE, 1), " ").optimize()
5351

5452

5553
def generator_main(file_name: str, graphs: Dict[str, "pynini.FstLike"]) -> None:
@@ -84,13 +82,9 @@ def __init__(self, name: str, kind: str, deterministic: bool = True):
8482
self._fst = None
8583
self.deterministic = deterministic
8684

87-
self.far_path = Path(
88-
os.path.dirname(os.path.abspath(__file__)) + "/grammars/" + kind + "/" + name + ".far"
89-
)
85+
self.far_path = Path(os.path.dirname(os.path.abspath(__file__)) + "/grammars/" + kind + "/" + name + ".far")
9086
if self.far_exist():
91-
self._fst = Far(
92-
self.far_path, mode="r", arc_type="standard", far_type="default"
93-
).get_fst()
87+
self._fst = Far(self.far_path, mode="r", arc_type="standard", far_type="default").get_fst()
9488

9589
def far_exist(self) -> bool:
9690
return self.far_path.exists()
@@ -116,9 +110,7 @@ def delete_tokens(self, fst) -> "pynini.FstLike":
116110
+ delete_space
117111
+ pynutil.delete("}")
118112
)
119-
return res @ pynini.cdrewrite(
120-
pynini.cross("\u00a0", " "), "", "", NEMO_SIGMA
121-
)
113+
return res @ pynini.cdrewrite(pynini.cross("\u00a0", " "), "", "", NEMO_SIGMA)
122114

123115

124116
# ---- PT-specific (Brazilian: 1.000.000 or 1 000 000) ----
@@ -172,9 +164,7 @@ def shift_cardinal_gender_pt(fst: "pynini.FstLike") -> "pynini.FstLike":
172164
)
173165
fem_hundreds = pynini.cdrewrite(
174166
pynini.cross("entos", "entas"),
175-
pynini.union(
176-
"duz", "trez", "quatroc", "quinh", "seisc", "setec", "oitoc", "novec"
177-
),
167+
pynini.union("duz", "trez", "quatroc", "quinh", "seisc", "setec", "oitoc", "novec"),
178168
pynini.union(NEMO_SPACE, pynini.accep("[EOS]"), pynini.accep('"')),
179169
NEMO_SIGMA,
180170
)

nemo_text_processing/text_normalization/pt/taggers/cardinal.py

Lines changed: 75 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
NEMO_WHITE_SPACE,
2727
GraphFst,
2828
delete_space,
29-
insert_space,
3029
filter_cardinal_punctuation,
30+
insert_space,
3131
)
3232
from nemo_text_processing.text_normalization.pt.utils import get_abs_path, load_labels
3333

@@ -48,33 +48,29 @@ def __init__(self, deterministic: bool = True):
4848
super().__init__(name="cardinal", kind="classify", deterministic=deterministic)
4949

5050
specials = {
51-
row[0]: row[1]
52-
for row in load_labels(get_abs_path("data/numbers/cardinal_specials.tsv"))
53-
if len(row) >= 2
51+
row[0]: row[1] for row in load_labels(get_abs_path("data/numbers/cardinal_specials.tsv")) if len(row) >= 2
5452
}
5553
connector_e = insert_space + pynutil.insert(specials["connector"]) + insert_space
5654
thousand = specials["thousand"]
5755
hundred_100 = specials["hundred_100"]
5856
hundred_1 = specials["hundred_1"]
5957

6058
scale_rows = load_labels(get_abs_path("data/numbers/scales.tsv"))
61-
scales = [
62-
(row[0], row[1], int(row[2]))
63-
for row in scale_rows
64-
if len(row) >= 3 and row[2].strip().isdigit()
65-
]
59+
scales = [(row[0], row[1], int(row[2])) for row in scale_rows if len(row) >= 3 and row[2].strip().isdigit()]
6660

6761
_num = lambda p: pynini.string_file(get_abs_path(f"data/numbers/{p}"))
6862
zero, digit, teens, tens, hundreds = (
69-
_num("zero.tsv"), _num("digit.tsv"), _num("teens.tsv"), _num("tens.tsv"), _num("hundreds.tsv")
63+
_num("zero.tsv"),
64+
_num("digit.tsv"),
65+
_num("teens.tsv"),
66+
_num("tens.tsv"),
67+
_num("hundreds.tsv"),
7068
)
7169
digits_no_one = (NEMO_DIGIT - "1") @ digit
7270

7371
graph_tens = teens | (tens + (pynutil.delete("0") | (connector_e + digit)))
7472
self.tens = graph_tens.optimize()
75-
self.two_digit_non_zero = pynini.union(
76-
digit, graph_tens, (pynini.cross("0", NEMO_SPACE) + digit)
77-
).optimize()
73+
self.two_digit_non_zero = pynini.union(digit, graph_tens, (pynini.cross("0", NEMO_SPACE) + digit)).optimize()
7874

7975
graph_hundreds = hundreds + pynini.union(
8076
pynutil.delete("00"),
@@ -109,7 +105,8 @@ def __init__(self, deterministic: bool = True):
109105
(connector_e + graph_tens),
110106
(connector_e + pynutil.delete("0") + digit),
111107
),
112-
hundreds + pynini.union(
108+
hundreds
109+
+ pynini.union(
113110
(connector_e + graph_tens),
114111
(connector_e + digit),
115112
),
@@ -129,7 +126,9 @@ def __init__(self, deterministic: bool = True):
129126
)
130127
t_comp_no_one = pynini.union(
131128
pynutil.delete("000") + h_comp_no_one,
132-
h_comp_no_one + insert_space + pynutil.insert(thousand)
129+
h_comp_no_one
130+
+ insert_space
131+
+ pynutil.insert(thousand)
133132
+ ((insert_space + h_comp) | pynutil.delete("000")),
134133
pynini.cross("001", thousand) + ((insert_space + h_comp) | pynutil.delete("000")),
135134
)
@@ -154,8 +153,10 @@ def __init__(self, deterministic: bool = True):
154153
# Units 6 (u6): pure get "e" after scale; compound no "e"
155154
u6_one = pynini.cross("000001", "1") @ digit
156155
u6_pure = pynini.union(
157-
u6_one, pynini.cross("001000", thousand),
158-
pynini.cross("000010", "10") @ graph_tens, pynini.cross("000100", hundred_100),
156+
u6_one,
157+
pynini.cross("001000", thousand),
158+
pynini.cross("000010", "10") @ graph_tens,
159+
pynini.cross("000100", hundred_100),
159160
(pynini.cross("010000", "10") @ graph_tens) + insert_space + pynutil.insert(thousand),
160161
pynini.cross("100000", hundred_100) + insert_space + pynutil.insert(thousand),
161162
)
@@ -164,36 +165,40 @@ def __init__(self, deterministic: bool = True):
164165
z18 = pynini.accep("0" * 18) # 18 zeros: branch no "e"
165166
smaller_e = (connector_e + u6_pure) | u6_compound | pynutil.delete("0" * 6)
166167
smaller = u6 | pynutil.delete("0" * 6)
167-
graph_24 = (
168-
((NEMO_DIGIT**18 - z18) + NEMO_DIGIT**6) @ (graph_large_scales + smaller_e)
169-
) | ((z18 + NEMO_DIGIT**6) @ (pynutil.delete(z18) + smaller))
168+
graph_24 = (((NEMO_DIGIT**18 - z18) + NEMO_DIGIT**6) @ (graph_large_scales + smaller_e)) | (
169+
(z18 + NEMO_DIGIT**6) @ (pynutil.delete(z18) + smaller)
170+
)
170171

171172
trail_by_z = {9: trail_9, 12: trail_12}
172173
magnitude_patterns = [
173174
self._build_magnitude_pattern(
174-
one_label, plural_suffix, magnitude_zeros, trail_by_z.get(magnitude_zeros),
175-
connector_e, insert_space, digit, graph_tens, graph_hundreds,
175+
one_label,
176+
plural_suffix,
177+
magnitude_zeros,
178+
trail_by_z.get(magnitude_zeros),
179+
connector_e,
180+
insert_space,
181+
digit,
182+
graph_tens,
183+
graph_hundreds,
176184
)
177185
for one_label, plural_suffix, magnitude_zeros in scales
178186
if magnitude_zeros > 0
179187
]
180188

181189
pad = (NEMO_DIGIT - "0") + pynini.closure(NEMO_DIGIT, 0)
182190
pad = pad @ pynini.cdrewrite(pynini.closure(pynutil.insert("0")), "[BOS]", "", NEMO_SIGMA) @ NEMO_DIGIT**24
183-
norm = pynini.cdrewrite(delete_space, "[BOS]", "", NEMO_SIGMA) @ pynini.cdrewrite(delete_space, "", "[EOS]", NEMO_SIGMA)
184-
norm = norm @ pynini.cdrewrite(pynini.cross(pynini.closure(NEMO_WHITE_SPACE, 2), NEMO_SPACE), NEMO_ALPHA, NEMO_ALPHA, NEMO_SIGMA)
191+
norm = pynini.cdrewrite(delete_space, "[BOS]", "", NEMO_SIGMA) @ pynini.cdrewrite(
192+
delete_space, "", "[EOS]", NEMO_SIGMA
193+
)
194+
norm = norm @ pynini.cdrewrite(
195+
pynini.cross(pynini.closure(NEMO_WHITE_SPACE, 2), NEMO_SPACE), NEMO_ALPHA, NEMO_ALPHA, NEMO_SIGMA
196+
)
185197
self.graph = reduce(lambda a, b: a | b, magnitude_patterns, pad @ graph_24 @ norm) | zero
186198
self.graph = filter_cardinal_punctuation(self.graph).optimize()
187199

188-
optional_minus_graph = pynini.closure(
189-
pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1
190-
)
191-
final_graph = (
192-
optional_minus_graph
193-
+ pynutil.insert("integer: \"")
194-
+ self.graph
195-
+ pynutil.insert("\"")
196-
)
200+
optional_minus_graph = pynini.closure(pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1)
201+
final_graph = optional_minus_graph + pynutil.insert("integer: \"") + self.graph + pynutil.insert("\"")
197202
final_graph = self.add_tokens(final_graph)
198203
self.fst = final_graph.optimize()
199204

@@ -221,13 +226,19 @@ def _build_scale_trailing_graph(self, scale_3, sub_graph, trailing_len, total_le
221226
@staticmethod
222227
def _pure_inputs(num_digits):
223228
"""Inputs 1, 10, 100, ... as num_digits-digit strings."""
224-
return pynini.union(
225-
*[pynini.accep(str(10**k).zfill(num_digits)) for k in range(0, num_digits)]
226-
)
229+
return pynini.union(*[pynini.accep(str(10**k).zfill(num_digits)) for k in range(0, num_digits)])
227230

228231
def _magnitude_graph(
229-
self, one_word, plural_suffix, zero_count, graph_digit, graph_tens, graph_hundreds,
230-
connector_e, insert_space, trailing_pair=None,
232+
self,
233+
one_word,
234+
plural_suffix,
235+
zero_count,
236+
graph_digit,
237+
graph_tens,
238+
graph_hundreds,
239+
connector_e,
240+
insert_space,
241+
trailing_pair=None,
231242
):
232243
"""Round (1–3 digit + scale + zeros); optional trailing (e + pure | space + compound)."""
233244
zeros = "0" * zero_count
@@ -236,54 +247,51 @@ def _magnitude_graph(
236247
for L in (1, 2, 3):
237248
total = zero_count + L
238249
if L == 1:
239-
lead = pynini.cross("1", one_word) | (
240-
(NEMO_DIGIT - "1") @ graph_digit + pynutil.insert(plural_suffix)
241-
)
250+
lead = pynini.cross("1", one_word) | ((NEMO_DIGIT - "1") @ graph_digit + pynutil.insert(plural_suffix))
242251
else:
243-
lead = (
244-
pynini.closure(NEMO_DIGIT, L, L)
245-
@ (graph_tens if L == 2 else graph_hundreds)
246-
+ pynutil.insert(plural_suffix)
252+
lead = pynini.closure(NEMO_DIGIT, L, L) @ (graph_tens if L == 2 else graph_hundreds) + pynutil.insert(
253+
plural_suffix
247254
)
248255
lead_fst = NEMO_DIGIT**L @ lead
249-
round_pats.append(
250-
pynini.closure(NEMO_DIGIT, total, total) @ (lead_fst + pynutil.delete(zeros))
251-
)
256+
round_pats.append(pynini.closure(NEMO_DIGIT, total, total) @ (lead_fst + pynutil.delete(zeros)))
252257
if trailing_pair:
253258
pure, compound = trailing_pair
254-
trail_part = (
255-
NEMO_DIGIT**zero_count @ (connector_e + pure)
256-
| NEMO_DIGIT**zero_count @ (insert_space + compound)
257-
)
258-
trail_pats.append(
259-
pynini.closure(NEMO_DIGIT, total, total) @ (lead_fst + trail_part)
259+
trail_part = NEMO_DIGIT**zero_count @ (connector_e + pure) | NEMO_DIGIT**zero_count @ (
260+
insert_space + compound
260261
)
262+
trail_pats.append(pynini.closure(NEMO_DIGIT, total, total) @ (lead_fst + trail_part))
261263
graph_round = pynini.union(*round_pats)
262264
graph_trail = pynini.union(*trail_pats) if trail_pats else None
263265
return graph_round, graph_trail
264266

265267
def _build_magnitude_pattern(
266268
self,
267-
one_label, plural_suffix, magnitude_zeros,
269+
one_label,
270+
plural_suffix,
271+
magnitude_zeros,
268272
trailing_pair,
269-
connector_e, insert_space,
270-
graph_digit, graph_tens, graph_hundreds,
273+
connector_e,
274+
insert_space,
275+
graph_digit,
276+
graph_tens,
277+
graph_hundreds,
271278
):
272279
"""Restrict length; round + optional non-zero trailing."""
273-
restrict = (NEMO_DIGIT - "0") + pynini.closure(
274-
NEMO_DIGIT, magnitude_zeros, magnitude_zeros + 2
275-
)
280+
restrict = (NEMO_DIGIT - "0") + pynini.closure(NEMO_DIGIT, magnitude_zeros, magnitude_zeros + 2)
276281
graph_round, graph_trail = self._magnitude_graph(
277-
one_label, plural_suffix, magnitude_zeros,
278-
graph_digit, graph_tens, graph_hundreds,
279-
connector_e, insert_space, trailing_pair,
282+
one_label,
283+
plural_suffix,
284+
magnitude_zeros,
285+
graph_digit,
286+
graph_tens,
287+
graph_hundreds,
288+
connector_e,
289+
insert_space,
290+
trailing_pair,
280291
)
281292
if graph_trail is None:
282293
return pynutil.add_weight(restrict @ graph_round, -1.0)
283294
non_zero_trail = pynini.union(
284-
*[
285-
NEMO_DIGIT**n + (NEMO_DIGIT**magnitude_zeros - pynini.accep("0" * magnitude_zeros))
286-
for n in (1, 2, 3)
287-
]
295+
*[NEMO_DIGIT**n + (NEMO_DIGIT**magnitude_zeros - pynini.accep("0" * magnitude_zeros)) for n in (1, 2, 3)]
288296
)
289297
return pynutil.add_weight(restrict @ (graph_round | (non_zero_trail @ graph_trail)), -1.0)

0 commit comments

Comments
 (0)