Skip to content

Commit c052711

Browse files
committed
black format
1 parent 23e4eaf commit c052711

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+1246
-1855
lines changed

itn/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from itn.main import main
22

3-
if __name__ == '__main__':
3+
if __name__ == "__main__":
44
main()

itn/chinese/inverse_normalizer.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,52 +31,53 @@
3131

3232
class InverseNormalizer(Processor):
3333

34-
def __init__(self,
35-
cache_dir=None,
36-
overwrite_cache=False,
37-
enable_standalone_number=True,
38-
enable_0_to_9=False,
39-
enable_million=False):
40-
super().__init__(name='zh_inverse_normalizer', ordertype='itn')
34+
def __init__(
35+
self,
36+
cache_dir=None,
37+
overwrite_cache=False,
38+
enable_standalone_number=True,
39+
enable_0_to_9=False,
40+
enable_million=False,
41+
):
42+
super().__init__(name="zh_inverse_normalizer", ordertype="itn")
4143
self.convert_number = enable_standalone_number
4244
self.enable_0_to_9 = enable_0_to_9
4345
self.enable_million = enable_million
4446
if cache_dir is None:
4547
cache_dir = files("itn")
46-
self.build_fst('zh_itn', cache_dir, overwrite_cache)
48+
self.build_fst("zh_itn", cache_dir, overwrite_cache)
4749

4850
def build_tagger(self):
49-
tagger = (add_weight(Date().tagger, 1.02)
50-
| add_weight(Whitelist().tagger, 1.01)
51-
| add_weight(Fraction().tagger, 1.05)
52-
| add_weight(
53-
Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05)
54-
| add_weight(
55-
Money(enable_0_to_9=self.enable_0_to_9).tagger, 1.04)
56-
| add_weight(Time().tagger, 1.05)
57-
| add_weight(
58-
Cardinal(self.convert_number, self.enable_0_to_9,
59-
self.enable_million).tagger, 1.06)
60-
| add_weight(Math().tagger, 1.10)
61-
| add_weight(LicensePlate().tagger, 1.0)
62-
| add_weight(Char().tagger, 100)).optimize()
51+
tagger = (
52+
add_weight(Date().tagger, 1.02)
53+
| add_weight(Whitelist().tagger, 1.01)
54+
| add_weight(Fraction().tagger, 1.05)
55+
| add_weight(Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05)
56+
| add_weight(Money(enable_0_to_9=self.enable_0_to_9).tagger, 1.04)
57+
| add_weight(Time().tagger, 1.05)
58+
| add_weight(Cardinal(self.convert_number, self.enable_0_to_9, self.enable_million).tagger, 1.06)
59+
| add_weight(Math().tagger, 1.10)
60+
| add_weight(LicensePlate().tagger, 1.0)
61+
| add_weight(Char().tagger, 100)
62+
).optimize()
6363

6464
tagger = tagger.star
6565
# remove the last space
66-
self.tagger = tagger @ self.build_rule(delete(' '), '', '[EOS]')
66+
self.tagger = tagger @ self.build_rule(delete(" "), "", "[EOS]")
6767

6868
def build_verbalizer(self):
69-
verbalizer = (Cardinal(self.convert_number, self.enable_0_to_9,
70-
self.enable_million).verbalizer
71-
| Char().verbalizer
72-
| Date().verbalizer
73-
| Fraction().verbalizer
74-
| Math().verbalizer
75-
| Measure(enable_0_to_9=self.enable_0_to_9).verbalizer
76-
| Money(enable_0_to_9=self.enable_0_to_9).verbalizer
77-
| Time().verbalizer
78-
| LicensePlate().verbalizer
79-
| Whitelist().verbalizer).optimize()
69+
verbalizer = (
70+
Cardinal(self.convert_number, self.enable_0_to_9, self.enable_million).verbalizer
71+
| Char().verbalizer
72+
| Date().verbalizer
73+
| Fraction().verbalizer
74+
| Math().verbalizer
75+
| Measure(enable_0_to_9=self.enable_0_to_9).verbalizer
76+
| Money(enable_0_to_9=self.enable_0_to_9).verbalizer
77+
| Time().verbalizer
78+
| LicensePlate().verbalizer
79+
| Whitelist().verbalizer
80+
).optimize()
8081
postprocessor = PostProcessor(remove_interjections=True).processor
8182

8283
self.verbalizer = (verbalizer @ postprocessor).star

itn/chinese/rules/cardinal.py

Lines changed: 81 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,8 @@
2121

2222
class Cardinal(Processor):
2323

