Skip to content

Commit 3446afd

Browse files
authored
Allow flexible forwarding of NLP columns to SBS (#56)
* initial attempt * refactor token properties inside stitching * read from CLI * fix test * handle empty confidence * set confidence when present * use pinned kaldi dockerhub image * update version
1 parent dcf2655 commit 3446afd

7 files changed

Lines changed: 89 additions & 50 deletions

File tree

Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Using kaldi image for pre-built OpenFST, version is 1.7.2
2-
FROM kaldiasr/kaldi:latest as kaldi-base
2+
FROM kaldiasr/kaldi:cpu-debian10-2024-07-29 as kaldi-base
3+
34
FROM debian:11
45

56
COPY --from=kaldi-base /opt/kaldi/tools/openfst /opt/openfst

src/fstalign.cpp

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,15 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
242242
stitches.emplace_back();
243243
Stitching &part = stitches.back();
244244
part.classLabel = tk_classLabel;
245-
part.reftk = ref_tk;
246-
part.hyptk = hyp_tk;
245+
part.reftk = {ref_tk};
246+
part.hyptk = {hyp_tk};
247247
bool del = false, ins = false, sub = false;
248248
if (ref_tk == INS) {
249249
part.comment = "ins";
250250
} else if (hyp_tk == DEL) {
251251
part.comment = "del";
252252
} else if (hyp_tk != ref_tk) {
253-
part.comment = "sub(" + part.hyptk + ")";
253+
part.comment = "sub(" + part.hyptk.token + ")";
254254
}
255255

256256
// for classes, we will have only one token in the global vector
@@ -281,10 +281,10 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
281281

282282
if (!hyp_ctm_rows.empty()) {
283283
auto ctmPart = hyp_ctm_rows[hypRowIndex];
284-
part.start_ts = ctmPart.start_time_secs;
285-
part.duration = ctmPart.duration_secs;
286-
part.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs;
287-
part.confidence = ctmPart.confidence;
284+
part.hyptk.start_ts = ctmPart.start_time_secs;
285+
part.hyptk.duration = ctmPart.duration_secs;
286+
part.hyptk.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs;
287+
part.hyptk.confidence = ctmPart.confidence;
288288

289289
part.hyp_orig = ctmPart.word;
290290
// sanity check
@@ -308,21 +308,24 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
308308
float ts = stof(hypNlpPart.ts);
309309
float endTs = stof(hypNlpPart.endTs);
310310

311-
part.start_ts = ts;
312-
part.end_ts = endTs;
313-
part.duration = endTs - ts;
311+
part.hyptk.start_ts = ts;
312+
part.hyptk.end_ts = endTs;
313+
part.hyptk.duration = endTs - ts;
314314
} else if (!hypNlpPart.ts.empty()) {
315315
float ts = stof(hypNlpPart.ts);
316316

317-
part.start_ts = ts;
318-
part.end_ts = ts;
319-
part.duration = 0.0;
317+
part.hyptk.start_ts = ts;
318+
part.hyptk.end_ts = ts;
319+
part.hyptk.duration = 0.0;
320320
} else if (!hypNlpPart.endTs.empty()) {
321321
float endTs = stof(hypNlpPart.endTs);
322322

323-
part.start_ts = endTs;
324-
part.end_ts = endTs;
325-
part.duration = 0.0;
323+
part.hyptk.start_ts = endTs;
324+
part.hyptk.end_ts = endTs;
325+
part.hyptk.duration = 0.0;
326+
}
327+
if (!hypNlpPart.confidence.empty()) {
328+
part.hyptk.confidence = stof(hypNlpPart.confidence);
326329
}
327330
}
328331

@@ -575,15 +578,15 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
575578
// if the comment starts with 'ins'
576579
if (stitch.comment.find("ins") == 0 && !add_inserts) {
577580
// there's no nlp row info for such case, let's skip over it
578-
if (stitch.confidence >= 1) {
579-
logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk, stitch.start_ts);
581+
if (stitch.hyptk.confidence >= 1) {
582+
logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk.token, stitch.hyptk.start_ts);
580583
}
581584

582585
continue;
583586
}
584587

585588
string original_nlp_token = stitch.nlpRow.token;
586-
string ref_tk = stitch.reftk;
589+
string ref_tk = stitch.reftk.token;
587590

588591
// trying to salvage some of the original punctuation in a relatively safe manner
589592
if (iequals(ref_tk, original_nlp_token)) {
@@ -597,21 +600,21 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
597600
ref_tk = original_nlp_token;
598601
} else if (stitch.comment.find("ins") == 0) {
599602
assert(add_inserts);
600-
logger->debug("an insertion was found for {} {}", stitch.hyptk, stitch.comment);
603+
logger->debug("an insertion was found for {} {}", stitch.hyptk.token, stitch.comment);
601604
ref_tk = "";
602-
stitch.comment = "ins(" + stitch.hyptk + ")";
605+
stitch.comment = "ins(" + stitch.hyptk.token + ")";
603606
}
604607

