@@ -242,15 +242,15 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
242242 stitches.emplace_back ();
243243 Stitching &part = stitches.back ();
244244 part.classLabel = tk_classLabel;
245- part.reftk = ref_tk;
246- part.hyptk = hyp_tk;
245+ part.reftk = { ref_tk} ;
246+ part.hyptk = { hyp_tk} ;
247247 bool del = false , ins = false , sub = false ;
248248 if (ref_tk == INS) {
249249 part.comment = " ins" ;
250250 } else if (hyp_tk == DEL) {
251251 part.comment = " del" ;
252252 } else if (hyp_tk != ref_tk) {
253- part.comment = " sub(" + part.hyptk + " )" ;
253+ part.comment = " sub(" + part.hyptk . token + " )" ;
254254 }
255255
256256 // for classes, we will have only one token in the global vector
@@ -281,10 +281,10 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
281281
282282 if (!hyp_ctm_rows.empty ()) {
283283 auto ctmPart = hyp_ctm_rows[hypRowIndex];
284- part.start_ts = ctmPart.start_time_secs ;
285- part.duration = ctmPart.duration_secs ;
286- part.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs ;
287- part.confidence = ctmPart.confidence ;
284+ part.hyptk . start_ts = ctmPart.start_time_secs ;
285+ part.hyptk . duration = ctmPart.duration_secs ;
286+ part.hyptk . end_ts = ctmPart.start_time_secs + ctmPart.duration_secs ;
287+ part.hyptk . confidence = ctmPart.confidence ;
288288
289289 part.hyp_orig = ctmPart.word ;
290290 // sanity check
@@ -308,21 +308,24 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
308308 float ts = stof (hypNlpPart.ts );
309309 float endTs = stof (hypNlpPart.endTs );
310310
311- part.start_ts = ts;
312- part.end_ts = endTs;
313- part.duration = endTs - ts;
311+ part.hyptk . start_ts = ts;
312+ part.hyptk . end_ts = endTs;
313+ part.hyptk . duration = endTs - ts;
314314 } else if (!hypNlpPart.ts .empty ()) {
315315 float ts = stof (hypNlpPart.ts );
316316
317- part.start_ts = ts;
318- part.end_ts = ts;
319- part.duration = 0.0 ;
317+ part.hyptk . start_ts = ts;
318+ part.hyptk . end_ts = ts;
319+ part.hyptk . duration = 0.0 ;
320320 } else if (!hypNlpPart.endTs .empty ()) {
321321 float endTs = stof (hypNlpPart.endTs );
322322
323- part.start_ts = endTs;
324- part.end_ts = endTs;
325- part.duration = 0.0 ;
323+ part.hyptk .start_ts = endTs;
324+ part.hyptk .end_ts = endTs;
325+ part.hyptk .duration = 0.0 ;
326+ }
327+ if (!hypNlpPart.confidence .empty ()) {
328+ part.hyptk .confidence = stof (hypNlpPart.confidence );
326329 }
327330 }
328331
@@ -575,15 +578,15 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
575578 // if the comment starts with 'ins'
576579 if (stitch.comment .find (" ins" ) == 0 && !add_inserts) {
577580 // there's no nlp row info for such case, let's skip over it
578- if (stitch.confidence >= 1 ) {
579- logger->warn (" an insertion with high confidence was found for {}@{}" , stitch.hyptk , stitch.start_ts );
581+ if (stitch.hyptk . confidence >= 1 ) {
582+ logger->warn (" an insertion with high confidence was found for {}@{}" , stitch.hyptk . token , stitch. hyptk .start_ts );
580583 }
581584
582585 continue ;
583586 }
584587
585588 string original_nlp_token = stitch.nlpRow .token ;
586- string ref_tk = stitch.reftk ;
589+ string ref_tk = stitch.reftk . token ;
587590
588591 // trying to salvage some of the original punctuation in a relatively safe manner
589592 if (iequals (ref_tk, original_nlp_token)) {
@@ -597,21 +600,21 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
597600 ref_tk = original_nlp_token;
598601 } else if (stitch.comment .find (" ins" ) == 0 ) {
599602 assert (add_inserts);
600- logger->debug (" an insertion was found for {} {}" , stitch.hyptk , stitch.comment );
603+ logger->debug (" an insertion was found for {} {}" , stitch.hyptk . token , stitch.comment );
601604 ref_tk = " " ;
602- stitch.comment = " ins(" + stitch.hyptk + " )" ;
605+ stitch.comment = " ins(" + stitch.hyptk . token + " )" ;
603606 }
604607
605608 if (ref_tk == NOOP) {
606609 continue ;
607610 }
608611
609612 output_nlp_file << ref_tk << " |" << stitch.nlpRow .speakerId << " |" ;
610- if (stitch.hyptk == DEL) {
613+ if (stitch.hyptk . token == DEL) {
611614 // we have no ts/endTs data to put...
612615 output_nlp_file << " ||" ;
613616 } else {
614- output_nlp_file << fmt::format (" {0:.4f}" , stitch.start_ts ) << " |" << fmt::format (" {0:.4f}" , stitch.end_ts )
617+ output_nlp_file << fmt::format (" {0:.4f}" , stitch.hyptk . start_ts ) << " |" << fmt::format (" {0:.4f}" , stitch. hyptk .end_ts )
615618 << " |" ;
616619 }
617620
@@ -632,7 +635,7 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
632635}
633636
634637void HandleWer (FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
635- AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case) {
638+ AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case, std::vector<string> ref_extra_columns, std::vector<string> hyp_extra_columns ) {
636639 // int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename,
637640 // string composition_approach, bool record_case_stats) {
638641 auto logger = logger::GetOrCreateLogger (" fstalign" );
@@ -698,7 +701,7 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine
698701 JsonLogUnigramBigramStats (topAlignment);
699702 if (!output_sbs.empty ()) {
700703 logger->info (" output_sbs = {}" , output_sbs);
701- WriteSbs (topAlignment, stitches, output_sbs);
704+ WriteSbs (topAlignment, stitches, output_sbs, ref_extra_columns, hyp_extra_columns );
702705 }
703706
704707 if (!output_nlp.empty () && !nlp_ref_loader) {
@@ -720,3 +723,15 @@ void HandleAlign(NlpFstLoader& refLoader, CtmFstLoader& hypLoader, SynonymEngine
720723 align_stitches_to_nlp (refLoader, stitches);
721724 write_stitches_to_nlp (stitches, output_nlp_file, refLoader.mJsonNorm );
722725}
726+
727+ string GetTokenPropertyAsString (Stitching stitch, bool refToken, string property) {
728+ std::unordered_map<std::string, std::function<string (Token)>> col_name_to_val = {
729+ {" speaker" , [](Token tk) {return tk.speaker ;}},
730+ {" ts" , [](Token tk) {return to_string (tk.start_ts );}},
731+ {" endTs" , [](Token tk) {return to_string (tk.end_ts );}},
732+ {" confidence" , [](Token tk) {return to_string (tk.confidence );}},
733+ };
734+ if (refToken) return col_name_to_val[property](stitch.reftk );
735+ if (!refToken) return col_name_to_val[property](stitch.hyptk );
736+ return " " ;
737+ }
0 commit comments