24-
def __init__(self,
25-
enable_standalone_number=True,
26-
enable_0_to_9=True,
27-
enable_million=False):
28-
super().__init__('cardinal')
24+
def __init__(self, enable_standalone_number=True, enable_0_to_9=True, enable_million=False):
25+
super().__init__("cardinal")
2926
self.number = None
3027
self.number_exclude_0_to_9 = None
3128
self.enable_standalone_number = enable_standalone_number
@@ -35,84 +32,97 @@ def __init__(self,
3532
self.build_verbalizer()
3633

3734
def build_tagger(self):
38-
zero = string_file(
39-
get_abs_path('../itn/chinese/data/number/zero.tsv')) # 0
40-
digit = string_file(
41-
get_abs_path('../itn/chinese/data/number/digit.tsv')) # 1 ~ 9
42-
special_tilde = string_file(
43-
get_abs_path(
44-
'../itn/chinese/data/number/special_tilde.tsv')) # 七八十->70~80
45-
special_tilde = special_tilde + add_weight(
46-
(accep("万") | accep("亿")), -0.1).ques
47-
special_dash = string_file(
48-
get_abs_path(
49-
'../itn/chinese/data/number/special_dash.tsv')) # 七八十->70-80
50-
special_dash = special_dash + add_weight(
51-
(accep("万") | accep("亿")), -0.1).ques
52-
sign = string_file(
53-
get_abs_path('../itn/chinese/data/number/sign.tsv')) # + -
54-
dot = string_file(
55-
get_abs_path('../itn/chinese/data/number/dot.tsv')) # .
35+
zero = string_file(get_abs_path("../itn/chinese/data/number/zero.tsv")) # 0
36+
digit = string_file(get_abs_path("../itn/chinese/data/number/digit.tsv")) # 1 ~ 9
37+
special_tilde = string_file(get_abs_path("../itn/chinese/data/number/special_tilde.tsv")) # 七八十->70~80
38+
special_tilde = special_tilde + add_weight((accep("万") | accep("亿")), -0.1).ques
39+
special_dash = string_file(get_abs_path("../itn/chinese/data/number/special_dash.tsv")) # 七八十->70-80
40+
special_dash = special_dash + add_weight((accep("万") | accep("亿")), -0.1).ques
41+
sign = string_file(get_abs_path("../itn/chinese/data/number/sign.tsv")) # + -
42+
dot = string_file(get_abs_path("../itn/chinese/data/number/dot.tsv")) # .
5643

5744
# 0. 基础数字
58-
addzero = insert('0')
45+
addzero = insert("0")
5946
digits = zero | digit # 0 ~ 9
6047
# 十一 => 11, 十二 => 12
61-
teen = cross('十', '1') + (digit | add_weight(addzero, 0.1))
48+
teen = cross("十", "1") + (digit | add_weight(addzero, 0.1))
6249
# 一十一 => 11, 二十一 => 21, 三十 => 30
63-
tens = digit + delete('十') + (digit | add_weight(addzero, 0.1))
50+
tens = digit + delete("十") + (digit | add_weight(addzero, 0.1))
6451
# 一百一十 => 110, 一百零一 => 101, 一百一 => 110, 一百 => 100
65-
hundred = (digit + delete('百') + (tens
66-
| teen
67-
| add_weight(zero + digit, 0.1)
68-
| add_weight(digit + addzero, 0.5)
69-
| add_weight(addzero**2, 1.0)))
52+
hundred = (
53+
digit
54+
+ delete("百")
55+
+ (
56+
tens
57+
| teen
58+
| add_weight(zero + digit, 0.1)
59+
| add_weight(digit + addzero, 0.5)
60+
| add_weight(addzero**2, 1.0)
61+
)
62+
)
7063
# 一千一百一十一 => 1111, 一千零一十一 => 1011, 一千零一 => 1001
7164
# 一千一 => 1100, 一千 => 1000
72-
thousand = (digit + delete('千') +
73-
(hundred
74-
| add_weight(zero + (tens | teen), 0.1)
75-
| add_weight(addzero + zero + digit, 0.5)
76-
| add_weight(digit + addzero**2, 0.8)
77-
| add_weight(addzero**3, 1.0)))
65+
thousand = (
66+
digit
67+
+ delete("千")
68+
+ (
69+
hundred
70+
| add_weight(zero + (tens | teen), 0.1)
71+
| add_weight(addzero + zero + digit, 0.5)
72+
| add_weight(digit + addzero**2, 0.8)
73+
| add_weight(addzero**3, 1.0)
74+
)
75+
)
7876
# 10001111, 1001111, 101111, 11111, 10111, 10011, 10001, 10000
7977
if self.enable_million:
8078
ten_thousand = (
81-
(thousand | hundred | teen | tens | digit) + delete('万') +
82-
(thousand
83-
| add_weight(zero + hundred, 0.1)
84-
| add_weight(addzero + zero + (tens | teen), 0.5)
85-
| add_weight(addzero + addzero + zero + digit, 0.5)
86-
| add_weight(digit + addzero**3, 0.8)
87-
| add_weight(addzero**4, 1.0)))
79+
(thousand | hundred | teen | tens | digit)
80+
+ delete("万")
81+
+ (
82+
thousand
83+
| add_weight(zero + hundred, 0.1)
84+
| add_weight(addzero + zero + (tens | teen), 0.5)
85+
| add_weight(addzero + addzero + zero + digit, 0.5)
86+
| add_weight(digit + addzero**3, 0.8)
87+
| add_weight(addzero**4, 1.0)
88+
)
89+
)
8890
else:
8991
ten_thousand = (
90-
(teen | tens | digit) + delete('万') +
91-
(thousand
92-
| add_weight(zero + hundred, 0.1)
93-
| add_weight(addzero + zero + (tens | teen), 0.5)
94-
| add_weight(addzero + addzero + zero + digit, 0.5)
95-
| add_weight(digit + addzero**3, 0.8)
96-
| add_weight(addzero**4, 1.0)))
97-
ten_thousand |= (thousand | hundred) + accep("万") + delete(
98-
"零").ques + (thousand | hundred | tens | teen | digits).ques
92+
(teen | tens | digit)
93+
+ delete("万")
94+
+ (
95+
thousand
96+
| add_weight(zero + hundred, 0.1)
97+
| add_weight(addzero + zero + (tens | teen), 0.5)
98+
| add_weight(addzero + addzero + zero + digit, 0.5)
99+
| add_weight(digit + addzero**3, 0.8)
100+
| add_weight(addzero**4, 1.0)
101+
)
102+
)
103+
ten_thousand |= (
104+
(thousand | hundred)
105+
+ accep("万")
106+
+ delete("零").ques
107+
+ (thousand | hundred | tens | teen | digits).ques
108+
)
99109