605608
if (ref_tk == NOOP) {
606609
continue;
607610
}
608611

609612
output_nlp_file << ref_tk << "|" << stitch.nlpRow.speakerId << "|";
610-
if (stitch.hyptk == DEL) {
613+
if (stitch.hyptk.token == DEL) {
611614
// we have no ts/endTs data to put...
612615
output_nlp_file << "||";
613616
} else {
614-
output_nlp_file << fmt::format("{0:.4f}", stitch.start_ts) << "|" << fmt::format("{0:.4f}", stitch.end_ts)
617+
output_nlp_file << fmt::format("{0:.4f}", stitch.hyptk.start_ts) << "|" << fmt::format("{0:.4f}", stitch.hyptk.end_ts)
615618
<< "|";
616619
}
617620

@@ -632,7 +635,7 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
632635
}
633636

634637
void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
635-
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case) {
638+
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case, std::vector<string> ref_extra_columns, std::vector<string> hyp_extra_columns) {
636639
// int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename,
637640
// string composition_approach, bool record_case_stats) {
638641
auto logger = logger::GetOrCreateLogger("fstalign");
@@ -698,7 +701,7 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine
698701
JsonLogUnigramBigramStats(topAlignment);
699702
if (!output_sbs.empty()) {
700703
logger->info("output_sbs = {}", output_sbs);
701-
WriteSbs(topAlignment, stitches, output_sbs);
704+
WriteSbs(topAlignment, stitches, output_sbs, ref_extra_columns, hyp_extra_columns);
702705
}
703706

704707
if (!output_nlp.empty() && !nlp_ref_loader) {
@@ -720,3 +723,15 @@ void HandleAlign(NlpFstLoader& refLoader, CtmFstLoader& hypLoader, SynonymEngine
720723
align_stitches_to_nlp(refLoader, stitches);
721724
write_stitches_to_nlp(stitches, output_nlp_file, refLoader.mJsonNorm);
722725
}
726+
727+
string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property) {
728+
std::unordered_map<std::string, std::function<string(Token)>> col_name_to_val = {
729+
{"speaker", [](Token tk) {return tk.speaker;}},
730+
{"ts", [](Token tk) {return to_string(tk.start_ts);}},
731+
{"endTs", [](Token tk) {return to_string(tk.end_ts);}},
732+
{"confidence", [](Token tk) {return to_string(tk.confidence);}},
733+
};
734+
if (refToken) return col_name_to_val[property](stitch.reftk);
735+
if (!refToken) return col_name_to_val[property](stitch.hyptk);
736+
return "";
737+
}

src/fstalign.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,21 @@ fstalign.h
1515
using namespace std;
1616
using namespace fst;
1717

18+
// Represent information associated with a reference or hypothesis token
19+
struct Token {
20+
string token;
21+
float start_ts=0.0;
22+
float end_ts=0.0;
23+
float duration=0.0;
24+
float confidence=-1.0;
25+
string speaker;
26+
};
27+
1828
// Stitchings will be used to represent fstalign output, combining reference,
1929
// hypothesis, and error information into a record-like data structure.
2030
struct Stitching {
21-
string reftk;
22-
string hyptk;
23-
float start_ts;
24-
float end_ts;
25-
float duration;
26-
float confidence;
31+
Token reftk;
32+
Token hyptk;
2733
string classLabel;
2834
RawNlpRecord nlpRow;
2935
string hyp_orig;
@@ -42,17 +48,12 @@ struct AlignerOptions {
4248
int levenstein_maximum_error_streak = 100;
4349
};
4450

45-
// original
46-
// void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine, string output_sbs, string
47-
// output_nlp,
48-
// int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename,
49-
// string composition_approach, bool record_case_stats);
50-
// void HandleAlign(NlpFstLoader *refLoader, CtmFstLoader *hypLoader, SynonymEngine *engine, ofstream &output_nlp_file,
51-
// int numBests, string symbols_filename, string composition_approach);
5251

5352
void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
54-
AlignerOptions alignerOptions, bool add_inserts_nlp = false, bool use_case = false);
53+
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case, std::vector<string> ref_extra_columns, std::vector<string> hyp_extra_columns);
5554
void HandleAlign(NlpFstLoader &refLoader, CtmFstLoader &hypLoader, SynonymEngine &engine, ofstream &output_nlp_file,
5655
AlignerOptions alignerOptions);
5756

57+
string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property);
58+
5859
#endif // __FSTALIGN_H__

