Skip to content

Commit 4cc07ee

Browse files
authored
Add whitelist param to ITN (#30)
* add whitelist param to itn Signed-off-by: ekmb <ebakhturina@nvidia.com> * add whitelist to export Signed-off-by: ekmb <ebakhturina@nvidia.com> * update docstrings Signed-off-by: ekmb <ebakhturina@nvidia.com> --------- Signed-off-by: ekmb <ebakhturina@nvidia.com>
1 parent 2f6f2f6 commit 4cc07ee

18 files changed

Lines changed: 133 additions & 39 deletions

File tree

nemo_text_processing/inverse_text_normalization/ar/taggers/tokenize_and_classify.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ class ClassifyFst(GraphFst):
4141
Args:
4242
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
4343
overwrite_cache: set to True to overwrite .far files
44+
whitelist: path to a file with whitelist replacements
4445
"""
4546

46-
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
47+
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False, whitelist: str = None):
4748
super().__init__(name="tokenize_and_classify", kind="classify")
4849

4950
far_file = None

nemo_text_processing/inverse_text_normalization/de/taggers/tokenize_and_classify.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,12 @@ class ClassifyFst(GraphFst):
5757
Args:
5858
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
5959
overwrite_cache: set to True to overwrite .far files
60+
whitelist: path to a file with whitelist replacements
6061
"""
6162

62-
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False, deterministic: bool = True):
63+
def __init__(
64+
self, cache_dir: str = None, overwrite_cache: bool = False, deterministic: bool = True, whitelist: str = None
65+
):
6366
super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic)
6467

6568
far_file = None
@@ -80,7 +83,7 @@ def __init__(self, cache_dir: str = None, overwrite_cache: bool = False, determi
8083
tn_date_verbalizer = TNDateVerbalizer(ordinal=tn_ordinal_verbalizer, deterministic=False)
8184
tn_electronic_tagger = TNElectronicTagger(deterministic=False)
8285
tn_electronic_verbalizer = TNElectronicVerbalizer(deterministic=False)
83-
tn_whitelist_tagger = TNWhitelistTagger(input_case="cased", deterministic=False)
86+
tn_whitelist_tagger = TNWhitelistTagger(input_case="cased", deterministic=False, input_file=whitelist)
8487

8588
cardinal = CardinalFst(tn_cardinal_tagger=tn_cardinal_tagger)
8689
cardinal_graph = cardinal.fst

nemo_text_processing/inverse_text_normalization/de/taggers/whitelist.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,17 @@ class WhiteListFst(GraphFst):
2424
e.g. misses -> tokens { name: "Mrs." }
2525
Args:
2626
tn_whitelist_tagger: TN whitelist tagger
27+
input_file: path to a file with whitelist replacements (each line of the file: written_form\tspoken_form\n),
28+
e.g. nemo_text_processing/inverse_text_normalization/en/data/whitelist.tsv
2729
"""
2830

29-
def __init__(self, tn_whitelist_tagger: GraphFst, deterministic: bool = True):
31+
def __init__(self, tn_whitelist_tagger: GraphFst, deterministic: bool = True, input_file: str = None):
3032
super().__init__(name="whitelist", kind="classify", deterministic=deterministic)
3133

32-
whitelist = pynini.invert(tn_whitelist_tagger.graph)
34+
if input_file:
35+
whitelist = pynini.string_file(input_file).invert()
36+
else:
37+
whitelist = pynini.invert(tn_whitelist_tagger.graph)
38+
3339
graph = pynutil.insert("name: \"") + convert_space(whitelist) + pynutil.insert("\"")
3440
self.fst = graph.optimize()

nemo_text_processing/inverse_text_normalization/en/taggers/tokenize_and_classify.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ class ClassifyFst(GraphFst):
4747
Args:
4848
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
4949
overwrite_cache: set to True to overwrite .far files
50+
whitelist: path to a file with whitelist replacements
5051
"""
5152

52-
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
53+
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False, whitelist: str = None):
5354
super().__init__(name="tokenize_and_classify", kind="classify")
5455

5556
far_file = None
@@ -75,7 +76,7 @@ def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
7576
word_graph = WordFst().fst
7677
time_graph = TimeFst().fst
7778
money_graph = MoneyFst(cardinal=cardinal, decimal=decimal).fst
78-
whitelist_graph = WhiteListFst().fst
79+
whitelist_graph = WhiteListFst(input_file=whitelist).fst
7980
punct_graph = PunctuationFst().fst
8081
electronic_graph = ElectronicFst().fst
8182
telephone_graph = TelephoneFst(cardinal).fst

