Skip to content

Commit b02ab32

Browse files
agutkincopybara-github
authored andcommitted
VectorFst: Add sanity checks for reading the start state and the next state for the rest of the states.
The check is disabled for secondary component FST in `MergeFst`. The reason for this is because the secondary FST in MergeFst is allowed to be incomplete pointing to the states in the primary FST. PiperOrigin-RevId: 910904063
1 parent b90e1b7 commit b02ab32

4 files changed

Lines changed: 62 additions & 2 deletions

File tree

openfst/lib/fst.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ FstReadOptions::FstReadOptions(const absl::string_view source,
117117
isymbols(isymbols),
118118
osymbols(osymbols),
119119
read_isymbols(true),
120-
read_osymbols(true) {
120+
read_osymbols(true),
121+
verify(true) {
121122
mode = ReadMode(absl::GetFlag(FLAGS_fst_read_mode));
122123
}
123124

@@ -140,7 +141,8 @@ std::string FstReadOptions::DebugString() const {
140141
"\" read_osymbols: \"", (read_osymbols ? "true" : "false"),
141142
"\" header: \"", (header ? "set" : "null"), "\" isymbols: \"",
142143
(isymbols ? "set" : "null"), "\" osymbols: \"",
143-
(osymbols ? "set" : "null"), "\"");
144+
(osymbols ? "set" : "null"), "\" verify: \"", (verify ? "true" : "false"),
145+
"\"");
144146
}
145147

146148
} // namespace fst

openfst/lib/fst.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ struct FstReadOptions {
8585
FileReadMode mode; // Read or map files (advisory, if possible)
8686
bool read_isymbols; // Read isymbols, if any (default: true).
8787
bool read_osymbols; // Read osymbols, if any (default: true).
88+
bool verify; // Perform FST type-specific light-weight sanity
89+
// check, e.g., check that the destinations of
90+
// all the arcs are valid.
8891

8992
explicit FstReadOptions(absl::string_view source = "<unspecified>",
9093
const FstHeader* absl_nullable header = nullptr,

openfst/lib/merge-fst.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,12 +495,42 @@ class MergeFstImpl : public FstImpl<A> {
495495
if (!state_map) return nullptr;
496496
FstReadOptions fopts(opts);
497497
fopts.header = nullptr; // Component Fst headers were written out.
498+
499+
// Read primary Fst.
498500
std::unique_ptr<ExpandedFst<A> > primary_fst(
499501
ExpandedFst<A>::Read(strm, fopts));
500502
if (!primary_fst) return nullptr;
503+
504+
// Read secondary Fst. Disable the verification on read because this Fst is
505+
// incomplete.
506+
fopts.verify = false;
501507
std::unique_ptr<ExpandedFst<A> > secondary_fst(
502508
ExpandedFst<A>::Read(strm, fopts));
503509
if (!secondary_fst) return nullptr;
510+
511+
// Make sure the destination states in the secondary Fst are sane.
512+
StateId max_next_state = std::numeric_limits<StateId>::min();
513+
for (StateIterator<Fst<A>> siter(*secondary_fst); !siter.Done();
514+
siter.Next()) {
515+
const auto& state = siter.Value();
516+
for (ArcIterator<Fst<A>> aiter(*secondary_fst, state); !aiter.Done();
517+
aiter.Next()) {
518+
const auto& arc = aiter.Value();
519+
if (arc.nextstate == kNoStateId) {
520+
LOG(ERROR) << "MergeFst::Read: Disallowed next state: " << kNoStateId;
521+
return nullptr;
522+
}
523+
max_next_state = std::max(arc.nextstate, max_next_state);
524+
}
525+
}
526+
const size_t max_states =
527+
primary_fst->NumStates() + secondary_fst->NumStates();
528+
if (max_next_state >= max_states) {
529+
LOG(ERROR) << "MergeFst::Read: Next state " << max_next_state << "in "
530+
<< "secondary FST is bigger than maximum possible number "
531+
<< "of states " << max_states;
532+
return nullptr;
533+
}
504534
return new MergeFstImpl<A, M>(
505535
*impl, *primary_fst, *secondary_fst, hdr.NumStates(), *state_map);
506536
}

openfst/lib/vector-fst.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <ios>
2828
#include <iosfwd>
2929
#include <istream>
30+
#include <limits>
3031
#include <memory>
3132
#include <optional>
3233
#include <ostream>
@@ -483,6 +484,8 @@ VectorFstImpl<S>* VectorFstImpl<S>::Read(std::istream& strm,
483484
impl->BaseImpl::SetStart(hdr.Start());
484485
if (hdr.NumStates() != kNoStateId) impl->ReserveStates(hdr.NumStates());
485486
StateId state = 0;
487+
StateId max_next_state = std::numeric_limits<StateId>::min();
488+
StateId min_next_state = std::numeric_limits<StateId>::max();
486489
for (; hdr.NumStates() == kNoStateId || state < hdr.NumStates(); ++state) {
487490
Weight weight;
488491
if (!weight.Read(strm)) break;
@@ -506,13 +509,35 @@ VectorFstImpl<S>* VectorFstImpl<S>::Read(std::istream& strm,
506509
LOG(ERROR) << "VectorFst::Read: Read failed: " << opts.source;
507510
return nullptr;
508511
}
512+
if (arc.nextstate == kNoStateId) {
513+
LOG(ERROR) << "VectorFst::Read: Disallowed next state: " << kNoStateId;
514+
return nullptr;
515+
}
516+
max_next_state = std::max(arc.nextstate, max_next_state);
517+
min_next_state = std::min(arc.nextstate, min_next_state);
509518
impl->BaseImpl::AddArc(state, std::move(arc));
510519
}
511520
}
512521
if (hdr.NumStates() != kNoStateId && state != hdr.NumStates()) {
513522
LOG(ERROR) << "VectorFst::Read: Unexpected end of file: " << opts.source;
514523
return nullptr;
515524
}
525+
// Sanity check for the start state.
526+
if (impl->Start() != kNoStateId && impl->Start() >= state) {
527+
LOG(ERROR) << "VectorFst::Read: start state " << impl->Start()
528+
<< " out of range [0, " << state << ")";
529+
return nullptr;
530+
}
531+
// Sanity check on next states.
532+
if (min_next_state < 0) {
533+
LOG(ERROR) << "VectorFst::Read: Next state is negative: " << min_next_state;
534+
return nullptr;
535+
}
536+
if (opts.verify && max_next_state >= state) {
537+
LOG(ERROR) << "VectorFst::Read: Next state " << max_next_state
538+
<< " is larger than total number of states " << state;
539+
return nullptr;
540+
}
516541
return impl.release();
517542
}
518543

0 commit comments

Comments
 (0)