100110
# 1. 利用基础数字所构建的包含0~9的标准数字
101111
# 个/十/百/千/万
102112
number = digits | teen | tens | hundred | thousand | ten_thousand
103113
# 兆/亿
104-
number = ((number + accep('兆') + delete('零').ques).ques +
105-
(number + accep('亿') + delete('零').ques).ques + number)
114+
number = (
115+
(number + accep("兆") + delete("零").ques).ques + (number + accep("亿") + delete("零").ques).ques + number
116+
)
106117
# 负的xxx 1.11, 1.01
107118
number = sign.ques + number + (dot + digits.plus).ques
108119
# 五六万 => 5~6万,三五千 => 3000~5000,六七百 => 600~700,三四十 => 30~40, 三四十亿 => 30~40亿
109120
number |= special_tilde
110121
# 十七八 => 17-8, 四十五六 => 45-6, 三百七八十 => 370-80, 四十五六万 => 45-6万, 一万六七 => 16000-7000
111-
_special_dash = cross('十', '1') + special_dash
112-
_special_dash |= digit + delete('十') + special_dash
113-
_special_dash |= digit + delete('百') + special_dash
114-
_special_dash |= digit + delete('万') + digit + insert(
115-
'000-') + digit + insert('000')
122+
_special_dash = cross("十", "1") + special_dash
123+
_special_dash |= digit + delete("十") + special_dash
124+
_special_dash |= digit + delete("百") + special_dash
125+
_special_dash |= digit + delete("万") + digit + insert("000-") + digit + insert("000")
116126
number |= _special_dash
117127

118128
self.number = number.optimize()
@@ -123,33 +133,30 @@ def build_tagger(self):
123133
# 十/百/千/万
124134
number_exclude_0_to_9 = teen | tens | hundred | thousand | ten_thousand
125135
# 兆/亿
126-
number_exclude_0_to_9 = (((number_exclude_0_to_9 | digits) +
127-
accep('兆') + delete('零').ques).ques +
128-
((number_exclude_0_to_9 | digits) +
129-
accep('亿') + delete('零').ques).ques +
130-
number_exclude_0_to_9)
136+
number_exclude_0_to_9 = (
137+
((number_exclude_0_to_9 | digits) + accep("兆") + delete("零").ques).ques
138+
+ ((number_exclude_0_to_9 | digits) + accep("亿") + delete("零").ques).ques
139+
+ number_exclude_0_to_9
140+
)
131141
# 负的xxx 1.11, 1.01
132-
number_exclude_0_to_9 |= ((number_exclude_0_to_9 | digits) +
133-
(dot + digits.plus).plus)
142+
number_exclude_0_to_9 |= (number_exclude_0_to_9 | digits) + (dot + digits.plus).plus
134143
# 五六万,三五千,六七百,三四十
135144
# 十七八美元 => $17~18, 四十五六岁 => 45-6岁,
136145
# 三百七八公里 => 370-80km, 三百七八十千克 => 370-80kg
137146
number_exclude_0_to_9 |= special_tilde
138147
number_exclude_0_to_9 |= add_weight(_special_dash, -0.1)
139148