nemo_text_processing/inverse_text_normalization/en/taggers/whitelist.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import logging
17+
1618
import pynini
1719
from nemo_text_processing.inverse_text_normalization.en.utils import get_abs_path
1820
from nemo_text_processing.text_normalization.en.graph_utils import GraphFst, convert_space
@@ -23,12 +25,20 @@ class WhiteListFst(GraphFst):
2325
"""
2426
Finite state transducer for classifying whitelisted tokens
2527
e.g. misses -> tokens { name: "mrs." }
26-
This class has highest priority among all classifier grammars. Whitelisted tokens are defined and loaded from "data/whitelist.tsv".
28+
This class has highest priority among all classifier grammars.
29+
Whitelisted tokens are defined and loaded from "data/whitelist.tsv" (unless input_file specified).
30+
31+
Args:
32+
input_file: path to a file with whitelist replacements (each line of the file: written_form\tspoken_form\n),
33+
e.g. nemo_text_processing/inverse_text_normalization/en/data/whitelist.tsv
2734
"""
2835

29-
def __init__(self):
36+
def __init__(self, input_file: str = None):
3037
super().__init__(name="whitelist", kind="classify")
3138

32-
whitelist = pynini.string_file(get_abs_path("data/whitelist.tsv")).invert()
39+
if input_file:
40+
whitelist = pynini.string_file(input_file).invert()
41+
else:
42+
whitelist = pynini.string_file(get_abs_path("data/whitelist.tsv")).invert()
3343
graph = pynutil.insert("name: \"") + convert_space(whitelist) + pynutil.insert("\"")
3444
self.fst = graph.optimize()

nemo_text_processing/inverse_text_normalization/es/taggers/tokenize_and_classify.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ class ClassifyFst(GraphFst):
4747
Args:
4848
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
4949
overwrite_cache: set to True to overwrite .far files
50+
whitelist: path to a file with whitelist replacements
5051
"""
5152

52-
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
53+
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False, whitelist: str = None):
5354
super().__init__(name="tokenize_and_classify", kind="classify")
5455

5556
far_file = None
@@ -79,7 +80,7 @@ def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
7980
word_graph = WordFst().fst
8081
time_graph = TimeFst().fst
8182
money_graph = MoneyFst(cardinal=cardinal, decimal=decimal).fst
82-
whitelist_graph = WhiteListFst().fst
83+
whitelist_graph = WhiteListFst(input_file=whitelist).fst
8384
punct_graph = PunctuationFst().fst
8485
electronic_graph = ElectronicFst().fst
8586
telephone_graph = TelephoneFst().fst

nemo_text_processing/inverse_text_normalization/es/taggers/whitelist.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,21 @@ class WhiteListFst(GraphFst):
2222
"""
2323
Finite state transducer for classifying whitelisted tokens
2424
e.g. usted -> tokens { name: "ud." }
25-
This class has highest priority among all classifier grammars. Whitelisted tokens are defined and loaded from "data/whitelist.tsv".
25+
This class has highest priority among all classifier grammars.
26+
27+
Whitelisted tokens are defined and loaded from "data/whitelist.tsv" (unless input_file specified).
28+
29+
Args:
30+
input_file: path to a file with whitelist replacements (each line of the file: written_form\tspoken_form\n),
31+
e.g. nemo_text_processing/inverse_text_normalization/es/data/whitelist.tsv
2632
"""
2733

28-
def __init__(self):
34+
def __init__(self, input_file: str = None):
2935
super().__init__(name="whitelist", kind="classify")
3036

31-
whitelist = pynini.string_file(get_abs_path("data/whitelist.tsv")).invert()
37+
if input_file:
38+
whitelist = pynini.string_file(input_file).invert()
39+
else:
40+
whitelist = pynini.string_file(get_abs_path("data/whitelist.tsv")).invert()
3241
graph = pynutil.insert("name: \"") + convert_space(whitelist) + pynutil.insert("\"")
3342
self.fst = graph.optimize()

nemo_text_processing/inverse_text_normalization/fr/taggers/tokenize_and_classify.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ class ClassifyFst(GraphFst):
4747
Args:
4848
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
4949
overwrite_cache: set to True to overwrite .far files
50+
whitelist: path to a file with whitelist replacements
5051
"""
5152

