Skip to content

Commit 97eaa3e

Browse files
authored
Adding multireference script (#58)
1 parent 6eb2b6e commit 97eaa3e

2 files changed

Lines changed: 363 additions & 0 deletions

File tree

tools/README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,40 @@ A simple bash script that is meant for benchmarking the resource (RAM and runtim
2525

2626
Example usage:
2727
`bash gather_runtime_metrics.sh output_for_this_release.csv`
28+
29+
## sbs2fst.py
30+
A python interface to simplify the conversion of a side-by-side file, generated from fstalign's `--output-sbs` flag, into [files that can be used to produce an FST using OpenFST](https://www.openfst.org/twiki/bin/view/FST/FstQuickTour).
31+
32+
Example usage:
33+
34+
`python sbs2fst.py sbs_file.txt fst_file_name`
35+
36+
The output will be two files: `fst_file_name.fst` which will describe the FST in the AT&T FSM format used by OpenFST, and `fst_file_name.txt` which contains the complete list of symbols in the FST.
37+
38+
The additional flags can be passed into the python script to add metadata that fstalign uses for tracking performance. These are useful to understand when fstalign picks tokens that are: only in the side-by-side's `ref_token` column (labeled by the `--left` flag), only in the side-by-side's `hyp_token` column (labeled by the `--right` flag), or in both columns because the `ref_token` and `hyp_token` agree (labeled by the `--gold` flag).
39+
40+
Example usage:
41+
42+
`python sbs2fst.py --tag --left VERBATIM --right NONVERBATIM --gold AGREEMENT sbs_file.txt fst_file_name`
43+
44+
The output will produce an FST with tags indicating tokens that were only in the `ref_token` with `VERBATIM`, tokens that were only in the `hyp_token` with `NONVERBATIM`, and tokens that were in both columns with `AGREEMENT`.
45+
46+
### Compiling the FST
47+
Once you have used `sbs2fst.py` to produce the `.txt` and `.fst` files, you *must* then compile the FST before passing it into fstalign. An example command can be found below:
48+
49+
`fstcompile --isymbols=${SYMBOLS} --osymbols=${SYMBOLS} ${TXT_FST} ${COMPILED_FST}`
50+
51+
where `SYMBOLS` is the `.txt` file produced by `sbs2fst.py`, `TXT_FST` is the `.fst` file, and `COMPILED_FST` is a new `.fst` file that produces the binary FST usable by fstalign.
52+
53+
Example usage:
54+
```bash
55+
python sbs2fst.py --tag --left VERBATIM --right NONVERBATIM --gold AGREEMENT sbs_file.txt fst_file_name
56+
fstcompile --isymbols=fst_file_name.txt --osymbols=fst_file_name.txt fst_file_name.fst fst_file_name.compiled.fst
57+
```
58+
You can now use `fst_file_name.compiled.fst` in fstalign with the corresponding symbols file as follows:
59+
```bash
60+
fstalign --ref fst_file_name.complied.fst --symbols fst_file_name.txt ...
61+
```
62+
63+
Note that when you `sbs2fst.py` to produce a "tagged" FST with the `--tag` flag, fstalign will aggregate WER metrics for each of the specified tags (`--left`, `--right`, and `--gold`) in the JSON log file specified by fstalign's `--json-log` flag.
64+

tools/sbs2fst.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (C) 2023
5+
# Author: Miguel Ángel del Río Fernández <miguel.delrio@rev.com>
6+
# All Rights Reserved
7+
8+
from argparse import ArgumentParser
9+
from collections import OrderedDict
10+
from dataclasses import dataclass, field
11+
from itertools import takewhile
12+
from pathlib import Path
13+
from typing import Dict, Generator, List, Optional, Tuple
14+
15+
16+
@dataclass
17+
class SbsEntry:
18+
""" Represent a single SBS line."""
19+
ref_word: str
20+
hyp_word: str
21+
error: bool
22+
entity_class: str
23+
wer_tags: List[str] = field(default_factory=list)
24+
extra_columns: List[str] = field(default_factory=list)
25+
26+
@classmethod
27+
def from_line(cls, line: str) -> 'SbsEntry':
28+
parts = line.strip(' \n').split('\t')
29+
if len(parts) == 4:
30+
# old format
31+
entry = SbsEntry(parts[0].strip(), parts[1].strip(),
32+
parts[2] == 'ERR', parts[3])
33+
elif len(parts) == 5:
34+
# New format, wer_tags
35+
entry = SbsEntry(parts[0].strip(), parts[1].strip(),
36+
parts[2] == 'ERR', parts[3],
37+
[tag for tag in parts[4].split('|') if tag])
38+
elif len(parts) > 5:
39+
entry = SbsEntry(parts[0].strip(), parts[1].strip(),
40+
parts[2] == 'ERR', parts[3],
41+
[tag for tag in parts[4].split('|') if tag],
42+
extra_columns=parts[5:])
43+
else:
44+
raise RuntimeError(f"Could not parse the line as SBS:\n{line}")
45+
return entry
46+
47+
def __str__(self):
48+
if self.error:
49+
err_str = "ERR"
50+
else:
51+
err_str = ""
52+
wer_tags_str = "|".join(self.wer_tags)
53+
if wer_tags_str:
54+
wer_tags_str += "|"
55+
return '\t'.join([self.ref_word, self.hyp_word, err_str,
56+
self.entity_class, wer_tags_str]+self.extra_columns)
57+
58+
59+
def load_from_file(fp: Path) -> Generator[SbsEntry, None, None]:
60+
with open(fp) as f:
61+
f.readline()
62+
lines = takewhile(lambda x: not x.startswith("--------"), f.readlines())
63+
for line in lines:
64+
yield SbsEntry.from_line(line)
65+
66+
67+
class FSTState:
68+
def __init__(self):
69+
self.state: int = 0
70+
self.vocabulary: OrderedDict = OrderedDict({"<eps>": 0})
71+
self.unique_id: int = 0
72+
73+
def update_vocabulary(self, word):
74+
"""If `word` isn't in `self.vocabulary` add it
75+
with it's own unique id."""
76+
if word not in self.vocabulary:
77+
self.vocabulary[word] = len(self.vocabulary)
78+
79+
def get_uid(self):
80+
"""Return the `self.unique_id` and increment it
81+
by one."""
82+
current_uid = self.unique_id
83+
self.unique_id += 1
84+
return current_uid
85+
86+
87+
def init_args():
88+
parser = ArgumentParser(description="SBS to FST")
89+
parser.add_argument("sbs_file", type=Path, help="The input SBS file")
90+
parser.add_argument("fst_file", type=Path, help="The output FST file")
91+
parser.add_argument(
92+
"--left",
93+
type=str,
94+
default="LEFT",
95+
help="Label for the left column. This label will be given to "
96+
"words that occur on the left (reference) side of the SBS "
97+
"during an ERR.",
98+
)
99+
parser.add_argument(
100+
"--right",
101+
type=str,
102+
default="RIGHT",
103+
help="Label for the right column. This label will be given to "
104+
"words that occur on the right (hypothesis) side of the SBS "
105+
"during an ERR.",
106+
)
107+
parser.add_argument(
108+
"--gold",
109+
type=str,
110+
default="GOLD",
111+
help="Label for the gold column. This is for words that both "
112+
"transcripts agree upon in the SBS.",
113+
)
114+
parser.add_argument(
115+
"--tag",
116+
action="store_true",
117+
help="If set, the script will add extra tagging information",
118+
)
119+
return parser.parse_args()
120+
121+
122+
def prepare_IO(
123+
input: Path,
124+
output: Path,
125+
):
126+
"""Determines if the input is a directory or file and prepares output accordingly"""
127+
input_files = []
128+
output_files = []
129+
if input.is_dir():
130+
output.mkdir(parents=True, exist_ok=True)
131+
for file in input.glob("**/*.txt"):
132+
input_files.append(file)
133+
output_files.append(output / file.stem)
134+
else:
135+
input_files = [input]
136+
output_files = [output]
137+
return input_files, output_files
138+
139+
140+
def _to_fst_line(state1, state2, arc, weight: float=0):
141+
return f"{state1} {state2} {arc} {arc} {weight}"
142+
143+
144+
def flush_span(
145+
span: List[str], state: int, *, tag: Optional[str] = None, branch_factor: int = 0
146+
) -> Tuple[List[str], int]:
147+
"""Flush the span by generating the relevant fst lines. If `tag`
148+
is set add surrounding fst lines to correspond to the tag.
149+
`branch_factor` can also be set to increase the initial transition
150+
from the tag state to the first span state (SHOULD ONLY BE USED IN
151+
COMBINATION WITH `tag`).
152+
The primary use of the `branch_factor` is for the right side during
153+
a disagreement -- you want the first right-side arc to go from the same
154+
start as the left-side to a new state that isn't used by the left side
155+
at all. So by specifying the `branch_factor` you can "skip" states.
156+
In the context of a disagreement, the left-side will have 0 `branch_factor`
157+
while the right-side must have a `branch_factor` the size of left-side
158+
length.
159+
"""
160+
if len(span) == 0:
161+
return [], state
162+
163+
span_state = state + branch_factor + 1
164+
if tag:
165+
fst_lines = [_to_fst_line(state, span_state, tag)]
166+
else:
167+
fst_lines = [_to_fst_line(state, span_state, span[0])]
168+
span = span[1:]
169+
170+
for token in span:
171+
fst_lines.append(_to_fst_line(span_state, span_state + 1, token))
172+
span_state += 1
173+
174+
if tag:
175+
fst_lines.append(_to_fst_line(span_state, span_state + 1, tag))
176+
span_state += 1
177+
return fst_lines, span_state
178+
179+
180+
def agreement_flush(
181+
gold_span: List[str], fst_state: FSTState, *, tag: bool = False, gold: Optional[str] = None
182+
) -> List[str]:
183+
"""Flush "gold" spans when both sides of the sbs agree and update the FSTState.
184+
If `tag` is True, adds a unique tag around the span using `gold` to label.
185+
"""
186+
gold_tag = None
187+
if tag:
188+
gold_tag = f"___MULTIREF:{fst_state.get_uid()}_{gold}___"
189+
fst_state.update_vocabulary(gold_tag)
190+
191+
gold_fst_lines, new_state = flush_span(gold_span, fst_state.state, tag=gold_tag)
192+
fst_state.state = new_state
193+
194+
return gold_fst_lines
195+
196+
197+
def disagreement_flush(
198+
left_span: List[str],
199+
right_span: List[str],
200+
fst_state: FSTState,
201+
*,
202+
tag: bool = False,
203+
left: Optional[str] = None,
204+
right: Optional[str] = None,
205+
) -> List[str]:
206+
"""Flush the left and right spans when transcripts disagree and update the FSTState.
207+
If `tag` is True, adds a unique tag around the left span using `left` to label and
208+
around the right span using `right` to label.
209+
"""
210+
fst_lines = []
211+
212+
left_tag = None
213+
if tag:
214+
left_tag = f"___MULTIREF:{fst_state.get_uid()}_{left}___"
215+
fst_state.update_vocabulary(left_tag)
216+
217+
left_fst_lines, left_end_state = flush_span(left_span, fst_state.state, tag=left_tag)
218+
fst_lines.extend(left_fst_lines)
219+
220+
right_tag = None
221+
if tag:
222+
right_tag = f"___MULTIREF:{fst_state.get_uid()}_{right}___"
223+
fst_state.update_vocabulary(right_tag)
224+
225+
right_fst_lines, right_end_state = flush_span(
226+
right_span, fst_state.state, tag=right_tag, branch_factor=len(left_fst_lines)
227+
)
228+
fst_lines.extend(right_fst_lines)
229+
230+
max_end_state = max(left_end_state, right_end_state)
231+
# We have to return both paths back to same state to progress
232+
fst_lines.append(_to_fst_line(left_end_state, max_end_state + 1, "<eps>"))
233+
fst_lines.append(_to_fst_line(right_end_state, max_end_state + 1, "<eps>"))
234+
235+
fst_state.state = max_end_state + 1
236+
237+
return fst_lines
238+
239+
240+
def sbs2fst(
241+
sbs_file: Path,
242+
*,
243+
tag: bool = False,
244+
gold: Optional[str] = None,
245+
left: Optional[str] = None,
246+
right: Optional[str] = None,
247+
) -> Tuple[List[str], Dict[str, int]]:
248+
"""Given an `sbs_file` create the equivalent fst object.
249+
Optionally you can include tags by setting `tag` to true. These will be distinguished by the
250+
tag labels provided in `gold` (agreements), `left` (words on the reference side of the sbs not in hypothesis),
251+
and `right` (words on the hypothesis side of the sbs not in reference).
252+
"""
253+
sbs = load_from_file(sbs_file)
254+
255+
fst_state = FSTState()
256+
fst_lines = []
257+
258+
left_span = []
259+
right_span = []
260+
gold_span = []
261+
for row_idx, row in enumerate(sbs):
262+
ref_word = "<eps>" if row.ref_word == "<ins>" else row.ref_word
263+
hyp_word = "<eps>" if row.hyp_word == "<del>" else row.hyp_word
264+
265+
fst_state.update_vocabulary(ref_word)
266+
fst_state.update_vocabulary(hyp_word)
267+
268+
if row.ref_word == row.hyp_word:
269+
# First flush the left & right spans to empty any disagreements
270+
if len(left_span) > 0 or len(right_span) > 0:
271+
disagreement_fst_lines = disagreement_flush(
272+
left_span, right_span, fst_state, tag=tag, left=left, right=right
273+
)
274+
fst_lines.extend(disagreement_fst_lines)
275+
left_span = []
276+
right_span = []
277+
gold_span.append(row.ref_word)
278+
else:
279+
# First flush the gold span to empty any agreements
280+
if len(gold_span) > 0:
281+
gold_fst_lines = agreement_flush(gold_span, fst_state, tag=tag, gold=gold)
282+
fst_lines.extend(gold_fst_lines)
283+
gold_span = []
284+
285+
if ref_word != "<eps>":
286+
left_span.append(ref_word)
287+
if hyp_word != "<eps>":
288+
right_span.append(hyp_word)
289+
290+
# Flush the spans that have infomration. It'll only be a gold or a disagreement. Not both
291+
if len(gold_span) > 0:
292+
gold_fst_lines = agreement_flush(gold_span, fst_state, tag=tag, gold=gold)
293+
fst_lines.extend(gold_fst_lines)
294+
elif len(left_span) > 0 or len(right_span) > 0:
295+
disagreement_fst_lines = disagreement_flush(
296+
left_span, right_span, fst_state, tag=tag, left=left, right=right
297+
)
298+
fst_lines.extend(disagreement_fst_lines)
299+
300+
fst_lines.append(f"{fst_state.state}")
301+
302+
return fst_lines, fst_state.vocabulary
303+
304+
305+
def main(
306+
sbs_file: Path,
307+
fst_file: Path,
308+
tag: bool = False,
309+
gold: Optional[str] = None,
310+
left: Optional[str] = None,
311+
right: Optional[str] = None,
312+
):
313+
for inpath, outpath in zip(*prepare_IO(sbs_file, fst_file)):
314+
fst_lines, vocabulary = sbs2fst(inpath, tag=tag, gold=gold, left=left, right=right)
315+
316+
with open(f"{outpath}.fst", "w") as fstfile:
317+
fstfile.write("\n".join(fst_lines))
318+
319+
with open(f"{outpath}.txt", "w") as fstfile:
320+
for key, value in vocabulary.items():
321+
fstfile.write(f"{key} {value}\n")
322+
323+
324+
if __name__ == "__main__":
325+
args = init_args()
326+
main(**vars(args))

0 commit comments

Comments
 (0)