Skip to content

Commit 82dec1d

Browse files
authored
dynamic pruning and substitution restrictions (#60)
* massive speed improvement, with some WER improvements too in large complicated files with lots of deletions * adding a flexible beam, adjustable via command line parameters * adding --strict-punctuation mode, that will allow punctuation marks to be substituted only within themselves * adding support for strict punctuation and favouring same words in alignments * fixing the default value of the beam to 50 * adding new test case file * adding new result for std composition * fixes, and new test cases updates * bumping version and adding release notes to the readme
1 parent b10f0ea commit 82dec1d

28 files changed

Lines changed: 793 additions & 213 deletions

README.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# fstalign
55
- [Overview](#Overview)
6+
- [What's new in 2.0](#What's-new-in-2.0)
67
- [Installation](#Installation)
78
* [Dependencies](#Dependencies)
89
* [Build](#Build)
@@ -14,6 +15,97 @@
1415

1516
Due to its use of OpenFST and lazy algorithms for text-based alignment, `fstalign` is efficient for calculating WER while also providing significant flexibility for different measurement features and error analysis.
1617

18+
## What's new in 2.0
19+
20+
Version 2.0 introduces two major changes:
21+
1. A new method to traverse the composition graph, which dramatically improves the overall speed, especially when the sequences are long contain many errors.
22+
We have files that took 25 minutes to align before that can now take about 7 seconds. This is especially noticeable with the adapted composition (the default).
23+
1. Some smarts were introduced when --use-case and --use-punctuation are enabled.
24+
Now, by default, punctuation symbols can only be substituted by other punctuation symbols (or deleted/inserted).
25+
Also, words that differ only by the first letter case will be preffered for substitution.
26+
27+
28+
Here's an example of the 1.x behavior and the 2.0 version
29+
```
30+
==> v1.x sbs.txt <==
31+
ref_token hyp_token IsErr Class Wer_Tag_Entities
32+
Welcome Welcome ###322_###|
33+
back back
34+
to to
35+
another another
36+
episode episode ###323_###|
37+
of of
38+
Podcasts Podcast ERR ###324_###|
39+
in and ERR
40+
Color Color ###167_###|###325_###|
41+
: of ERR
42+
The the ERR
43+
Podcast Podcast ###168_###|###326_###|
44+
. .
45+
I I
46+
47+
==> v2.0 sbs.txt <==
48+
ref_token hyp_token IsErr Class Wer_Tag_Entities
49+
Welcome Welcome ###322_###|
50+
back back
51+
to to
52+
another another
53+
episode episode ###323_###|
54+
of of
55+
Podcasts Podcast ERR ###324_###|
56+
in and ERR
57+
Color Color ###167_###|###325_###|
58+
<ins> of ERR
59+
: <del> ERR
60+
The the ERR
61+
Podcast Podcast ###168_###|###326_###|
62+
```
63+
The confusion between `:` and `of` is not longer allowed.
64+
65+
Also, here's how favoring or not the substitution based on case-insensitive comparison, while still counting it as an error, looks like:
66+
```
67+
==> v1.x sbs.txt <==
68+
ref_token hyp_token IsErr Class Wer_Tag_Entities
69+
shorten shorten ###801_###|
70+
It's it's ERR
71+
Berry Barry ERR ###785_###|###788_###|###802_###|
72+
. .
73+
Just Just
74+
Yeah like ERR ###805_###|
75+
. <del> ERR
76+
Like <del> ERR
77+
, <del> ERR
78+
I I ###809_###|
79+
have have
80+
a a
81+
nickname nickname
82+
83+
==> v2.0 sbs.txt <==
84+
ref_token hyp_token IsErr Class Wer_Tag_Entities
85+
It's it's ERR
86+
Berry Barry ERR ###785_###|###788_###|###802_###|
87+
. .
88+
Just Just
89+
Yeah <del> ERR ###805_###|
90+
. <del> ERR
91+
Like like ERR
92+
, <del> ERR
93+
I I ###809_###|
94+
have have
95+
a a
96+
nickname nickname
97+
```
98+
Here, `Like <-> like` substitution is favored. While this generally won't change the WER value itself (although it can), it will improve the timing alignments.
99+
100+
101+
These behavior, as well as the beam size (that has a default value of 50.0) can be controlled with the following new parameters:
102+
```
103+
--disable-strict-punctuation
104+
Disable strict punctuation alignment (which prevents punctuation aligning with words).
105+
--disable-favored-subs Disable favored substitutions (which makes alignment favor substitutions between words which differ only by case).
106+
--favored-sub-cost FLOAT Cost for favored substitutions (e.g., case diff). Default: 0.1
107+
```
108+
17109
## Installation
18110

19111
### Dependencies

src/AdaptedComposition.cpp

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ AdaptedCompositionFst::AdaptedCompositionFst(const fst::StdFst &fstA, const fst:
2121
#endif
2222
}
2323

24-
AdaptedCompositionFst::AdaptedCompositionFst(const fst::StdFst &fstA, const fst::StdFst &fstB, SymbolTable &symbols)
24+
AdaptedCompositionFst::AdaptedCompositionFst(const fst::StdFst &fstA, const fst::StdFst &fstB, const SymbolTable &symbols)
2525
: fstA_{fstA}, fstB_{fstB} {
2626
logger_ = logger::GetOrCreateLogger("AdaptedCompositionFst");
2727
logger_->set_level(spdlog::level::info);
@@ -36,6 +36,35 @@ AdaptedCompositionFst::AdaptedCompositionFst(const fst::StdFst &fstA, const fst:
3636
ins_label_id_ = symbols.Find(options.symIns);
3737
}
3838

39+
AdaptedCompositionFst::AdaptedCompositionFst(const fst::StdFst &fstA, const fst::StdFst &fstB, const SymbolTable &symbols, const AlignerOptions& options)
40+
: fstA_(fstA), // Reference to input FST A
41+
fstB_(fstB), // Reference to input FST B
42+
symbols_(&symbols), // Pointer to symbol table
43+
strict_punctuation_(options.strict_punctuation), // Store strict punctuation flag
44+
punctuation_ids_(options.punctuation_ids),
45+
// Favored substitutions
46+
use_favored_substitutions_(options.use_favored_substitutions),
47+
favored_substitution_cost_(options.favored_substitution_cost),
48+
favorable_substitution_map_(options.favorable_substitution_map)
49+
// Initialize other members if they exist (e.g., current_composed_next_state_id = 0;)
50+
{
51+
logger_ = logger::GetOrCreateLogger("AdaptedCompositionFst"); // Use member logger_ if declared
52+
logger_->set_level(spdlog::level::info);
53+
#if TRACE
54+
logger_->set_level(spdlog::level::trace);
55+
#endif
56+
57+
// Initialize special symbol IDs (assuming these are member variables now)
58+
FstAlignOption fst_options; // Contains special symbol names (symSub etc)
59+
// Ensure sub_label_id_ etc are members if used elsewhere
60+
sub_label_id_ = symbols_->Find(fst_options.symSub);
61+
del_label_id_ = symbols_->Find(fst_options.symDel);
62+
ins_label_id_ = symbols_->Find(fst_options.symIns);
63+
64+
// Initialize entity/synonym label vectors
65+
SetSymbols(&symbols);
66+
}
67+
3968
AdaptedCompositionFst::~AdaptedCompositionFst() {}
4069

4170
bool AdaptedCompositionFst::IsEntityLabel(int labelId) {
@@ -349,8 +378,40 @@ bool AdaptedCompositionFst::TryGetArcsAtState(StateId fromStateId, vector<fst::S
349378
#if TRACE
350379
logger_->trace("{}/{} >] adding sub/{}/{}", dbg_count, here_snap, arcA.ilabel, arcB.olabel);
351380
#endif
352-
out_vector->push_back(StdArc(arcA.ilabel, arcB.olabel, substitution_cost, sub_state_ref_id));
353-
arc_added++;
381+
// --- Strict Punctuation Check ---
382+
bool skip_substitution = false;
383+
if (strict_punctuation_) {
384+
bool ilabel_is_punct = (punctuation_ids_.count(arcA.ilabel) > 0);
385+
bool olabel_is_punct = (punctuation_ids_.count(arcB.olabel) > 0); // Check arcB.olabel for hyp side?
386+
387+
if (ilabel_is_punct != olabel_is_punct) {
388+
skip_substitution = true;
389+
}
390+
}
391+
// --- End Strict Punctuation Check ---
392+
393+
if (!skip_substitution) {
394+
// --- Favored Substitution Cost Check ---
395+
float current_sub_cost = 1.0f; // Default substitution cost
396+
if (use_favored_substitutions_) {
397+
int labelA = arcA.ilabel;
398+
int labelB = arcB.olabel;
399+
// Check bounds and if labelA has a favored partner which is labelB
400+
if (labelA >= 0 && labelA < favorable_substitution_map_.size() &&
401+
favorable_substitution_map_[labelA] == labelB)
402+
{
403+
current_sub_cost = favored_substitution_cost_; // Use lower cost
404+
#if TRACE
405+
logger_->trace("Applying favored sub cost ({}) for {} ({}) <-> {} ({})",
406+
current_sub_cost, symbols_->Find(labelA), labelA, symbols_->Find(labelB), labelB);
407+
#endif
408+
}
409+
}
410+
// --- End Favored Substitution Cost Check ---
411+
412+
out_vector->push_back(StdArc(arcA.ilabel, arcB.olabel, current_sub_cost, sub_state_ref_id));
413+
arc_added++;
414+
}
354415
}
355416
}
356417

@@ -402,7 +463,7 @@ bool AdaptedCompositionFst::TryGetArcsAtState(StateId fromStateId, vector<fst::S
402463
return true;
403464
}
404465

405-
void AdaptedCompositionFst::SetSymbols(fst::SymbolTable *symbols) {
466+
void AdaptedCompositionFst::SetSymbols(const fst::SymbolTable *symbols) {
406467
symbols_ = symbols;
407468
synonyms_label_ids.clear();
408469
synonyms_label_ids.resize(symbols->NumSymbols(), false);

src/AdaptedComposition.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ AdaptedComposition.h
1313
#include <utility>
1414
#include "IComposition.h"
1515
#include "utilities.h"
16+
#include "fstalign.h"
1617

1718
using namespace std;
1819

@@ -48,7 +49,7 @@ class AdaptedCompositionFst : public IComposition {
4849

4950
StateId current_composed_next_state_id = 0;
5051

51-
fst::SymbolTable *symbols_;
52+
const fst::SymbolTable *symbols_;
5253
std::vector<bool> synonyms_label_ids;
5354
std::vector<bool> entity_label_ids;
5455

@@ -57,6 +58,13 @@ class AdaptedCompositionFst : public IComposition {
5758
// possible optimizations : limit to const FST or limit to StdVectorFst
5859
const fst::StdFst &fstA_;
5960
const fst::StdFst &fstB_;
61+
// Add members to store options
62+
bool strict_punctuation_ = false;
63+
std::unordered_set<int> punctuation_ids_;
64+
// Favored substitutions
65+
bool use_favored_substitutions_ = false;
66+
float favored_substitution_cost_ = 0.1f;
67+
std::vector<int> favorable_substitution_map_;
6068

6169
StateId GetOrCreateComposedState(StateId a, StateId b);
6270
bool IsEntityLabel(int labelId);
@@ -65,7 +73,8 @@ class AdaptedCompositionFst : public IComposition {
6573

6674
public:
6775
AdaptedCompositionFst(const fst::StdFst &fstA, const fst::StdFst &fstB);
68-
AdaptedCompositionFst(const fst::StdFst &fstA, const fst::StdFst &fstB, SymbolTable &symbols);
76+
AdaptedCompositionFst(const fst::StdFst &fstA, const fst::StdFst &fstB, const SymbolTable &symbols);
77+
AdaptedCompositionFst(const fst::StdFst &fstA, const fst::StdFst &fstB, const SymbolTable &symbols, const AlignerOptions& options);
6978
~AdaptedCompositionFst();
7079

7180
StateId Start();
@@ -78,7 +87,7 @@ class AdaptedCompositionFst : public IComposition {
7887
// a is in the composed-graph referencial
7988
bool DoesComposedStateExist(StateId a);
8089

81-
void SetSymbols(fst::SymbolTable *symbols);
90+
void SetSymbols(const fst::SymbolTable *symbols);
8291

8392
void DebugComposedGraph();
8493
};

src/PathHeap.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,48 @@ bool pruneMe = (*last_wer_index)->numErrors + 20 < (*iter)->numErrors; --> seem
114114

115115
return pruned;
116116
}
117+
118+
int PathHeap::prune_relative(float beam_width) {
119+
if (heap.empty()) {
120+
return 0;
121+
}
122+
123+
auto logger = logger::GetOrCreateLogger("pathheap");
124+
size_t initial_size = heap.size();
125+
126+
// Find the best costSoFar in the current heap
127+
// Note: The heap is ordered by the complex shortlistComparatorSharedPtr,
128+
// so the first element isn't necessarily the one with the lowest costSoFar.
129+
// We need to iterate to find the minimum costSoFar.
130+
float best_cost = std::numeric_limits<float>::max();
131+
for (const auto& entry : heap) {
132+
if (entry->costSoFar < best_cost) {
133+
best_cost = entry->costSoFar;
134+
}
135+
}
136+
137+
float cost_threshold = best_cost + beam_width;
138+
139+
logger->debug("==== Relative pruning starting (Beam: {}) =====", beam_width);
140+
logger->debug("Initial size: {}, Best cost: {:.4f}, Threshold: {:.4f}",
141+
initial_size, best_cost, cost_threshold);
142+
143+
int pruned_count = 0;
144+
auto iter = heap.begin();
145+
while (iter != heap.end()) {
146+
// Check if the current entry's cost exceeds the threshold
147+
if ((*iter)->costSoFar > cost_threshold) {
148+
// Remove the element and advance the iterator
149+
iter = heap.erase(iter);
150+
pruned_count++;
151+
} else {
152+
// Otherwise, just advance the iterator
153+
++iter;
154+
}
155+
}
156+
157+
logger->debug("After relative pruning: {} items remain ({} pruned)", heap.size(), pruned_count);
158+
logger->debug("-----\n");
159+
160+
return pruned_count;
161+
}

src/PathHeap.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class PathHeap {
6464
void insert(std::shared_ptr<ShortlistEntry> entry);
6565
shared_ptr<ShortlistEntry> removeFirst();
6666
int prune(int targetSz);
67+
int prune_relative(float beam_width);
6768
int size();
6869
std::shared_ptr<ShortlistEntry> GetBestWerCandidate();
6970
int pruningErrorOffset = 20;

0 commit comments

Comments
 (0)