-
Notifications
You must be signed in to change notification settings - Fork 155
Expand file tree
/
Copy pathalignment.py
More file actions
267 lines (217 loc) · 9.08 KB
/
alignment.py
File metadata and controls
267 lines (217 loc) · 9.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import string
from argparse import ArgumentParser
from typing import List
import pynini
from pynini import Far
"""
This files takes 1. Far file containing a fst graph created by TN or ITN 2. entire string.
Optionally: 3. start position of substring 4. end (exclusive) position of substring
and returns input output mapping of all words in the string bounded by whitespace.
If start and end position specified returns
1. mapped output string 2. start and end indices of mapped substring
Usage:
python alignment.py --fst=<fst file> --text="2615 Forest Av, 1 Aug 2016" --rule=TOKENIZE_AND_CLASSIFY --start=22 --end=26 --grammar=TN
Output:
inp string: |2615 Forest Av, 1 Aug 2016|
out string: |twenty six fifteen Forest Avenue , the first of august twenty sixteen|
inp indices: [22:26]
out indices: [55:69]
in: |2016| out: |twenty sixteen|
python alignment.py --fst=<fst file> --text="2615 Forest Av, 1 Aug 2016" --rule=TOKENIZE_AND_CLASSIFY
Output:
inp string: |2615 Forest Av, 1 Aug 2016|
out string: |twenty six fifteen Forest Avenue , the first of august twenty sixteen|
inp indices: [0:4] out indices: [0:18]
in: |2615| out: |twenty six fifteen|
inp indices: [5:11] out indices: [19:25]
in: |Forest| out: |Forest|
inp indices: [12:15] out indices: [26:34]
in: |Av,| out: |Avenue ,|
inp indices: [16:17] out indices: [39:44]
in: |1| out: |first|
inp indices: [18:21] out indices: [48:54]
in: |Aug| out: |august|
inp indices: [22:26] out indices: [55:69]
in: |2016| out: |twenty sixteen|
Disclaimer: The heuristic algorithm relies on monotonous alignment and can fail in certain situations,
e.g. when word pieces are reordered by the fst:
python alignment.py --fst=<fst file> --text=\"$1\" --rule=\"tokenize_and_classify\" --start=0 --end=1
inp string: |$1|
out string: |one dollar|
inp indices: [0:1] out indices: [0:3]
in: |$| out: |one|
"""
def parse_args():
args = ArgumentParser("map substring to output with FST")
args.add_argument("--fst", help="FAR file containing FST", type=str, required=True)
args.add_argument(
"--grammar", help="tn or itn", type=str, required=False, choices=[ITN_MODE, TN_MODE], default=TN_MODE
)
args.add_argument(
"--rule",
help="rule name in FAR file containing FST",
type=str,
default='tokenize_and_classify',
required=False,
)
args.add_argument(
"--text",
help="input string",
type=str,
default="2615 Forest Av, 90601 CA, Santa Clara. 10kg, 12/16/2018, $123.25. 1 Aug 2016.",
)
args.add_argument("--start", help="start index of substring to be mapped", type=int, required=False)
args.add_argument("--end", help="end index of substring to be mapped", type=int, required=False)
return args.parse_args()
EPS = "<eps>"
WHITE_SPACE = "\u23b5"
ITN_MODE = "itn"
TN_MODE = "tn"
tn_item_special_chars = ["$", "\\", ":", "+", "-", "="]
tn_itn_symbols = list(string.ascii_letters + string.digits) + tn_item_special_chars
def get_word_segments(text: str) -> List[List[int]]:
"""
Returns word segments from given text based on white space in form of list of index spans.
"""
spans = []
cur_span = [0]
for idx, ch in enumerate(text):
if len(cur_span) == 0 and ch != " ":
cur_span.append(idx)
elif ch == " ":
cur_span.append(idx)
assert len(cur_span) == 2
spans.append(cur_span)
cur_span = []
elif idx == len(text) - 1:
idx += 1
cur_span.append(idx)
assert len(cur_span) == 2
spans.append(cur_span)
return spans
def create_symbol_table() -> pynini.SymbolTable:
"""
Creates and returns Pynini SymbolTable used to label alignment with ascii instead of integers
"""
table = pynini.SymbolTable()
for num in range(34, 200): # ascii alphanum + letter range
table.add_symbol(chr(num), num)
table.add_symbol(EPS, 0)
table.add_symbol(WHITE_SPACE, 32)
return table
def get_string_alignment(fst: pynini.Fst, input_text: str, symbol_table: pynini.SymbolTable):
"""
create alignment of input text based on shortest path in FST. Symbols used for alignment are from symbol_table
Returns:
output: list of tuples, each mapping input character to output
"""
lattice = pynini.shortestpath(input_text @ fst)
paths = lattice.paths(input_token_type=symbol_table, output_token_type=symbol_table)
ilabels = paths.ilabels()
olabels = paths.olabels()
logging.debug("input: " + paths.istring())
logging.debug("output: " + paths.ostring())
output = list(zip([symbol_table.find(x) for x in ilabels], [symbol_table.find(x) for x in olabels]))
logging.debug(f"alignment: {output}")
paths.next()
assert paths.done()
output_str = "".join(map(remove, [x[1] for x in output]))
return output, output_str
def _get_aligned_index(alignment: List[tuple], index: int):
"""
Given index in contracted input string computes corresponding index in alignment (which has EPS)
"""
aligned_index = 0
idx = 0
while idx < index:
if alignment[aligned_index][0] != EPS:
idx += 1
aligned_index += 1
while alignment[aligned_index][0] == EPS:
aligned_index += 1
return aligned_index
def _get_original_index(alignment, aligned_index):
"""
Given index in aligned output, returns corresponding index in contracted output string
"""
og_index = 0
idx = 0
while idx < aligned_index:
if alignment[idx][1] != EPS:
og_index += 1
idx += 1
return og_index
remove = lambda x: "" if x == EPS else " " if x == WHITE_SPACE else x
def indexed_map_to_output(alignment: List[tuple], start: int, end: int, mode: str):
"""
Given input start and end index of contracted substring return corresponding output start and end index
Args:
alignment: alignment generated by FST with shortestpath, is longer than original string since including eps transitions
start: inclusive start position in input string
end: exclusive end position in input string
mode: grammar type for either tn or itn
Returns:
output_og_start_index: inclusive start position in output string
output_og_end_index: exclusive end position in output string
"""
# get aligned start and end of input substring
aligned_start = _get_aligned_index(alignment, start)
aligned_end = _get_aligned_index(alignment, end - 1) # inclusive
logging.debug(f"0: |{list(map(remove, [x[0] for x in alignment[aligned_start:aligned_end+1]]))}|")
logging.debug(f"1: |{aligned_start}:{aligned_end+1}|")
# extend aligned_start to left
while (
aligned_start - 1 > 0
and alignment[aligned_start - 1][0] == EPS
and (alignment[aligned_start - 1][1] in tn_itn_symbols or alignment[aligned_start - 1][1] == EPS)
):
aligned_start -= 1
while (
aligned_end + 1 < len(alignment)
and alignment[aligned_end + 1][0] == EPS
and (alignment[aligned_end + 1][1] in tn_itn_symbols or alignment[aligned_end + 1][1] == EPS)
):
aligned_end += 1
if mode == TN_MODE:
while (aligned_end + 1) < len(alignment) and (
alignment[aligned_end + 1][1] in tn_itn_symbols or alignment[aligned_end + 1][1] == EPS
):
aligned_end += 1
output_og_start_index = _get_original_index(alignment=alignment, aligned_index=aligned_start)
output_og_end_index = _get_original_index(alignment=alignment, aligned_index=aligned_end + 1)
return output_og_start_index, output_og_end_index
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
args = parse_args()
fst = Far(args.fst, mode='r')
try:
fst = fst[args.rule]
except:
raise ValueError(f"{args.rule} not found. Please specify valid --rule argument.")
input_text = args.text
table = create_symbol_table()
alignment, output_text = get_string_alignment(fst=fst, input_text=input_text, symbol_table=table)
logging.info(f"inp string: |{args.text}|")
logging.info(f"out string: |{output_text}|")
if args.start is None:
indices = get_word_segments(input_text)
else:
indices = [(args.start, args.end)]
for x in indices:
start, end = indexed_map_to_output(start=x[0], end=x[1], alignment=alignment, mode=args.grammar)
logging.info(f"inp indices: [{x[0]}:{x[1]}] out indices: [{start}:{end}]")
logging.info(f"in: |{input_text[x[0]:x[1]]}| out: |{output_text[start:end]}|")