52-
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
53+
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False, whitelist: str = None):
5354
super().__init__(name="tokenize_and_classify", kind="classify")
5455

5556
far_file = None
@@ -79,7 +80,7 @@ def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
7980
word_graph = WordFst().fst
8081
time_graph = TimeFst().fst
8182
money_graph = MoneyFst(cardinal, decimal).fst
82-
whitelist_graph = WhiteListFst().fst
83+
whitelist_graph = WhiteListFst(input_file=whitelist).fst
8384
punct_graph = PunctuationFst().fst
8485
electronic_graph = ElectronicFst().fst
8586
telephone_graph = TelephoneFst().fst

nemo_text_processing/inverse_text_normalization/fr/taggers/whitelist.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,20 @@ class WhiteListFst(GraphFst):
2222
"""
2323
Finite state transducer for classifying whitelisted tokens
2424
e.g. misses -> tokens { name: "mrs." }
25-
This class has highest priority among all classifier grammars. Whitelisted tokens are defined and loaded from "data/whitelist.tsv".
25+
This class has highest priority among all classifier grammars.
26+
Whitelisted tokens are defined and loaded from "data/whitelist.tsv" (unless input_file specified).
27+
28+
Args:
29+
input_file: path to a file with whitelist replacements (each line of the file: written_form\tspoken_form\n),
30+
e.g. nemo_text_processing/inverse_text_normalization/fr/data/whitelist.tsv
2631
"""
2732

28-
def __init__(self):
33+
def __init__(self, input_file: str = None):
2934
super().__init__(name="whitelist", kind="classify")
3035

31-
whitelist = pynini.string_file(get_abs_path("data/whitelist.tsv"))
36+
if input_file:
37+
whitelist = pynini.string_file(input_file).invert()
38+
else:
39+
whitelist = pynini.string_file(get_abs_path("data/whitelist.tsv"))
3240
graph = pynutil.insert("name: \"") + convert_space(whitelist) + pynutil.insert("\"")
3341
self.fst = graph.optimize()

nemo_text_processing/inverse_text_normalization/inverse_normalize.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from argparse import ArgumentParser
1617
from time import perf_counter
1718
from typing import List
@@ -28,6 +29,8 @@ class InverseNormalizer(Normalizer):
2829
2930
Args:
3031
lang: language specifying the ITN
32+
whitelist: path to a file with whitelist replacements. (each line of the file: written_form\tspoken_form\n),
33+
e.g. nemo_text_processing/inverse_text_normalization/en/data/whitelist.tsv
3134
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
3235
overwrite_cache: set to True to overwrite .far files
3336
max_number_of_permutations_per_split: a maximum number
@@ -37,6 +40,7 @@ class InverseNormalizer(Normalizer):
3740
def __init__(
3841
self,
3942
lang: str = 'en',
43+
whitelist: str = None,
4044
cache_dir: str = None,
4145
overwrite_cache: bool = False,
4246
max_number_of_permutations_per_split: int = 729,
@@ -87,7 +91,7 @@ def __init__(
8791
VerbalizeFinalFst,
8892
)
8993

90-
self.tagger = ClassifyFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache)
94+
self.tagger = ClassifyFst(cache_dir=cache_dir, whitelist=whitelist, overwrite_cache=overwrite_cache)
9195
self.verbalizer = VerbalizeFinalFst()
9296
self.parser = TokenParser()
9397
self.lang = lang
@@ -128,6 +132,12 @@ def parse_args():
128132
parser.add_argument(
129133
"--language", help="language", choices=['en', 'de', 'es', 'pt', 'ru', 'fr', 'vi'], default="en", type=str
130134
)
135+
parser.add_argument(
136+
"--whitelist",
137+
help="Path to a file with with whitelist replacements," "e.g., inverse_normalization/en/data/whitelist.tsv",
138+
default=None,
139+
type=str,
140+
)
131141
parser.add_argument("--verbose", help="print info for debugging", action='store_true')
132142
parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
133143
parser.add_argument(
@@ -141,9 +151,11 @@ def parse_args():
141151

142152
if __name__ == "__main__":
143153
args = parse_args()
154+
155+
whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
144156
start_time = perf_counter()
145157
inverse_normalizer = InverseNormalizer(
146-
lang=args.language, cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache
158+
lang=args.language, cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache, whitelist=whitelist,
147159
)
148160
print(f'Time to generate graph: {round(perf_counter() - start_time, 2)} sec')
149161

0 commit comments

Comments
 (0)