140-
self.number_exclude_0_to_9 = (sign.ques +
141-
number_exclude_0_to_9).optimize()
149+
self.number_exclude_0_to_9 = (sign.ques + number_exclude_0_to_9).optimize()
142150

143151
# 3. 特殊格式的数字
144152
# cardinal string like 127.0.0.1, used in ID, IP, etc.
145153
cardinal = digits.plus + (dot + digits.plus).plus
146154
# float number like 1.11
147-
cardinal |= (number + dot + digits.plus)
155+
cardinal |= number + dot + digits.plus
148156
# cardinal string like 110 or 12306 or 13125617878, used in phone,
149157
# 340621199806051223, used in ID card
150-
idcard_last_char = digits | 'X' | 'x'
151-
cardinal |= (digits**3 | digits**4 | digits**5 | digits**11
152-
| (digits**17 + idcard_last_char) | digits**18)
158+
idcard_last_char = digits | "X" | "x"
159+
cardinal |= digits**3 | digits**4 | digits**5 | digits**11 | (digits**17 + idcard_last_char) | digits**18
153160

154161
# 4. 特殊格式的数字 + 标准数字
155162
# cardinal string like 23
@@ -160,6 +167,5 @@ def build_tagger(self):
160167
cardinal |= add_weight(number, 0.1)
161168
else:
162169
cardinal |= add_weight(number_exclude_0_to_9, 0.1)
163-
tagger = insert('value: "') + cardinal + (insert(" ") + cardinal).star \
164-
+ insert('"')
170+
tagger = insert('value: "') + cardinal + (insert(" ") + cardinal).star + insert('"')
165171
self.tagger = self.add_tokens(tagger)

itn/chinese/rules/char.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
class Char(Processor):
2121

2222
def __init__(self):
23-
super().__init__(name='char')
23+
super().__init__(name="char")
2424
self.build_tagger()
2525
self.build_verbalizer()
2626

itn/chinese/rules/date.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,27 @@
2222
class Date(Processor):
2323

2424
def __init__(self):
25-
super().__init__(name='date')
25+
super().__init__(name="date")
2626
self.build_tagger()
2727
self.build_verbalizer()
2828

2929
def build_tagger(self):
30-
digit = string_file(
31-
get_abs_path('../itn/chinese/data/number/digit.tsv')) # 1 ~ 9
32-
zero = string_file(
33-
get_abs_path('../itn/chinese/data/number/zero.tsv')) # 0
34-
35-
yyyy = digit + (digit | zero)**3 # 二零零八年
36-
yyy = digit + (digit | zero)**2 # 公元一六八年
37-
yy = (digit | zero)**2 # 零八年奥运会
38-
mm = string_file(get_abs_path('../itn/chinese/data/date/mm.tsv'))
39-
dd = string_file(get_abs_path('../itn/chinese/data/date/dd.tsv'))
40-
41-
year = insert('year: "') + (yyyy | yyy | yy) + \
42-
delete('年') + insert('" ')
43-
year_only = insert('year: "') + (yyyy | yyy | yy) + \
44-
accep('年') + insert('"')
30+
digit = string_file(get_abs_path("../itn/chinese/data/number/digit.tsv")) # 1 ~ 9
31+
zero = string_file(get_abs_path("../itn/chinese/data/number/zero.tsv")) # 0
32+
33+
yyyy = digit + (digit | zero) ** 3 # 二零零八年
34+
yyy = digit + (digit | zero) ** 2 # 公元一六八年
35+
yy = (digit | zero) ** 2 # 零八年奥运会
36+
mm = string_file(get_abs_path("../itn/chinese/data/date/mm.tsv"))
37+
dd = string_file(get_abs_path("../itn/chinese/data/date/dd.tsv"))
38+
39+
year = insert('year: "') + (yyyy | yyy | yy) + delete("年") + insert('" ')
40+
year_only = insert('year: "') + (yyyy | yyy | yy) + accep("年") + insert('"')
4541
month = insert('month: "') + mm + insert('"')
4642
day = insert(' day: "') + dd + insert('"')
4743

4844
# yyyy/mm/dd | yyyy/mm | mm/dd | yyyy
49-
date = ((year + month + day)
50-
| (year + month)
51-
| (month + day)) | year_only
45+
date = ((year + month + day) | (year + month) | (month + day)) | year_only
5246
self.tagger = self.add_tokens(date)
5347

5448
def build_verbalizer(self):

0 commit comments

Comments
 (0)