src/main.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ int main(int argc, char **argv) {
4040
bool disable_cutoffs = false;
4141
bool disable_hyphen_ignore = false;
4242

43+
std::vector<string> ref_extra_columns = std::vector<string>();
44+
std::vector<string> hyp_extra_columns = std::vector<string>();
45+
4346
CLI::App app("Rev FST Align");
4447
app.set_help_all_flag("--help-all", "Expand all help");
4548
app.add_flag("--version", version, "Show fstalign version.");
@@ -97,6 +100,10 @@ int main(int argc, char **argv) {
97100

98101
c->add_option("--composition-approach", composition_approach,
99102
"Desired composition logic. Choices are 'standard' or 'adapted'");
103+
c->add_option("--ref-extra-cols", ref_extra_columns,
104+
"Extra columns from the reference to include in SBS output.");
105+
c->add_option("--hyp-extra-cols", hyp_extra_columns,
106+
"Extra columns from the hypothesis to include in SBS output.");
100107
}
101108
get_wer->add_option("--wer-sidecar", wer_sidecar_filename,
102109
"WER sidecar json file.");
@@ -180,7 +187,7 @@ int main(int argc, char **argv) {
180187
}
181188

182189
if (command == "wer") {
183-
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp, use_case);
190+
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp, use_case, ref_extra_columns, hyp_extra_columns);
184191
} else if (command == "align") {
185192
if (output_nlp.empty()) {
186193
console->error("the output nlp file must be specified");

src/version.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
22

33
#define FSTALIGNER_VERSION_MAJOR 1
4-
#define FSTALIGNER_VERSION_MINOR 13
4+
#define FSTALIGNER_VERSION_MINOR 14
55
#define FSTALIGNER_VERSION_PATCH 0

src/wer.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ void RecordCaseWer(const vector<Stitching> &aligned_stitches) {
262262
for (const auto &stitch : aligned_stitches) {
263263
const string &hyp = stitch.hyp_orig;
264264
const string &ref = stitch.nlpRow.token;
265-
const string &reftk = stitch.reftk;
266-
const string &hyptk = stitch.hyptk;
265+
const string &reftk = stitch.reftk.token;
266+
const string &hyptk = stitch.hyptk.token;
267267
const string &ref_casing = stitch.nlpRow.casing;
268268

269269
if (hyptk == DEL || reftk == INS) {
@@ -526,7 +526,7 @@ void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp)
526526
hyp = "";
527527
}
528528

529-
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename) {
529+
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename, const vector<string> extra_ref_columns, const vector<string> extra_hyp_columns) {
530530
auto logger = logger::GetOrCreateLogger("wer");
531531
logger->set_level(spdlog::level::info);
532532

@@ -536,7 +536,14 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
536536
AlignmentTraversor visitor(topAlignment);
537537
string prev_tk_classLabel = "";
538538
logger->info("Side-by-Side alignment info going into {}", sbs_filename);
539-
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities") << endl;
539+
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities");
540+
for (string col_name: extra_ref_columns) {
541+
myfile << fmt::format("\tref_{0}", col_name);
542+
}
543+
for (string col_name: extra_hyp_columns) {
544+
myfile << fmt::format("\thyp_{0}", col_name);
545+
}
546+
myfile << endl;
540547

541548
// keep track of error groupings
542549
ErrorGroups groups_err;
@@ -554,8 +561,8 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
554561
for (auto wer_tag: wer_tags) {
555562
tk_wer_tags = tk_wer_tags + "###" + wer_tag.tag_id + "_" + wer_tag.entity_type + "###|";
556563
}
557-
string ref_tk = p_stitch.reftk;
558-
string hyp_tk = p_stitch.hyptk;
564+
string ref_tk = p_stitch.reftk.token;
565+
string hyp_tk = p_stitch.hyptk.token;
559566
string tag = "";
560567

561568
if (ref_tk == NOOP) {
@@ -587,7 +594,15 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
587594
eff_class = tk_classLabel;
588595
}
589596

590-
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags) << endl;
597+
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags);
598+
599+
for (string col_name: extra_ref_columns) {
600+
myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, true, col_name));
601+
}
602+
for (string col_name: extra_hyp_columns) {
603+
myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, false, col_name));
604+
}
605+
myfile << endl;
591606
offset++;
592607
}
593608

src/wer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,5 @@ void CalculatePrecisionRecall(wer_alignment &topAlignment, int threshold);
4949
typedef vector<pair<size_t, string>> ErrorGroups;
5050

5151
void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp);
52-
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename);
52+
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename, const vector<string> extra_ref_columns, const vector<string> extra_hyp_columns);
5353
void JsonLogUnigramBigramStats(wer_alignment &topAlignment);

0 commit comments

Comments
 (0)